In [None]:
import torch
import torchvision.transforms as tr

from torch.utils.data import DataLoader, random_split, ConcatDataset
import numpy as np

import DG
from UNet import UNet
import Oper
from Stats import print_data
import Output

In [None]:
### paths
train_path = 'dataset1/train/'
gt_path = 'dataset1/train_GT/SEG'
test_path = 'dataset1/test/'
result_path = 'dataset1/test_RES/'
save_path = 'saved_models/'##
stats_path = 'stats/'##
   
### U-Net Params
in_channels=1 
n_classes=2
depth=3
wf=6 
padding=True
batch_norm=False
up_mode='upconv'
#### Param Notes

"""
Args:
    in_channels (int): number of input channels
    n_classes (int): number of output channels
    depth (int): depth of the network
    wf (int): number of filters in the first layer is 2**wf
    padding (bool): if True, apply padding such that the input shape
                    is the same as the output.
                    This may introduce artifacts
    batch_norm (bool): Use BatchNorm after layers with an
                       activation function
    up_mode (str): one of 'upconv' or 'upsample'.
                   'upconv' will use transposed convolutions for
                   learned upsampling.
                   'upsample' will use bilinear upsampling.
"""

### Model Running Params
epochs = 16
pad = 6
train_ratio = 0.8

#### optim Params
optim_name = 'Adam'
lr = 1e-5
momentum = 0.99#for SGD
betas = (0.9, 0.999)
eps = 1e-08
weight_decay = 0

#### loss function Params
loss_func='cross_entropy'
gamma = 0
alpha = 0.75

In [None]:
### exp specific params
save_path = save_path + '46_all_fl_best_0.99.tar'
stats_path = stats_path + '46_all_fl_best_0.99'

loss_func = 'focal_loss'
depth = 4
wf = 6

train_ratio = 0.99

In [None]:
# fix backend
seed = 10
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed) 
torch.backends.cudnn.deterministic = True 
torch.backends.cudnn.benchmark = False

In [None]:
if __name__ == "__main__": 
    if torch.cuda.is_available():
        print('GPU is available.')
        device = torch.device('cuda')
    else:
        print('GPU is not available. Use CPU instead.')
        device = torch.device('cpu')

    #数据变换
    tr_ori = tr.Compose([
        tr.ToPILImage(),
        tr.Grayscale(1),
        tr.Pad(pad),
        tr.ToTensor()
    ])
    #原数据构成的数据集
    dataset_ori = DG.DatasetGen(train_path, gt_path, tr_ori)
    #水平翻转
    dataset_h = DG.DatasetHGen(train_path, gt_path, tr_ori)
    #垂直翻转
    dataset_v = DG.DatasetVGen(train_path, gt_path, tr_ori) 
    #水平+垂直翻转（旋转180度）
    dataset_hv = DG.DatasetHVGen(train_path, gt_path, tr_ori) 
    #逆时针旋转90度
    dataset_r90 = DG.DatasetR90Gen(train_path, gt_path, tr_ori)
    #逆时针旋转270度
    dataset_r270 = DG.DatasetR270Gen(train_path, gt_path, tr_ori) 
    #变形，采用U-Net原文的参数
    dataset_ed = DG.DatasetEDGen(train_path, gt_path, tr_ori, 10, 3, [3, 0])
    #转置
    dataset_tp = DG.DatasetTPGen(train_path, gt_path, tr_ori)
    #另一种转置
    dataset_sktp = DG.DatasetSTPGen(train_path, gt_path, tr_ori)  
    #合并数据集
    dataset = ConcatDataset([dataset_ori, dataset_h, dataset_v, 
                             dataset_hv, dataset_r90, dataset_r270, 
                             dataset_ed, dataset_tp, dataset_sktp])
    
    train_size = int(np.floor(train_ratio * (dataset.__len__())))
    val_size = dataset.__len__() - train_size
    
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)

    model = UNet(
                in_channels=in_channels, 
                n_classes=n_classes, 
                depth=depth, 
                wf=wf, 
                padding=padding, 
                batch_norm=batch_norm, 
                up_mode=up_mode).to(device)
    
    if optim_name == 'Adam':
        optim = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
    elif optim_name == 'SGD':
        optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    
    stats = Oper.run_model(
                model = model, 
                optim = optim,
                train_loader = train_loader, 
                val_loader = val_loader, 
                device = device,
                save_path = save_path,
                train_size = train_size,
                val_size = val_size,
                epochs = epochs,
                pad = pad,
                lr = lr, 
                betas = betas, 
                eps = eps, 
                weight_decay = weight_decay,
                loss_func=loss_func,
                gamma = gamma, 
                alpha = alpha)
    
    print_data(epochs = epochs, stats = stats, stats_path = stats_path)

In [None]:
torch.cuda.empty_cache()