# Train model

## References & related works
https://github.com/meetshah1995/pytorch-semseg/blob/master/train.py
https://github.com/bfortuner/pytorch_tiramisu/blob/master/camvid_dataset.py
https://github.com/bfortuner/pytorch_tiramisu/blob/master/tiramisu-pytorch.ipynb
https://github.com/bodokaiser/piwise/blob/master/piwise/transform.py
https://github.com/bodokaiser/piwise/blob/master/main.py    

Fix conda slow loading https://github.com/pytorch/pytorch/issues/537    

## Tensorboard loading

In [None]:
# get_ipython().system_raw('tensorboard --logdir /tmp/log --host 0.0.0.0 --port 6006 &')

In [None]:
# get_ipython().system_raw('lt --port 6006 >> url.txt 2>&1 &')

In [None]:
# !cat url.txt

In [None]:
from model.densenet import FCDenseNet103, FCDenseNet67, FCDenseNet57
from model.panelnet import PanelNet

from processing_utils import image_manipulations as i_manips
from processing_utils import runtime_logic
from processing_utils import analysis_utils
from processing_utils import load_data

import argparse
import os
import pickle
from copy import deepcopy

import yaml
import torch
import numpy as np
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision import datasets, models
from torchvision.transforms import Compose, Normalize, ToTensor, ColorJitter, FiveCrop, CenterCrop, Lambda, RandomCrop, Resize

## Hyperparameters

In [None]:
seed = 0
torch.cuda.manual_seed(seed)    

run_name = "semseg_test"
data_root = os.path.join(os.getcwd(), "dummy_data/dummy_tiles")
img_directory = os.path.join(data_root, "images/train/")
tile_directory = os.path.join(data_root, "labels/train/")

normalize_params_file = f'outputs/norm_params/{run_name}_norm_params.yaml'
previous_model_state = None # Pickled model state dictionary path

# Can be set to none if TensorboardX isn't installed
tensorboard_outpath = f'outputs/tensorboard/{run_name}_log' # Set to none for temp

# Set colour information
colours = {'void':[255,255,255],
           'building':[255,0,0]}
classes = ['void', 'building']

n_epochs = 25
input_size = 400
init_l_rate = 1e-4
lr_decay = 0.1
# lr_decay_patience = 3
lr_decay_epoch = [5, 10, 15]
w_decay = 1e-4

batch_size = 1
num_channels = 3
num_classes = 2

class_weights = torch.Tensor([1, 1])

criterion = nn.NLLLoss(class_weights)
if torch.cuda.is_available():
    criterion = criterion.cuda()

report_results_per_n_batches = {'train':1, 'val':1}
save_interval = 9999

shutdown_after = False

In [None]:
# model = PanelNetTwo(num_classes,
#                 num_channels,
#                 lyr1_kernels=100,
#                 growth_rate=25).apply(analysis_utils.weights_init).cuda()

In [None]:
%%capture
model = FCDenseNet57(n_classes=2, in_channels=num_channels)
model.apply(analysis_utils.weights_init)

In [None]:
# state_dict = torch.load(previous_model_state)
# model.load_state_dict(state_dict)
if torch.cuda.is_available():
    model = model.cuda()

## Data Loading

### Get normalize parameters

In [None]:
# # Get normalize parameters
all_imgs = i_manips.get_images(img_directory)

if os.path.isfile(normalize_params_file):
    stream = open(normalize_params_file, 'r')
    norm_params = yaml.load(stream)
else:
    norm_params = i_manips.get_normalize_params(all_imgs, num_bands=num_channels)
    analysis_utils.write_normalize_values(norm_params, normalize_params_file)

means = norm_params["means"]
sdevs = norm_params["sdevs"]

### Transformations

In [None]:
input_transforms = Compose([
    ColorJitter(0.05, 0.05, 0.05),
    CenterCrop(input_size),
    ToTensor(),
    Normalize([means[0],means[1],means[2]],
              [sdevs[0],sdevs[1],sdevs[2]])
])

val_transforms = Compose([
    CenterCrop(input_size),
    ToTensor(),
    Normalize([means[0],means[1],means[2]],
              [sdevs[0],sdevs[1],sdevs[2]])
])

target_transforms = Compose([ 
    CenterCrop(input_size),    
    Lambda(lambda x: torch.LongTensor(np.array(x)))#.squeeze(0))
])

joint_trans = ["hflip"]      

In [None]:
## Load training data    
train_data = load_data.SemSegImageData(split="train",
                                       root_path=data_root,
                                       input_transform=input_transforms,
                                       target_transform=target_transforms,
                                       joint_trans=joint_trans)

train_loader = DataLoader(train_data,
                          batch_size=batch_size,
                          shuffle=True)

val_data = load_data.SemSegImageData(split = "val",
                                     root_path=data_root,
                                     input_transform=val_transforms,
                                     target_transform=target_transforms,
                                     joint_trans=joint_trans)

val_loader = DataLoader(val_data,
                        batch_size=batch_size,
                        shuffle=True)

### Instantiate runtime

In [None]:
optimizer = optim.RMSprop(model.parameters(), lr=init_l_rate, weight_decay=w_decay)
train_analysis = runtime_logic.SemSegAnalysis(model,
                                              classes,
                                              means,
                                              sdevs,
                                              train_loader=train_loader,
                                              val_loader=val_loader,
                                              label_colors=colours)

In [None]:
train_analysis.instantiate_loss_tracker('outputs')
train_analysis.loss_tracker.setup_output_storage(run_name)
train_analysis.instantiate_visualiser(tensorboard_outpath)

## Instantialize trainer

In [None]:
arguments = {# model components
             'run_name':run_name,
             'optimizer':optimizer,
             'criterion':criterion,

             # Hyperparameters
             'n_epochs':n_epochs,
             'batch_size':batch_size,
             'lr_decay':lr_decay,
             'lr_decay_epoch':lr_decay_epoch,
             # 'lr_decay_patience':lr_decay_patience,
             # 'class_weights':class_weights,

             # Saving & Information retrieval
             'report_interval':report_results_per_n_batches,
             'save_interval':save_interval,
    
             'shutdown':shutdown_after,
            }

## Perform training

In [None]:
train_analysis.train(arguments)

In [None]:
# train_analysis.loss_tracker.save_model(model, 0)