In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

from datautility import *
from dataset import *
from vnet import *
from training import *

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader, sampler, SubsetRandomSampler
from torchvision import transforms, utils

import torch.nn.functional as F  # useful stateless functions
import torchvision.transforms as T

#------------------------------- GLOBAL VARIABLES -------------------------------------#

USE_GPU = True
BATCH_SIZE = 8
NUM_WORKERS = 6
NUM_TRAIN = 80 # 80 training sample and 37 validation sample
LEARNING_RATE = 1e-2

dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


## If Training new data only
* Positive: 67 samples
* Negative: 50 samples

In [3]:
positive_idx = np.arange(73) + 1
positive_idx = np.delete(positive_idx, [8, 15, 16, 17, 29, 33])

negative_idx = np.arange(53) + 1
negative_idx = np.delete(negative_idx, [14, 15, 16])

image_dict = np.concatenate((positive_idx, negative_idx))

metric = ['ad', 'ak', 'awf', 'eas_De_par', 'eas_De_perp', 'eas_tort', 'FA', 'ias_Da', 'md', 'mk', 'rd', 'rk']
print('{} subjects with {} metrics each'.format(len(image_dict), len(metric)))

117 subjects with 12 metrics each


In [4]:
regen = False

if regen:
    data_index = np.arange(107)
    data_idnex = np.random.shuffle(data_index)
    print(list(data_index))
    
else:
    image_index = np.array ([40, 64, 58, 103, 19, 5, 68, 56, 66, 10, 75, 43, 1, 81, 83, 49, 11, 80, 102,\
                             82, 69, 13, 4, 61, 70, 100, 23, 72, 55, 16, 90, 53, 78, 21, 39, 25, 74, 42, 22,\
                             79, 48, 24, 2, 8, 9, 59, 0, 3, 91, 84, 15, 95, 106, 27, 94, 65, 96, 63, 7, 71,\
                             57, 30, 86, 62, 31, 93, 99, 104, 51, 50, 26, 17, 46, 35, 38, 60, 87, 20, 67, 77,\
                             45, 34, 44, 54, 41, 105, 88, 98, 85, 97, 6, 29, 101, 73, 28, 36, 76, 18, 89, 52,\
                             32, 14, 33, 47, 92, 37, 12]
                            )

dataset_image = MTBIDataset(image_dict, 
                            metric, 
                            transform=transforms.Compose([
                                 downSample(2),
                                 RandomAffine(15, 10)
                            ]),
                     )

#-------------------------CREATE DATA LOADER FOR TRAIN AND VAL------------------------#

data_size = len(dataset_image)

train_loader = DataLoader(dataset_image, batch_size=BATCH_SIZE, \
                    sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)),\
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(dataset_image, batch_size=BATCH_SIZE,
                    sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN,data_size)),\
                    num_workers=NUM_WORKERS)

# data_set = MTBIDataset(image_dict, metric, transform=None, mode='new')

In [None]:
for i_batch, sample_batched in enumerate(train_loader):
    print(i_batch, sample_batched['image'].size(), \
          sample_batched['label'].size())
    # observe 4th batch and stop.
    if i_batch == 3:
        # show_batch_image(sample_batched['image'],sample_batched['label'],BATCH_SIZE)
        break

In [5]:
#-------------------------NEW MODEL INIT WEIGHT--------------------------------------#

LoadCKP = False
CKPPath = 'checkpoint2019-03-31 13:33:50.772063.pth'

model = LNet(img_size=(64, 96, 64))
model.apply(weights_init)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=50, verbose=True)

if LoadCKP:
    model, optimizer, scheduler = loadckp(model, optimizer, scheduler, CKPPath, device=device)

In [None]:
loss = nn.BCELoss()

train(model, train_loader, validation_loader, optimizer, scheduler, device, dtype, lossFun=loss, epochs=500, streopch=0)

Epoch 0 finished ! Training Loss: 0.6226080473926332
     validation loss = 27.6310
Checkpoint 1 saved !
Epoch 1 finished ! Training Loss: 0.6891201039155325
     validation loss = 27.3388
Epoch 2 finished ! Training Loss: 0.49541984995206195
     validation loss = 2.7888
Epoch 3 finished ! Training Loss: 0.5574507349067264
     validation loss = 2.8100
Epoch 4 finished ! Training Loss: 0.5734136700630188
     validation loss = 1.8156
Epoch 5 finished ! Training Loss: 0.6257185108131833
     validation loss = 1.2421
Epoch 6 finished ! Training Loss: 0.5665172735850016
     validation loss = 3.0602
Epoch 7 finished ! Training Loss: 0.5774756454759173
     validation loss = 2.5544
Epoch 8 finished ! Training Loss: 0.5412749151388804
     validation loss = 1.4302
Epoch 9 finished ! Training Loss: 0.6114688714345297
     validation loss = 2.2342
Epoch 10 finished ! Training Loss: 0.5701064434316423
     validation loss = 1.9906
Epoch 11 finished ! Training Loss: 0.5396849281258054
     val