### Disclaimer
Distribution authorized to U.S. Government agencies and their contractors. Other requests for this document shall be referred to the MIT Lincoln Laboratory Technology Office.

This material is based upon work supported by the Under Secretary of Defense for Research and Engineering under Air Force Contract No. FA8702-15-D-0001. Any opinions, findings, conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the Under Secretary of Defense for Research and Engineering.

© 2019 Massachusetts Institute of Technology.

The software/firmware is provided to you on an As-Is basis

Delivered to the U.S. Government with Unlimited Rights, as defined in DFARS Part 252.227-7013 or 7014 (Feb 2014). Notwithstanding any copyright notice, U.S. Government rights in this work are defined by DFARS 252.227-7013 or DFARS 252.227-7014 as detailed above. Use of this work other than as specifically authorized by the U.S. Government may violate any copyrights that exist in this work.

# Train Segmentation Model

### Contents
- [Configuration](#Configuration)
- [Define Model](#Define-Model)
- [Train Model](#Train-Model)
- [Visualize Performance](#Visualize-Performance)
- [Export Model to ONNX](#Export-Model-to-ONNX)

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np

%load_ext autoreload
%autoreload 2
from tesse_semantic_segmentation.data import TESSEDataset
from tesse_semantic_segmentation.utils import GOSEEK_CLASSES, CrossEntropy

import segmentation_models_pytorch as smp

# Configuration

In [None]:
BATCH_SIZE = 32
N_CLASSES = len(GOSEEK_CLASSES)
EPOCHS = 20
IMG_RESOLUTION = (256, 320)
RGB_IMG_DIR = "data/goseek-v0.1.0-v2/rgb/"
SEGMENTATION_IMG_DIR = "data/goseek-v0.1.0-v2/segmentation/"

log_dir = Path("./goseek-v0.1.0-weights-v2")
log_dir.mkdir(exist_ok=True, parents=True)

## Create Datasets

In [None]:
imgs = sorted(Path(RGB_IMG_DIR).glob("*png"))
labels = sorted(Path(SEGMENTATION_IMG_DIR).glob("*png"))

training_scenes = (1, )
validation_scenes = (2,)

train_imgs = [img for img in imgs if int(img.stem[-1]) in training_scenes]
train_labels = [label for label in labels if int(label.stem[-1]) in training_scenes]

val_imgs = [img for img in imgs if int(img.stem[-1]) in validation_scenes]
val_labels = [label for label in labels if int(label.stem[-1]) in validation_scenes]

assert len(train_imgs) == len(train_labels) and len(val_imgs) == len(val_labels)

In [None]:
def preprocessor(image, label):
    """ Preprocessor to resize images to correct resolution. """
    interpolation = cv2.INTER_LINEAR

    # opencv flips height and width
    image = cv2.resize(image, IMG_RESOLUTION[::-1])  # default binlinear
    label = cv2.resize(
        label, IMG_RESOLUTION[::-1], interpolation=cv2.INTER_NEAREST
    )  # nearest neighbor to avoid blurring label
    return image, label

In [None]:
# Dataset for train images
train_dataset = TESSEDataset(imgs, labels, N_CLASSES, preprocessor=preprocessor)

# Dataset for validation images
valid_dataset = TESSEDataset(val_imgs, val_labels, N_CLASSES, preprocessor=preprocessor)

# Define Model

In [None]:
ENCODER = "resnet18"
PRETRAINED_WEIGHTS = "imagenet"
DEVICE = "cuda"

model = smp.Unet(
    encoder_name=ENCODER, encoder_weights=PRETRAINED_WEIGHTS, classes=N_CLASSES
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, PRETRAINED_WEIGHTS)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=12)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

In [None]:
inverse_class_frequecy = train_dataset.calculate_inverse_class_frequency().to(DEVICE)
loss = CrossEntropy(weights=inverse_class_frequecy)

metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

optimizer = torch.optim.Adam([dict(params=model.parameters(), lr=0.0001),])

In [None]:
train_epoch = smp.utils.train.TrainEpoch(
    model, loss=loss, metrics=metrics, optimizer=optimizer, device=DEVICE, verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, loss=loss, metrics=metrics, device=DEVICE, verbose=True,
)

# Train Model

In [None]:
max_score = 0

for i in range(0, EPOCHS):

    print("\nEpoch: {}".format(i))
    train_logs = train_epoch.run(train_loader)
    validation_logs = valid_epoch.run(valid_loader)

    # save model if it's the current best
    if max_score < validation_logs["iou_score"]:
        max_score = validation_logs["iou_score"]
        torch.save(model, f"{log_dir}/{ENCODER}-epoch-{i+1}.pth")

    if i == 25:
        optimizer.param_groups[0]["lr"] = 1e-5
        print("Decrease decoder learning rate to 1e-5!")

# Visualize Performance

In [None]:
%matplotlib notebook

In [None]:
_ = model.train(False)

In [None]:
idx = np.random.randint(len(valid_dataset))
img, label = valid_dataset[idx]
pred = model(torch.tensor(img[np.newaxis]).to(DEVICE))[0]

fig, ax = plt.subplots(1, 3)
ax[0].imshow(img.transpose(1, 2, 0))
ax[1].imshow(pred.argmax(0).cpu().numpy())
ax[2].imshow(label.argmax(0))

# Export Model to ONNX

In [None]:
MODEL_WEIGHT_PATH = ""
ONNX_FILE_NAME = ""

assert Path(MODEL_WEIGHT_PATH).exists(), "Must give valid weight path"
assert (
    ONNX_FILE_NAME and ONNX_FILE_NAME[-5:] == ".onnx"
), "Must give ONNX file name with extension `.onnx`"

In [None]:
class ArgMaxModel(nn.Module):
    """ Segmentation model wrapper to return 1 channel class prediction instead of 
    prediction probabilities. """

    def __init__(self):
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x).argmax(axis=1)

In [None]:
argmax_model = ArgMaxModel().to(torch.device('cpu'))

In [None]:
x_in = torch.ones((1, 3) + IMG_RESOLUTION, requires_grad=True)

In [None]:
torch.onnx.export(argmax_model, x_in, ONNX_FILE_NAME, verbose=True)