In [None]:
import os
import torch

from dataset import get_loader
from architectures import UNet
from losses import FocalLoss
from operation import run

## 1. Global Variables and Backend Fixation

In [None]:
train_path = './_data/train.pkl'
log_dir = './_logs/'
save_dir = './_saved_models'

if not os.path.exists(log_dir):
    os.makedirs(log_dir)

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

random_state = 42

torch.manual_seed(random_state)
torch.cuda.manual_seed_all(random_state)
torch.backends.cudnn.deterministic = True 
torch.backends.cudnn.benchmark = False

## 2. GPU Mounting

In [None]:
def gpu_manager(model):
    device_cnt = torch.cuda.device_count()
    if device_cnt > 0:
        if device_cnt == 1:
            print('Only 1 GPU is available.')
        else:
            print(f"{device_cnt} GPUs are available.")
            model = torch.nn.DataParallel(model)
        model = model.cuda()
    else:
        print('Only CPU is available.')
        
    return model

## 3. Configuration

### 3.1. About Padding:
when using valid padding, the output size may not match input size, thus zero padding the input is needed,

the sizes should be calculated in advance based on the network structure, in our case (default UNet):
- `input_size = 204`, 
- `pad_size = (input_size - output_size)//2`

the corresponding output size is `116`, so we still need to slice it into `112**2`

if same padding then `pad_size = 0`

### 3.2. UNet Args:
- **in_channels (int)**: number of input channels
- **n_classes (int)**: number of output channels
- **depth (int)**: depth of the network
- **wf (int)**: number of filters in the first layer is `2**wf`
- **padding (bool)**: if True, apply padding such that the input shape is the same as the output. This may introduce artifacts
- **batch_norm (bool)**: Use BatchNorm after layers with an activation function
- **up_mode (str)**: one of 'upconv' or 'upsample'. 'upconv' will use transposed convolutions for learned upsampling. 'upsample' will use bilinear upsampling.

In [None]:
resize = {
    'pad_size': 46,
    'output_size': 112,## plz use even single values, odd size and rectangles are not taken care of!
}

augs = {
    'rotate':(-15,15),
    'scale':(0.85,1.15),
    'translate':(0.15,0.15),
    'shear':(-10,10,-10,10),
}

data_params = {
    'val_size': 1,
    'batch_size': 64,
    'num_workers': 4,
}

model_params = {
    'in_channels':1,
    'n_classes':2,
    'depth':4,## 4 at maximum
    'wf':5,
    'padding':False,
    'batch_norm':False,
    'up_mode':'upconv',
}

optim_params = {
    'lr':3e-4,
    'betas':(0.9, 0.999),
    'eps':1e-08,
    'weight_decay':1e-4,
}

loss_params = {
    'reduction':'none',
}

In [None]:
checkpoint = torch.load(os.path.join(save_dir, 'checkpoint0.tar.gz'))

model = gpu_manager(UNet(**model_params))
model.load_state_dict(checkpoint['model_state_dict'])

optim = torch.optim.Adam(model.parameters(), **optim_params)
optim.load_state_dict(checkpoint['optim_state_dict'])

print(checkpoint['epoch'])

In [None]:
run_params = {
    'dataloader': get_loader(train_path, resize, augs, **data_params),
    'model': model,
    'optim': optim,
    'scheduler': checkpoint['scheduler'],
    'criterion': torch.nn.CrossEntropyLoss(**loss_params),
    'epochs': 300 - checkpoint['epoch'],
    'log_path': os.path.join(log_dir, 'log_33/'),
    'save_path': os.path.join(save_dir, 'checkpoint33.tar.gz')
}

In [None]:
run(**run_params)