# Fine Tuning of Feature encoder

A very common way of applying deep learning techniques in digital pathology is Multiple Instance Learning (MIL) 
The giga pixel image is cropped into a set of equally sized non overlapping tiles. The tiles are encoded with a pretrained feature extractor and then aggregated within in a second step by a trainable model to solve a specific task. 
Tipically used feature encoders were pretrained either on Imagenet21k or even on larger histological datasets. The architectures of such models can reach from smaller ones like the Resnet18  up to Swim transformer based transformer architectures. 
Since pre training larger models from scratch requires huge amounts of ressources, the focus of this model is fine tuning a pretrained model. 

**The following experiments attempt to explore whether existing feature encoders can be fine-tuned for better survival analysis.**


The aim is to carry out the following experiments:

1. Resnet18: train from scratch/fine tune on Survival Analysis
2. ViT Tiny: train from scratch/fine tune on Survival Analysis (+Multimodality)
3. Vit Tiny MAE: fine tune on Survival Analysis 
4. Vit Tiny MAE: fine tune on Survival Analysis in a supMAE fashion

To aquire those experiments, the following subtasks are needed: 
1. Create a custom dataset(from zip) that statifies on a patient level into a train/test split.
2. Create models,find checkpoints,load parameters such that DDP is applicable
3. Create a training function which allows partially freezing weights, finetune, train from checkpoint.
4. Create an encoding pipeline. 


# DataLoader

The Data consists of ~1000 patients. Each patient has exactly one genetic feature vector and can have multiple sets(can have multiple slides) of tile-sets.

Idea: 

The dataloader receives a dataframe which contains meta data and genetic data (stratified by train/test split on a patient level)
and further a path to the tiles. From this path, an os.walk is done to create a second dataframe which contains the file path and the slide_id.
the dataframe will be adapted to contain the tile_path, meta data and the index of the respecitve row within the genomic tensor

1. Import Dataframe, create Gen tensor, metadataframe
2. os.walk on tilepath to create tile dataframe,
3.  add mapping for slide_id idx


In [4]:
# 1.)
from utils.Aggregation_Utils import *
df_path = ...
df_train,df_test,df_val = prepare_csv(df_path="/work4/seibel/PORPOISE/datasets_csv/tcga_brca_all_clean.csv.zip",split="traintestval",n_bins=4,save = False,frac_train=0.7,frac_val=0.1)
df_train = df_train[df_train["traintest"]==0] # if train 



genomics_tensor = torch.Tensor(df_train[df_train.keys()[11:]].to_numpy()).to(torch.float32)
df_meta = df_train[["slide_id","survival_months_discretized","censorship","survival_months"]]
diction = dict([(name,idx) for idx,name in enumerate(df_meta["slide_id"]) ])

In [1]:
#2.)
import os
import pandas as pd
tile_path = "/globalwork/seibel/TCGA-BRCA-TILES-NORM/"
ext = "jpg"
file_list = []
root_list = []
for root, dirs, files in os.walk(tile_path, topdown=False):
    for name in files:
        file_list.append(os.path.join(root, name))
        root_list.append(root.split("/")[-1]+".svs")

df_tiles = pd.DataFrame({"tilepath":file_list,"slide_id":root_list},)
df_tiles = df_tiles[df_tiles["tilepath"].str.endswith(ext)] # Avoid having other files than .<ext> files in Dataframe


print(df_tiles.tilepath.iloc[0])
print(df_tiles.slide_id.iloc[0])
df_tiles.head()

/globalwork/seibel/TCGA-BRCA-TILES-NORM/TCGA-AN-A0XP-01Z-00-DX1.A4EE3970-5C1F-482E-9AED-F1E3C52A776F/TCGA-AN-A0XP-01Z-00-DX1_(15207,21291).jpg
TCGA-AN-A0XP-01Z-00-DX1.A4EE3970-5C1F-482E-9AED-F1E3C52A776F.svs


Unnamed: 0,tilepath,slide_id
0,/globalwork/seibel/TCGA-BRCA-TILES-NORM/TCGA-A...,TCGA-AN-A0XP-01Z-00-DX1.A4EE3970-5C1F-482E-9AE...
1,/globalwork/seibel/TCGA-BRCA-TILES-NORM/TCGA-A...,TCGA-AN-A0XP-01Z-00-DX1.A4EE3970-5C1F-482E-9AE...
2,/globalwork/seibel/TCGA-BRCA-TILES-NORM/TCGA-A...,TCGA-AN-A0XP-01Z-00-DX1.A4EE3970-5C1F-482E-9AE...
3,/globalwork/seibel/TCGA-BRCA-TILES-NORM/TCGA-A...,TCGA-AN-A0XP-01Z-00-DX1.A4EE3970-5C1F-482E-9AE...
4,/globalwork/seibel/TCGA-BRCA-TILES-NORM/TCGA-A...,TCGA-AN-A0XP-01Z-00-DX1.A4EE3970-5C1F-482E-9AE...


## safe dataframe

In [3]:
#title = "df-TCGA-BRCA-TIILES-NORM.csv"
#df_tiles.to_csv(title,index=False)

In [8]:
# 3.)
df_tiles.insert(2,"slideid_idx",df_tiles["slide_id"].map(diction))
df_tiles = df_tiles.dropna()
df_tiles.slideid_idx = df_tiles.slideid_idx.astype(int)
df_tiles

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_tiles.slideid_idx = df_tiles.slideid_idx.astype(int)


Unnamed: 0,tilepath,slide_id,slideid_idx
0,/globalwork/seibel/TCGA-BRCA-TILES-NORM/TCGA-A...,TCGA-AN-A0XP-01Z-00-DX1.A4EE3970-5C1F-482E-9AE...,289
1,/globalwork/seibel/TCGA-BRCA-TILES-NORM/TCGA-A...,TCGA-AN-A0XP-01Z-00-DX1.A4EE3970-5C1F-482E-9AE...,289
2,/globalwork/seibel/TCGA-BRCA-TILES-NORM/TCGA-A...,TCGA-AN-A0XP-01Z-00-DX1.A4EE3970-5C1F-482E-9AE...,289
3,/globalwork/seibel/TCGA-BRCA-TILES-NORM/TCGA-A...,TCGA-AN-A0XP-01Z-00-DX1.A4EE3970-5C1F-482E-9AE...,289
4,/globalwork/seibel/TCGA-BRCA-TILES-NORM/TCGA-A...,TCGA-AN-A0XP-01Z-00-DX1.A4EE3970-5C1F-482E-9AE...,289
...,...,...,...
2943700,/globalwork/seibel/TCGA-BRCA-TILES-NORM/TCGA-A...,TCGA-A2-A0YD-01Z-00-DX1.B81FF541-F154-4C49-950...,74
2943701,/globalwork/seibel/TCGA-BRCA-TILES-NORM/TCGA-A...,TCGA-A2-A0YD-01Z-00-DX1.B81FF541-F154-4C49-950...,74
2943702,/globalwork/seibel/TCGA-BRCA-TILES-NORM/TCGA-A...,TCGA-A2-A0YD-01Z-00-DX1.B81FF541-F154-4C49-950...,74
2943703,/globalwork/seibel/TCGA-BRCA-TILES-NORM/TCGA-A...,TCGA-A2-A0YD-01Z-00-DX1.B81FF541-F154-4C49-950...,74


## Complete Custom Dataset
It is better to create the dataframe once instead of including the os-walk within the dataloader! 

In [10]:
from torch.utils.data import DataLoader,Dataset
from utils.Aggregation_Utils import *
import pandas as pd
import os 
from PIL import Image
from torchvision import transforms



class TileDataset(Dataset):
    def __init__(self,df_path,tile_path,ext,trainmode,transform):
        """Custom Dataset for Feature Extractor Finetuning for Survival Analysis 

        Args:
            df_path (str): Path to Dataframe which contains meta data and genomic data 
            tilepath (str): path to folder which contains subfolders with tiles(subfolder names must ne slide id)
            ext (str): file extension of tiles(eg jpg or png)
            trainmode (Bool): To generate train set or test set 
        """
        super(TileDataset,self).__init__()
        #Genomic Tensor and Meta Dataframe
        df = pd.read_csv(df_path) 

        assert trainmode in ["train","test","val"], "Dataset mode not known"
        df[df["traintest"]==(0 if trainmode=="train" else 1 if trainmode=="test" else 2)]
        
            
        self.genomics_tensor = torch.Tensor(df[df.keys()[11:]].to_numpy()).to(torch.float32)
        self.df_meta = df[["slide_id","survival_months_discretized","censorship","survival_months"]]
        
        # Tile Data Frame
        file_list = []
        root_list = []
        for root, dirs, files in os.walk(tile_path, topdown=False):
            for name in files:
                file_list.append(os.path.join(root, name))
                root_list.append(root.split("/")[-1]+".svs")


        df_tiles = pd.DataFrame({"tilepath":file_list,"slide_id":root_list},)
        df_tiles = df_tiles[df_tiles["tilepath"].str.endswith(ext)]
        
        # add slide_id to index mapping
        diction= dict([(name,idx) for idx,name in enumerate(self.df_meta["slide_id"]) ]) 
        df_tiles.insert(2,"slideid_idx",df_tiles["slide_id"].map(diction))
        df_tiles = df_tiles.dropna()
        df_tiles.slideid_idx = df_tiles.slideid_idx.astype(int)
        self.df_tiles = df_tiles
        
        # TODO transforms 
        self.transforms = transform
    def __len__(self):
        return len(self.df_tiles)
    def __getitem__(self,idx):
        
        tile_path,_,slide_idx = self.df_tiles.iloc[idx]
        tile = Image.open(tile_path)
        tile = self.transforms(tile)
        
        label = torch.tensor(self.df_meta.iloc[slide_idx, 1]).type(torch.int64)
        censorship = torch.tensor(self.df_meta.iloc[slide_idx, 2]).type(torch.int64)
        label_cont = torch.tensor(self.df_meta.iloc[slide_idx,3]).type(torch.int64)
        return tile, self.genomics_tensor[slide_idx], censorship, label,label_cont
        
        
        
        
df_path_train = "/work4/seibel/PORPOISE/datasets_csv/tcga_brca__4bins_trainsplit.csv"
df_path_test = "/work4/seibel/PORPOISE/datasets_csv/tcga_brca__4bins_testsplit.csv"

tilepath = "/work4/seibel/data/TCGA-BRCA-TILES/"
ext = "jpg"

trainmode="train"


DS = TileDataset(df_path_train,tilepath,ext,trainmode,transform=transforms.ToTensor())
DS.__getitem__(8)

(tensor([[[0.8392, 0.7569, 0.8000,  ..., 0.9373, 0.9373, 0.9412],
          [0.7843, 0.7412, 0.7765,  ..., 0.9373, 0.9373, 0.9412],
          [0.7373, 0.7412, 0.7608,  ..., 0.9373, 0.9373, 0.9412],
          ...,
          [0.9255, 0.9294, 0.9176,  ..., 0.9451, 0.9529, 0.9412],
          [0.9373, 0.9294, 0.9059,  ..., 0.9412, 0.9569, 0.9412],
          [0.9333, 0.9373, 0.9333,  ..., 0.9373, 0.9569, 0.9412]],
 
         [[0.7255, 0.6549, 0.7137,  ..., 0.9373, 0.9373, 0.9412],
          [0.6706, 0.6392, 0.6902,  ..., 0.9373, 0.9373, 0.9412],
          [0.6275, 0.6392, 0.6745,  ..., 0.9373, 0.9373, 0.9412],
          ...,
          [0.9294, 0.9333, 0.9216,  ..., 0.9373, 0.9451, 0.9333],
          [0.9529, 0.9451, 0.9216,  ..., 0.9333, 0.9490, 0.9333],
          [0.9490, 0.9529, 0.9490,  ..., 0.9294, 0.9490, 0.9333]],
 
         [[0.8353, 0.7529, 0.7961,  ..., 0.9451, 0.9451, 0.9490],
          [0.7804, 0.7373, 0.7725,  ..., 0.9451, 0.9451, 0.9490],
          [0.7255, 0.7373, 0.7569,  ...,

In [8]:
from datasets.Tile_Dataset import TileDataset
from torchvision import transforms
import os 
df_path_train ="/nodes/bevog/work4/seibel/PORPOISE/datasets_csv/tcga_brca__4bins_trainsplit.csv"
assert os.path.exists(df_path_train)
tile_path ="/nodes/bevog/work4/seibel/data/TCGA-BRCA-TILES/"
ext = "jpg"
batch_size = 32
transform_train=transforms.PILToTensor()
train_set = TileDataset(df_path=df_path_train,tile_path=tile_path,ext=ext,trainmode = "train",transform=transform_train)

# Loss 
The loss can be used from the previous colab notebook but has to be adapted to fully run on gpu 


In [1]:
from utils.Aggregation_Utils import Survival_Loss
import torch
B = 10
nbins  = 4 
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

#hist_tile,gen, censorship, label,label_cont = DS.__getitem__(8)

criterion = Survival_Loss(0.2)
prediction_logits = torch.rand(B,nbins,device=device)
c  = torch.randint(0,2,size=(B,1),device=device)
l = torch.randint(0,nbins,size=(B,),device=device)




criterion(prediction_logits,c,l)

  return torch._C._cuda_getDeviceCount() > 0


cpu


tensor(1.8186)

# Survival 


In [None]:
import torch
from torch import nn
channels = 12
model = nn.Sequential(nn.Conv2d(3, channels, kernel_size=(3, 3), padding='same', bias=True),
                          nn.AdaptiveAvgPool2d(1),
                          nn.Flatten(1),
                          nn.Linear(channels,1),
                          nn.Flatten(0))



x = torch.rand((5,3,16,16))
model(x).size()

In [None]:
import argparse
import yaml
import os
f = "./encoder_configs/base.yaml"
os.path.exists(f)
with open(f, 'r') as file:
        config = yaml.safe_load(file)
        
print(config["train_settings"]["checkpoint_path"])


In [None]:
DF1 = ['case_id', 'slide_id', 'site', 'traintest', 'metas','gendata']
DF2 = ['TILEPATH','SLIDE_ID']
#get one tensor for gendata, one df with tilepath, metadata, tensoridx

In [None]:
mode = "vali"
assert mode in ["train","test","val"], "Dataset mode not known"
df_train.survival_months_discretized[df_train.survival_months_discretized==( 0 if mode=="train" else 1 if mode=="test" else 2)]

In [None]:
test_len = 2
train_len = 3
diction = dict([(name,0) if idx<train_len else (name,1) if idx<train_len+test_len  else (name,2) for idx,name in enumerate(["a","b","c","d","e","f"])])
diction

In [6]:
import torch 
from torch import nn
bins = 4
B=12
out_all =[]
c_all =  []
l_all = []

for i in range(9):
    out = torch.rand((B,bins))
    l = torch.randint(0,bins,size=(B,))
    c = torch.randint(0,2,size=(B,))
    
    out_all.append(out)
    l_all.append(l)
    c_all.append(c)


In [9]:
torch.cat(out_all,dim=0).size()

torch.Size([108, 4])

In [12]:
from sksurv.metrics import concordance_index_censored
h = nn.Sigmoid()(torch.cat(out_all,dim=0))
S = torch.cumprod(1-h,dim = -1)
risk = -S.sum(dim=1) 
notc = (1-torch.cat(c_all,dim=0)).numpy().astype(bool)
c_index = concordance_index_censored(notc, torch.cat(l_all,dim=0),risk)

# Models




In [2]:
import torch 
from torch import nn 

class MAE(nn.Module):
    def __init__(self,multimodal,supervised_surv):
        super(MAE,self).__init__()
        self.encoder = nn.Identity()
        self.decoder = nn.Identity()
        self.enc_emb = ...
        self.dec_emb = ...
        self.enc_pos_emb = ...
        self.dec_pos_emb = ...
        self.y_encoder = ...    
        self.masktoken = ...
    def forward(self,x,y):
        x_seq =  img2seq(x)
        unmasked,masked,unmasked_idx,masked_idx = self.masking(x)
        encoded = self.encoder(torch.concat([self.enc_emb(unmasked)+self.enc_pos_emb(unmasked_idx),self.y_encoder(y)],dim=1))
        unmasked_enc,y_enc = torch.split(encoded)# find correct way of splitting
        decoded = torch.cat([self.dec_emb(unmasked_enc)+self.dec_pos_emb(unmasked_idx),self.dec_pos_emb(masked_idx)+ maskedtoken[None:,:]],dim=1)
        _,decoded_masked = torch.split(decoded)#split based on length of mask
        
        if self.supervised_surv:
            surv_in = torch.stack([y_enc,torch.mean(unmasked_enc)])
            
            
        
        return masked, decoded_masked

        
        
        
    def img2seq():
        ...
    def seq2img():
        ...
    def masking():
        ...
        

In [1]:
#!pip install timm
from models.mae_models.models_mae_modified import mae_vit_tiny_patch16
model = mae_vit_tiny_patch16()

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
import torch
B,C,H,W = 20,3,224,224
imgs =  torch.rand((B,C,H,W))
mask_ratio = 0.75
y = torch.rand((B,1,192))
ids_shuffle=None
surv = True

latent, mask, ids_restore, ids_shuffle = model.forward_encoder(imgs, mask_ratio,y, ids_shuffle)
latent_hist,latent_gen = torch.split(latent,split_size_or_sections=[latent.size(1)-1,1],dim=1)
pred = model.forward_decoder(latent_hist, ids_restore)  # [N, L, p*p*3]
lossMAE = model.forward_loss(imgs, pred, mask)

if surv:
    surv_in = torch.cat((torch.mean(latent_hist,dim=1),latent_gen.squeeze(1)),dim=1)
    survivalhead
 

In [20]:
ckpt_path = "/work4/seibel/data/mae_tiny_400e.pth.tar"
ckpt = torch.load(ckpt_path, map_location="cpu")
state_dict = {k.replace("module.model.", ""): v for k, v in ckpt["model"].items()}
model.load_state_dict(state_dict, strict=False)


<All keys matched successfully>