In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

from datautility import *
from dataset import *
from vnet import *
from training import *

%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader, sampler, SubsetRandomSampler
from torchvision import transforms, utils

import torch.nn.functional as F  # useful stateless functions
import torchvision.transforms as T

#------------------------------- GLOBAL VARIABLES -------------------------------------#

USE_GPU = True
BATCH_SIZE = 16
NUM_WORKERS = 6
NUM_TRAIN = 80 # 80 training sample and 37 validation sample
LEARNING_RATE = 1e-3

dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


## If Training new data only
* Positive: 67 samples
* Negative: 50 samples

In [4]:
positive_idx = np.arange(73) + 1
positive_idx = np.delete(positive_idx, [8, 15, 16, 17, 29, 33])

negative_idx = np.arange(53) + 1
negative_idx = np.delete(negative_idx, [14, 15, 16])

image_dict = np.concatenate((positive_idx, negative_idx))
metric = ['ak', 'md', 'ad', 'FA'] # 8 Channels selected from Paper

# metric = ['ak', 'awf', 'eas_De_par', 'eas_De_perp', 'FA', 'ias_Da', 'md', 'mk'] # 8 Channels selected from Paper
print('{} subjects with {} metrics each'.format(len(image_dict), len(metric)))

117 subjects with 4 metrics each


In [5]:
regen = False

if regen:
    data_index = np.arange(117)
    data_idnex = np.random.shuffle(data_index)
    print(list(data_index))
    
else:
    data_idnex = np.array ([10, 22, 1, 86, 74, 109, 54, 59, 75, 4, 26, 5, 79, 40, 39, 112,\
                            104, 51, 116, 80, 77, 57, 3, 85, 29, 55, 18, 87, 50, 27, 0, 108,\
                            45, 34, 46, 15, 6, 102, 68, 64, 32, 89, 2, 100, 25, 67, 70, 113,\
                            81, 56, 8, 110, 35, 47, 115, 23, 88, 63, 17, 99, 49, 30, 92, 78,\
                            37, 19, 48, 24, 28, 103, 73, 95, 96, 20, 11, 9, 33, 60, 7, 76,\
                            71, 93, 65, 106, 53, 13, 111, 101, 12, 42, 16, 91, 43, 84, 97,\
                            72, 41, 82, 83, 66, 94, 31, 90, 98, 114, 107, 44, 58, 52, 61, 69,\
                            38, 36, 62, 105, 14, 21]
                            )

dataset_image = MTBIDataset(data_idnex,
                            image_dict, 
                            metric, 
                            transform=transforms.Compose([
                                 # RandomAffine(0, 1)
                            ]),
                     )

#-------------------------CREATE DATA LOADER FOR TRAIN AND VAL------------------------#

data_size = len(dataset_image)

train_loader = DataLoader(dataset_image, batch_size=BATCH_SIZE, \
                    sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)),\
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(dataset_image, batch_size=BATCH_SIZE,
                    sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN,data_size)),\
                    num_workers=NUM_WORKERS)

# data_set = MTBIDataset(image_dict, metric, transform=None, mode='new')

In [6]:
for i_batch, sample_batched in enumerate(validation_loader):
    print(i_batch, sample_batched['image'].size(), \
          sample_batched['label'].size())
    # observe 4th batch and stop.
    if i_batch == 8:
        # show_batch_image(sample_batched['image'],sample_batched['label'],BATCH_SIZE)
        break

0 torch.Size([16, 4, 64, 96, 64]) torch.Size([16, 1])
1 torch.Size([16, 4, 64, 96, 64]) torch.Size([16, 1])
2 torch.Size([5, 4, 64, 96, 64]) torch.Size([5, 1])


In [7]:
#-------------------------NEW MODEL INIT WEIGHT--------------------------------------#

LoadCKP = False
CKPPath = 'checkpoint2019-03-31 13:33:50.772063.pth'

model = LNet1(img_size=(64, 96, 64))
model.apply(weights_init)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=50, verbose=True)

if LoadCKP:
    model, optimizer, scheduler = loadckp(model, optimizer, scheduler, CKPPath, device=device)

In [8]:
loss = nn.BCELoss()

train(model, train_loader, validation_loader, optimizer, scheduler, device, dtype, lossFun=loss, epochs=5000, streopch=0)

Epoch 0 finished ! Training Loss: 1.1974, acc: 0.4375
     validation loss = 5.7756, accuracy = 0.6042
Checkpoint 1 saved !
Epoch 1 finished ! Training Loss: 0.7502, acc: 0.5375
     validation loss = 2.3286, accuracy = 0.6458
Epoch 2 finished ! Training Loss: 0.6885, acc: 0.6000
     validation loss = 1.0935, accuracy = 0.4042
Epoch 3 finished ! Training Loss: 0.6827, acc: 0.5750
     validation loss = 1.0842, accuracy = 0.6000
Epoch 4 finished ! Training Loss: 0.7058, acc: 0.5250
     validation loss = 0.7943, accuracy = 0.6208
Epoch 5 finished ! Training Loss: 0.6936, acc: 0.5875
     validation loss = 0.8821, accuracy = 0.4417
Epoch 6 finished ! Training Loss: 0.6611, acc: 0.6500
     validation loss = 0.8460, accuracy = 0.4875
Epoch 7 finished ! Training Loss: 0.6667, acc: 0.6500
     validation loss = 0.7274, accuracy = 0.5333
Epoch 8 finished ! Training Loss: 0.7012, acc: 0.5625
     validation loss = 0.7159, accuracy = 0.5333
Epoch 9 finished ! Training Loss: 0.6912, acc: 0.612

     validation loss = 0.7759, accuracy = 0.5750
Epoch 79 finished ! Training Loss: 0.1807, acc: 0.9875
     validation loss = 0.8323, accuracy = 0.5708
Epoch 80 finished ! Training Loss: 0.1621, acc: 0.9875
     validation loss = 0.9162, accuracy = 0.5083
Epoch 81 finished ! Training Loss: 0.1271, acc: 0.9875
     validation loss = 0.7591, accuracy = 0.5542
Epoch 82 finished ! Training Loss: 0.0983, acc: 1.0000
     validation loss = 0.8315, accuracy = 0.6375
Epoch 83 finished ! Training Loss: 0.0893, acc: 1.0000
     validation loss = 0.7502, accuracy = 0.5958
Epoch 84 finished ! Training Loss: 0.0628, acc: 1.0000
     validation loss = 0.9988, accuracy = 0.4417
Epoch 85 finished ! Training Loss: 0.0631, acc: 1.0000
     validation loss = 0.9670, accuracy = 0.4417
Epoch 86 finished ! Training Loss: 0.0513, acc: 1.0000
     validation loss = 0.8179, accuracy = 0.5750
Epoch 87 finished ! Training Loss: 0.0375, acc: 1.0000
     validation loss = 0.8003, accuracy = 0.5750
Epoch 88 finish

     validation loss = 1.4223, accuracy = 0.5292
Epoch 157 finished ! Training Loss: 0.0108, acc: 1.0000
     validation loss = 1.0492, accuracy = 0.5542
Epoch 158 finished ! Training Loss: 0.0116, acc: 1.0000
     validation loss = 0.9844, accuracy = 0.6000
Epoch 159 finished ! Training Loss: 0.0061, acc: 1.0000
     validation loss = 0.9890, accuracy = 0.6000
Epoch 160 finished ! Training Loss: 0.0043, acc: 1.0000
     validation loss = 0.9275, accuracy = 0.6667
Epoch 161 finished ! Training Loss: 0.0084, acc: 1.0000
     validation loss = 1.1495, accuracy = 0.5542
Epoch 162 finished ! Training Loss: 0.0053, acc: 1.0000
     validation loss = 1.2487, accuracy = 0.4417
Epoch 163 finished ! Training Loss: 0.0040, acc: 1.0000
     validation loss = 0.9771, accuracy = 0.5792
Epoch 164 finished ! Training Loss: 0.0070, acc: 1.0000
     validation loss = 1.2104, accuracy = 0.5083
Epoch 165 finished ! Training Loss: 0.0078, acc: 1.0000
     validation loss = 1.1117, accuracy = 0.6000
Epoch 

     validation loss = 1.3025, accuracy = 0.5083
Epoch 235 finished ! Training Loss: 0.0021, acc: 1.0000
     validation loss = 1.2345, accuracy = 0.5083
Epoch 236 finished ! Training Loss: 0.0027, acc: 1.0000
     validation loss = 1.4760, accuracy = 0.5083
Epoch 237 finished ! Training Loss: 0.0020, acc: 1.0000
     validation loss = 1.2202, accuracy = 0.5542
Epoch 238 finished ! Training Loss: 0.0052, acc: 1.0000
     validation loss = 1.1447, accuracy = 0.5542
Epoch 239 finished ! Training Loss: 0.0036, acc: 1.0000
     validation loss = 1.2638, accuracy = 0.5542
Epoch 240 finished ! Training Loss: 0.0021, acc: 1.0000
     validation loss = 1.3317, accuracy = 0.5083
Epoch 241 finished ! Training Loss: 0.0020, acc: 1.0000
     validation loss = 1.5479, accuracy = 0.5083
Epoch 242 finished ! Training Loss: 0.0035, acc: 1.0000
     validation loss = 1.1989, accuracy = 0.5542
Epoch 243 finished ! Training Loss: 0.0031, acc: 1.0000
     validation loss = 1.5148, accuracy = 0.5292
Epoch 

     validation loss = 1.7204, accuracy = 0.5042
Epoch 313 finished ! Training Loss: 0.0007, acc: 1.0000
     validation loss = 1.3877, accuracy = 0.6417
Epoch 314 finished ! Training Loss: 0.0017, acc: 1.0000
     validation loss = 1.1467, accuracy = 0.6417
Epoch 315 finished ! Training Loss: 0.0007, acc: 1.0000
     validation loss = 1.6985, accuracy = 0.5708
Epoch 316 finished ! Training Loss: 0.0012, acc: 1.0000
     validation loss = 1.4565, accuracy = 0.6167
Epoch 317 finished ! Training Loss: 0.0011, acc: 1.0000
     validation loss = 1.3695, accuracy = 0.6625
Epoch 318 finished ! Training Loss: 0.0005, acc: 1.0000
     validation loss = 1.6774, accuracy = 0.6167
Epoch 319 finished ! Training Loss: 0.0013, acc: 1.0000
     validation loss = 1.3364, accuracy = 0.6167
Epoch 320 finished ! Training Loss: 0.0005, acc: 1.0000
     validation loss = 1.3623, accuracy = 0.6167
Epoch 321 finished ! Training Loss: 0.0010, acc: 1.0000
     validation loss = 1.6807, accuracy = 0.5708
Epoch 

     validation loss = 1.5677, accuracy = 0.5042
Epoch 391 finished ! Training Loss: 0.0005, acc: 1.0000
     validation loss = 1.3773, accuracy = 0.6417
Epoch 392 finished ! Training Loss: 0.0010, acc: 1.0000
     validation loss = 1.2464, accuracy = 0.6417
Epoch 393 finished ! Training Loss: 0.0008, acc: 1.0000
     validation loss = 1.3794, accuracy = 0.5750
Epoch 394 finished ! Training Loss: 0.0004, acc: 1.0000
     validation loss = 1.4050, accuracy = 0.5292
Epoch 395 finished ! Training Loss: 0.0011, acc: 1.0000
     validation loss = 1.6046, accuracy = 0.5292
Epoch 396 finished ! Training Loss: 0.0004, acc: 1.0000
     validation loss = 1.2086, accuracy = 0.6208
Epoch 397 finished ! Training Loss: 0.0007, acc: 1.0000
     validation loss = 1.3429, accuracy = 0.5750
Epoch 398 finished ! Training Loss: 0.0010, acc: 1.0000
     validation loss = 1.4741, accuracy = 0.5500
Epoch 399 finished ! Training Loss: 0.0004, acc: 1.0000
     validation loss = 1.5053, accuracy = 0.5958
Epoch 

     validation loss = 1.6419, accuracy = 0.5542
Epoch 469 finished ! Training Loss: 0.0004, acc: 1.0000
     validation loss = 1.5099, accuracy = 0.6208
Epoch 470 finished ! Training Loss: 0.0007, acc: 1.0000
     validation loss = 1.7237, accuracy = 0.6417
Epoch 471 finished ! Training Loss: 0.0019, acc: 1.0000
     validation loss = 1.9597, accuracy = 0.5917
Epoch 472 finished ! Training Loss: 0.0004, acc: 1.0000
     validation loss = 1.9321, accuracy = 0.5500
Epoch 473 finished ! Training Loss: 0.0005, acc: 1.0000
     validation loss = 1.7512, accuracy = 0.5500
Epoch 474 finished ! Training Loss: 0.0009, acc: 1.0000
     validation loss = 1.3808, accuracy = 0.5750
Epoch 475 finished ! Training Loss: 0.0002, acc: 1.0000
     validation loss = 1.6055, accuracy = 0.6208
Epoch 476 finished ! Training Loss: 0.0010, acc: 1.0000
     validation loss = 1.3972, accuracy = 0.5750
Epoch 477 finished ! Training Loss: 0.0002, acc: 1.0000
     validation loss = 1.5075, accuracy = 0.5750
Epoch 

     validation loss = 1.4249, accuracy = 0.5542
Epoch 547 finished ! Training Loss: 0.0227, acc: 0.9750
     validation loss = 1.8609, accuracy = 0.4875
Epoch 548 finished ! Training Loss: 0.0228, acc: 1.0000
     validation loss = 1.5252, accuracy = 0.5333
Epoch 549 finished ! Training Loss: 0.0034, acc: 1.0000
     validation loss = 1.1783, accuracy = 0.5792
Epoch 550 finished ! Training Loss: 0.0033, acc: 1.0000
     validation loss = 1.3733, accuracy = 0.5333
Checkpoint 551 saved !
Epoch 551 finished ! Training Loss: 0.0020, acc: 1.0000
     validation loss = 1.3279, accuracy = 0.4250
Epoch 552 finished ! Training Loss: 0.0035, acc: 1.0000
     validation loss = 1.7088, accuracy = 0.4000
Epoch 553 finished ! Training Loss: 0.0026, acc: 1.0000
     validation loss = 1.4274, accuracy = 0.5583
Epoch 554 finished ! Training Loss: 0.0020, acc: 1.0000
     validation loss = 1.8912, accuracy = 0.4667
Epoch 555 finished ! Training Loss: 0.0014, acc: 1.0000
     validation loss = 1.6230, a

     validation loss = 2.1568, accuracy = 0.6167
Epoch 625 finished ! Training Loss: 0.0030, acc: 1.0000
     validation loss = 3.0746, accuracy = 0.5458
Epoch 626 finished ! Training Loss: 0.0019, acc: 1.0000
     validation loss = 1.7468, accuracy = 0.7292
Epoch 627 finished ! Training Loss: 0.0040, acc: 1.0000
     validation loss = 1.6826, accuracy = 0.7292
Epoch 628 finished ! Training Loss: 0.0010, acc: 1.0000
     validation loss = 1.8524, accuracy = 0.6833
Epoch 629 finished ! Training Loss: 0.0022, acc: 1.0000
     validation loss = 1.7269, accuracy = 0.6833
Epoch 630 finished ! Training Loss: 0.0010, acc: 1.0000
     validation loss = 2.0850, accuracy = 0.6583
Epoch 631 finished ! Training Loss: 0.0013, acc: 1.0000
     validation loss = 1.4932, accuracy = 0.7042
Epoch 632 finished ! Training Loss: 0.0004, acc: 1.0000
     validation loss = 1.8149, accuracy = 0.6583
Epoch 633 finished ! Training Loss: 0.0003, acc: 1.0000
     validation loss = 2.0470, accuracy = 0.6583
Epoch 

     validation loss = 1.7817, accuracy = 0.6583
Epoch 703 finished ! Training Loss: 0.0005, acc: 1.0000
     validation loss = 1.4074, accuracy = 0.7500
Epoch 704 finished ! Training Loss: 0.0002, acc: 1.0000
     validation loss = 1.9387, accuracy = 0.6792
Epoch 705 finished ! Training Loss: 0.0001, acc: 1.0000
     validation loss = 2.6947, accuracy = 0.5667
Epoch 706 finished ! Training Loss: 0.0003, acc: 1.0000
     validation loss = 2.1497, accuracy = 0.6583
Epoch 707 finished ! Training Loss: 0.0003, acc: 1.0000
     validation loss = 1.7024, accuracy = 0.7250
Epoch 708 finished ! Training Loss: 0.0001, acc: 1.0000
     validation loss = 1.3995, accuracy = 0.7708
Epoch 709 finished ! Training Loss: 0.0004, acc: 1.0000
     validation loss = 1.7190, accuracy = 0.7042
Epoch 710 finished ! Training Loss: 0.0005, acc: 1.0000
     validation loss = 1.6932, accuracy = 0.6583
Epoch 711 finished ! Training Loss: 0.0003, acc: 1.0000
     validation loss = 1.5120, accuracy = 0.7250
Epoch 

KeyboardInterrupt: 