In [None]:
import os
import numpy as np
import sys
import torch
import h5py
import openslide

import matplotlib.pyplot as plt

sys.path.append('../')
from embeddings.embeddings import get_mixture_params
from utils.visualization_utils import get_panther_encoder, get_mixture_plot_figure



In [None]:
def visualize_wsi_pt_assignment(slide_id, type, fold):
    """ Visualize the mixture proportion distribution for a specific WSI. """
    # input paths
    h5_feats_fpath = f'../data/data_files/tcga_{type}/wsi/extracted_res0_5_patch256_uni/feats_h5/{slide_id}.h5'
    split_folder = f"../data/data_files/tcga_{type}/splits/{fold}"

    # Get feats
    h5 = h5py.File(h5_feats_fpath, 'r')
    feats = torch.Tensor(h5['features'][:]).unsqueeze(0)

    # Get PANTHER model and the wsi's to obtain the embeddings
    panther_encoder = get_panther_encoder(split_folder=split_folder)

    # Get proportions of each mixture component
    with torch.inference_mode():
        out, qqs = panther_encoder.representation(feats).values()
        pis, mus = get_mixture_params(out, p=16)
        pis = pis[0].detach().cpu().numpy()

    # Plot the mixture proportion distribution
    display(get_mixture_plot_figure(pis))


def find_patch_size(all_cords, index):
    """Find the patch size by comparing the coords with the next patch"""

    next_coords = all_cords[index+1:]
    prev_coords = all_cords[:index]
    coords_patch = all_cords[index]
    y_ps = 0
    for y_cord, x_cord in next_coords:
        if y_ps > 0:
            break

        if y_cord > coords_patch[0]:
            y_ps = y_cord-coords_patch[0]

    # If the patch is at the last edge
    if y_ps == 0:
        for y_cord, x_cord in prev_coords[::-1]:
            if y_ps > 0:
                break

            if y_cord < coords_patch[0]:
                y_ps = coords_patch[0] - y_cord
    
    assert y_ps > 0, "patch size is 0. Somthing is going wrong!"

    # a patch is always square
    return y_ps



def visualize_pt(slide_id, case_id, type, fold):
    """Visualize prototypes using patches of a single wsi."""
    # input paths
    slide_fpath = f'../data/data_files/tcga_{type}/wsi/images/{slide_id}.svs'
    h5_feats_fpath = f'../data/data_files/tcga_{type}/wsi/extracted_res0_5_patch256_uni/feats_h5/{slide_id}.h5'
    split_folder = f"../data/data_files/tcga_{type}/splits/{fold}"

    # Get WSI and feats
    wsi = openslide.open_slide(slide_fpath)
    h5 = h5py.File(h5_feats_fpath, 'r')
    feats = torch.Tensor(h5['features'][:]).unsqueeze(0)

    # Get PANTHER model and the wsi's to obtain the embeddings
    panther_encoder = get_panther_encoder(split_folder=split_folder)

    fig, axes = plt.subplots(1, 16, figsize=(16 * 1.8, 2.5))
    # Get proportions of each mixture component
    with torch.inference_mode():
        out, qqs = panther_encoder.representation(feats).values()
        pis, mus = get_mixture_params(out, p=16)
        pis = pis[0].detach().cpu().numpy()
        qq = qqs[0,:,:,0].cpu().numpy()

        # Show closest patch for each prototype
        for pt in range(16):
            top_index = np.argsort(qq[:, pt])[-1]
            coords = h5['coords']
            patch_size = find_patch_size(coords, top_index)
            patch = wsi.read_region(
                (coords[top_index][0], coords[top_index][1]),
                level=0,
                size=(patch_size, patch_size)
            ).convert("RGB")

            ax = axes[0, pt]
            ax.imshow(patch)
            ax.axis("off")
            if pis[pt] < 0.0005:
                ax.set_title(f"$\\mathbf{{W({pt})}}$, $\\pi$<0.001", fontsize=8)
            else:
                ax.set_title(f"$\\mathbf{{W({pt})}}$, $\\pi$={pis[pt]:.3f}", fontsize=8)
            ax.text(-0.05, 0.5, f"{case_id}", va='center', ha='right', rotation=90, fontsize=9, transform=ax.transAxes)

    plt.tight_layout()
    plt.show()


# Visualize the prototypes by showing the closest patches

In [None]:

slide_id = "TCGA-A7-A13E-01Z-00-DX1.891954FF-316A-4562-AA14-429631944F22"
case_id = "-".join(slide_id.split('-')[:3])
    
visualize_pt(slide_id, case_id, type, fold)

# Visualize the mixture proportion distribution

In [None]:
slide_id = "TCGA-A7-A13E-01Z-00-DX1.891954FF-316A-4562-AA14-429631944F22"
type = "brca"
fold = 2
visualize_wsi_pt_assignment(slide_id, type, fold)