In [None]:
import os
import json
import warnings
import numpy as np
from settings import IMAGE_HEIGHT, IMAGE_WIDTH, OUT_DIR

warnings.filterwarnings("ignore")

# Create output directory
if not os.path.exists(OUT_DIR):
    os.mkdir(OUT_DIR)

### Define the train function

In [None]:
import glob
from time import time
import torch
from tqdm import tqdm
import torch.nn as nn
from torch.utils.data import DataLoader
from core.models.nts_net import NTSModel
from core.loss import list_loss, ranking_loss
from torch.optim.lr_scheduler import MultiStepLR
import shutil

def train(train, val, n_classes, epochs, batch_size, hr, scheduler_gamma=0.5):

    # Identify device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup model output dir
    folder_name = "train_{}".format(int(time()))
    out_path = os.path.join(OUT_DIR, folder_name)
    os.mkdir(out_path)

    # Setup dataloader
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val, batch_size=batch_size, shuffle=True, num_workers=2)

    model = NTSModel(top_n=hr["proposal_num"], n_classes=n_classes, image_height=IMAGE_HEIGHT, image_width=IMAGE_WIDTH).to(device)
    criterion = torch.nn.CrossEntropyLoss()

    # Setup optimizers
    resnet_parameters = list(model.resnet.parameters())
    navigator_parameters = list(model.navigator.parameters())
    concat_parameters = list(model.concat_net.parameters())
    partcls_parameters = list(model.partcls_net.parameters())

    resnet_optim_params = {"lr": hr["resnet_lr"], "weight_decay": hr["resnet_weight_decay"], "momentum": hr["resnet_momentum"]}
    navigator_optim_params = {"lr": hr["navigator_lr"], "weight_decay": hr["navigator_weight_decay"], "momentum": hr["navigator_momentum"]}
    concat_optim_params = {"lr": hr["concat_lr"], "weight_decay": hr["concat_weight_decay"], "momentum": hr["concat_momentum"]}
    partcls_optim_params = {"lr": hr["partcls_lr"], "weight_decay": hr["partcls_weight_decay"], "momentum": hr["partcls_momentum"]}

    resnet_optimizer = torch.optim.SGD(resnet_parameters, **resnet_optim_params)
    navigator_optimizer = torch.optim.SGD(navigator_parameters, **navigator_optim_params)
    concat_optimizer = torch.optim.SGD(concat_parameters, **concat_optim_params)
    partcls_optimizer = torch.optim.SGD(partcls_parameters, **partcls_optim_params)

    # Setup learning rate scheduler
    scheduler_interval = [int(epochs*0.25), int(epochs*0.5), int(epochs*0.75)]

    schedulers = [MultiStepLR(resnet_optimizer, milestones=scheduler_interval, gamma=scheduler_gamma),
                MultiStepLR(navigator_optimizer, milestones=scheduler_interval, gamma=scheduler_gamma),
                MultiStepLR(concat_optimizer, milestones=scheduler_interval, gamma=scheduler_gamma),
                MultiStepLR(partcls_optimizer, milestones=scheduler_interval, gamma=scheduler_gamma)]

    model = nn.DataParallel(model)

    history = {
        "train_loss": [],
        "val_loss": [],
        "train_accuracy": [],
        "val_accuracy": [],
    }

    for epoch in range(epochs):

        for scheduler in schedulers:
            scheduler.step()

        epoch_loss = 0
        epoch_accuracy = 0
        epoch_val_loss = 0
        epoch_val_accuracy = 0
        with tqdm(total=len(train_loader)) as pbar:
            for i, (inputs, labels) in enumerate(train_loader):
                inputs, labels = inputs.to(device), labels.to(device)
                batch_size = inputs.size(0)

                resnet_optimizer.zero_grad()
                navigator_optimizer.zero_grad()
                concat_optimizer.zero_grad()
                partcls_optimizer.zero_grad()

                resnet_logits, concat_logits, part_logits, top_n_idxs, top_n_proba = model(inputs)
                
                # Losses
                resnet_loss = criterion(resnet_logits, labels)
                navigator_loss = list_loss(part_logits.view(batch_size * hr["proposal_num"], -1),
                                        labels.unsqueeze(1).repeat(1, hr["proposal_num"]).view(-1)).view(batch_size, hr["proposal_num"])
                concat_loss = criterion(concat_logits, labels)
                rank_loss = ranking_loss(top_n_proba, navigator_loss, proposal_num=hr["proposal_num"])
                partcls_loss = criterion(part_logits.view(batch_size * hr["proposal_num"], -1),
                                    labels.unsqueeze(1).repeat(1, hr["proposal_num"]).view(-1))
                
                loss = resnet_loss + concat_loss + rank_loss + partcls_loss
                loss.backward()

                resnet_optimizer.step()
                navigator_optimizer.step()
                concat_optimizer.step()
                partcls_optimizer.step()

                accuracy = (concat_logits.argmax(dim=1) == labels).float().mean()
                
                epoch_loss += concat_loss.item()
                epoch_accuracy += accuracy.item()

                pbar.set_postfix_str("Train loss: {:.4f}, Train accuracy: {:.4f}".format(epoch_loss / (i+1), epoch_accuracy / (i+1)))
                pbar.update(1)

        with tqdm(total=(len(val_loader))) as pbar:
            with torch.no_grad():
                for i, (inputs, labels) in enumerate(val_loader):
                    inputs, labels = inputs.to(device), labels.to(device)
                    batch_size = inputs.size(0)

                    _, concat_logits, _, _, _ = model(inputs)

                    concat_loss = criterion(concat_logits, labels)
                    

                    accuracy = (concat_logits.argmax(dim=1) == labels).float().mean()

                    epoch_val_loss += concat_loss.item()
                    epoch_val_accuracy += accuracy.item()


                    pbar.set_postfix_str("Val loss: {:.4f}, Val accuracy: {:.4f}".format(epoch_val_loss / (i+1), epoch_val_accuracy / (i+1)))
                    pbar.update(1)

        epoch_loss = epoch_loss/len(train_loader)
        epoch_val_loss = epoch_val_loss/len(val_loader)

        epoch_accuracy = epoch_accuracy/len(train_loader)
        epoch_val_accuracy = epoch_val_accuracy/len(val_loader)

        history["train_loss"].append(epoch_loss)
        history["val_loss"].append(epoch_val_loss)    

        history["train_accuracy"].append(epoch_accuracy)
        history["val_accuracy"].append(epoch_val_accuracy) 

        print(f"Epoch {epoch+1} - Loss: {epoch_loss:.4f} - Accuracy: {epoch_accuracy:.4f} - Val Loss: {epoch_val_loss:.4f} - Val Accuracy: {epoch_val_accuracy:.4f}")

        torch.save({
            "train_accuracy": history["train_accuracy"][-1],
            "val_accuracy": history["val_accuracy"][-1],
            "proposal_num": hr["proposal_num"],
            "n_classes": n_classes,
            "state_dict": model.module.state_dict(),
        }, os.path.join(out_path, f"epoch_{epoch+1}.ckpt"))


    # Remove all but the best checkpoints
    best_epoch_idx = np.argmax(history["val_accuracy"])
    os.rename(os.path.join(out_path, f"epoch_{best_epoch_idx + 1}.ckpt"), os.path.join(out_path, "model.ckpt"))

    for file in glob.glob(f"{out_path}/**"):
        if not file.endswith('model.ckpt'):    
                os.remove(file)

    # Update latest model weights
    src = os.path.join(out_path, "model.ckpt")
    dst = os.path.join(OUT_DIR, "latest_model.ckpt")
    shutil.copyfile(src, dst)

    # Report best
    print("Best epoch:", best_epoch_idx+1)
    print("Best val accuracy:", history["val_accuracy"][best_epoch_idx])

    return history

### Load the train data

In [None]:
from torchvision.transforms import Compose, Resize, ToTensor, RandomHorizontalFlip, RandomCrop
from torch.utils.data import ConcatDataset
from torchvision.datasets import FGVCAircraft
from PIL import Image

transform = Compose([
  Resize((IMAGE_HEIGHT, IMAGE_WIDTH), Image.BILINEAR),
  ToTensor(),
])

augment_transform = Compose([
    Resize((int(IMAGE_HEIGHT * 1.5), int(IMAGE_WIDTH * 1.5)), Image.BILINEAR),
    RandomCrop((IMAGE_HEIGHT, IMAGE_WIDTH)),
    RandomHorizontalFlip(),
    ToTensor(),
])

val_transform = Compose([
  Resize((IMAGE_HEIGHT, IMAGE_WIDTH), Image.BILINEAR),
  ToTensor(),
])

# Load data
train_data = FGVCAircraft(root="data", split="train", transform=transform, download=True)
n_classes = len(train_data.classes)
augmented_data = FGVCAircraft(root="data", split="train", transform=augment_transform, download=True)

train_data = ConcatDataset([train_data, augmented_data])
val_data = FGVCAircraft(root="data", split="val", transform=val_transform, download=True)

print("Train data size:", len(train_data))
print("Val data size:", len(val_data))

### Train the model

In [None]:
NUM_EPOCHS = 20
BATCH_SIZE = 8

# Load hyperparameters
with open("hyperparameters.json", "r") as f:
    hr = json.load(f)

history = train(train_data, val_data, n_classes=n_classes, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, hr=hr)

### Plot train history

In [None]:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
xticks = np.arange(1, len(history["train_accuracy"]) + 1)

ax[0].plot(history["train_accuracy"], label="train")
ax[0].plot(history["val_accuracy"], label="val")
ax[0].set_xticks(np.arange(len(history["train_accuracy"])), xticks)
ax[0].set_xlabel("Epoch")
ax[0].set_ylabel("Accuracy")
ax[0].legend()
ax[0].set_title("Train vs. Val accuracy")

ax[1].plot(history["train_loss"], label="train")
ax[1].plot(history["val_loss"], label="val")
ax[1].set_xticks(np.arange(len(history["train_accuracy"])), xticks)
ax[1].set_xlabel("Epoch")
ax[1].set_ylabel("Loss")
ax[1].legend()
ax[1].set_title("Train vs. Val Loss")
plt.show()