In [None]:
import os
import pandas as pd

import torch
from torch.utils.data import DataLoader

from gliopath.train.task.gene import seed_torch, train, EmbeddingDataset, TaskHead
from gliopath.utils.proces import split_dataset

os.chdir('F:/workspace/pathology/gigapath')

In [2]:
seed = 42
dataset_df = pd.read_table('data\\metadata.tbl', sep='\t')
embed_path = 'output/all_slides_embeds.pt'
z_score = False
gene_col = ['IDH1','TP53','ATRX','PTEN','EGFR','TERT']
num_classes = len(gene_col)
batch_size = 4
num_workers = 4
embed_dim = 1536

splits = ['train', 'val', 'test']
split_col = 'split_col'
id_col = 'id'
params = {
    'lr': 0.02,
    'min_lr': 0.0,
    'train_iters': 4000,
    'eval_interval': 100,
    'output_dir': 'output/models/gene',
    'optim': 'sgd',
    'weight_decay': 0.01,
}

In [3]:
# set the random seed
seed_torch(torch.device('cuda'), 0)
# read the metadata
dataset_df = split_dataset(dataset_df, id_col='id', type_col='tumour_type', val_split=0.2, test_split=0.1, in_df=True, split_col='split_col')

# load the dataset
train_dataset, val_dataset, test_dataset = [EmbeddingDataset(dataset_df, embed_path, split_col=split_col, split=split, id_col=id_col, type_col=gene_col, z_score=z_score) for split in splits]
# set num_classes
print(f'Sample size:\nTrain: {len(train_dataset)}\tVal: {len(val_dataset)}\tTest: {len(test_dataset)}')

Sample size:
Train: 68	Val: 20	Test: 12


  collated_dict = torch.load(self.embed_path)
  collated_dict = torch.load(self.embed_path)
  collated_dict = torch.load(self.embed_path)


In [4]:
# infinite sampler for training
train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset, replacement=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, sampler=train_sampler, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

# Load the model
model = TaskHead(embed_dim, num_classes)

In [5]:
# Train the model
train(model, train_loader, val_loader, test_loader, **params)

Set the optimizer as sgd
Start training
Iteration [9/4000]	Loss: 0.6599343419075012	LR: 0.0199996915764479
Iteration [19/4000]	Loss: 0.6316266655921936	LR: 0.01999876632481661
Iteration [29/4000]	Loss: 0.6054642200469971	LR: 0.01999722430218001
Iteration [39/4000]	Loss: 0.5812366008758545	LR: 0.01999506560365732
Iteration [49/4000]	Loss: 0.5587648749351501	LR: 0.019992290362407236
Iteration [59/4000]	Loss: 0.5378857851028442	LR: 0.01998889874961971
Iteration [69/4000]	Loss: 0.5184516310691833	LR: 0.01998489097450538
Iteration [79/4000]	Loss: 0.5003296136856079	LR: 0.019980267284282715
Iteration [89/4000]	Loss: 0.48340025544166565	LR: 0.019975027964162704
Iteration [99/4000]	Loss: 0.46755653619766235	LR: 0.019969173337331274
Start evaluating ...
Val [99/4000] Accuracy: 0.5 f1: 0.21364522417153997 Precision: 0.36904761904761907 Recall: 0.18686868686868688 AUROC: 0.45749579124579126 AUPRC: 0.5809939472608442
Best f1 increase from 0 to 0.21364522417153997
Iteration [109/4000]	Loss: 0.45270

  model.load_state_dict(torch.load(f'{output_dir}/best_model.pth'))


Test Accuracy: 0.5 f1: 0.33649732620320855 Precision: 0.3518518518518518 Recall: 0.4172619047619048 AUROC: 0.6525628306878306 AUPRC: 0.6921046204677156
