### Path Setup

In [None]:
# Change these path
MODEL_SAVE_PATH = "models"

DATA = [
    # TFRecord or Dataset Directory
]

### Includes

In [None]:
import multiprocessing
import os
import pathlib
from collections import OrderedDict

import cv2
import numpy as np
import torch
from matplotlib import pyplot as plt
from numpy import random
from scipy import signal
from torch.optim import Adam
from torch.utils.data import DataLoader, random_split
from tqdm.auto import tqdm

from models import get_model
from sampler import ImageSampler

if torch.cuda.is_available():
    device = torch.device("cuda")

### Initalize Data Sampler

In [None]:
dataset = ImageSampler(DATA)
dataset.prepare()

train_test_split = 0.8
batch_size = 256
train_size = int(len(dataset) * train_test_split)
trainset, testset = random_split(dataset, [train_size, len(dataset) - train_size])
print(f"Training: {len(trainset)}, Testing {len(testset)}")

In [None]:
plt.style.use("classic")
fig, axes = plt.subplots(2, 6, figsize=(25, 8))
axes = axes.flatten()
for i in range(12):
    img, steering, throttle = dataset[random.randint(0, len(dataset))]
    axes[i].imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    axes[i].axis("off")
    axes[i].set_title(f"{steering: .4f}, {throttle: .4f}")

In [None]:
def direction_metric(pred, act):
    angle_true = act[:, 0]
    angle_pred = pred[:, 0]
    turns = torch.abs(angle_true) > 0.1
    logits = torch.sign(angle_pred[turns]) == torch.sign(angle_true[turns])
    return torch.sum(logits.float()), len(logits)

def angle_metric(pred, act):
    angle_true = act[:, 0]
    angle_pred = pred[:, 0]
    logits = torch.abs(angle_true - angle_pred) < 0.1
    return torch.mean(logits.float())

def loss_fn(steering, throttle, steering_pred, throttle_pred, throttle_weight):
    steering_loss = ((steering - steering_pred)**2).mean()
    throttle_loss = ((throttle - throttle_pred)**2).mean()
    loss = steering_loss + throttle_weight * throttle_loss
    return loss

In [None]:
class Trainer:
    def __init__(self, save_dir, model: torch.nn.Module, optim: torch.optim.Optimizer, turning_weight=5, epochs=200):
        self.model = model
        self.optim = optim
        self.turning_weight = turning_weight
        self.epochs = epochs

        self.save_dir = pathlib.Path(save_dir).joinpath(model.NAME)
        self.save_dir.mkdir(parents=True, exist_ok=True)
        
        self.train_log = []  # loss, angle, direction
        self.validation_log = []  # loss, angle, direction

        self.best_loss = np.inf
        self.best_angle_metric = 0
        self.best_direction_metric = 0

        self.i = 0

    def load(self, fname):
        data = np.load(fname)
        self.i = data["i"]
        self.train_log = data["train_log"].tolist()
        self.validation_log = data["validation_log"].tolist()
        self.best_loss = data["best_loss"]
        self.best_angle_metric = data["best_angle_metric"]
        self.best_direction_metric = data["best_direction_metric"]

    def save(self, fname):
        torch.save({
            "state": self.model.state_dict(),
            "optim": self.optim.state_dict(),
        }, os.path.join(self.save_dir, f"last.pth"))

        np.savez_compressed(
            fname,
            train_log=self.train_log,
            validation_log=self.validation_log,
            i=self.i,
            best_loss=self.best_loss,
            best_angle_metric=self.best_angle_metric,
            best_direction_metric=self.best_direction_metric
        )

    def train(self, sampler_train, sampler_test):
        epochs = self.epochs
        batches_train = len(sampler_train)
        batches_test = len(sampler_test)

        epochs_bar = tqdm(total=epochs)
        epochs_bar.set_description("Epochs")
        batch_bar = tqdm(total=batches_train)

        epochs_bar.update(self.i)
        epochs_bar.refresh()
        while self.i < epochs:
            #Training
            batch_bar.set_description("Training")
            batch_bar.reset(batches_train)
            for i, (img, steering, throttle) in enumerate(sampler_train):
                Y = torch.stack([steering, throttle], dim=1).type(torch.float32).to(device)
                #Tensor Processing
                X = img.to(device).permute(0, 3, 1, 2) / 256 #Starting dimensions [1,100,90,160] -> After permute [100, 1, 90, 160] (Pytorch Tensor format[Number, Channels, Height, Width])

                #Using the model for inference
                self.optim.zero_grad()
                Y_pred = self.model(X)

                #Loss calculation and backpropogation
                loss = loss_fn(Y[:, 0], Y[:, 1], Y_pred[:, 0], Y_pred[:, 1], throttle_weight=0.2)
                loss.backward()
                self.optim.step()

                #Some extra metrics to grade performance by
                loss = loss.item()
                ang_metric = angle_metric(Y_pred, Y).item()
                dir_metric, num = direction_metric(Y_pred, Y)
                dir_metric = (dir_metric / num).item() if num > 0 else np.nan

                #Debugging/Logging
                self.train_log.append((loss, ang_metric, dir_metric))

                batch_bar.set_postfix(ordered_dict=OrderedDict(
                    Loss=f"{loss: .3f}",
                    Best_loss=f"{self.best_loss: .3f}",
                    Angle=f"{ang_metric: .3f}",
                    Best_angle=f"{self.best_angle_metric: .3f}",
                    Direction=f"{dir_metric: .3f}",
                    Best_Dir=f"{self.best_direction_metric: .3f}"
                ))
                batch_bar.update()

            #Validation (Testing)
            batch_bar.set_description("Validation")
            batch_bar.reset(batches_test)
            validation_avg = torch.zeros(batches_test, 3)  # loss, angle, direction_sum
            direction_num = 0

            for j, (img, steering, throttle) in enumerate(sampler_test):
                Y = torch.stack([steering, throttle], dim=1).type(torch.float32).to(device)
                X = img.to(device).permute(0, 3, 1, 2) / 256

                with torch.no_grad():
                    Y_pred = self.model(X)

                # Test on Validation Set
                val_loss = loss_fn(Y[:, 0], Y[:, 1], Y_pred[:, 0], Y_pred[:, 1], throttle_weight=0.2)
                ang_metric = angle_metric(Y_pred, Y)
                dir_metric, num = direction_metric(Y_pred, Y)
                validation_avg[j] = torch.stack([val_loss, ang_metric, dir_metric])
                direction_num += num

                batch_bar.set_postfix(ordered_dict=OrderedDict(
                    Loss=f"{val_loss.item(): .3f}",
                    Best_loss=f"{self.best_loss: .3f}",
                    Angle=f"{ang_metric.item(): .3f}",
                    Best_angle=f"{self.best_angle_metric: .3f}",
                    Direction=f"{dir_metric.item() / num if num > 0 else np.nan: .3f}",
                    Best_Dir=f"{self.best_direction_metric: .3f}"
                ))
                batch_bar.update()

            validation_avg = validation_avg.sum(dim=0)
            validation_avg[:2] /= batches_test
            validation_avg[2] /= direction_num

            #Debugging/Logging
            self.validation_log.append(validation_avg.tolist())

            val_loss, ang_metric, dir_metric = self.validation_log[-1]
            if val_loss < self.best_loss:
                self.best_loss = val_loss
                if self.best_loss < 0.02:
                    torch.save({
                        "state": self.model.state_dict(),
                        "optim": self.optim.state_dict(),
                    }, os.path.join(self.save_dir, f"best_loss.pth"))
            if ang_metric > self.best_angle_metric:
                self.best_angle_metric = ang_metric
                if self.best_angle_metric > 0.6:
                    torch.save({
                        "state": self.model.state_dict(),
                        "optim": self.optim.state_dict(),
                    }, os.path.join(self.save_dir, f"best_angle.pth"))
            if dir_metric > self.best_direction_metric:
                self.best_direction_metric = dir_metric
                if self.best_direction_metric > 0.8:
                    torch.save({
                        "state": self.model.state_dict(),
                        "optim": self.optim.state_dict(),
                    }, os.path.join(self.save_dir, f"best_dir.pth"))

            # Slow for large model
            # torch.save({
            #     "state": self.model.state_dict(),
            #     "optim": self.optim.state_dict(),
            # }, os.path.join(self.save_dir, f"last.pth"))

            batch_bar.refresh()
            epochs_bar.update()
            self.i += 1

### Train

In [None]:
trainers = {}

In [None]:
all_models = [
    # Pytorch Hub
    # "alexnet",
    # "vgg16_bn",
    "resnet34",
    "googlenet",
    # Custom
    "cnn"
]

save_dir = pathlib.Path(MODEL_SAVE_PATH)
load_trainer = True

for model_name in all_models:
    print(f"Training {model_name}")

    if model_name in trainers:
        trainer = trainers[model_name]
        # move the model back to device
        trainer.model = trainer.model.to(device)
        trainer.optim.load_state_dict(trainer.optim.state_dict())

    else:
        model = get_model(model_name)().to(device)
        optimizer = Adam(model.parameters(), lr=1e-4)
        save_model = save_dir.joinpath(model.NAME).joinpath("last.pth")
        if save_model.exists():
            print(f"Loading model from {save_model}")
            state = torch.load(save_model)
            model.load_state_dict(state["state"])
            optimizer.load_state_dict(state["optim"])

        trainer = Trainer(save_dir, model, optimizer, turning_weight=5, epochs=1000)
        save_trainer = save_dir.joinpath(model.NAME).joinpath("trainer_log.npz")
        if load_trainer and save_trainer.exists():
            print(f"Loading trainer from {save_trainer}")
            trainer.load(save_trainer)
        trainers[model_name] = trainer
        del model, optimizer

    if trainer.i < trainer.epochs:
        try:
            sampler_train = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=multiprocessing.cpu_count(), persistent_workers=True)
            sampler_test = DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=multiprocessing.cpu_count(), persistent_workers=True)
            trainer.train(sampler_train, sampler_test)
        finally:
            # Move the model to CPU
            trainer.model = trainer.model.to('cpu')
            trainer.optim.load_state_dict(trainer.optim.state_dict())
            trainer.save(pathlib.Path(MODEL_SAVE_PATH).joinpath(trainer.model.NAME).joinpath("trainer_log.npz"))

            try:
                # Close iterator
                sampler_train._iterator._shutdown_workers()
                sampler_test._iterator._shutdown_workers()
            except:
                pass

In [None]:
# Reload
import importlib
import models
importlib.reload(models)
from models import get_model

In [None]:
# Clean cache
g = globals()
del_list = ["model", "trainer", "optimizer"] + [f"_{i}" for i in range(1000)] + [f"_i{i}" for i in range(1000)]
for i in del_list:
    if i in g:
        del g[i]

import gc
gc.collect()
gc.collect()
torch.cuda.empty_cache()

### Test

In [None]:
test_sampler = DataLoader(dataset, batch_size=12, shuffle=True)
test_iterator = iter(test_sampler)

In [None]:
trainer = trainers["cnn"]
model = trainer.model.to(device)
train_log = np.array(trainer.train_log)
validation_log = np.array(trainer.validation_log)

In [None]:
img, steering, throttle = next(test_iterator)
Y = torch.stack([steering, throttle], dim=1).type(torch.float32).to(device)
X = img.to(device).permute(0, 3, 1, 2) / 256
with torch.no_grad():
    Y_pred = model(X)

val_loss = loss_fn(Y[:, 0], Y[:, 1], Y_pred[:, 0], Y_pred[:, 1], throttle_weight=0.2)
print(val_loss)

plt.style.use("classic")
fig, axes = plt.subplots(2, 6, figsize=(40, 8))
axes = axes.flatten()
p_fn = lambda x: ','.join([f'{i: .4f}' for i in x])

for i in range(12):
    axes[i].imshow(img[i].cpu().numpy()[...,::-1])
    axes[i].axis("off")
    axes[i].set_title(f"Model:{p_fn(Y_pred[i].tolist())}\nTrue: {p_fn([steering[i], throttle[i]])}")

### Export

In [None]:
model_openbot = get_model_openbot(model.NAME)()
model_openbot.load_state_dict(model.state_dict())
dummy_input = torch.randn(1, 224, 224, 3)
torch.onnx.export(model_openbot, dummy_input, model.NAME + ".onnx", verbose=False, input_names=["img_input"])

### Plot Loss

In [None]:
plt.style.use("classic")
plt.figure(figsize=(25, 4))
plt.subplot(1, 3, 1)
plt.plot(train_log[:, 0], '.', markersize=1, color="black")
plt.yscale('log')

plt.subplot(1, 3, 2)
plt.plot(validation_log[:, 0], '-', markersize=3, color="black")
plt.yscale("log")

plt.subplot(1, 3, 3)
plt.plot(signal.convolve(train_log[:, 0], np.ones(100) / 100, 'valid'), '.', markersize=1, color="red")

### Plot Angle Metric

In [None]:
plt.style.use("classic")
plt.figure(figsize=(25, 4))
plt.subplot(1, 3, 1)
plt.plot(train_log[:, 1], '.', markersize=1, color="black")

plt.subplot(1, 3, 2)
plt.plot(validation_log[:, 1], '-', markersize=3, color="black")

plt.subplot(1, 3, 3)
plt.plot(signal.convolve(train_log[:, 1], np.ones(100) / 100, 'valid'), '.', markersize=1, color="red")

### Plot Direction Metric

In [None]:
plt.style.use("classic")
plt.figure(figsize=(25, 4))
plt.subplot(1, 3, 1)
plt.plot(train_log[:, 2], '.', markersize=1, color="black")

plt.subplot(1, 3, 2)
plt.plot(validation_log[:, 2], '-', markersize=3, color="black")

plt.subplot(1, 3, 3)
plt.plot(signal.convolve(train_log[:, 2], np.ones(100) / 100, 'valid'), '.', markersize=1, color="red")