In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
import onnx
import sys
import re
import torch.optim as optim
import time
from torch.utils.data import DataLoader
from scipy import ndimage
from tqdm import tqdm
from datetime import datetime

In [2]:
sys.path.append('..')
from classes.dataset_utils.toTorchDataset import ProcessedKit23TorchDataset
from classes.models import resnet_model_generator
from classes.config_class import ProjectModelResnetConfig
from classes.epoch_results import EpochResult

In [3]:
training_data = ProcessedKit23TorchDataset(train_data=True, test_size=0.25, dataset_dir ="./dataset/affine_transformed")
test_data = ProcessedKit23TorchDataset(train_data=False, test_size=0.25, dataset_dir ="./dataset/affine_transformed")

In [4]:
proj_config = ProjectModelResnetConfig(model_depth=10)
proj_resnet_model, _ = resnet_model_generator.generate_model(proj_config)

In [5]:
proj_config.set_net_model(proj_resnet_model)

In [6]:
criterion = nn.CrossEntropyLoss(ignore_index=-1)
optimizer = optim.SGD(proj_resnet_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-3)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
if not proj_config.no_cuda:
    criterion = criterion.cuda()

In [7]:
train_from_pretrained = False
epoch_res = EpochResult()
epoch_start = 0
if train_from_pretrained:
    print("loading from pretrained Med3D model")
    if proj_config.model_depth == 10:
        proj_config.load_med3d_pretrain_weigth("../pretrainedModel/resnet_10_23dataset.pth")
    elif proj_config.model_depth == 50:
        proj_config.load_med3d_pretrain_weigth("../pretrainedModel/resnet_50_23dataset.pth")
    else:
        raise Exception("Only depth 10 and 50 are used for now.")
else:
    # this continues from certain training points
    checkpoint, epoch_res = proj_config.load_weight_from_epoch("./training_checkpoints/Model_resnet_10_epoch0.pth.tar")
    optimizer.load_state_dict(checkpoint['optimizer'])
    epoch_start = epoch_res.epoch_list[-1] + 1
    
 

./training_checkpoints/Model_resnet_10_epoch0.pth.tar


In [8]:
data_loader = DataLoader(training_data, batch_size=proj_config.batch_size, shuffle=True, num_workers=proj_config.num_workers, pin_memory=proj_config.pin_memory)

In [None]:
train_time_start = time.time()
batches_per_epoch = len(data_loader)

for epoch in range(epoch_start, proj_config.max_epoch):
    current_lr = scheduler.get_last_lr()
    running_loss = None
    print("current epoch={:5d} Learning Rate={}".format(epoch, current_lr))
    
    for batch_idx, batch_data  in enumerate(data_loader):
        imgs, segs = batch_data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        y_preds = proj_resnet_model(imgs.float())

        [n, _, z_size, y_size, x_size] = y_preds.shape

        resized_segs = np.zeros([n, z_size, y_size, x_size])
        for idx in range(n):
            seg = segs[idx][0]
            [ori_z, ori_y, ori_x] = seg.shape 
            scale = [z_size/ori_z, y_size/ori_y, x_size/ori_x]
            this_affine = np.array([[scale[0], 0, 0],[0, scale[1], 0],[0, 0, scale[2]]])
            resized_segs[idx] = ndimage.affine_transform(seg, this_affine, output_shape=resized_segs[idx].shape, cval=0)

        resized_segs = torch.tensor(resized_segs).to(torch.int64)
        loss = criterion(y_preds, resized_segs)
        running_loss = loss.item()
        loss.backward()                
        optimizer.step()
        
        
        total_processed_batches = epoch * batches_per_epoch + 1 + batch_idx
        avg_batch_time = (time.time() - train_time_start) / total_processed_batches
        if batch_idx % 50 == 0:
            print("Epoch:{} Batch:{} loss = {:.3f}, avg_batch_time = {:.3f}".format(epoch, batch_idx, running_loss, avg_batch_time))
    scheduler.step()
    epoch_res.append_result(epoch, running_loss, current_lr)
    model_checkpoint_path = proj_config.save_checkpoint_pathname(epoch, with_Datetime=False)
    torch.save({'epoch_list': epoch_res.epoch_list, 'loss_list': epoch_res.loss_list, 'lr_list': epoch_res.lr_list, 
                'state_dict': proj_resnet_model.state_dict(),'optimizer': optimizer.state_dict()},model_checkpoint_path, _use_new_zipfile_serialization=True)

print('Finished Training')

current epoch=    1 Learning Rate=[0.001]
Epoch:1 Batch:0 loss = 0.014, avg_batch_time = 0.018
Epoch:1 Batch:50 loss = 0.017, avg_batch_time = 0.733
Epoch:1 Batch:100 loss = 0.016, avg_batch_time = 1.844
Epoch:1 Batch:150 loss = 0.010, avg_batch_time = 2.269
Epoch:1 Batch:200 loss = 0.012, avg_batch_time = 2.579
Epoch:1 Batch:250 loss = 0.011, avg_batch_time = 2.833
Epoch:1 Batch:300 loss = 0.008, avg_batch_time = 3.037
Epoch:1 Batch:350 loss = 0.010, avg_batch_time = 3.209
current epoch=    2 Learning Rate=[0.0009801]
Epoch:2 Batch:0 loss = 0.007, avg_batch_time = 3.271
Epoch:2 Batch:50 loss = 0.008, avg_batch_time = 3.410
Epoch:2 Batch:100 loss = 0.006, avg_batch_time = 3.535
Epoch:2 Batch:150 loss = 0.006, avg_batch_time = 3.645
Epoch:2 Batch:200 loss = 0.007, avg_batch_time = 3.741
Epoch:2 Batch:250 loss = 0.008, avg_batch_time = 3.828
Epoch:2 Batch:300 loss = 0.006, avg_batch_time = 3.903
Epoch:2 Batch:350 loss = 0.005, avg_batch_time = 3.971
current epoch=    3 Learning Rate=[0.0