In [None]:
from __future__ import annotations

import json
import math
import os
import pprint
import random
import shutil
import time
import warnings
from datetime import datetime
from pathlib import Path
from typing import List

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms.v2.functional as VF
from sklearn.model_selection import train_test_split
from torch.optim import SGD
from torch.optim.lr_scheduler import ExponentialLR, StepLR
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision.datasets import CIFAR100, ImageFolder
from torchvision.models import EfficientNet_B0_Weights, efficientnet_b0

warnings.filterwarnings("ignore")
%matplotlib inline

In [None]:
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
# device = "cpu"
print(f"using {device=}")

In [None]:
SEED = 2147483647
torch.manual_seed(SEED)
random.seed(SEED)

In [None]:
# dataset = "cifar"
dataset = "tiny"

### Data

In [None]:
def download_dataset_tiny():
    data_dir = Path("data/tiny/tiny-imagenet-200")
    train_dir = data_dir / "train"
    val_dir = data_dir / "val"

    if not data_dir.exists():
        print("download start...")
        !wget -P 'data/tiny' http://cs231n.stanford.edu/tiny-imagenet-200.zip
        print("unzip start...")
        !unzip -qq 'data/tiny/tiny-imagenet-200.zip' -d "data/tiny"
        print("unzip end...")

        with open(data_dir / "val" / "val_annotations.txt") as fin:
            data = fin.readlines()
            im2class = {}
            for line in data:
                words = line.split("\t")
                fname, imclass = words[0], words[1]
                im2class[fname] = imclass

        print("rename started...")
        for img, imgclass in im2class.items():
            new_dir: Path = val_dir / imgclass
            if not new_dir.exists():
                new_dir.mkdir(parents=True)
            if not (new_dir / img).exists():
                shutil.move(val_dir / "images" / img, new_dir / img)

        (val_dir / "images").rmdir()

    return train_dir, val_dir

In [None]:
train_test_transform = EfficientNet_B0_Weights.IMAGENET1K_V1.transforms() # shared, no augmentations used

In [None]:
if dataset == "tiny":
    train_dir, val_dir = download_dataset_tiny()
    train_dataset_full = ImageFolder(train_dir, transform=train_test_transform)
    test_dataset = ImageFolder(val_dir, transform=train_test_transform)
    num_classes = 200
elif dataset == "cifar":
    path = "./data/cifar100"
    train_dataset_full = CIFAR100(
        root=path, train=True, transform=train_test_transform, download=True
    )
    test_dataset = CIFAR100(
        root=path, train=False, transform=train_test_transform, download=True
    )
    num_classes = 100

In [None]:
train_indices, val_indices = train_test_split(
    range(len(train_dataset_full)),
    random_state=SEED,
    train_size=0.9,
    # train_size=0.01,
    # train_size = 3e-3,
    # test_size=0.01,
    stratify=train_dataset_full.targets,
)

train_dataset = Subset(train_dataset_full, train_indices)
val_dataset = Subset(train_dataset_full, val_indices)
len(train_dataset), len(val_dataset), len(test_dataset)

In [None]:
train_targets = [train_dataset_full.targets[i] for i in train_dataset.indices]
plt.hist(train_targets, bins=num_classes)
plt.show()
val_targets = [train_dataset_full.targets[i] for i in val_dataset.indices]
plt.hist(val_targets, bins=num_classes)
plt.show()
plt.hist(test_dataset.targets, bins=num_classes)
plt.show();

In [None]:
bs = 64
num_workers = 2

# state changes with all 3 loaders reading
data_loader_generator = torch.Generator().manual_seed(SEED)

train_loader = DataLoader(
    train_dataset,
    batch_size=bs,
    num_workers=num_workers,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
    generator=data_loader_generator,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=bs,
    num_workers=num_workers,
    shuffle=False,
    drop_last=False,
    pin_memory=True,
    generator=data_loader_generator,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=bs,
    num_workers=num_workers,
    shuffle=False,
    drop_last=False,
    pin_memory=True,
    generator=data_loader_generator,
)

In [None]:
print(train_dataset.indices[:10])
print(val_dataset.indices[:10])

In [None]:
len(set(train_dataset.indices).union(val_dataset.indices))

### Model definition

In [None]:
def get_model(num_classes):
    model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
    model.classifier[1] = torch.nn.Linear(1280, num_classes)
    return model


model = get_model(num_classes)
model = model.to(device)
print(model)

In [None]:
if dataset == "cifar":
    # https://github.com/pytorch/ignite/blob/master/examples/notebooks/EfficientNet_Cifar100_finetuning.ipynb
    # not sure where they got it from but works well, e5(0 indexed), 0.835 val acc, 0.8354 test
    lr = 1e-2
    optimizer = SGD(
        params=[
            {"params": model.features[0].parameters(), "lr": lr * 0.1},
            {"params": model.features[1:].parameters(), "lr": lr * 0.2},
            {"params": model.classifier.parameters(), "lr": lr},
        ],
        momentum=0.9,
        weight_decay=0.001,
        nesterov=True,
    )

    scheduler = ExponentialLR(optimizer, gamma=0.975)

elif dataset == "tiny":
    # https://arxiv.org/pdf/1912.08136v2 section 5.1.5
    lr = 1e-3
    optimizer = SGD(
        params=model.parameters(),
        lr=lr,
        momentum=0.9,
        weight_decay=0.0001,
        nesterov=True,  # not clear from paper, reference nesterov in other places as inspiration...
    )
    scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
else:
    raise ValueError("no such dataset config!")

### Training loop

In [None]:
def get_acc(y_pred: torch.Tensor, y_true: torch.Tensor, bs: int) -> float:
    return ((y_pred == y_true).sum() / bs).item()

In [None]:
def run_validation(
    model: torch.nn.Module, val_loader: DataLoader
) -> tuple[float, float]:
    """_summary_

    Args:
        model (torch.nn.Module): _description_
        val_loader (DataLoader): _description_

    Returns:
        tuple[float, float]: loss, accuracy
    """
    print("--------- running validation ---------")
    pre_validation_time = time.time()
    total_loss, total_correct = 0.0, 0.0
    model.eval()

    end = time.time()
    with torch.no_grad():
        for batchi, (X, Y_true) in enumerate(val_loader, start=1):
            data_loading_duration_ms = (time.time() - end) * 1e3

            pre_forward_time = time.time()

            X, Y_true = X.to(device), Y_true.to(device)
            output = model(X)  # (bs, num_classes)
            loss = F.cross_entropy(output, Y_true)

            post_forward_time = time.time()

            loss = loss.item()
            Y_pred = output.argmax(dim=1)  # (bs, )

            correct = (Y_pred == Y_true).sum().item()
            batch_acc = correct / len(Y_true)
            total_loss += loss
            total_correct += correct

            metric_calc_duration_ms = (time.time() - post_forward_time) * 1e3
            forward_duration_ms = (post_forward_time - pre_forward_time) * 1e3
            batch_time = (
                data_loading_duration_ms + forward_duration_ms + metric_calc_duration_ms
            )
            if batchi < 5 or batchi % 20 == 0:
                print(
                    "validation {:>4}/{} | val_loss={:>7.4f} | val_acc={:>7.3f} | total time (ms) {:>8.2f} | times (ms) :: data load {:>7.2f}, forward {:>7.2f}, metric calc {:>7.2f}".format(
                        batchi,
                        len(val_loader),
                        loss,
                        batch_acc,
                        batch_time,
                        data_loading_duration_ms,
                        forward_duration_ms,
                        metric_calc_duration_ms,
                    )
                )
            end = time.time()

        total_acc = (
            total_correct / len(val_loader.dataset)
        )
        total_loss /= len(val_loader)
        print(f"Total validation set acc: {total_acc:.2f}")
        print(f"Val loss average per batch: {total_loss:.4f}")

        validation_time_sec = time.time() - pre_validation_time
        print(f"Total validation time: {validation_time_sec:2f} sec")

    return total_loss, total_acc

In [None]:
epochs = 30
early_stopping = False  # really stop

In [None]:
config = {
    "optimizer_defaults": optimizer.defaults,
    "scheduler": {
        "type": str(type(scheduler)),
        "base_lr": scheduler.base_lrs,
        "gamma": scheduler.gamma,
    },
    "bs": bs,
    "data": {
        "dataset": dataset,
        "train": len(train_dataset),
        "val": len(val_dataset),
        "test": len(test_dataset),
    },
}
pprint.pprint(config)

In [None]:
# log_folder = Path(f"drive/MyDrive/colab_outputs/{dataset}")

log_folder = Path(f"checkpoints/models/{dataset}")
Path(log_folder).mkdir(exist_ok=True)

checkpoint_path = log_folder / Path(f"{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}")
checkpoint_path.mkdir(parents=True)

print(f"Saving all logs to: {checkpoint_path}")


best_val_epoch, best_val_loss = -1, torch.inf
train_acc_buff, train_loss_buff, val_acc_buff, val_loss_buff = [], [], [], []

test_loss_buff, test_acc_buff = [], []


for epoch in range(epochs):
    epoch_start = time.time()

    print(f"\n====== epoch {epoch} =========\n")

    epoch_train_loss, epoch_train_correct = 0.0, 0.0

    batch_times = []
    end = time.time()
    for batchi, (X, Y_true) in enumerate(train_loader, start=1):
        data_loading_duration_ms = (time.time() - end) * 1e3
        pre_forward_time = time.time()

        model.train()
        X, Y_true = X.to(device), Y_true.to(device)

        output = model(X)  # (bs, classes)
        loss = F.cross_entropy(output, Y_true)

        post_forward_time = time.time()

        optimizer.zero_grad()
        loss.backward()

        post_backward_time = time.time()
        optimizer.step()
        post_step_time = time.time()

        forward_duration_ms = (post_forward_time - pre_forward_time) * 1e3
        backward_duration_ms = (post_backward_time - post_forward_time) * 1e3
        step_duration_ms = (post_step_time - post_backward_time) * 1e3

        with torch.no_grad():
            Y_pred = output.argmax(dim=1)  # (bs,)
            correct = (Y_pred == Y_true).sum().item()
            acc = correct / len(Y_pred)

            epoch_train_loss += loss.item()
            epoch_train_correct += correct
            metric_update_duration_ms = (time.time() - post_step_time) * 1e3

            batch_time = (
                data_loading_duration_ms
                + forward_duration_ms
                + backward_duration_ms
                + step_duration_ms
                + metric_update_duration_ms
            )
            batch_times.append(batch_time)

            if batchi < 5 or batchi % 20 == 0:
                print(
                    "batch {:>4}/{} | loss = {:>7.4f} | acc = {:>7.4f} | total time (ms) {:>10.2f} | times (ms) :: data load {:>7.2f}, forward {:>7.2f}, backward {:>7.2f}, step {:>7.2f}, metrics update {:>7.2f} ".format(
                        batchi,
                        len(train_loader),
                        loss.item(),
                        acc,
                        batch_time,
                        data_loading_duration_ms,
                        forward_duration_ms,
                        backward_duration_ms,
                        step_duration_ms,
                        metric_update_duration_ms,
                    )
                )

        end = time.time()

    if scheduler:
        scheduler.step()

    epoch_train_loss /= len(train_loader)
    epoch_train_acc = epoch_train_correct / len(train_loader.dataset)

    val_loss, val_acc = run_validation(model, val_loader)

    if early_stopping and val_loss > best_val_loss:
        print(
            f"=== EARLY STOPPING === new loss: {val_loss}, previous best: {best_val_loss} at epoch {best_val_epoch}"
        )
        break

    model_save_time_ms = 0.0
    if val_loss < best_val_loss:
        print("New best result found!")
        best_val_loss = val_loss
        best_val_epoch = epoch

        pre_model_save_time = time.time()
        torch.save(
            model.state_dict(),
            checkpoint_path
            / f"model_e{epoch}_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}_vacc_{str(round(val_acc, 2)).replace('.', 'p')}_vl_{str(round(val_loss, 2)).replace('.', 'p')}.pt",
        )
        model_save_time_ms = (time.time() - pre_model_save_time) * 1e3

    # update accumulators
    train_acc_buff.append(epoch_train_acc)
    train_loss_buff.append(epoch_train_loss)
    val_acc_buff.append(val_acc)
    val_loss_buff.append(val_loss)

    epoch_duration_sec = time.time() - epoch_start
    print(
        "Epoch summary: \ntrain :: loss {:>7.4f} |  acc {:>7.4f} \nval :: loss {:>7.4f} |  acc {:>7.4f} | best val loss so far {:>7.4f} at epoch {:>2}".format(
            epoch_train_loss,
            epoch_train_acc,
            val_loss,
            val_acc,
            best_val_loss,
            best_val_epoch,
        )
    )
    print(
        f"Epoch time (sec) {epoch_duration_sec:.2f} ({epoch_duration_sec / 60:.2f} min), model save time (ms): {model_save_time_ms:.2f} , average batch time (ms): {sum(batch_times) / len(batch_times) :<2f}"
    )

    print("++++++ test set eval +++++")
    test_loss, test_acc = run_validation(model, test_loader)
    test_loss_buff.append(test_loss)
    test_acc_buff.append(test_acc)
    print("Test result: loss={:.3f}, acc={:.3f}".format(test_loss, test_acc))

In [None]:
fig = plt.figure()

plt.plot(train_loss_buff, label="train loss")
plt.plot(val_loss_buff, label="val loss")
plt.plot(test_loss_buff, label="test loss")
plt.legend()
plt.xlabel("epoch")
plt.xticks(list(range(len(train_loss_buff))))
plt.show()
fig.savefig(checkpoint_path / "losses.png")

In [None]:
fig = plt.figure()
plt.plot(train_acc_buff, label="train acc")
plt.plot(val_acc_buff, label="val acc")
plt.plot(test_acc_buff, label="test acc")
plt.legend()
plt.xlabel("epoch")
plt.xticks(list(range(len(train_acc_buff))))
plt.show()
fig.savefig(checkpoint_path / "acc.png")

In [None]:
with open(checkpoint_path / "graphdata.json", "w") as fout:
    json.dump(
        {
            **{
                "losses": {
                    "train": train_loss_buff,
                    "val": val_loss_buff,
                    "test": test_loss_buff,
                },
                "acc": {
                    "train": train_acc_buff,
                    "val": val_acc_buff,
                    "test": test_acc_buff,
                },
            },
            **config,
        },
        fout,
    )