In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

from datautility import *
from dataset import *
from vnet import *
from training import *

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader, sampler, SubsetRandomSampler
from torchvision import transforms, utils

import torch.nn.functional as F  # useful stateless functions
import torchvision.transforms as T

#------------------------------- GLOBAL VARIABLES -------------------------------------#

USE_GPU = True
BATCH_SIZE = 4
NUM_WORKERS = 6
NUM_TRAIN = 80 # 80 training sample and 37 validation sample
LEARNING_RATE = 1e-2

dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
    print('using GPU for training')
else:
    device = torch.device('cpu')
    print('using CPU for training')

using GPU for training


## If Training new data only
* Positive: 67 samples
* Negative: 50 samples

In [3]:
positive_idx = np.arange(73) + 1
positive_idx = np.delete(positive_idx, [8, 15, 16, 17, 29, 33])

negative_idx = np.arange(53) + 1
negative_idx = np.delete(negative_idx, [14, 15, 16])

image_dict = np.concatenate((positive_idx, negative_idx))

metric = ['ak', 'awf', 'eas_De_par', 'eas_De_perp', 'FA', 'ias_Da', 'md', 'mk', 'rk'] # 9 Channels selected from Paper
print('{} subjects with {} metrics each'.format(len(image_dict), len(metric)))

117 subjects with 9 metrics each


In [4]:
regen = False

if regen:
    data_index = np.arange(117)
    data_idnex = np.random.shuffle(data_index)
    print(list(data_index))
    
else:
    data_idnex = np.array ([10, 22, 1, 86, 74, 109, 54, 59, 75, 4, 26, 5, 79, 40, 39, 112,\
                            104, 51, 116, 80, 77, 57, 3, 85, 29, 55, 18, 87, 50, 27, 0, 108,\
                            45, 34, 46, 15, 6, 102, 68, 64, 32, 89, 2, 100, 25, 67, 70, 113,\
                            81, 56, 8, 110, 35, 47, 115, 23, 88, 63, 17, 99, 49, 30, 92, 78,\
                            37, 19, 48, 24, 28, 103, 73, 95, 96, 20, 11, 9, 33, 60, 7, 76,\
                            71, 93, 65, 106, 53, 13, 111, 101, 12, 42, 16, 91, 43, 84, 97,\
                            72, 41, 82, 83, 66, 94, 31, 90, 98, 114, 107, 44, 58, 52, 61, 69,\
                            38, 36, 62, 105, 14, 21]
                            )

dataset_image = MTBIDataset(data_idnex,
                            image_dict, 
                            metric, 
                            transform=transforms.Compose([
                                 downSample(2),
                                 RandomAffine(15, 10)
                            ]),
                     )

#-------------------------CREATE DATA LOADER FOR TRAIN AND VAL------------------------#

data_size = len(dataset_image)

train_loader = DataLoader(dataset_image, batch_size=BATCH_SIZE, \
                    sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)),\
                    num_workers=NUM_WORKERS)
validation_loader = DataLoader(dataset_image, batch_size=BATCH_SIZE,
                    sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN,data_size)),\
                    num_workers=NUM_WORKERS)

# data_set = MTBIDataset(image_dict, metric, transform=None, mode='new')

In [None]:
for i_batch, sample_batched in enumerate(validation_loader):
    print(i_batch, sample_batched['image'].size(), \
          sample_batched['label'].size())
    # observe 4th batch and stop.
    if i_batch == 8:
        # show_batch_image(sample_batched['image'],sample_batched['label'],BATCH_SIZE)
        break

In [5]:
#-------------------------NEW MODEL INIT WEIGHT--------------------------------------#

LoadCKP = False
CKPPath = 'checkpoint2019-03-31 13:33:50.772063.pth'

model = LNet(img_size=(64, 96, 64))
model.apply(weights_init)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=50, verbose=True)

if LoadCKP:
    model, optimizer, scheduler = loadckp(model, optimizer, scheduler, CKPPath, device=device)

In [None]:
loss = nn.BCELoss()

train(model, train_loader, validation_loader, optimizer, scheduler, device, dtype, lossFun=loss, epochs=500, streopch=0)