In [None]:
## Lines for Google Colab to import Drive repository and configure GitHub
from google.colab import drive
drive.mount('/content/gdrive')
%cd /content/gdrive/My Drive/capita_selecta_cvbm/notebooks

In [None]:
## Lines for Google Colab to push and pull form GitHub repository
# %cd /content/gdrive/My Drive/capita_selecta_cvbm
# !git pull origin master

# !git remote rm origin
# !git remote add origin https://Beerend:XXXXX@github.com/Beerend/TReNDS.git

# !git pull origin master
# !git status
# !git add train_TReNDS.ipynb
# !git commit -m 'Added MAE loss'
# !git push origin master

In [13]:
import os
import time
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from torch.nn import MSELoss, L1Loss
from datasets import TReNDS
from datasets.TReNDS import TReNDSDataset
from models import resnet, deeplight, resnet_4d
# from google.colab import output
from importlib import reload

# reload(TReNDS)
reload(resnet)
# from datasets.TReNDS import TReNDSDataset
from models import resnet


#ResNet3D is from SeuTao (https://github.com/SeuTao/RSNA2019_Intracranial-Hemorrhage-Detection/tree/bbdb1e1d645953ef4b2f23c87b6fba44aff023ea/3DNet)

In [2]:
# Home PC
data_path = '/Volumes/External Hard Drive/Documents/University of Twente/Computer Science/Capita Selecta'
root = '../'

# Google Colab
# data_path = '/content/gdrive/My Drive/capita_selecta_cvbm'
# root = '/content/gdrive/My Drive/capita_selecta_cvbm'

available_models = ['deeplight',
                    'deeplight_tempframe_26',
                    'deeplight_resnet10',
                    'resnet10',
                    'resnet10_4d']

model_name = 'resnet10'
fold_index = 0

# Options
opts = {
    'rand_seed'  : 1,
    'no_cuda'    : True,
    'temp_mean'  : False,
    'preprocess' : False,
    'scale_norm' : False,
    'lr'         : 1e-4,
    'train_bs'   : 1,
    'test_bs'    : 1,
    'epochs'     : 1,
    'fold_index' : fold_index,
    'n_splits'   : 5,
    'model_name' : model_name,
    'save_at_eps': list(range(1,61)),
    'test_at_eps': list(range(1,61)),
    'save_dir'   : os.path.join(root, 'results/%s/%s'%(model_name, str(fold_index))),
    'resume'     : None, #os.path.join(root, 'results/deeplight/0/epoch_5.pth.tar'),
    'pretrain'   : None,
}

if not os.path.exists(opts['save_dir']):
    os.makedirs(opts['save_dir'])
    
torch.manual_seed(opts['rand_seed'])
earlier_epochs = 0

In [14]:
# Generate model
assert model_name in available_models
if model_name=='deeplight':
    model = deeplight.original()
elif model_name=='deeplight_tempframe_26':
    model = deeplight.original(temp_frame=26)
elif model_name=='resnet10':
    model = resnet.resnet10(shortcut_type='B', no_cuda=opts['no_cuda'], num_class=1)
elif model_name=='resnet10_4d':
    model = resnet_4d.resnet10_4d(shortcut_type='B', no_cuda=opts['no_cuda'], num_class=1)
    
optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=opts['lr']) #, betas=(.9,.999), eps=1e-08)
mse   = MSELoss()
mae   = L1Loss()

if not opts['no_cuda']:
    model.cuda()
    
num_params    = sum(p.numel() for p in model.parameters())
num_tr_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Loaded %s (param: %d, trainable: %d, GPU: %s)'%(model_name, num_params, num_tr_params, not opts['no_cuda']))

Loaded resnet10 (param: 15497537, trainable: 15497537, GPU: False)


In [7]:
# Train from checkpoint
if opts['resume']:
    if os.path.isfile(opts['resume']):
        print('Loading checkpoint from:', opts['resume'])
        load_dict = torch.load(opts['resume']) #, map_location=torch.device('cpu'))
        model.load_state_dict(load_dict['state_dict'])
        optim.load_state_dict(load_dict['optim'])
        earlier_epochs = load_dict['epoch']
        
# Train from pre-trained model
elif opts['pretrain']:
    if os.path.isfile(opts['pretrain']):
        print('Loading pre-trained weights from:', opts['pretrain'])
        model_dict = model.state_dict()
        pretrain   = torch.load(opts['pretrain'])
        pretr_dict = {k:v for k,v in pretrain['state_dict'].items() if k in model_dict.keys() and 'conv1' not in k}
        model_dict.update(pretr_dict)
        model.load_state_dict(model_dict)

Loading checkpoint from: /content/gdrive/My Drive/capita_selecta_cvbm/results/deeplight/0/epoch_5.pth.tar


In [4]:
# Get dataset
train_set    = TReNDSDataset(data_path, 'train', n_splits=opts['n_splits'], fold=fold_index,
                             preprocess=opts['preprocess'], norm=opts['scale_norm'],
                             temp_mean=opts['temp_mean'])
train_loader = DataLoader(train_set, batch_size=opts['train_bs'], shuffle=True, pin_memory=True)
test_set     = TReNDSDataset(data_path, 'test', n_splits=opts['n_splits'], fold=fold_index,
                             preprocess=opts['preprocess'], norm=opts['scale_norm'],
                             temp_mean=opts['temp_mean'])
test_loader  = DataLoader(test_set, batch_size=opts['test_bs'], shuffle=False, pin_memory=True)

Loaded dataset with 4701 train samples in fold 0.
Loaded dataset with 1176 test samples in fold 0.


In [15]:
# Train model
if not opts['resume']:
    log_file = open(os.path.join(opts['save_dir'], 'log.txt'), 'w')
    log_file.write('Epoch,set,MSE,MAE,time\n')
    log_file.flush()
    log_file.close()

for epoch in range(1+earlier_epochs, opts['epochs']+1+earlier_epochs):
    # TODO: adjust learning rate

    model.train()
    start_time = time.time()
    batches    = len(train_loader)

    batch_id = 1
    tot_mae  = 0.
    tot_mse  = 0.
    
    for batch_data in train_loader:
        imgs, lbls   = batch_data
        batch_id_tot = batch_id+(epoch-earlier_epochs-1)*batches

        if not opts['no_cuda']:
            imgs = imgs.cuda()
            lbls = lbls.cuda()
        
        optim.zero_grad()
        preds    = model(imgs)
        mae_loss = mae(preds, lbls)
        mse_loss = mse(preds, lbls)
        mse_loss.backward()
        if model_name=='deeplight':
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
        optim.step()
        
        avg_batch_time = (time.time()-start_time)/batch_id_tot
#         output.clear('batch_inf')
#         with output.use_tags('batch_inf'):
#             print('Training model => Epoch: %d/%d - batch: %d/%d - loss: %.5f (MSE) %.5f (MAE) - time: %.3f'%(epoch,
#                 opts['epochs']+earlier_epochs, batch_id, batches, mse_loss.item(), mae_loss.item(), avg_batch_time))
        
#         print('Training model => Epoch: %d/%d - batch: %d/%d - loss: %.5f (MSE) %.5f (MAE) - time: %.3f'%(epoch,
#             opts['epochs']+earlier_epochs, batch_id, batches, mse_loss.item(), mae_loss.item(), avg_batch_time), end='\r')
        
        tot_mae+=mae_loss.item()
        tot_mse+=mse_loss.item()
        batch_id+=1
    
    avg_mae  = tot_mae/batches
    avg_mse  = tot_mse/batches
    tot_time = time.time()-start_time
    
    log_file = open(os.path.join(opts['save_dir'], 'log.txt'), 'a')
    log_file.write('%d,train,%.5f,%.5f,%.1f\n'%(epoch, avg_mse, avg_mae, tot_time))
    log_file.flush()
    log_file.close()
            
    if epoch==opts['epochs'] or epoch in opts['save_at_eps']:
        filename = os.path.join(opts['save_dir'], 'epoch_%d.pth.tar'%(epoch))
        torch.save({'epoch':epoch, 'state_dict':model.state_dict(),
            'optim':optim.state_dict()}, filename)
        
    if epoch in opts['test_at_eps']:
        results = evaluate_model(test_loader, model, mae, mse, epoch, opts)

Input: torch.Size([1, 53, 52, 63, 53])
After conv1: torch.Size([1, 64, 26, 32, 27])
After maxpool: torch.Size([1, 64, 13, 16, 14])
--- shape x: torch.Size([1, 64, 13, 16, 14])
--- shape r: torch.Size([1, 64, 13, 16, 14])
After layer1: torch.Size([1, 64, 13, 16, 14])
--- Downsample residual
--- shape x: torch.Size([1, 128, 7, 8, 7])
--- shape r: torch.Size([1, 128, 7, 8, 7])
After layer2: torch.Size([1, 128, 7, 8, 7])
--- Downsample residual
--- shape x: torch.Size([1, 256, 7, 8, 7])
--- shape r: torch.Size([1, 256, 7, 8, 7])
After layer3: torch.Size([1, 256, 7, 8, 7])
--- Downsample residual
--- shape x: torch.Size([1, 512, 7, 8, 7])
--- shape r: torch.Size([1, 512, 7, 8, 7])
After layer4: torch.Size([1, 512, 7, 8, 7])


KeyboardInterrupt: 

In [15]:
def evaluate_model(test_loader, model, mae, mse, epoch, opts):
    model.eval()
    start_time = time.time()
    batches    = len(test_loader)
    all_preds  = []
    all_labls  = []
    tot_mae    = 0.
    tot_mse    = 0.

    with torch.no_grad():
        batch_id = 1
        for batch_data in test_loader:
            imgs, lbls  = batch_data

            if not opts['no_cuda']:
                imgs = imgs.cuda()
                lbls = lbls.cuda()
            
            preds    = model(imgs)
            mae_loss = mae(preds, lbls)
            mse_loss = mse(preds, lbls)

            output.clear('batch_inf')
            with output.use_tags('batch_inf'):
                print('Evaluating model => Batch: %d/%d - loss: %.5f (MSE) %.5f (MAE)'%(batch_id,
                    batches, mse_loss.item(), mae_loss.item()))
            
            tot_mae+=mae_loss.item()
            tot_mse+=mse_loss.item()
            all_preds.append(preds.data.cpu().numpy().flatten())
            all_labls.append(lbls.data.cpu().numpy().flatten())
            batch_id+=1

    avg_mae  = tot_mae/batches
    avg_mse  = tot_mse/batches
    tot_time = time.time()-start_time
    
    log_file = open(os.path.join(opts['save_dir'], 'log.txt'), 'a')
    log_file.write('%d,test,%.5f,%.5f,%.1f\n'%(epoch, avg_mse, avg_mae, tot_time))
    log_file.flush()
    log_file.close()

    all_preds = np.concatenate(all_preds, axis=0)
    all_labls = np.concatenate(all_labls, axis=0)
    filename  = os.path.join(opts['save_dir'], 'preds_epoch_%d.csv'%(epoch))
    results   = pd.DataFrame(data={'Pred':all_preds, 'Label':all_labls})
    results.to_csv(filename, index=False)

    output.clear('batch_inf')
    print('Average loss: %.3f (MSE) %.3f (MAE)'%(avg_mae,avg_mse))

    return results

In [10]:
evaluate_model(test_loader, model, mae, mse, 5, opts)

Average loss: 11.102 (MSE) 188.586 (MAE)


Unnamed: 0,Pred,Label
0,52.161644,38.617382
1,52.289677,35.326580
2,52.289486,35.326580
3,52.289654,64.203110
4,52.289528,66.532631
...,...,...
1171,52.289722,57.436077
1172,52.289742,48.948757
1173,52.289402,42.941154
1174,52.289536,14.257265
