In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
import onnx
import sys
import re
import torch.optim as optim
import time
from torch.utils.data import DataLoader
from scipy import ndimage
from tqdm import tqdm
from datetime import datetime

In [2]:
sys.path.append('..')
from classes.dataset_utils.toTorchDataset import ProcessedKit23TorchDataset
from classes.models import resnet_model_generator


In [3]:
from classes.config_class import ProjectModelResnetConfig

In [4]:
training_data = ProcessedKit23TorchDataset(train_data=True, test_size=0.25, dataset_dir ="./dataset/affine_transformed")
test_data = ProcessedKit23TorchDataset(train_data=False, test_size=0.25, dataset_dir ="./dataset/affine_transformed")

In [5]:
resume_path = "../pretrainedModel/resnet_10_23dataset.pth"
checkpoint = []
if torch.cuda.is_available():
    checkpoint = torch.load(resume_path)
else:
    checkpoint = torch.load(resume_path, map_location=torch.device('cpu'))


In [6]:
proj_resnet_model, pro_resnet_params = resnet_model_generator.generate_model(ProjectModelResnetConfig)

In [7]:
net_dict = proj_resnet_model.state_dict()
pretrain_dict = {}
for k,v in checkpoint['state_dict'].items():
    has_match = False
    for model_key in net_dict.keys():
        k_len = len(model_key)
        c_key = k[7:]
        if c_key == model_key:
            pretrain_dict[model_key] = v
            print("Checkpoint Key: {:60s} model key: {:60s}  {}  {}".format(k, model_key, v.size(), type(v)))
            has_match = True
            break
    if not has_match:
        print("Checkpoint Key: {:60s} has no match".format(k))
    
net_dict.update(pretrain_dict)


Checkpoint Key: module.conv1.weight                                          model key: conv1.weight                                                  torch.Size([64, 1, 7, 7, 7])  <class 'torch.Tensor'>
Checkpoint Key: module.bn1.weight                                            model key: bn1.weight                                                    torch.Size([64])  <class 'torch.Tensor'>
Checkpoint Key: module.bn1.bias                                              model key: bn1.bias                                                      torch.Size([64])  <class 'torch.Tensor'>
Checkpoint Key: module.bn1.running_mean                                      model key: bn1.running_mean                                              torch.Size([64])  <class 'torch.Tensor'>
Checkpoint Key: module.bn1.running_var                                       model key: bn1.running_var                                               torch.Size([64])  <class 'torch.Tensor'>
Checkpoint Key: module.bn1.num_ba

In [8]:
proj_resnet_model.load_state_dict(net_dict)

<All keys matched successfully>

In [9]:
data_loader = DataLoader(training_data, batch_size=ProjectModelResnetConfig.batch_size, shuffle=True, num_workers=ProjectModelResnetConfig.num_workers, pin_memory=ProjectModelResnetConfig.pin_memory)

In [10]:

criterion = nn.CrossEntropyLoss(ignore_index=-1)
optimizer = optim.SGD(proj_resnet_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-3)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
if not ProjectModelResnetConfig.no_cuda:
    ProjectModelResnetConfig.pin_memory = True
    criterion = criterion.cuda()

In [11]:
train_time_start = time.time()
batches_per_epoch = len(data_loader)
epoch_list = []
loss_list = []
lr_list = []
for epoch in range(ProjectModelResnetConfig.max_epoch):
    current_lr = scheduler.get_last_lr()
    print("current epoch={:5d} Learning Rate={}".format(epoch, current_lr))
    lr_list.append(current_lr)
    
    for batch_idx, batch_data  in enumerate(data_loader):
        imgs, segs = batch_data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        y_preds = proj_resnet_model(imgs.float())

        [n, _, z_size, y_size, x_size] = y_preds.shape

        resized_segs = np.zeros([n, z_size, y_size, x_size])
        for idx in range(n):
            seg = segs[idx][0]
            [ori_z, ori_y, ori_x] = seg.shape 
            scale = [z_size/ori_z, y_size/ori_y, x_size/ori_x]
            this_affine = np.array([[scale[0], 0, 0],[0, scale[1], 0],[0, 0, scale[2]]])
            resized_segs[idx] = ndimage.affine_transform(seg, this_affine, output_shape=resized_segs[idx].shape, cval=0)

        resized_segs = torch.tensor(resized_segs).to(torch.int64)
        loss = criterion(y_preds, resized_segs)
        running_loss = loss.item()
        loss.backward()                
        optimizer.step()
        
        
        total_processed_batches = epoch * batches_per_epoch + 1 + batch_idx
        avg_batch_time = (time.time() - train_time_start) / total_processed_batches
        if batch_idx % 50 == 0:
            print("Epoch:{} Batch:{} loss = {:.3f}, avg_batch_time = {:.3f}".format(epoch, batch_idx, running_loss, avg_batch_time))
    scheduler.step()
    epoch_list.append(epoch)
    loss_list.append(loss.item())
    model_checkpoint_path = ProjectModelResnetConfig.save_checkpoint_pathname(epoch)
    torch.save({'epoch_list': epoch_list, 'loss_list': loss_list, 'lr_list': lr_list, 'state_dict': proj_resnet_model.state_dict(),'optimizer': optimizer.state_dict()},model_checkpoint_path)

print('Finished Training')

current epoch=    0 Learning Rate=[0.001]
Epoch:0 Batch:0 loss = 1.275, avg_batch_time = 7.245
Epoch:0 Batch:50 loss = 0.114, avg_batch_time = 7.853
Epoch:0 Batch:100 loss = 0.057, avg_batch_time = 8.485
Epoch:0 Batch:150 loss = 0.031, avg_batch_time = 8.825
Epoch:0 Batch:200 loss = 0.030, avg_batch_time = 9.017
Epoch:0 Batch:250 loss = 0.022, avg_batch_time = 9.066
Epoch:0 Batch:300 loss = 0.020, avg_batch_time = 9.035
Epoch:0 Batch:350 loss = 0.015, avg_batch_time = 8.980
current epoch=    1 Learning Rate=[0.00099]
Epoch:1 Batch:0 loss = 0.013, avg_batch_time = 8.966
Epoch:1 Batch:50 loss = 0.015, avg_batch_time = 8.826
Epoch:1 Batch:100 loss = 0.014, avg_batch_time = 8.771
Epoch:1 Batch:150 loss = 0.010, avg_batch_time = 8.696
Epoch:1 Batch:200 loss = 0.008, avg_batch_time = 8.627
Epoch:1 Batch:250 loss = 0.009, avg_batch_time = 8.570
Epoch:1 Batch:300 loss = 0.007, avg_batch_time = 8.523
Epoch:1 Batch:350 loss = 0.006, avg_batch_time = 8.479
current epoch=    2 Learning Rate=[0.000

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x109eb6290>
Traceback (most recent call last):
  File "/usr/local/opt/pyenv/versions/3.10.9/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    def __del__(self):
  File "/usr/local/opt/pyenv/versions/3.10.9/lib/python3.10/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
    _error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 69809) is killed by signal: Interrupt: 2. 


KeyboardInterrupt: 