In [None]:
import torch
import torchvision.transforms as tr

from torch.utils.data import DataLoader, random_split, ConcatDataset
import numpy as np

import DG
from UNet import UNet
import Oper
from Stats import print_data

In [None]:
train_path = 'dataset1/train/'
gt_path = 'dataset1/train_GT/SEG'
test_path = 'dataset1/test/'
result_path = 'dataset1/test_RES/'
save_path = 'saved_models/'##
stats_path = 'stats/'##

in_channels=1 
n_classes=2
depth=4
wf=6
padding=True
batch_norm=False
up_mode='upconv'

epochs = 16
pad = 6
train_ratio = 0.9

#### optim Params
optim_name = 'Adam'
lr = 1e-5
momentum = 0.99#for SGD
betas = (0.9, 0.999)
eps = 1e-08
weight_decay = 0

#### loss function Params
loss_func='cross_entropy'
gamma = 0
alpha = 0.75

In [None]:
model_path = save_path + '46_all_fl_best_0.99.tar'
save_path = save_path + '46_all_fl_best_0.99.tar'
stats_path = stats_path + '46_all_fl_best_0.99'

depth=4
wf=6
loss_func = 'focal_loss'

In [None]:
best_model = UNet(
            in_channels=in_channels, 
            n_classes=n_classes, 
            depth=depth, 
            wf=wf, 
            padding=padding, 
            batch_norm=batch_norm, 
            up_mode=up_mode)

optim = torch.optim.Adam(best_model.parameters(), lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)

In [None]:
checkpoint = torch.load(model_path)##
best_model.load_state_dict(checkpoint['model_state_dict'])
optim = optim.load_state_dict(checkpoint['optim_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

epochs = 16
print(epoch)
print(loss)

In [None]:
if torch.cuda.is_available():
    print('GPU is available.')
    device = torch.device('cuda')
else:
    print('GPU is not available. Use CPU instead.')
    device = torch.device('cpu')
    
best_model = best_model.to(device)

In [None]:
tr_ori = tr.Compose([
    tr.ToPILImage(),
    tr.Grayscale(1),
    tr.Pad(pad),
    tr.ToTensor()
])

dataset_ori = DG.DatasetGen(train_path, gt_path, tr_ori)

dataset_h = DG.DatasetHGen(train_path, gt_path, tr_ori)
dataset_v = DG.DatasetVGen(train_path, gt_path, tr_ori) 
dataset_hv = DG.DatasetHVGen(train_path, gt_path, tr_ori) 

dataset_r90 = DG.DatasetR90Gen(train_path, gt_path, tr_ori)
dataset_r270 = DG.DatasetR270Gen(train_path, gt_path, tr_ori) 
dataset_ed = DG.DatasetEDGen(train_path, gt_path, tr_ori, 10, 3, [3, 0]) 

dataset = ConcatDataset([dataset_ori, dataset_h, dataset_v, 
                         dataset_hv, dataset_r90, dataset_r270, dataset_ed])
train_size = int(np.floor(train_ratio * (dataset.__len__())))
val_size = dataset.__len__() - train_size

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)

stats = Oper.run_model(
            model = best_model, 
            optim = optim,
            train_loader = train_loader, 
            val_loader = val_loader, 
            device = device,
            save_path = save_path,
            train_size = train_size,
            val_size = val_size,
            epochs = epochs,
            pad = pad,
            lr = lr, 
            betas = betas, 
            eps = eps, 
            weight_decay = weight_decay,
            loss_func=loss_func,
            gamma = gamma, 
            alpha = alpha)

print_data(epochs = epochs, stats = stats, stats_path = stats_path)