In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

from datautility import *
from dataset import *
from vnet import *
from training import *

%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader, sampler, SubsetRandomSampler
from torchvision import transforms, utils

import torch.nn.functional as F  # useful stateless functions
import torchvision.transforms as T

#------------------------------- GLOBAL VARIABLES -------------------------------------#

USE_GPU = True
BATCH_SIZE = 16
NUM_WORKERS = 6
NUM_TRAIN = 80 # 80 training sample and 37 validation sample
LEARNING_RATE = 1e-3

dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


## If Training new data only
* Positive: 67 samples
* Negative: 50 samples

In [4]:
positive_idx = np.arange(73) + 1
positive_idx = np.delete(positive_idx, [8, 15, 16, 17, 29, 33])

negative_idx = np.arange(53) + 1
negative_idx = np.delete(negative_idx, [14, 15, 16])

image_dict = np.concatenate((positive_idx, negative_idx))
metric = ['ak', 'md', 'ad', 'FA'] # 8 Channels selected from Paper

# metric = ['ak', 'awf', 'eas_De_par', 'eas_De_perp', 'FA', 'ias_Da', 'md', 'mk'] # 8 Channels selected from Paper
print('{} subjects with {} metrics each'.format(len(image_dict), len(metric)))

117 subjects with 4 metrics each


In [5]:
regen = False

if regen:
    data_index = np.arange(117)
    data_idnex = np.random.shuffle(data_index)
    print(list(data_index))
    
else:
    data_idnex = np.array ([10, 22, 1, 86, 74, 109, 54, 59, 75, 4, 26, 5, 79, 40, 39, 112,\
                            104, 51, 116, 80, 77, 57, 3, 85, 29, 55, 18, 87, 50, 27, 0, 108,\
                            45, 34, 46, 15, 6, 102, 68, 64, 32, 89, 2, 100, 25, 67, 70, 113,\
                            81, 56, 8, 110, 35, 47, 115, 23, 88, 63, 17, 99, 49, 30, 92, 78,\
                            37, 19, 48, 24, 28, 103, 73, 95, 96, 20, 11, 9, 33, 60, 7, 76,\
                            71, 93, 65, 106, 53, 13, 111, 101, 12, 42, 16, 91, 43, 84, 97,\
                            72, 41, 82, 83, 66, 94, 31, 90, 98, 114, 107, 44, 58, 52, 61, 69,\
                            38, 36, 62, 105, 14, 21]
                            )

dataset_image = MTBIDataset(data_idnex,
                            image_dict, 
                            metric, 
                            transform=transforms.Compose([
                                 # RandomAffine(0, 1)
                            ]),
                     )

#-------------------------CREATE DATA LOADER FOR TRAIN AND VAL------------------------#

data_size = len(dataset_image)

train_loader = DataLoader(dataset_image, batch_size=BATCH_SIZE, \
                    sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)),\
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(dataset_image, batch_size=BATCH_SIZE,
                    sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN,data_size)),\
                    num_workers=NUM_WORKERS)

# data_set = MTBIDataset(image_dict, metric, transform=None, mode='new')

In [6]:
for i_batch, sample_batched in enumerate(validation_loader):
    print(i_batch, sample_batched['image'].size(), \
          sample_batched['label'].size())
    # observe 4th batch and stop.
    if i_batch == 8:
        # show_batch_image(sample_batched['image'],sample_batched['label'],BATCH_SIZE)
        break

0 torch.Size([16, 4, 64, 96, 64]) torch.Size([16, 1])
1 torch.Size([16, 4, 64, 96, 64]) torch.Size([16, 1])
2 torch.Size([5, 4, 64, 96, 64]) torch.Size([5, 1])


In [7]:
#-------------------------NEW MODEL INIT WEIGHT--------------------------------------#

LoadCKP = False
CKPPath = 'checkpoint2019-03-31 13:33:50.772063.pth'

model = LNet1(img_size=(64, 96, 64))
model.apply(weights_init)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=50, verbose=True)

if LoadCKP:
    model, optimizer, scheduler = loadckp(model, optimizer, scheduler, CKPPath, device=device)

In [None]:
loss = nn.BCELoss()

train(model, train_loader, validation_loader, optimizer, scheduler, device, dtype, lossFun=loss, epochs=5000, streopch=0)

Epoch 0 finished ! Training Loss: 1.1974, acc: 0.4375
     validation loss = 5.7756, accuracy = 0.6042
Checkpoint 1 saved !
Epoch 1 finished ! Training Loss: 0.7502, acc: 0.5375
     validation loss = 2.3286, accuracy = 0.6458
Epoch 2 finished ! Training Loss: 0.6885, acc: 0.6000
     validation loss = 1.0935, accuracy = 0.4042
Epoch 3 finished ! Training Loss: 0.6827, acc: 0.5750
     validation loss = 1.0842, accuracy = 0.6000
Epoch 4 finished ! Training Loss: 0.7058, acc: 0.5250
     validation loss = 0.7943, accuracy = 0.6208
Epoch 5 finished ! Training Loss: 0.6936, acc: 0.5875
     validation loss = 0.8821, accuracy = 0.4417
Epoch 6 finished ! Training Loss: 0.6611, acc: 0.6500
     validation loss = 0.8460, accuracy = 0.4875
Epoch 7 finished ! Training Loss: 0.6667, acc: 0.6500
     validation loss = 0.7274, accuracy = 0.5333
Epoch 8 finished ! Training Loss: 0.7012, acc: 0.5625
