In [1]:
%load_ext autoreload
%autoreload 1
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}

In [2]:
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader
import timm
from timm import optim, scheduler
import torch
from torchvision import transforms as tfm
from sklearn import metrics as skmet
import matplotlib.pyplot as plt
import json
import transforms as my_transforms

%aimport dataset
from models import MultiTaskFrameClassifier
ImageData = dataset.ImageData

In [3]:
artifact_folder = '/zfs/wficai/pda/model_run_artifacts/20220818_multitask_224x224'

with open(artifact_folder + '/config.json', 'r') as f: 
    cfg = json.load(f)

# put all config variables in scope to avoid the need to laboriously index cfg
for k, v in cfg.items():
    v = f"'{v}'" if type(v)==str else v
    exec(f"{k}={v}")
del cfg

In [4]:
# optionally override settings
view_filter = ['pdaView', 'pdaRelatedView', 'nonPDAView']
mode_filter = ['2d', 'color', 'color_compare']
device = torch.device('cuda:1')  # you may need 'cuda:0'

In [5]:
tfms = my_transforms.ImageTransforms(res)
tfms_test = tfms.get_transforms(transforms['test'])

In [6]:
df_test = pd.read_csv(f'{artifact_folder}/{out_paths["test"]}')
d_test = ImageData(df_test, transforms = tfms_test, mode_filter = mode_filter, view_filter = view_filter)
dl_test = DataLoader(d_test, batch_size=bs_test, num_workers=num_workers)

print("Number of frames after filtering:", len(d_test.data))

Number of frames after filtering: 44961


In [7]:
# create model
encoder = timm.create_model(model, pretrained=pretrained, num_classes=1, in_chans=3, drop_rate=dropout)
clf = MultiTaskFrameClassifier(encoder).to(device)    
clf.load_state_dict(torch.load(f"{artifact_folder}/model_checkpoint.ckpt"))
clf.eval()
loss_function = MultiTaskFrameClassifier.multi_task_loss

target_ls = []
output_ls = []
study_ls = []
video_ls = []
view_ls = []
mode_ls = []
losses = []

for ix, batch in enumerate(dl_test):
    print(f"Batch {ix+1}", end = "\r")
    inputs = batch['img'].to(device)
    targets = {k: batch[k].to(device).type(torch.float32) for k in ['trg_type', 'trg_mode', 'trg_view']}
    
    target_ls.append(batch['trg_type'].numpy())
    view_ls.append(batch['trg_view'].numpy())
    mode_ls.append(batch['trg_mode'].numpy())
    study_ls += batch['study']
    video_ls += batch['video']

    with torch.no_grad():
        outputs = clf(inputs)
        output_ls.append(outputs)
        loss = loss_function(outputs, targets, weights)
        losses.append(loss)

Batch 90

# Compute Metrics

In [8]:
def compute_metrics(y_true, y_pred, thresh=0.5):
    mets = dict()
    is_multiclass = (len(y_pred.shape)==2) & (y_pred.shape[-1]>1)
    
    if not is_multiclass:
        y_pred_cls = (y_pred>thresh).astype(int)
    else:
        y_pred_cls = np.argmax(y_pred, axis=-1)
    
    mets['num_samples'] = len(y_true)
    mets['roc_auc'] = skmet.roc_auc_score(y_true, y_pred, multi_class='ovr')
    mets['accuracy'] = skmet.accuracy_score(y_true, y_pred_cls)
    mets['sensitivity'] = skmet.recall_score(y_true, y_pred_cls, average='micro')
    
    if not is_multiclass:
        mets['specificity'] = skmet.recall_score(y_true, y_pred_cls, pos_label=0)
    
    return mets

In [9]:
pred_type = np.concatenate([out['type'].cpu().squeeze() for out in output_ls])
pred_type = 1/(1+np.exp(-pred_type))
trg_type = np.concatenate([trg.squeeze() for trg in target_ls])

pred_view = np.concatenate([out['view'].cpu().squeeze() for out in output_ls])
pred_view = np.exp(pred_view) / np.exp(pred_view).sum(axis=-1, keepdims=True)
trg_view = np.concatenate([trg.squeeze() for trg in view_ls])

pred_mode = np.concatenate([out['mode'].cpu().squeeze() for out in output_ls])
pred_mode = np.exp(pred_mode) / np.exp(pred_mode).sum(axis=-1, keepdims=True)
trg_mode = np.concatenate([trg.squeeze() for trg in mode_ls])

### PDA

In [20]:
df_pda = pd.DataFrame({'type': trg_type, 'pred': pred_type, 'video': video_ls, 'mode': trg_mode, 'view': trg_view})
df_pda_unmapped = df_pda.copy()
df_pda['mode'] = df_pda['mode'].map(ImageData.inv_mode_map)
df_pda['view'] = df_pda['view'].map(ImageData.inv_view_map)
df_pda.head()

Unnamed: 0,type,pred,video,mode,view
0,0,0.37201,study37_dicom89,color,pdaRelatedView
1,0,0.164154,study37_dicom89,color,pdaRelatedView
2,0,0.160626,study37_dicom89,color,pdaRelatedView
3,0,0.743612,study37_dicom89,color,pdaRelatedView
4,0,0.104012,study37_dicom89,color,pdaRelatedView


In [11]:
print("frame-level scores:")
compute_metrics(df_pda['type'], df_pda['pred'])

frame-level scores:


{'num_samples': 44961,
 'roc_auc': 0.7078433334050096,
 'accuracy': 0.6567024754787483,
 'sensitivity': 0.6567024754787483,
 'specificity': 0.7392204857842214}

In [12]:
grouped_results = df_pda.groupby(['view', 'mode']).apply(lambda x: compute_metrics(x['type'], x['pred']))
grouped_results = pd.DataFrame(grouped_results.tolist(), index=grouped_results.index)
grouped_results

Unnamed: 0_level_0,Unnamed: 1_level_0,num_samples,roc_auc,accuracy,sensitivity,specificity
view,mode,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
nonPDAView,2d,11631,0.70104,0.63838,0.63838,0.944585
nonPDAView,color,12073,0.68347,0.631078,0.631078,0.589045
nonPDAView,color_compare,6800,0.867283,0.680147,0.680147,0.641061
pdaRelatedView,2d,3107,0.776489,0.516897,0.516897,0.989407
pdaRelatedView,color,2250,0.935112,0.822222,0.822222,0.7225
pdaRelatedView,color_compare,2091,0.865418,0.687709,0.687709,0.55814
pdaView,2d,1643,0.672757,0.468655,0.468655,0.991379
pdaView,color,1374,0.912451,0.808588,0.808588,0.598326
pdaView,color_compare,3992,0.871356,0.772044,0.772044,0.701433


In [14]:
print('video-level-scores')
df_pda_vid = df_pda.groupby(['type', 'video', 'mode', 'view'], as_index=False).agg('mean')
compute_metrics(df_pda_vid['type'], df_pda_vid['pred'])

video-level-scores


{'num_samples': 478,
 'roc_auc': 0.7195093810462041,
 'accuracy': 0.6694560669456067,
 'sensitivity': 0.6694560669456067,
 'specificity': 0.7715355805243446}

In [15]:
grouped_results = df_pda_vid.groupby(['view', 'mode']).apply(lambda x: compute_metrics(x['type'], x['pred']))
grouped_results = pd.DataFrame(grouped_results.tolist(), index=grouped_results.index)
grouped_results

Unnamed: 0_level_0,Unnamed: 1_level_0,num_samples,roc_auc,accuracy,sensitivity,specificity
view,mode,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
nonPDAView,2d,139,0.743721,0.611511,0.611511,0.975309
nonPDAView,color,128,0.718988,0.632812,0.632812,0.492754
nonPDAView,color_compare,57,0.951389,0.754386,0.754386,0.729167
pdaRelatedView,2d,36,0.86875,0.5,0.5,1.0
pdaRelatedView,color,28,0.964103,0.892857,0.892857,0.8
pdaRelatedView,color_compare,20,0.947917,0.75,0.75,0.583333
pdaView,2d,19,0.72619,0.368421,0.368421,1.0
pdaView,color,17,1.0,1.0,1.0,1.0
pdaView,color_compare,34,0.917857,0.852941,0.852941,0.785714


In [16]:
df_pda_vid_goodviews = df_pda_vid.query('mode!="2d" and view!="nonPDAView"')
compute_metrics(df_pda_vid_goodviews['type'], df_pda_vid_goodviews['pred'])

{'num_samples': 99,
 'roc_auc': 0.9454470877768664,
 'accuracy': 0.8686868686868687,
 'sensitivity': 0.8686868686868687,
 'specificity': 0.7608695652173914}

In [17]:
print("View prediction")
compute_metrics(trg_view, pred_view)

View prediction


{'num_samples': 44961,
 'roc_auc': 0.8871884989028015,
 'accuracy': 0.7831453926736505,
 'sensitivity': 0.7831453926736505}

In [19]:
df_pda_vid.view

0          nonPDAView
1          nonPDAView
2          nonPDAView
3      pdaRelatedView
4      pdaRelatedView
            ...      
473        nonPDAView
474        nonPDAView
475        nonPDAView
476    pdaRelatedView
477        nonPDAView
Name: view, Length: 478, dtype: object

In [18]:
print("Mode prediction")
compute_metrics(trg_mode, pred_mode)

Mode prediction


{'num_samples': 44961,
 'roc_auc': 0.9929284210019661,
 'accuracy': 0.9867218255821713,
 'sensitivity': 0.9867218255821713}