In [1]:
import sys
# Use this block to download extra dependencies


In [2]:
import torchvision
import torch
import pydicom as dicom

import numpy as np
import pandas as pd
from functools import partial
import rnsa

from sklearn.model_selection import train_test_split
from sklearn import metrics

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

In [3]:
# %load_ext autoreload
# %autoreload 1

In [4]:
model_save_path = '/home/tstrebel/models/rnsa-densenet.pt'
train_img_dir = '/home/tstrebel/assets/rnsa-pneumonia/train-images'
annotations_file_path = '/home/tstrebel/assets/rnsa-pneumonia/stage_2_train_labels.csv.zip'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
label_df = pd.read_csv(annotations_file_path).groupby('patientId').first().reset_index()

X_train, X_test = train_test_split(label_df, test_size=.2, stratify=label_df.Target, random_state=99)
X_val, X_test, = train_test_split(X_test, test_size=.4, stratify=X_test.Target, random_state=99)
train_ix, val_ix, test_ix = X_train.index.tolist(), X_val.index.tolist(), (X_test.index.tolist())
del(X_train)
del(X_val)
del(X_test)
print('train {:,} - validate {:,} - test {:,}'.format(len(train_ix), len(val_ix), len(test_ix)))

train 21,347 - validate 3,202 - test 2,135


In [6]:
mean = [0.5]
std = [0.225]

train_transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    # torchvision.transforms.RandomRotation((-5, 5)),
    torchvision.transforms.Resize(512),
    torchvision.transforms.CenterCrop(448),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean, std),
])

val_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(512),
    torchvision.transforms.CenterCrop(448),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean, std),
])

label_transform = torchvision.transforms.Compose([
    partial(torch.tensor, dtype=torch.float),
    partial(torch.unsqueeze, dim=0),
])
    
train_dataset = rnsa.RNSADataset(train_img_dir, annotations_file_path, train_ix, train_transform, label_transform)
val_dataset = rnsa.RNSADataset(train_img_dir, annotations_file_path, val_ix, val_transform, label_transform)

In [7]:
model = rnsa.Densenet121(torchvision.models.densenet121(weights='DEFAULT'))
# model = torch.load(model_save_path)

model = model.to(device)

for param in model.features.parameters():
    param.requires_grad = False

optimizer = torch.optim.SGD(model.classifier.parameters(), lr=1e-3, momentum=.9, weight_decay=1e-4)
criterion = torch.nn.BCEWithLogitsLoss()
lr_scheduler = rnsa.LRScheduler(optimizer)

In [None]:
rnsa.train_model(model,  
                 model_save_path,
                 train_dataset, 
                 val_dataset,
                 optimizer, 
                 criterion, 
                 batch_size=32,
                 num_epochs=10
                )

epoch 1/10
----------
train Loss: 	0.5230 Acc: 0.7394 LR: 0.001000 Time elapsed: 5m 50s
validation Loss: 0.4631 Acc: 0.7833 LR: 0.001000 Time elapsed: 6m 46s
epoch 2/10
----------
train Loss: 	0.4864 Acc: 0.7672 LR: 0.001000 Time elapsed: 12m 34s
validation Loss: 0.4523 Acc: 0.7811 LR: 0.001000 Time elapsed: 13m 23s
epoch 3/10
----------
train Loss: 	0.4758 Acc: 0.7728 LR: 0.001000 Time elapsed: 19m 13s
validation Loss: 0.4518 Acc: 0.7848 LR: 0.001000 Time elapsed: 20m 2s
epoch 4/10
----------
train Loss: 	0.4690 Acc: 0.7756 LR: 0.001000 Time elapsed: 25m 53s
validation Loss: 0.4692 Acc: 0.7814 LR: 0.001000 Time elapsed: 26m 44s
epoch 5/10
----------
train Loss: 	0.4646 Acc: 0.7802 LR: 0.001000 Time elapsed: 32m 31s
validation Loss: 0.4381 Acc: 0.7936 LR: 0.001000 Time elapsed: 33m 20s
epoch 6/10
----------
train Loss: 	0.4636 Acc: 0.7813 LR: 0.001000 Time elapsed: 39m 14s
validation Loss: 0.4597 Acc: 0.7720 LR: 0.001000 Time elapsed: 40m 3s
epoch 7/10
----------


In [None]:
model, best_loss, best_acc = rnsa.load_checkpoint(model_save_path, device)

for param in model.features.parameters():
    if not param.requires_grad:
        param.requires_grad = True

optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=.9, weight_decay=1e-4)
lr_scheduler = rnsa.LRScheduler(optimizer)

In [None]:
rnsa.train_model(model,  
                 model_save_path,
                 train_dataset, 
                 val_dataset,
                 optimizer, 
                 criterion, 
                 batch_size=32,
                 num_epochs=10,
                 init_best_loss=best_loss
                 init_best_acc=best_acc)

In [None]:
test_dataset = rnsa.RNSADataset(train_img_dir, annotations_file_path, test_ix, train_transform, label_transform)
test_data_loader = rnsa.get_data_loader(test_dataset, batch_size=32)

with torch.no_grad():
    running_targets = torch.Tensor(0, 1).to(device)
    running_outputs = torch.Tensor(0, 1).to(device)
    
    model.eval()
    for inputs, targets in test_data_loader:
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = model(inputs)
        
        running_targets = torch.vstack((running_targets, targets))
        running_outputs = torch.vstack((running_outputs, outputs)) 

In [None]:
y_true = running_targets.cpu().numpy()
y_proba = torch.nn.Sigmoid()(running_outputs).cpu().numpy()

precision, recall, thresholds = metrics.precision_recall_curve(y_true, y_proba)
f1_scores = (2 * precision * recall) / (precision + recall)
ix = np.argmax(f1_scores)
best_thresh = thresholds[ix]

y_pred = y_proba >= best_thresh #running_outputs.argmax(dim=1).cpu().numpy()

precision = metrics.precision_score(y_true, y_pred)
recall = metrics.recall_score(y_true, y_pred)
f1 = metrics.f1_score(y_true, y_pred)

print('best threshold\t{:.4f}'.format(best_thresh))
print('precision:\t{:.4f}'.format(precision))
print('recall:\t\t{:.4f}'.format(recall))
print('f1:\t\t{:.4f}'.format(f1))
print()

fig = plt.figure(dpi=75)
ax = plt.gca()
sns.heatmap(metrics.confusion_matrix(y_true, y_pred), annot=True, fmt=',', cmap='Blues')
fig.tight_layout()
plt.show()