# Linear probing demo 
In this notebook, you can evalate slide embeddings for TITAN using linear probing.

In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import pandas as pd
import torch
import yaml

from transformers import AutoModel
from titan.eval_linear_probe import train_and_evaluate_logistic_regression_with_val
from titan.utils import bootstrap

import os
os.environ["OMP_NUM_THREADS"] = "8"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# load model from huggingface
model = AutoModel.from_pretrained('MahmoodLab/TITAN', trust_remote_code=True)
model = model.to(device)

In [None]:
# load task configs
with open('../datasets/config_tcga-ot.yaml', 'r') as file:
    task_config = yaml.load(file, Loader=yaml.FullLoader)
target = task_config['target']
label_dict = task_config['label_dict']

In [None]:
# load pre-extracted TITAN slide embeddings for TCGA
import pickle
from huggingface_hub import hf_hub_download
slide_feature_path = hf_hub_download(
    "MahmoodLab/TITAN", 
    filename="TCGA_TITAN_features.pkl",
)
with open(slide_feature_path, 'rb') as file:
  data = pickle.load(file)
embeddings_df = pd.DataFrame({'slide_id': data['filenames'], 'embeddings': list(data['embeddings'][:])})

In [None]:
# load splits
train_split = pd.read_csv('../datasets/tcga-ot_train.csv')
train_df = pd.merge(embeddings_df, train_split, on='slide_id')
val_split = pd.read_csv('../datasets/tcga-ot_val.csv')
val_df = pd.merge(embeddings_df, val_split, on='slide_id')
test_split = pd.read_csv('../datasets/tcga-ot_test.csv')
test_df = pd.merge(embeddings_df, test_split, on='slide_id')

In [None]:
train_data = np.stack(train_df.embeddings.values)
train_labels = train_df[target].apply(lambda x: label_dict[x]).values
val_data = np.stack(val_df.embeddings.values)
val_labels = val_df[target].apply(lambda x: label_dict[x]).values
test_data = np.stack(test_df.embeddings.values)
test_labels = test_df[target].apply(lambda x: label_dict[x]).values

In [None]:
log_spaced_values = np.logspace(np.log10(10e-2), np.log10(10e2), num=3)
results, outputs = train_and_evaluate_logistic_regression_with_val(train_data, train_labels, val_data, val_labels, test_data, test_labels, log_spaced_values=log_spaced_values)
# to use the default setting from our paper use the default value for searching C (log_spaced_values = np.logspace(np.log10(10e-6), np.log10(10e5), num=45))
# results = train_and_evaluate_logistic_regression_with_val(train_data, train_labels, val_data, val_labels, test_data, test_labels)
for key, value in results.items():
    print(f"{key.split('/')[-1]: <12}: {value:.4f}")

In [None]:
bootstrap_kwargs = {'n': 1000, 'alpha': 0.95}
results_mean, results_std = bootstrap(results_dict=outputs, **bootstrap_kwargs)  # takes a while as 46 imbalanced classes are bootstrapped
for keys, values in results_mean.items():
    print(f"{keys.split('/')[-1]: <12}: {values:.4f} ± {results_std[keys]:.4f}")