In [None]:
import numpy as np
from umap import UMAP
from sklearn.decomposition import PCA, FastICA
import matplotlib.pyplot as plt
import seaborn as sns
import os
import pandas as pd

In [None]:
all_embeddings = np.load("/data/msaqib3/nnUNet_embeddings_script/all_embeddings.npy")

In [None]:
umap = UMAP()
all_embeddings_umap = umap.fit_transform(all_embeddings)

In [None]:
plt.figure(figsize=(20,20))
sns.scatterplot(x=all_embeddings_umap[:, 0], y=all_embeddings_umap[:, 1], alpha=0.01)

In [None]:
all_patch_files = os.listdir("embeddings")

In [None]:
embedding_df = pd.DataFrame(columns=list(range(320)))

for patch_file in all_patch_files:
    patch_embeddings = np.load(os.path.join("embeddings", patch_file))
    patch_embedding_shape = patch_embeddings.shape
    mid_shape_axis_2 = patch_embedding_shape[2] // 2
    embedding_df.loc[patch_file] = np.median(patch_embeddings[:,:,mid_shape_axis_2,:].reshape(320,-1), axis=1)

In [None]:
umapped_patches = umap.transform(embedding_df)

In [None]:
umapped_patches = pd.DataFrame(umapped_patches, index=embedding_df.index, columns=['UMAP1', 'UMAP2'])

In [None]:
plt.figure(figsize=(20,20))
sns.scatterplot(x=umapped_patches['UMAP1'], y=umapped_patches['UMAP2'], alpha=0.02)


In [None]:
outlier_index = (umapped_patches['UMAP1'] < -1) | (umapped_patches['UMAP1'] > 1) | (umapped_patches['UMAP2'] < 8) | (umapped_patches['UMAP2'] > 12.5)

In [None]:
umapped_patches[outlier_index
]

In [None]:
outlier_index.sum()

In [None]:
plt.figure(figsize=(8,8))
sns.scatterplot(x=umapped_patches.loc[~outlier_index,'UMAP1'], y=umapped_patches.loc[~outlier_index,'UMAP2'])


In [None]:
umapped_patches[(umapped_patches['UMAP2'] > 10.4) & (umapped_patches['UMAP2'] < 11)
    ]

In [None]:
umapped_patches[
    (umapped_patches['UMAP1'] > 0.2) & (umapped_patches['UMAP1'] < 0.5) & (umapped_patches['UMAP2'] > 10) & (umapped_patches['UMAP2'] < 10.5)
    ]

In [None]:
from tqdm import tqdm

In [None]:
max_embedding_dim = 320
quantiles = [0, 0.01, 0.05, 0.1, 0.15, 0.25, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.75, 0.8, 0.85, 0.9, 0.95, 0.99, 1]
plt.figure(figsize=(30,int(max_embedding_dim/len(quantiles)*30+30)))
for embedding_dim in tqdm(range(max_embedding_dim)):
    for j, quantile in enumerate(quantiles):
        ax = plt.subplot(max_embedding_dim, len(quantiles), embedding_dim * len(quantiles) + j + 1)
        quantile_value = embedding_df[embedding_dim].quantile(quantile, interpolation='nearest')
        index_name = embedding_df[embedding_df[embedding_dim] == quantile_value].index[0]
        npy_file = np.load(f"/data/msaqib3/nnUNet_preprocessed/Dataset111_7TExVivoFlash/nnUNetPlans_3d_fullres/{index_name}")
        mid_shape_axis_2  = npy_file.shape[2] // 2
        plt.imshow(npy_file[0,:,mid_shape_axis_2,:], cmap='gray')
        # plt.axis('off')
        if embedding_dim == 0:
            ax.set_title(f"Quantile {quantile}", fontsize=16)
        if j == 0:
            ax.set_ylabel(f"Emb Dim {embedding_dim}", fontsize=16)
        ax.set_xticks([])
        ax.set_yticks([])
    plt.tight_layout()
    # break
plt.savefig("embedding_dim_quantiles.png", dpi=300)