In [1]:
%load_ext autoreload
%autoreload 1
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}

In [2]:
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader
import timm
from timm import optim, scheduler
import torch
from torchvision import transforms as tfm
from sklearn import metrics as skmet
import matplotlib.pyplot as plt
import json
import transforms as my_transforms

%aimport dataset
ImageData = dataset.ImageData

In [3]:
# artifact_folder = '/mnt/data/pda/model_run_artifacts/20220818_just2d_64x64'
# artifact_folder = '/mnt/data/pda/model_run_artifacts/20220818_justcolor_64x64'
# artifact_folder = '/mnt/data/pda/model_run_artifacts/20220818_justcolornonpda_64x64'
# artifact_folder = '/mnt/data/pda/model_run_artifacts/20220818_justcolorjustpdaview_128x128'
# artifact_folder = '/mnt/data/pda/model_run_artifacts/20220818_no2d_224x224'
artifact_folder = '/mnt/data/pda/model_run_artifacts/20220818_justcolor_224x224'

with open(artifact_folder + '/config.json', 'r') as f: 
    cfg = json.load(f)

# put all config variables in scope to avoid the need to laboriously index cfg
for k, v in cfg.items():
    v = f"'{v}'" if type(v)==str else v
    exec(f"{k}={v}")
del cfg

In [4]:
# optionally override the view/mode filters
view_filter = ['pdaView', 'pdaRelatedView', 'nonPDAView']
mode_filter = ['2d', 'color', 'color_compare']

In [5]:
tfms = my_transforms.Transforms(res)
tfms_test = tfms.get_transforms(transforms['test'])

In [6]:
df_test = pd.read_csv(f'{artifact_folder}/{out_paths["test"]}')
d_test = ImageData(df_test, transforms = tfms_test, mode_filter = mode_filter, view_filter = view_filter)
dl_test = DataLoader(d_test, batch_size=bs_test, num_workers=num_workers)

print("Number of frames after filtering:", len(d_test.data))

Number of frames after filtering: 54582


In [7]:
# create model
m = timm.create_model(model, pretrained=pretrained, checkpoint_path = f"{artifact_folder}/model_checkpoint.ckpt", num_classes=num_classes, in_chans=3)
m.to(device)           
m.eval()

loss_function = torch.functional.F.binary_cross_entropy_with_logits

target_ls = []
output_ls = []
study_ls = []
video_ls = []
view_ls = []
mode_ls = []
losses = []

for ix, batch in enumerate(dl_test):
    print(f"Batch {ix+1}", end = "\r")
    inputs = batch['img'].to(device)
    targets = batch['trg_type'].to(device).type(torch.float32)
    target_ls.append(targets.cpu().numpy())
    view_ls.append(batch['trg_view'].numpy())
    mode_ls.append(batch['trg_mode'].numpy())
    study_ls += batch['study']
    video_ls += batch['video']

    with torch.no_grad():
        outputs = m(inputs)
        output_ls.append(outputs.cpu().numpy())
        loss = loss_function(outputs.squeeze(), targets)
        losses.append(loss.detach().item())

Batch 110

In [8]:
df_results = pd.DataFrame(dict(
    study = study_ls,
    video = video_ls,
    predicted = np.concatenate(output_ls).squeeze(),
    target = np.concatenate(target_ls), 
    mode = np.concatenate(mode_ls),
    view = np.concatenate(view_ls)
))

df_results['mode'] = df_results['mode'].map(ImageData.inv_mode_map)
df_results.view = df_results.view.map(ImageData.inv_view_map)

df_results.predicted = 1 / (1 + np.exp(-df_results.predicted))

df_results.head(20)

Unnamed: 0,study,video,predicted,target,mode,view
0,study19,study19_dicom52,0.461446,0.0,2d,nonPDAView
1,study19,study19_dicom52,0.432965,0.0,2d,nonPDAView
2,study19,study19_dicom52,0.460939,0.0,2d,nonPDAView
3,study19,study19_dicom52,0.38377,0.0,2d,nonPDAView
4,study19,study19_dicom52,0.454097,0.0,2d,nonPDAView
5,study19,study19_dicom52,0.441345,0.0,2d,nonPDAView
6,study19,study19_dicom52,0.474623,0.0,2d,nonPDAView
7,study19,study19_dicom52,0.429579,0.0,2d,nonPDAView
8,study19,study19_dicom52,0.433079,0.0,2d,nonPDAView
9,study19,study19_dicom52,0.463418,0.0,2d,nonPDAView


In [9]:
def compute_metrics(y_true, y_pred, thresh=0.5):
    mets = dict()
    
    y_pred_cls = (y_pred>thresh).astype(int)
    
    mets['num_samples'] = len(y_true)
    mets['roc_auc'] = skmet.roc_auc_score(y_true, y_pred)
    mets['average_precision'] = skmet.average_precision_score(y_true, y_pred)
    mets['accuracy'] = skmet.accuracy_score(y_true, y_pred_cls)
    mets['sensitivity'] = skmet.recall_score(y_true, y_pred_cls)
    mets['specificity'] = skmet.recall_score(y_true, y_pred_cls, pos_label=0)
    
    return mets

# Frame-level results

In [10]:
compute_metrics(df_results.target, df_results.predicted)

{'num_samples': 54582,
 'roc_auc': 0.574522790223555,
 'average_precision': 0.5519508444731642,
 'accuracy': 0.5437323659814591,
 'sensitivity': 0.4877895360023681,
 'specificity': 0.5985992161416751}

In [11]:
grouped_results = df_results.groupby(['view', 'mode']).apply(lambda dat: compute_metrics(dat.target, dat.predicted))
grouped_results = pd.DataFrame(grouped_results.tolist(), index=grouped_results.index)
grouped_results

Unnamed: 0_level_0,Unnamed: 1_level_0,num_samples,roc_auc,average_precision,accuracy,sensitivity,specificity
view,mode,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
nonPDAView,2d,15735,0.654647,0.680514,0.507976,0.060714,0.988921
nonPDAView,color,11095,0.543686,0.431466,0.51023,0.679347,0.396167
nonPDAView,color_compare,10052,0.48503,0.373204,0.476821,0.666091,0.349
pdaRelatedView,2d,3137,0.581526,0.591499,0.478801,0.057716,0.943662
pdaRelatedView,color,2269,0.695216,0.697527,0.650948,0.800499,0.482176
pdaRelatedView,color_compare,3576,0.725591,0.802779,0.671421,0.815777,0.452498
pdaView,2d,1122,0.705883,0.720279,0.524064,0.058511,0.994624
pdaView,color,2381,0.755488,0.728158,0.640487,0.880365,0.436236
pdaView,color_compare,5215,0.712495,0.853899,0.716779,0.851179,0.391874


# Clip-level results

### Avg confidence over frames

In [12]:
df_results_clip_avg = df_results.groupby(['study', 'video', 'target', 'view', 'mode'], as_index=False).agg('mean')
display(df_results_clip_avg.head(10))

compute_metrics(df_results_clip_avg.target, df_results_clip_avg.predicted)

Unnamed: 0,study,video,target,view,mode,predicted
0,study11,study11_dicom100,1.0,nonPDAView,color,0.53933
1,study11,study11_dicom40,1.0,nonPDAView,2d,0.422019
2,study11,study11_dicom41,1.0,nonPDAView,2d,0.441315
3,study11,study11_dicom42,1.0,nonPDAView,color,0.479398
4,study11,study11_dicom43,1.0,nonPDAView,2d,0.424171
5,study11,study11_dicom44,1.0,nonPDAView,color,0.446651
6,study11,study11_dicom45,1.0,nonPDAView,color,0.474796
7,study11,study11_dicom46,1.0,nonPDAView,2d,0.454322
8,study11,study11_dicom47,1.0,nonPDAView,color,0.489328
9,study11,study11_dicom49,1.0,nonPDAView,2d,0.457639


{'num_samples': 546,
 'roc_auc': 0.5635840868478685,
 'average_precision': 0.5508111326169072,
 'accuracy': 0.5238095238095238,
 'sensitivity': 0.4828897338403042,
 'specificity': 0.5618374558303887}

In [24]:
grouped_results_clip_avg = df_results_clip_avg.\
    groupby(['view', 'mode']).\
    apply(lambda dat: compute_metrics(dat.target, dat.predicted, thresh=0.55))
grouped_results_clip_avg = pd.DataFrame(grouped_results_clip_avg.tolist(), index=grouped_results.index)
grouped_results_clip_avg

Unnamed: 0_level_0,Unnamed: 1_level_0,num_samples,roc_auc,average_precision,accuracy,sensitivity,specificity
view,mode,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
nonPDAView,2d,179,0.710428,0.745993,0.486034,0.010753,1.0
nonPDAView,color,122,0.548225,0.454327,0.581967,0.326531,0.753425
nonPDAView,color_compare,78,0.490489,0.383704,0.474359,0.28125,0.608696
pdaRelatedView,2d,38,0.530556,0.596702,0.473684,0.0,1.0
pdaRelatedView,color,29,0.861905,0.83988,0.758621,0.571429,0.933333
pdaRelatedView,color_compare,30,0.873303,0.909921,0.7,0.529412,0.923077
pdaView,2d,13,0.928571,0.930556,0.538462,0.0,1.0
pdaView,color,24,0.964286,0.943012,0.875,1.0,0.785714
pdaView,color_compare,33,0.855372,0.941341,0.818182,0.727273,1.0


### Misses

In [14]:
true = df_results_clip_avg.target
pred_cls = (df_results_clip_avg.predicted>0.5).astype(int)

df_results_clip_avg[true != pred_cls]

Unnamed: 0,study,video,target,view,mode,predicted
1,study11,study11_dicom40,1.0,nonPDAView,2d,0.422019
2,study11,study11_dicom41,1.0,nonPDAView,2d,0.441315
3,study11,study11_dicom42,1.0,nonPDAView,color,0.479398
4,study11,study11_dicom43,1.0,nonPDAView,2d,0.424171
5,study11,study11_dicom44,1.0,nonPDAView,color,0.446651
...,...,...,...,...,...,...
540,study8,study8_dicom87,0.0,nonPDAView,color,0.555935
541,study8,study8_dicom89,0.0,nonPDAView,color_compare,0.603827
542,study8,study8_dicom91,0.0,nonPDAView,color,0.519730
543,study8,study8_dicom92,0.0,nonPDAView,color_compare,0.551433


# Study-level results

In [25]:
df_results_study_avg = df_results_clip_avg.groupby(['study', 'target'], as_index=False).agg('mean')
display(df_results_study_avg)

compute_metrics(df_results_study_avg.target, df_results_study_avg.predicted)

Unnamed: 0,study,target,predicted
0,study11,1.0,0.491512
1,study12,1.0,0.493167
2,study13,0.0,0.477396
3,study13,1.0,0.4906
4,study14,1.0,0.519486
5,study19,0.0,0.476614
6,study2,0.0,0.501576
7,study21,1.0,0.456932
8,study23,1.0,0.512727
9,study25,0.0,0.504997


{'num_samples': 27,
 'roc_auc': 0.7235294117647059,
 'average_precision': 0.5566394716394717,
 'accuracy': 0.6666666666666666,
 'sensitivity': 0.6,
 'specificity': 0.7058823529411765}