In [1]:
%load_ext autoreload
%autoreload 1
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}

In [2]:
import pandas as pd
import numpy as np
from sklearn.model_selection import RepeatedStratifiedKFold
from torch.utils.data import DataLoader
import timm
from timm import optim, scheduler
import torch
from torch import nn
from torch.optim.lr_scheduler import ExponentialLR

from sklearn.model_selection import train_test_split
from sklearn import metrics as skmet
from jupyterplot import ProgressPlot
import matplotlib.pyplot as plt
import os
import json

import transforms as my_transforms
%aimport dataset
from models import MultiTaskFrameClassifier
ImageData = dataset.ImageData

In [3]:
artifact_folder = '/zfs/wficai/pda/model_run_artifacts/20220818_multitask_224x224'
# artifact_folder = '/zfs/wficai/pda/model_run_artifacts/20220818_all_224x224'
# artifact_folder = '/zfs/wficai/pda/model_run_artifacts/20220818_justcolor_224x224'
os.makedirs(artifact_folder, exist_ok=True)

datestamp = '20220901'

# Note: all configurations are packaged as dict for easy saving
cfg = dict(
    sanity_check = False,
    sanity_check_frac = 0.1,
    mode_filter =  ['2d', 'color', 'color_compare'],
    view_filter = ['pdaView', 'pdaRelatedView', 'nonPDAView'],
    test_frac = 0.25,
    bs_train = 256,  # batch size for training
    bs_test = 500,  # batch size for testing
    num_workers = 10,  # number of parallel data loading workers
    res = 224, # pixel size along height and width
    device = 'cuda:0',
    model = 'resnet50d',
    weights = {'type': 1.0, 'mode': 0.1, 'view': 0.1},
    num_epochs=12,
    lr = 0.001,
    lr_gamma = 0.92,
    dropout = 0.3,
    weight_decay = 0.001,
    pretrained=True,
    unfreeze_after_n=2,
    lr_unfrozen = 0.00001,
    in_paths = dict(
        frame = f'/zfs/wficai/pda/model_data/{datestamp}_frame.csv',
        video = f'/zfs/wficai/pda/model_data/{datestamp}_video.csv',
        study = f'/zfs/wficai/pda/model_data/{datestamp}_study.csv',
        patient_study = f'/zfs/wficai/pda/model_data/{datestamp}_patient_study.csv',
        patient = f'/zfs/wficai/pda/model_data/{datestamp}_patient.csv'
    ),
    out_paths = dict(
        train = 'train.csv',
        test = 'test.csv'
    ),
    transforms = dict(
        train = 'train',
        test = 'test'
    )
)

with open(artifact_folder + '/config.json', 'w') as f: 
    json.dump(cfg, f, indent=4)
    
# put all config variables in scope to avoid the need to laboriously index cfg
for k, v in cfg.items():
    v = f"'{v}'" if type(v)==str else v
    exec(f"{k}={v}")
del cfg

In [4]:
device = torch.device(device)

In [5]:
df_frame = pd.read_csv(in_paths['frame'])
df_video = pd.read_csv(in_paths['video'])
df_study = pd.read_csv(in_paths['study'])
df_patient_study = pd.read_csv(in_paths['patient_study'])
df_patient = pd.read_csv(in_paths['patient'])

In [6]:
df_study.patient_type.value_counts()

nopda    76
pda      45
Name: patient_type, dtype: int64

In [7]:
df_patient_train, df_patient_test = train_test_split(df_patient, test_size=test_frac, shuffle=True)
df_train = df_patient_train.merge(df_patient_study).merge(df_study, on=['patient_type', 'study']).merge(df_video, on=['patient_type', 'study']).merge(df_frame, on=['patient_type', 'external_id'])
df_test = df_patient_test.merge(df_patient_study).merge(df_study, on=['patient_type', 'study']).merge(df_video, on=['patient_type', 'study']).merge(df_frame, on=['patient_type', 'external_id'])


df_train.to_csv(f"{artifact_folder}/{out_paths['train']}", index=False)
df_test.to_csv(f"{artifact_folder}/{out_paths['test']}", index=False)

if sanity_check: 
    df_train = df_train.sample(frac=sanity_check_frac)
    df_test = df_test.sample(frac=sanity_check_frac)
df_train.shape, df_test.shape

((181809, 14), (44961, 14))

In [8]:
# ensure that patients are disjoint
train_patient = set(df_train.patient_id)
test_patient = set(df_test.patient_id)
assert train_patient.isdisjoint(test_patient), 'Set of train patients and set of test patients are not disjoint!'

# ensure that studies are disjoint
train_study = set(df_train.study + df_train.patient_type)
test_study = set(df_test.study + df_test.patient_type)
assert train_study.isdisjoint(test_study), 'Set of train studies and set of test studies are not disjoint!'

# ensure that videos are disjoint
train_vids = set(df_train.external_id + df_train.patient_type)
test_vids = set(df_test.external_id + df_test.patient_type)
assert train_vids.isdisjoint(test_vids), 'Set of train videos and set of test videos are not disjoint!'

# ensure that frames are disjoint
train_frames = set(df_train.png_path)
test_frames = set(df_test.png_path)
assert train_frames.isdisjoint(test_frames), 'Set of train frames and set of test frames are not disjoint!'

print("All disjoint checks passed")

All disjoint checks passed


In [9]:
tfms = my_transforms.ImageTransforms(res)
tfms_train = tfms.get_transforms(transforms['train'])
tfms_test = tfms.get_transforms(transforms['test'])

In [10]:
# create datasets
d_train = ImageData(df_train, transforms = tfms_train, mode_filter = mode_filter, view_filter = view_filter)
dl_train = DataLoader(d_train, batch_size=bs_train, num_workers=num_workers, shuffle=True)

d_test = ImageData(df_test, transforms = tfms_test, mode_filter = mode_filter, view_filter = view_filter)
dl_test = DataLoader(d_test, batch_size=bs_test, num_workers=num_workers)

print("Train data size after filtering:", len(d_train))
print("Test data size after filtering:", len(d_test))

Train data size after filtering: 181809
Test data size after filtering: 44961


In [11]:
test_batch = next(iter(dl_train))

In [18]:
def train_one_epoch(model, train_dataloader, loss_function, device):
    model.train()

    num_steps_per_epoch = len(train_dataloader)

    losses = []
    for ix, batch in enumerate(train_dataloader):
        inputs = batch['img'].to(device)
        targets = {k: batch[k].to(device).type(torch.float32) for k in ['trg_type', 'trg_mode', 'trg_view']}
        
        predictions = model(inputs)

        loss = loss_function(predictions, targets, weights)
        loss['total'].backward()
        optimizer.step()
        optimizer.zero_grad()

        losses.append({k: v.detach().item() for k, v in loss.items()})
        print(f"\tBatch {ix+1} of {num_steps_per_epoch}. Loss={loss['total'].detach().item():0.3f}", end='\r')
    
    print(' '*100, end='\r')
        
    losses = pd.DataFrame(losses).mean()
    return losses
            
def evaluate(model, test_dataloader, loss_function, device):
    model.eval()

    num_steps_per_epoch = len(test_dataloader)

    patient_ls = []
    target_ls = []
    output_ls = []
    losses = []
    for ix, batch in enumerate(test_dataloader):
        inputs = batch['img'].to(device)
        targets = {k: batch[k].to(device).type(torch.float32) for k in ['trg_type', 'trg_mode', 'trg_view']}
        target_ls.append(targets)
        
        with torch.no_grad():
            predictions = model(inputs)
            output_ls.append(predictions)
            loss = loss_function(predictions, targets, weights)
            
        losses.append(loss)
        
    #compute metrics
    
    metrics = compute_metrics(target_ls, output_ls)
    
    #average loss
    avg_losses = pd.DataFrame(losses).mean()
    
    return avg_losses, metrics

In [19]:
def compute_metrics(target_ls, output_ls):
    y_true = torch.concat([trg['trg_type'] for trg in target_ls]).detach().cpu().numpy()
    y_pred = torch.concat([out['type'] for out in output_ls]).detach().cpu().numpy().squeeze()
    
    # filter out nonPDAViews and 2d images when computing type prediction metrics
    trg_mode = torch.concat([trg['trg_mode'] for trg in target_ls]).detach().cpu().numpy()
    trg_view = torch.concat([trg['trg_view'] for trg in target_ls]).detach().cpu().numpy()
    type_filter = (trg_view==0) | (trg_mode==0)
    y_true = y_true[~type_filter]
    y_pred = y_pred[~type_filter]
    
    y_pred = 1/(1+np.exp(-y_pred))
    y_pred_cls = (y_pred>0.5).astype(int)
    
    mets = dict()    
    mets['roc_auc'] = skmet.roc_auc_score(y_true, y_pred)
    mets['average_precision'] = skmet.average_precision_score(y_true, y_pred)
    mets['accuracy'] = skmet.accuracy_score(y_true, y_pred_cls)
    mets['sensitivity'] = skmet.recall_score(y_true, y_pred_cls)
    mets['specificity'] = skmet.recall_score(y_true, y_pred_cls, pos_label=0)
    
    return mets

In [20]:
# create model
is_encoder_frozen = True if unfreeze_after_n>0 else False    
encoder = timm.create_model(model, pretrained=pretrained, num_classes=1, in_chans=3, drop_rate=dropout)
clf = MultiTaskFrameClassifier(encoder, encoder_frozen=is_encoder_frozen).to(device)
loss_func = MultiTaskFrameClassifier.multi_task_loss

In [21]:
outputs = clf(test_batch['img'].to(device))
targets = {k: test_batch[k].to(device) for k in ['trg_type', 'trg_mode', 'trg_view']}
loss_dict = loss_func(outputs, targets, weights=weights)
loss_dict

{'total': tensor(0.4156, device='cuda:0', grad_fn=<MeanBackward1>),
 'type': tensor(0.1845, device='cuda:0', grad_fn=<MeanBackward1>),
 'type_filtered': tensor(0.7155, device='cuda:0', grad_fn=<MeanBackward1>),
 'mode': tensor(1.1537, device='cuda:0', grad_fn=<MeanBackward1>),
 'view': tensor(1.1581, device='cuda:0', grad_fn=<MeanBackward1>)}

In [22]:
evaluate(clf, dl_test, loss_func, device)

(total            0.386604
 type             0.153986
 type_filtered         NaN
 mode             1.168871
 view             1.157308
 dtype: float64,
 {'roc_auc': 0.40852027752546866,
  'average_precision': 0.44109200866136505,
  'accuracy': 0.47419388070464613,
  'sensitivity': 0.029085033326600687,
  'specificity': 0.9375525651808242})

In [23]:
# fit
optimizer = optim.AdamP(clf.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = ExponentialLR(optimizer, gamma=lr_gamma)

train_loss_ls = []
test_loss_ls = []
metrics_ls = []

best_test_loss = 1000
for epoch in range(num_epochs):
    print("-"*40)
    print(f"Epoch {epoch+1} of {num_epochs}:")

    # maybe unfreeze 
    if epoch >= unfreeze_after_n and is_encoder_frozen:
        print("Unfreezing model encoder.")
        is_encoder_frozen=False
        for p in clf.encoder.parameters():
            p.requires_grad = True
            
        for g in optimizer.param_groups:
            g['lr'] = lr_unfrozen


    # train for a single epoch
    train_loss = train_one_epoch(clf, dl_train, loss_func, device)
    train_loss_ls.append(train_loss)
    print(f"Training:")
    print("\tcross_entropy:")
    for k, v in train_loss.items():
          print(f"\t\t{k} = {v:0.3f}") 

    # evaluate
    test_loss, metrics = evaluate(clf, dl_test, loss_func, device)
    test_loss_ls.append(test_loss)
    # metrics_ls.append(metrics)
    print(f"Test:")
    print("\tcross_entropy:")
    for k, v in test_loss.items():
          print(f"\t\t{k} = {v:0.3f}")
    print(f"\tmetrics (type):")
    for k, v in metrics.items():
        print(f"\t\t{k} = {v:0.3f}")

    # select models with the best type loss
    if test_loss['type'] < best_test_loss:
        torch.save(clf.state_dict(), f"{artifact_folder}/model_checkpoint.ckpt")
        best_test_loss = test_loss['type']
        
    scheduler.step()

----------------------------------------
Epoch 1 of 12:
Training:                                                                                           
	cross_entropy:
		total = 0.259
		type = 0.147
		type_filtered = 0.602
		mode = 0.336
		view = 0.785
Test:
	cross_entropy:
		total = 0.202
		type = 0.124
		type_filtered = nan
		mode = 0.135
		view = 0.648
	metrics (type):
		roc_auc = 0.850
		average_precision = 0.855
		accuracy = 0.687
		sensitivity = 0.961
		specificity = 0.401
----------------------------------------
Epoch 2 of 12:
Training:                                                                                           
	cross_entropy:
		total = 0.240
		type = 0.141
		type_filtered = 0.575
		mode = 0.236
		view = 0.755
Test:
	cross_entropy:
		total = 0.204
		type = 0.127
		type_filtered = nan
		mode = 0.127
		view = 0.638
	metrics (type):
		roc_auc = 0.857
		average_precision = 0.866
		accuracy = 0.663
		sensitivity = 0.969
		specificity = 0.344
----------------------

KeyboardInterrupt: 