## Import Modules

In [1]:
import pandas as pd
import numpy as np
import scipy.io as sio
import os
from os import listdir
from os.path import isfile, join
import time
import math
from IPython.display import clear_output
import glob

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

import matplotlib.pyplot as plt
import torchvision.transforms as T
#from torchsampler import ImbalancedDatasetSampler
import os
#os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
print('cuda available: '+ str(torch.cuda.is_available()))
#from skimage.morphology import disk, binary_dilation

from models_avm import NestedUNet
from loss_fun_avm import compute_per_channel_dice, DiceLoss, FocalLoss
from tra_val_avm import train, validation
from data_loader_avm import Dataset

cuda available: True


## Data Path

In [2]:
path_tra1 = '/work/samhong833/Data_AVM/forUNetpp/1_tra/1'
list_tra1 = os.listdir(path_tra1)
for i in range(len(list_tra1)):
    list_tra1[i] = path_tra1+'/'+list_tra1[i]
    
list_tra = list_tra1
    
path_val1 = '/work/samhong833/Data_AVM/forUNetpp/2_val/1'
list_val1 = os.listdir(path_val1)
for i in range(len(list_val1)):
    list_val1[i] = path_val1+'/'+list_val1[i]
    
list_val = list_val1

path_ts1 = '/work/samhong833/Data_AVM/forUNetpp/3_ts/1'
list_ts1 = os.listdir(path_ts1)
for i in range(len(list_ts1)):
    list_ts1[i] = path_ts1+'/'+list_ts1[i]
    
list_ts = list_ts1

path_tra_lab_txt = '/work/samhong833/Data_AVM/forYOLOv5/labels/1_tra'
path_val_lab_txt = '/work/samhong833/Data_AVM/forYOLOv5/labels/2_val'
path_ts_lab_txt = '/work/samhong833/Data_AVM/forYOLOv5/labels/3_ts'

## Data Loader

In [3]:
train_data = Dataset(list_tra,path_tra_lab_txt,rand_dilate=True,max_dilate_factor=10)
train_loader = torch.utils.data.DataLoader(
    train_data, 
    batch_size = 12,
    shuffle = True,   
)

val_data = Dataset(list_val,path_val_lab_txt,rand_dilate=True,max_dilate_factor=10)
val_loader = torch.utils.data.DataLoader(
    dataset = val_data,           
    batch_size = 4,                 
    shuffle = False,              
)

## Initialize Model and Optimization Parameters

In [4]:
# call model cuda for gpu
model = NestedUNet().cuda()

# define optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 10, gamma=0.95)

# loss function
kwargs = {"alpha": 0.5, "gamma": 3, "reduction": 'mean'}
criterion_FL = FocalLoss(**kwargs)
criterion_DICE = DiceLoss()
loss = [criterion_FL,criterion_DICE]

## Start to Train

In [None]:
# Create Directory
path = '/work/samhong833/Models/Seg_dia10'

if os.path.isdir(path)==False:
    os.mkdir(path)
path = os.path.join(path,"train")
if os.path.isdir(path)==False:
    os.mkdir(path)       
filenum = glob.glob(path + "/exp*")
path = path + "/exp" + str(len(filenum)+1)
os.mkdir(path)

# Train the Model
epochs = 500 # The number of epochs

valloss = 0
for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(path=path,
          model=model,
          loss=loss,
          optimizer=optimizer,
          dataloader=train_loader,
          epoch=epoch,
          scheduler=scheduler)
    print('-' * 89)
    vallossnew = validation(path=path,
          model=model,
          loss=loss,
          dataloader=val_loader,
          epoch=epoch)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s'.format(epoch, (time.time() - epoch_start_time)))
    print('-' * 89)
    
    scheduler.step()
    
    f2 = open(path + '/model_info.txt', 'a')
    if vallossnew<valloss or epoch ==1: 
        fname = path + '/best_val'  + '.tar'
        torch.save(model.state_dict(), fname)
        valloss = vallossnew
        f2.write('| best_val | epoch {:3d}| '.format(epoch)+'\r\n')        
    f2.close()
    

# Log ^