## this sample script takes continuous variable as outcome, as example

In [1]:
import os
import pandas as pd

import torch
from torch.utils.data import DataLoader

from gliopath.train.task.general import seed_torch, train, EmbeddingDataset, TaskHead
from gliopath.utils.proces import split_dataset
from gliopath.train.gadget import get_sampler

os.chdir('F:/workspace/pathology/gigapath')

In [6]:
seed = 42
dataset_df = pd.read_table('data\\metadata.tbl', sep='\t')
embed_path = 'output/all_slides_embeds.pt'
z_score = True
num_col = ['duration']
num_classes = len(num_col)
batch_size = 4
num_workers = 2
embed_dim = 1536
weighted_sampler = True

splits = ['train', 'val', 'test']
split_col = 'split_col'
id_col = 'id'
params = {
    'lr': 0.001,
    'min_lr': 0.0,
    'train_iters': 20,
    'eval_interval': 10,
    'output_dir': 'output/models/life',
    'optim': 'sgd',
    'weight_decay': 0.01,
    'outcome_type': 'cat',
    'gc_step': 10,
}

In [7]:
# set the random seed
seed_torch(torch.device('cuda'), 0)
# read the metadata
dataset_df = split_dataset(dataset_df, id_col='id', type_col='tumour_type', val_split=0.2, test_split=0.1, in_df=True, split_col='split_col')

# load the dataset
train_dataset, val_dataset, test_dataset = [EmbeddingDataset(dataset_df, embed_path, split_col=split_col, split=split, id_col=id_col, type_col=num_col, z_score=z_score) for split in splits]
# set num_classes
print(f'Sample size:\nTrain: {len(train_dataset)}\tVal: {len(val_dataset)}\tTest: {len(test_dataset)}')

Sample size:
Train: 68	Val: 20	Test: 12


  collated_dict = torch.load(self.embed_path)
  collated_dict = torch.load(self.embed_path)
  collated_dict = torch.load(self.embed_path)


In [8]:
# infinite sampler for training
# not sure if change shuffle to TRUE? (*)
train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset, replacement=True)
train_sampler = get_sampler(train_sampler)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, sampler=train_sampler, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

# Load the model
model = TaskHead(embed_dim, num_classes)

In [9]:
# Train the model
pred_gather, target_gather = train(model, train_loader, val_loader, test_loader, **params)

Set the optimizer as sgd
Start training
Iteration [9/20]	Loss: 294.0078430175781	LR: 0.0005
Start evaluating ...
Val [9/20] MAE: 32.124 RMSE: 41.014 R²: -1.435
Best MAE decrease from inf to 32.124
Iteration [19/20]	Loss: 87.50031280517578	LR: 0.0
Start evaluating ...
Val [19/20] MAE: 32.138 RMSE: 41.094 R²: -1.444


  model.load_state_dict(torch.load(f'{output_dir}/best_model.pth'))


Test MAE: 43.802 RMSE: 57.761 R²: -1.442
