In [1]:
import pandas as pd
import numpy as np
import os
from importlib import reload

import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns

from sklearn.model_selection import train_test_split
import skimage
from skimage import io
from skimage.transform import resize

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import time
import os
import copy

# local imports
import model as _model
import utils as _tools

In [2]:
"""
control the generation of trainning and dev datasets
"""

train, dev, _ = _tools.make_trainning_data(sample=50000, 
                                           return_frames=True, 
                                           state=1729)

print()
print("Trainning label distribution")
print(train['Cardiomegaly'].value_counts(normalize=True, dropna=False))

print()
print("Development label distribution")
print(dev['Cardiomegaly'].value_counts(normalize=True, dropna=False))

sampling 20000 records
train.shape: (15000, 19)
dev.shape: (5000, 19)
valid.shape: (234, 19)
saved: ./train.csv
saved: ./dev.csv
saved: ./valid.csv

Trainning label distribution
 NaN    0.790267
 1.0    0.123733
 0.0    0.049333
-1.0    0.036667
Name: Cardiomegaly, dtype: float64

Development label distribution
 NaN    0.7884
 1.0    0.1248
 0.0    0.0492
-1.0    0.0376
Name: Cardiomegaly, dtype: float64


In [None]:
"""
build and train the model
"""

reload(_model)

# build the models
resnet = _model.TransferModel(use_cpu=False)

# train + evaluate the model
resnet.train()


-------------------------------
Cardiomegaly Model epoch 1/20


  _warn_prf(average, modifier, msg_start, len(result))


Trainning loss: 0.6604 accuracy: 80.77 %
Validation loss: 0.6497 accuracy: 82.14 %

-------------------------------
Cardiomegaly Model epoch 2/20
Trainning loss: 0.6269 accuracy: 81.39 %
Validation loss: 0.6345 accuracy: 77.96 %

-------------------------------
Cardiomegaly Model epoch 3/20
Trainning loss: 0.6037 accuracy: 81.76 %
Validation loss: 0.6199 accuracy: 80.54 %

-------------------------------
Cardiomegaly Model epoch 4/20
Trainning loss: 0.5816 accuracy: 82.04 %
Validation loss: 0.6107 accuracy: 82.42 %

-------------------------------
Cardiomegaly Model epoch 5/20
Trainning loss: 0.5628 accuracy: 82.64 %
Validation loss: 0.6029 accuracy: 83.88 %

-------------------------------
Cardiomegaly Model epoch 6/20
Trainning loss: 0.5462 accuracy: 83.02 %
Validation loss: 0.5889 accuracy: 80.52 %

-------------------------------
Cardiomegaly Model epoch 7/20
Trainning loss: 0.5276 accuracy: 83.54 %
Validation loss: 0.5815 accuracy: 78.60 %

-------------------------------
Cardiome

Process Process-742:
Process Process-753:
Process Process-727:
Process Process-738:
Process Process-752:
Process Process-751:
Process Process-747:
Process Process-731:
Process Process-745:
Process Process-724:
Process Process-741:
Process Process-744:
Process Process-726:
Process Process-754:
Process Process-746:
Process Process-743:
Process Process-728:
Process Process-740:
Process Process-755:
Process Process-721:
Process Process-750:
Process Process-748:
Process Process-723:
Traceback (most recent call last):
  File "/sw/arcts/centos7/python3.8-anaconda/2020.07/lib/python3.8/multiprocessing/process.py", line 318, in _bootstrap
    util._exit_function()
Traceback (most recent call last):
  File "/sw/arcts/centos7/python3.8-anaconda/2020.07/lib/python3.8/multiprocessing/util.py", line 360, in _exit_function
    _run_finalizers()
  File "/sw/arcts/centos7/python3.8-anaconda/2020.07/lib/python3.8/multiprocessing/process.py", line 318, in _bootstrap
    util._exit_function()
  File "/sw/

In [None]:
# get results on dev set
results = resnet.evaluate_model(resnet.best_model, 
                                resnet.dataloader_dev, 
                                resnet.dev_map)
print(results.shape)

# get distributions of true labels
print()
print(results['y_true'].value_counts(normalize=True))

# get distributions of pred labels
print()
print(results['y_pred'].value_counts(normalize=True))

outpath = f"results/dev_results.csv"
results.to_csv(outpath, index=False)

In [None]:
results.head()

In [None]:
matplotlib.rcParams['figure.dpi'] = 150
results['y_prob'].hist(edgecolor='black', bins=30)
plt.title('Distribution of Propensities')

In [None]:
_time = list(range(len(resnet.train_loss_history)))

matplotlib.rcParams['figure.dpi'] = 150
plt.plot(_time, resnet.train_loss_history, c='black', label="Trainning")
plt.plot(_time, resnet.dev_loss_history, c='black', ls=":", label="Testing")
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Cross Entropy Loss")
plt.legend()
outpath = f"results/training_loss.png"
plt.savefig(outpath, bbox_inches='tight')

In [None]:
_time = list(range(len(resnet.train_acc_history)))

matplotlib.rcParams['figure.dpi'] = 150
plt.plot(_time, resnet.train_acc_history,  c='black', label="Trainning")
plt.plot(_time, resnet.dev_acc_history, c='black', ls=":", label="Testing")
plt.title("Trainning Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
outpath = f"results/training_accuracy.png"
plt.savefig(outpath, bbox_inches='tight')

In [None]:
reload(_tools)

res = _tools.get_classification_metrics(results)
res

In [None]:
"""
save the model
"""

outpath = f"models/{resnet.condition}_resnet18.pth"
torch.save(resnet.model.state_dict(), outpath)
print(f"saved: {outpath}")

In [None]:
# get results on valid set
results = resnet.evaluate_model(resnet.best_model, 
                                resnet.dataloader_valid, 
                                resnet.valid_map)
print(results.shape)

# get distributions of true labels
print()
print(results['y_true'].value_counts(normalize=True))

# get distributions of pred labels
print()
print(results['y_pred'].value_counts(normalize=True))

outpath = f"results/validation_results.csv"
results.to_csv(outpath, index=False)

print()
res = _tools.get_classification_metrics(results)
res

In [None]:
!git add .

In [None]:
!git commit -m "model updates"