In [None]:
import torch
import imgaug.augmenters as iaa
import numpy as np

from architectures import UNet
from dataset import dataset_gen
import gpu_manager as gm
import oper
from stats import print_data

## 1. Paths

In [None]:
train_path = '../RibFrac/preprocessed/train/'
val_path = '../RibFrac/preprocessed/val/'

save_path = 'saved_models/baseline.tar'
stats_path = 'stats/baseline'

## 2. Fix Backend (for reproducing results)

In [None]:
seed = 12
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True 
torch.backends.cudnn.benchmark = False

## 3. GPU Test
- TODO(3): Multi-GPU training/load balancing

In [None]:
if torch.cuda.is_available():
    print('GPU is available.')
    manager = gm.GPUManager()
    device = torch.device(manager.auto_choice())
else:
    print('GPU is not available. Use CPU instead.')
    device = torch.device('cpu')

## 4. Data Preprocessing and Augmentation

In [None]:
seqs_train = []
seqs_train.append(iaa.Sequential([
    iaa.size.Resize(128, interpolation='nearest'),
]))
seqs_train.append(iaa.Sequential([
    iaa.flip.Fliplr(1.0),
    iaa.size.Resize(128, interpolation='nearest'),
]))
seqs_train.append(iaa.Sequential([
    iaa.geometric.Affine(rotate=(-15,15), order=0),
    iaa.size.Resize(128, interpolation='nearest'),
]))
seqs_train.append(iaa.Sequential([
    iaa.geometric.Affine(scale=(0.8,1.2), order=0),
    iaa.size.Resize(128, interpolation='nearest'),
]))
seqs_train.append(iaa.Sequential([
    iaa.geometric.Affine(translate_px=(-32,32)),
    iaa.size.Resize(128, interpolation='nearest'),
]))

seqs_val = [iaa.Sequential([
    iaa.size.Resize(128, interpolation='nearest'),
])]

train_dataset = dataset_gen(train_path, seqs_train)
val_dataset = dataset_gen(val_path, seqs_val)

print('train size:', len(train_dataset))
print('val size:', len(val_dataset))

## 5. Parameter Selection
### 5.1 Model Selection
#### 5.1.1 U-Net

In [None]:
in_channels=1
n_classes=2
depth=5
wf=6 
padding=True
batch_norm=False
up_mode='upconv'

"""
Args:
    in_channels (int): number of input channels
    n_classes (int): number of output channels
    depth (int): depth of the network
    wf (int): number of filters in the first layer is 2**wf
    padding (bool): if True, apply padding such that the input shape
                    is the same as the output.
                    This may introduce artifacts
    batch_norm (bool): Use BatchNorm after layers with an
                       activation function
    up_mode (str): one of 'upconv' or 'upsample'.
                   'upconv' will use transposed convolutions for
                   learned upsampling.
                   'upsample' will use bilinear upsampling.
"""

model = UNet(
            in_channels=in_channels, 
            n_classes=n_classes, 
            depth=depth, 
            wf=wf, 
            padding=padding, 
            batch_norm=batch_norm, 
            up_mode=up_mode).to(device)

### 5.1.2 
- TODO(1): other nn architectures

### 5.2 Optimizer Selection
#### 5.2.1 `Adam`

In [None]:
lr = 1e-5
betas = (0.9, 0.999)
eps = 1e-08
weight_decay = 0

optim = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)

#### 5.2.2
- TODO(1): other optimizers

### 5.3 Loss function
#### 5.3.1 BCE/Focal Loss
- TODO(3): loss function init

In [None]:
loss_func = 'focal_loss'
gamma = 0
alpha = 0.75

#### 5.3.2
- TODO(2): other loss functions

## 6. Model Training and Evaluation

In [None]:
epochs = 4
pad = 0
batch_size = 64

stats = oper.run_model(
            device = device,
            train_dataset = train_dataset, 
            val_dataset = val_dataset,
            model = model,
            optim = optim, 
            loss_func=loss_func,
            gamma = gamma, 
            alpha = alpha,
            epochs = epochs,
            pad = pad,
            batch_size = batch_size,
            save_path = save_path)

## 7. Statistics
- TODO(3): Tensorboard monitor

In [None]:
print_data(
    epochs = epochs, 
    stats = stats, 
    stats_path = stats_path)