In [1]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}

In [2]:
import pandas as pd
import numpy as np
from sklearn.model_selection import RepeatedStratifiedKFold
from torch.utils.data import DataLoader
import timm
from timm import optim, scheduler
import torch
from torch import nn
from torch.optim.lr_scheduler import ExponentialLR

from sklearn.model_selection import train_test_split
from sklearn import metrics as skmet
from jupyterplot import ProgressPlot
import matplotlib.pyplot as plt
import os
import json

import copy

import transforms as my_transforms
from dataset import VideoData
from models_multitask import FrameClassifier, VideoClassifier_PI_PI

# Settings

### Import Frame Model Settings

In [3]:
artifact_folder = '/zfs/wficai/pda/model_run_artifacts/20220818_multitask_224x224'

with open(artifact_folder + '/config.json', 'r') as f: 
    cfg = json.load(f)

### Set Video-level settings
Some will override frame model settings

In [4]:
cfg_video = dict(
    bs_train = 6,  # batch size for training
    bs_test = 6,  # batch size for testing
    num_workers = 30,  # number of parallel data loading workers
    device = 'cuda:0',
    num_epochs=20,
    lr = 0.001,
    lr_unfrozen = 0.0001,
    lr_gamma = 0.92,
    time_downsample_factor = 8,
    time_downsample_method = 'random',
    dropout = 0.3,
    weight_decay = 0.001,
    pretrained=True,
    unfreeze_after_n=3,
    video_transforms = dict(
        train = 'train',
        test = 'test'
    )
)

cfg.update(cfg_video)

with open(artifact_folder + '/config_video.json', 'w') as f:
    json.dump(cfg, f, indent=4)

# put all config variables in scope to avoid the need to laboriously index cfg
for k, v in cfg.items():
    v = f"'{v}'" if type(v)==str else v
    exec(f"{k}={v}")
    
del cfg

In [5]:
device = torch.device(device)

In [6]:
# we need to use the same train/test split as was used in the frame model to avoid data leakage
df_train = pd.read_csv(f'{artifact_folder}/{out_paths["train"]}')
df_test = pd.read_csv(f'{artifact_folder}/{out_paths["test"]}')
df_train.shape, df_test.shape

((152630, 14), (43956, 14))

In [7]:
# ensure that patients are disjoint
train_patient = set(df_train.patient_id)
test_patient = set(df_test.patient_id)
assert train_patient.isdisjoint(test_patient), 'Set of train patients and set of test patients are not disjoint!'

# ensure that studies are disjoint
train_study = set(df_train.study + df_train.patient_type)
test_study = set(df_test.study + df_test.patient_type)
assert train_study.isdisjoint(test_study), 'Set of train studies and set of test studies are not disjoint!'

# ensure that videos are disjoint
train_vids = set(df_train.external_id + df_train.patient_type)
test_vids = set(df_test.external_id + df_test.patient_type)
assert train_vids.isdisjoint(test_vids), 'Set of train videos and set of test videos are not disjoint!'

# ensure that frames are disjoint
train_frames = set(df_train.png_path)
test_frames = set(df_test.png_path)
assert train_frames.isdisjoint(test_frames), 'Set of train frames and set of test frames are not disjoint!'

print("All disjoint checks passed")

All disjoint checks passed


In [8]:
tfms = my_transforms.VideoTransforms(res, time_downsample_factor)
tfms_train = tfms.get_transforms(transforms['train'])
tfms_test = tfms.get_transforms(transforms['test'])
tfms_train, tfms_test

(Compose(
     RandomEqualize(p=0.5)
     RandAugment(num_ops=2, magnitude=9, num_magnitude_bins=31, interpolation=InterpolationMode.NEAREST, fill=None)
     DownsampleTime()
     ConvertImageDtype()
     UpsamplingBilinear2d(size=246, mode=bilinear)
     CenterCrop(size=(224, 224))
     RandomErasing(p=0.5, scale=(0.02, 0.1), ratio=(0.3, 3.3), value=0, inplace=False)
     Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
     RandomHorizontalFlip(p=0.5)
     RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)
     RandomInvert(p=0.5)
 ),
 Compose(
     ConvertImageDtype()
     UpsamplingBilinear2d(size=246, mode=bilinear)
     CenterCrop(size=(224, 224))
     Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
 ))

In [9]:
def train_one_epoch(model, train_dataloader, loss_function, device):
    model.train()

    num_steps_per_epoch = len(train_dataloader)

    losses = []
    for ix, batch in enumerate(train_dataloader):
        inputs = batch['video'].to(device)
        num_frames = batch['num_frames']
        targets = {k: batch[k].to(device).type(torch.float32) for k in ('trg_type', 'trg_view', 'trg_mode')}
        outputs, _ = model(inputs, num_frames)
        loss = loss_function(outputs, targets)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        losses.append(loss.detach().item())
        print(f"\tBatch {ix+1} of {num_steps_per_epoch}. Loss={loss.detach().item():0.3f}", end='\r')
    
    print(' '*100, end='\r')
        
    return np.mean(losses)
            
            
def evaluate(model, test_dataloader, loss_function, device):
    model.eval()

    num_steps_per_epoch = len(test_dataloader)

    patient_ls = []
    target_ls = []
    output_ls = []
    losses = []
    for ix, batch in enumerate(test_dataloader):
        inputs = batch['video'].to(device)
        num_frames = batch['num_frames']
        targets = {k: batch[k].cpu().type(torch.float32).numpy() for k in ('trg_type', 'trg_view', 'trg_mode')}
        target_ls.append(targets)
        
        with torch.no_grad():
            outputs, _ = model(inputs, num_frames)
            outputs = {k: v.cpu().numpy() for k, v in outputs.items()}
            output_ls.append(outputs)
            loss = {k: v.detach().item() for k, v in loss_function(outputs, targets).items()}
            
        losses.append(loss)
        
    # metrics = compute_metrics(np.concatenate(target_ls), np.concatenate(output_ls))
    return np.mean(losses), None #metrics

In [10]:
# Sumanth todo: check out 3a_model_frames-multitask 'compute_metrics' function
def compute_metrics(y_true, y_pred):
    mets = dict()
    
    y_pred_cls = (y_pred>0.5).astype(int)
    
    mets['roc_auc'] = skmet.roc_auc_score(y_true, y_pred)
    mets['average_precision'] = skmet.average_precision_score(y_true, y_pred)
    mets['accuracy'] = skmet.accuracy_score(y_true, y_pred_cls)
    mets['sensitivity'] = skmet.recall_score(y_true, y_pred_cls)
    mets['specificity'] = skmet.recall_score(y_true, y_pred_cls, pos_label=0)
    
    return mets

In [11]:
# create datasets
d_train = VideoData(df_train, transforms = tfms_train, mode_filter = mode_filter, view_filter = view_filter)
dl_train = DataLoader(d_train, batch_size=bs_train, num_workers=num_workers, shuffle=True, collate_fn=VideoData.collate, pin_memory=True)

d_test = VideoData(df_test, transforms = tfms_test, mode_filter = mode_filter, view_filter = view_filter)
dl_test = DataLoader(d_test, batch_size=bs_test, num_workers=num_workers, collate_fn=VideoData.collate, pin_memory=True)

print("Train data size:", len(d_train))
print("Test data size:", len(d_test))

Train data size: 1692
Test data size: 457


In [12]:
test_batch = next(iter(dl_train))
test_batch['video'].shape

torch.Size([86, 3, 224, 224])

In [13]:
test_batch['num_frames']

[12, 10, 19, 18, 13, 14]

In [44]:
del FrameClassifier, VideoClassifier_PI_PI
from models_multitask import FrameClassifier, VideoClassifier_PI_PI

encoder = timm.create_model(model, pretrained=pretrained, num_classes=1, in_chans=3, drop_rate=dropout)
clf_frames = FrameClassifier(encoder, encoder_frozen=True).to(device)

# load pretrained weights for frame classifier
clf_frames.load_state_dict(torch.load(f"{artifact_folder}/model_checkpoint.ckpt"))

loss_func = FrameClassifier.multi_task_loss

# create video model
m = VideoClassifier_PI_PI(clf_frames, encoder_frozen=True, frame_classifier_frozen=False).to(device)

In [45]:
# evaluate on test batch
with torch.no_grad():
    y, attn = m(test_batch['video'].to(device), test_batch['num_frames'])
# y, test_batch['trg_type'], attn.shape

y

{'type': tensor([[ 0.4764],
         [-0.2844],
         [-0.0137],
         [ 0.2134],
         [ 0.5181],
         [ 0.2379]], device='cuda:0'),
 'mode': tensor([[ 1.7057, -1.9115, -1.9183],
         [ 2.6653, -0.3951, -4.6629],
         [-4.1133,  0.0144,  1.9066],
         [-1.0395, -3.2064,  2.0566],
         [ 1.6044, -0.8762, -2.7260],
         [ 0.2205, -0.6102, -1.7754]], device='cuda:0'),
 'view': tensor([[ 0.5037, -0.3100, -1.2585],
         [ 1.2725, -0.9084, -1.9620],
         [ 0.2403, -0.5814, -0.9293],
         [ 1.0006, -0.7483, -1.2813],
         [ 0.4912, -0.7867, -1.2312],
         [ 0.4742, -0.8788, -0.9785]], device='cuda:0')}

{'type': tensor([[ 0.5183],
         [-0.1875],
         [-0.0223],
         [ 0.2584],
         [ 0.5277],
         [ 0.2175]], device='cuda:0'),
 'mode': tensor([[ 1.7078, -1.9479, -1.8719],
         [ 2.6794, -0.4279, -4.6418],
         [-4.1952,  0.0840,  1.9263],
         [-1.0847, -3.1737,  2.0701],
         [ 1.5608, -0.9756, -2.6079],
         [ 0.1716, -0.5530, -1.7750]], device='cuda:0'),
 'view': tensor([[ 0.4843, -0.3117, -1.2315],
         [ 1.2347, -0.9009, -1.8831],
         [ 0.2223, -0.5681, -0.9079],
         [ 1.0061, -0.7631, -1.2782],
         [ 0.3817, -0.7295, -1.1251],
         [ 0.4934, -0.8746, -1.0141]], device='cuda:0')}

In [None]:
# fit
optimizer = optim.AdamP(m.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = ExponentialLR(optimizer, gamma=lr_gamma)
loss_function = torch.functional.F.binary_cross_entropy

In [None]:
evaluate(m, dl_test, loss_function, device)

In [None]:
train_loss_ls = []
test_loss_ls = []
metrics_ls = []
metrics_agg_ls = []

# progress plot
pp_metrics = ProgressPlot(x_lim=[1,num_epochs], y_lim=[0,1], plot_names = ['metrics'], x_label="Epoch", line_names=['AUROC', 'Avg. Prec.', 'Acc.', 'Sensitivity', 'Specificity'])

best_test_loss = 1000
is_frozen = True
for epoch in range(num_epochs):
    print("-"*40)
    print(f"Epoch {epoch+1} of {num_epochs}:")
    
    # maybe unfreeze 
    if epoch >= unfreeze_after_n and is_frozen:
        print("Unfreezing model encoder.")
        is_frozen=False
        for p in m.encoder.parameters():
            p.requires_grad = True
            
        # set all learning rates to the lower lr_unfrozen learning rate
        for g in optimizer.param_groups:
            g['lr'] = lr_unfrozen

    # train for a single epoch
    train_loss = train_one_epoch(m, dl_train, loss_function, device)
    train_loss_ls.append(train_loss)
    print(f"Training:")
    print(f"\tcross_entropy = {train_loss:0.3f}")       

    # evaluate
    test_loss, metrics = evaluate(m, dl_test, loss_function, device)
    test_loss_ls.append(test_loss)
    metrics_ls.append(metrics)
    print(f"Test:")
    print(f"\tcross_entropy = {test_loss:0.3f}")
    print(f"\tmetrics:")
    for k, v in metrics.items():
        print(f"\t\t{k} = {v:0.3f}")

    if test_loss < best_test_loss:
        torch.save(m.state_dict(), f"{artifact_folder}/model_checkpoint_video.ckpt")
        best_test_loss = test_loss
        
    scheduler.step()

    # TODO: use study-aggregated metrics
    pp_metrics.update([[v for _,v in metrics.items()]])

pp_metrics.finalize()