### Dependencies

In [11]:
# Base Dependencies
import os
import pickle
import sys
j_ = os.path.join

# LinAlg / Stats / Plotting Dependencies
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm

# Scikit-Learn Imports
import sklearn
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import cross_val_score, StratifiedKFold

#Torch Imports
import torch
import torch.nn as nn
from torch.utils.data.dataset import Dataset
torch.multiprocessing.set_sharing_strategy('file_system')

# Utils
from slide_extraction_utils import create_slide_embeddings
from slide_evaluation_utils import get_knn_classification_results

### Saving Each WSI Bag as it's Mean Instance-Level Embedding

In [7]:
r"""
Script for saving mean WSI features for each feature type in each task
"""

dataroot = './embeddings_slide_lib/'
saveroot = './embeddings_slide_lib/knn-subtyping/'
os.makedirs(saveroot, exist_ok=True)

for enc_name in ['resnet50mean', 'vit16mean', 'vit256mean']:
    for study in ['tcga_brca', 'tcga_kidney', 'tcga_lung']:
        print(f'Extracting {enc_name} embedddings for {study}')
        create_slide_embeddings(dataroot=dataroot, saveroot=saveroot,
                                enc_name=enc_name, study=study)

Extracting resnet50mean embedddings for tcga_brca


100%|██████████| 10/10 [00:00<00:00, 52891.60it/s]


Extracting resnet50mean embedddings for tcga_kidney


100%|██████████| 10/10 [00:00<00:00, 116185.71it/s]


Extracting resnet50mean embedddings for tcga_lung


100%|██████████| 10/10 [00:00<00:00, 63358.07it/s]


Extracting vit16mean embedddings for tcga_brca


100%|██████████| 10/10 [00:00<00:00, 47180.02it/s]


Extracting vit16mean embedddings for tcga_kidney


100%|██████████| 10/10 [00:00<00:00, 66052.03it/s]

Extracting vit16mean embedddings for tcga_lung



100%|██████████| 10/10 [00:00<00:00, 248183.67it/s]


Extracting vit256mean embedddings for tcga_brca


100%|██████████| 10/10 [00:00<00:00, 115545.56it/s]


Extracting vit256mean embedddings for tcga_kidney


100%|██████████| 10/10 [00:00<00:00, 161942.24it/s]


Extracting vit256mean embedddings for tcga_lung


100%|██████████| 10/10 [00:00<00:00, 86659.17it/s]


### 10-Fold CV Evaluation of Mean WSI Embeddings

In [6]:
r"""
Script for runnign 10-fold CV for each feature type for each TCGA study.
"""
    
results_all = []
dataroot = './embeddings_slide_lib/knn-subtyping/'

for enc_name in tqdm(['resnet50mean', 'vit16mean', 'vit256mean']):
    results_row = []
    for study in ['tcga_brca', 'tcga_lung', 'tcga_kidney']:
        for prop in [0.25, 1.0]:
            aucs = get_knn_classification_results(dataroot, study, enc_name, prop)
            aucs = '%0.3f +/- %0.3f' % (aucs.mean(), aucs.std())
            results_row.append([aucs])
    
    results_all.append(pd.DataFrame(results_row).T)
    
results_df = pd.concat(results_all)
results_df.index = ['resnet50mean', 'vit16mean', 'vit256mean']
results_df.columns = [0.25, 1.0, 0.25, 1.0, 0.25, 1.0]
results_df.index = ['', '', '']
results_df.insert(0, 'Pretrain', ['ImageNet', 'DINO', 'DINO'])
results_df.insert(1, 'Arch', ['ResNet-50','ViT-16', 'ViT-256'])
print(results_df.to_latex())
results_df

100%|█████████████████████████████████████████████| 3/3 [00:03<00:00,  1.21s/it]

\begin{tabular}{lllllllll}
\toprule
{} &  Pretrain &       Arch &             0.25 &              1.0 &             0.25 &              1.0 &             0.25 &              1.0 \\
\midrule
{} &  ImageNet &  ResNet-50 &  0.638 +/- 0.089 &  0.667 +/- 0.070 &  0.696 +/- 0.055 &  0.794 +/- 0.035 &  0.862 +/- 0.030 &  0.951 +/- 0.016 \\
{} &      DINO &     ViT-16 &  0.605 +/- 0.092 &  0.725 +/- 0.083 &  0.622 +/- 0.067 &  0.742 +/- 0.045 &  0.848 +/- 0.032 &  0.899 +/- 0.027 \\
{} &      DINO &    ViT-256 &  0.682 +/- 0.055 &  0.775 +/- 0.042 &  0.773 +/- 0.048 &  0.889 +/- 0.027 &  0.916 +/- 0.022 &  0.974 +/- 0.016 \\
\bottomrule
\end{tabular}






Unnamed: 0,Pretrain,Arch,0.25,1.0,0.25.1,1.0.1,0.25.2,1.0.2
,ImageNet,ResNet-50,0.638 +/- 0.089,0.667 +/- 0.070,0.696 +/- 0.055,0.794 +/- 0.035,0.862 +/- 0.030,0.951 +/- 0.016
,DINO,ViT-16,0.605 +/- 0.092,0.725 +/- 0.083,0.622 +/- 0.067,0.742 +/- 0.045,0.848 +/- 0.032,0.899 +/- 0.027
,DINO,ViT-256,0.682 +/- 0.055,0.775 +/- 0.042,0.773 +/- 0.048,0.889 +/- 0.027,0.916 +/- 0.022,0.974 +/- 0.016


### 1st Fold Evaluation of Mean WSI Embeddings

In [7]:
r"""
Script for running single-fold CV for each feature type for each TCGA study.
"""
    
results_all = []
dataroot = './embeddings_slide_lib/knn-subtyping/'

for enc_name in tqdm(['resnet50mean', 'vit16mean', 'vit256mean']):
    results_row = []
    for study in ['tcga_brca', 'tcga_lung', 'tcga_kidney']:
        for prop in [0.25, 1.0]:
            aucs = get_knn_classification_results(dataroot, study, enc_name, prop)
            aucs = '%0.3f' % (aucs.iloc[0][0])
            results_row.append([aucs])
    
    results_all.append(pd.DataFrame(results_row).T)
    
results_df = pd.concat(results_all)
results_df.index = ['resnet50mean', 'vit16mean', 'vit256mean']
results_df.columns = [0.25, 1.0, 0.25, 1.0, 0.25, 1.0]
results_df.index = ['', '', '']
results_df.insert(0, 'Pretrain', ['ImageNet', 'DINO', 'DINO'])
results_df.insert(1, 'Arch', ['ResNet-50','ViT-16', 'ViT-256'])
print(results_df.to_latex())
results_df

100%|██████████| 3/3 [00:01<00:00,  2.43it/s]

\begin{tabular}{lllllllll}
\toprule
{} &  Pretrain &       Arch &   0.25 &    1.0 &   0.25 &    1.0 &   0.25 &    1.0 \\
\midrule
{} &  ImageNet &  ResNet-50 &  0.706 &  0.566 &  0.681 &  0.789 &  0.867 &  0.947 \\
{} &      DINO &     ViT-16 &  0.719 &  0.833 &  0.586 &  0.668 &  0.855 &  0.892 \\
{} &      DINO &    ViT-256 &  0.711 &  0.808 &  0.728 &  0.947 &  0.929 &  0.979 \\
\bottomrule
\end{tabular}






Unnamed: 0,Pretrain,Arch,0.25,1.0,0.25.1,1.0.1,0.25.2,1.0.2
,ImageNet,ResNet-50,0.706,0.566,0.681,0.789,0.867,0.947
,DINO,ViT-16,0.719,0.833,0.586,0.668,0.855,0.892
,DINO,ViT-256,0.711,0.808,0.728,0.947,0.929,0.979


### Faster Sanity-Check that the Results are Correct

In [12]:
r"""
Script for running single-fold CV for each feature type for each TCGA study.
"""

dataroot = './embeddings_slide_lib/vit256mean_tcga_slide_embeddings/'
available_vit256_features = os.listdir(dataroot)

for study in ['tcga_brca', 'tcga_kidney', 'tcga_lung']:
    path2csv = '../Weakly-Supervised-Subtyping/dataset_csv/'
    dataset = pd.read_csv(j_(path2csv, f'{study}_subset.csv.zip'), index_col=2)
    dataset.index = dataset.index.str[:-4]
    embeddings_all, labels_all = [], []
    slide_ids = []
    
    if study == 'tcga_brca':
        label_dict={'IDC':0, 'ILC':1}
    elif study == 'tcga_kidney':
        label_dict={'CCRCC':0, 'PRCC':1, 'CHRCC': 2}
    elif study == 'tcga_lung':
        label_dict={'LUSC':1, 'LUAD':0}
                          

    for slide_id in tqdm(dataset.index):
        pt_fname, label = slide_id+'.pt', dataset.loc[slide_id]['oncotree_code']
        if (pt_fname in available_vit256_features) and (label in label_dict.keys()):        
            vit256_features = torch.load(os.path.join(dataroot, pt_fname)).mean(axis=0)
            embeddings_all.append(vit256_features)
            labels_all.append(label_dict[label])
            slide_ids.append(slide_id)

    embeddings_all = torch.stack(embeddings_all).numpy()
    labels_all = np.array(labels_all)             
                          
    clf = KNeighborsClassifier()
    skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=0)
    
    if len(label_dict.keys()) > 2:
        scores = cross_val_score(clf, embeddings_all, labels_all, cv=skf, scoring='roc_auc_ovr')
    else:
        scores = cross_val_score(clf, embeddings_all, labels_all, cv=skf, scoring='roc_auc')
        
    print(study, scores.mean())

100%|████████████████████████████████████████| 937/937 [00:01<00:00, 878.24it/s]


tcga_brca 0.7738300898746104


100%|████████████████████████████████████████| 905/905 [00:01<00:00, 663.50it/s]


tcga_kidney 0.9739114652064474


100%|████████████████████████████████████████| 958/958 [00:01<00:00, 754.76it/s]


tcga_lung 0.8935382487870264
