In [2]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
import sys
import re
import torch.optim as optim
import time
from torch.utils.data import DataLoader
from scipy import ndimage
from tqdm import tqdm
from datetime import datetime

In [16]:
sys.path.append('..')
from classes.dataset_utils.toTorchDataset import ProcessedKit23TorchDataset
from classes.models.unet3d import UNet3D

In [4]:
training_data = ProcessedKit23TorchDataset(train_data=True, test_size=0.25, dataset_dir ="./dataset/affine_transformed")
test_data = ProcessedKit23TorchDataset(train_data=False, test_size=0.25, dataset_dir ="./dataset/affine_transformed")

In [11]:
net = UNet3D(1, 4)

criterion = nn.CrossEntropyLoss(ignore_index=-1)
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-3)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
epoch_start = 0

In [12]:
data_loader = torch.utils.data.DataLoader(training_data, batch_size=4,shuffle=True, num_workers=2)

In [None]:
batches_per_epoch = len(data_loader)

for epoch in range(epoch_start, 10):
    current_lr = scheduler.get_last_lr()
    running_loss = None
    print("current epoch={:5d} Learning Rate={}".format(epoch, current_lr))
    
    for batch_idx, batch_data  in enumerate(data_loader):
        imgs, segs = batch_data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        y_preds = net(imgs.float())

        [n, _, z_size, y_size, x_size] = y_preds.shape

        resized_segs = np.zeros([n, z_size, y_size, x_size])
        for idx in range(n):
            seg = segs[idx][0]
            [ori_z, ori_y, ori_x] = seg.shape 
            scale = [z_size/ori_z, y_size/ori_y, x_size/ori_x]
            this_affine = np.array([[scale[0], 0, 0],[0, scale[1], 0],[0, 0, scale[2]]])
            resized_segs[idx] = ndimage.affine_transform(seg, this_affine, output_shape=resized_segs[idx].shape, cval=0)

        resized_segs = torch.tensor(resized_segs).to(torch.int64)
        loss = criterion(y_preds, resized_segs)
        running_loss = loss.item()
        loss.backward()                
        optimizer.step()
        
        
        total_processed_batches = epoch * batches_per_epoch + 1 + batch_idx
        avg_batch_time = (time.time() - train_time_start) / total_processed_batches
        if batch_idx % 50 == 0:
            print("Epoch:{} Batch:{} loss = {:.3f}, avg_batch_time = {:.3f}".format(epoch, batch_idx, running_loss, avg_batch_time))
    scheduler.step()
    epoch_res.append_result(epoch, running_loss, current_lr)
    model_checkpoint_path = proj_config.save_checkpoint_pathname(epoch, with_Datetime=False)
    torch.save({'epoch_list': epoch_res.epoch_list, 'loss_list': epoch_res.loss_list, 'lr_list': epoch_res.lr_list, 
                'state_dict': proj_resnet_model.state_dict(),'optimizer': optimizer.state_dict()},model_checkpoint_path, _use_new_zipfile_serialization=True)

print('Finished Training')

current epoch=    0 Learning Rate=[0.001]
