# Notes
In this instance, we demonstrate the `self-supervised` pretraining process using the TCGA cohort dataset. Upon completion of the pretraining phase, our model will be saved as a feature extractor. This entails exporting the pretrained model, ensuring its readiness for downstream applications. For the downstream tasks, we will utilize this pretrained model to extract latent features from the specific dataset intended for these tasks. Utilizing these extracted features, we will then construct a straightforward linear probing method to facilitate prediction

In [None]:
import sys
sys.path.insert(0, '/home/was966/Research/mims-responder/')
from responder.utils import plot_embed_with_label
from responder import PreTrainer, FineTuner

In [None]:
import os
from tqdm import tqdm
from itertools import chain
import pandas as pd
import numpy as np
import random, torch
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style = 'white', font_scale=1.5)

In [None]:
data_path = '../data/'
df_tpm = pd.read_pickle(os.path.join(data_path, 'TCGA', 'TPM.TCGA.TABLE'))
df_label = pd.read_pickle(os.path.join(data_path, 'TCGA', 'PATIENT.PROCESSED.TCGA.TABLE'))
df_gene = pd.read_pickle(os.path.join(data_path,'TCGA',  'GENE.TABLE'))
df_cancer = df_label[['cancer_type']]
n_samples = len(df_label)

In [None]:
task_name = 'ssl_notask' #self-supervised learning (ssl), no joint task

# generate a decoder task randomly (since we will only test the ssl pretraining)
random_task = np.random.randint(0, 2, size=(n_samples, 2))
df_task = pd.DataFrame(random_task, index = df_cancer.index) 

# random decoder task type and weight
task_type = 'c'
task_loss_weight = 0.0 #MUST BE ZERO

test_idx = df_cancer.groupby('cancer_type').apply(lambda x:x.sample(frac=0.1, 
                                                                    random_state=123).index.tolist()).tolist()
test_idx = list(chain(*test_idx))
train_idx = df_cancer[~df_cancer.index.isin(test_idx)].index

df_task_train = df_task.loc[train_idx]
df_tpm_train = df_tpm.loc[train_idx]

df_task_test = df_task.loc[test_idx]
df_tpm_test = df_tpm.loc[test_idx]

print(df_tpm_train.shape, len(df_tpm_test))

In [None]:
pretrainer = PreTrainer(
                        device='cuda',
                        encoder='transformer',
                        batch_size = 128,
                        epochs = 100, 
                        patience = 5,
                        lr = 1e-4, 
                        weight_decay = 1e-6,
                        K = 0.3, 
                        task_loss_weight = task_loss_weight,
                        with_wandb=False, 
                        work_dir='./results')

pretrainer.train(df_tpm_train, 
                 df_task_train, 
                 task_name, 
                 task_type, 
                 df_tpm_test, 
                 df_task_test, 
                 aug_method = 'mask',
                 mask_probability = 0.01, 
                 mask_self=True,
                 mask_self_probability = 0.5
                )

pretrainer.close()

In [None]:
df_label_plot = df_label[['cancer_type', 'gender']]
label_type = ['c', 'c']
df_tpm_train_emb, _ = pretrainer.predict(df_tpm_train, batch_size = 256)
dfp = df_tpm_train_emb.join(df_label_plot)
plot_embed_with_label(dfp, df_label_plot.columns, label_type, s=5, figsize=(5,5), min_dist = 0.8, metric = 'correlation', n_epochs=50)

In [None]:
from responder.utils import plot_embed_with_label, score
itrp_meta = pd.read_pickle(os.path.join(data_path, 'SKCM','PATIENT.ITRP.TABLE'))
itrp_df_tpm = pd.read_pickle(os.path.join(data_path, 'SKCM','TPM.ITRP.TABLE'))
df_label = itrp_meta[['cohort', 'response_label']]
label_type = ['c', 'c']
itrp_df_tpm_emb, _ = pretrainer.predict(itrp_df_tpm, batch_size = 256)
dfp = itrp_df_tpm_emb.join(df_label)
plot_embed_with_label(dfp, df_label.columns, label_type, s=20, figsize=(5,5),)# min_dist = 0.8, metric = 'correlation', n_epochs=100