In [2]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
import sys
import re
import torch.optim as optim
import time
from torch.utils.data import DataLoader
from scipy import ndimage
from tqdm import tqdm
from datetime import datetime

In [3]:
sys.path.append('..')
from classes.dataset_utils.toTorchDataset import ProcessedKit23TorchDataset
from classes.models.unet3d import UNet3D
from classes.config_class import ProjectModelResnetConfig
from classes.epoch_results import EpochResult

In [4]:
training_data = ProcessedKit23TorchDataset(train_data=True, test_size=0.25, dataset_dir ="./dataset/affine_transformed")
test_data = ProcessedKit23TorchDataset(train_data=False, test_size=0.25, dataset_dir ="./dataset/affine_transformed")

In [5]:
proj_config = ProjectModelResnetConfig(model_depth=50)
net = UNet3D(1, 3)

In [6]:
criterion = nn.CrossEntropyLoss(ignore_index=-1)
optimizer = optim.Adam(net.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
if not proj_config.no_cuda:
    criterion = criterion.cuda()

In [7]:
train_from_pretrained = False
epoch_res = EpochResult()
epoch_start = 0
if train_from_pretrained:
    print("loading from pretrained Med3D model")
    if proj_config.model_depth == 10:
        proj_config.load_med3d_pretrain_weigth("../pretrainedModel/resnet_10_23dataset.pth")
    elif proj_config.model_depth == 50:
        proj_config.load_med3d_pretrain_weigth("../pretrainedModel/resnet_50_23dataset.pth")
    else:
        raise Exception("Only depth 10 and 50 are used for now.")
else:
    # this continues from certain training points
    checkpoint, epoch_res = proj_config.load_weight_from_epoch("./training_checkpoints/Model_resnet_50_epoch5.pth.tar")
    optimizer.load_state_dict(checkpoint['optimizer'])
    epoch_start = epoch_res.epoch_list[-1] + 1

Exception: neural network model is not set

In [8]:
data_loader = DataLoader(training_data, batch_size=proj_config.batch_size, shuffle=True, num_workers=proj_config.num_workers, pin_memory=proj_config.pin_memory)

In [9]:
train_time_start = time.time()
batches_per_epoch = len(data_loader)

for epoch in range(epoch_start, proj_config.max_epoch):
    current_lr = scheduler.get_last_lr()
    running_loss = None
    print("current epoch={:5d} Learning Rate={}".format(epoch, current_lr))
    
    for batch_idx, batch_data  in enumerate(data_loader):
        imgs, segs = batch_data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        y_preds = proj_config.nn_model(imgs.float())

        [n, _, z_size, y_size, x_size] = y_preds.shape

        resized_segs = np.zeros([n, z_size, y_size, x_size])
        for idx in range(n):
            seg = segs[idx][0]
            [ori_z, ori_y, ori_x] = seg.shape 
            scale = [z_size/ori_z, y_size/ori_y, x_size/ori_x]
            this_affine = np.array([[scale[0], 0, 0],[0, scale[1], 0],[0, 0, scale[2]]])
            resized_segs[idx] = ndimage.affine_transform(seg, this_affine, output_shape=resized_segs[idx].shape, cval=0)

        resized_segs = torch.tensor(resized_segs).to(torch.int64)
        loss = criterion(y_preds, resized_segs)
        running_loss = loss.item()
        loss.backward()                
        optimizer.step()
        
        
        total_processed_batches = (epoch - epoch_start) * batches_per_epoch + 1 + batch_idx
        avg_batch_time = (time.time() - train_time_start) / total_processed_batches
        if batch_idx % 50 == 0:
            print("Epoch:{} Batch:{} loss = {:.5f}, avg_batch_time = {:.5f}".format(epoch, batch_idx, running_loss, avg_batch_time))
    scheduler.step()
    epoch_res.append_result(epoch, running_loss, current_lr)
    model_checkpoint_path = proj_config.save_checkpoint_pathname(epoch, with_Datetime=False)
    torch.save({'epoch_list': epoch_res.epoch_list, 'loss_list': epoch_res.loss_list, 'lr_list': epoch_res.lr_list, 
                'state_dict': proj_config.nn_model.state_dict(),'optimizer': optimizer.state_dict()},model_checkpoint_path, _use_new_zipfile_serialization=True)

print('Finished Training')

current epoch=    0 Learning Rate=[0.001]


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x107cb3e20>
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1442, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/connection

TypeError: 'NoneType' object is not callable