# Train Models

In [None]:
import subprocess
import json
import os
import torch
import copy
import onnx
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as tt
from tqdm.notebook import tqdm
from torch.optim import RMSprop, Adam, Adadelta
from functools import reduce

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
!pip install wandb -Uq
import wandb
wandb.login()

In [None]:
configs = []

# MODEL_A
MODEL_A_config = dict(
    epochs=100,
    classes=10,
    batch_size=32,
    learning_rate=0.001,
    weight_decay=0,
    optimizer="Adam",
    dataset="MNIST",
    architecture="MODEL_A",
    fake_quantization_const=0,
    fake_quantize_weights=False,
    target_crt_base_size=8,
    test_size=10000,
    quantize_set_size=10,
)
configs.append(MODEL_A_config)

# MODEL_B_POOL_REPL
MODEL_B_POOL_REPL_config = dict(
    epochs=10,
    classes=10,
    batch_size=32,
    learning_rate=0.001,
    weight_decay=0,
    optimizer="Adam",
    dataset="MNIST",
    architecture="MODEL_B_POOL_REPL",
    fake_quantization_const=0,
    fake_quantize_weights=False,
    target_crt_base_size=9,
    test_size=10000,
    quantize_set_size=100,
)
configs.append(MODEL_B_POOL_REPL_config)

# MODEL_C
MODEL_C_config = dict(
    epochs=100,
    classes=10,
    batch_size=32,
    learning_rate=0.001,
    weight_decay=0,
    optimizer="Adam",
    dataset="MNIST",
    architecture="MODEL_C",
    fake_quantization_const=0,
    fake_quantize_weights=False,
    target_crt_base_size=9,
    test_size=10000,
    quantize_set_size=100,
)
configs.append(MODEL_C_config)

# MODEL_D_POOL_REPL
MODEL_D_POOL_REPL_config = dict(
    epochs=100,
    classes=10,
    batch_size=32,
    learning_rate=0.001,
    weight_decay=0,
    optimizer="Adam",
    dataset="MNIST",
    architecture="MODEL_D_POOL_REPL",
    fake_quantization_const=0,
    fake_quantize_weights=False,
    target_crt_base_size=8,
    test_size=10000,
    quantize_set_size=100,
)
configs.append(MODEL_D_POOL_REPL_config)

# MODEL_E_30
MODEL_E_30_config = dict(
    epochs=100,
    classes=10,
    batch_size=32,
    learning_rate=0.001,
    weight_decay=0,
    optimizer="Adam",
    dataset="MNIST",
    architecture="MODEL_E_30",
    fake_quantization_const=0,
    fake_quantize_weights=False,
    target_crt_base_size=5,
    test_size=10000,
    quantize_set_size=100,
)
configs.append(MODEL_E_30_config)

# MODEL_E_100
MODEL_E_100_config = dict(
    epochs=100,
    classes=10,
    batch_size=32,
    learning_rate=0.001,
    weight_decay=0,
    optimizer="Adam",
    dataset="MNIST",
    architecture="MODEL_E_100",
    fake_quantization_const=0,
    fake_quantize_weights=False,
    target_crt_base_size=5,
    test_size=10000,
    quantize_set_size=100,
)
configs.append(MODEL_E_100_config)

# MODEL_F_MINIONN_POOL_REPL
MODEL_F_MINIONN_POOL_REPL_config = dict(
    epochs=100,
    classes=10,
    batch_size=128,
    learning_rate=0.001,
    weight_decay=1e-5,
    optimizer="Adam",
    dataset="CIFAR10",
    architecture="MODEL_F_MINIONN_POOL_REPL",
    fake_quantization_const=0,
    fake_quantize_weights=False,
    target_crt_base_size=9,
    test_size=10000,
    quantize_set_size=10,
)
configs.append(MODEL_F_MINIONN_POOL_REPL_config)

# MODEL_F_GNNP_POOL_REPL
MODEL_F_GNNP_POOL_REPL_config = dict(
    epochs=100,
    classes=10,
    batch_size=64,
    learning_rate=0.001,
    weight_decay=1e-5,
    optimizer="Adam",
    dataset="CIFAR10",
    architecture="MODEL_F_GNNP_POOL_REPL",
    fake_quantization_const=0,
    fake_quantize_weights=False,
    target_crt_base_size=9,
    test_size=10000,
    quantize_set_size=10,
)
configs.append(MODEL_F_GNNP_POOL_REPL_config)

In [None]:
def model_pipeline(hyperparameters):
    with wandb.init(project="dash", config=hyperparameters, reinit=True):
        config = wandb.config

        model, train_loader, test_loader, val_loader, criterion, optimizer = make(
            config
        )
        print(model)
        best_model = train(
            model, train_loader, test_loader, val_loader, criterion, optimizer, config
        )

    return best_model

In [None]:
def make(config):
    train, test, val = get_data(config.dataset)
    train_loader = DataLoader(
        train,
        batch_size=config.batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=8,
    )
    test_loader = DataLoader(
        test, batch_size=config.batch_size, pin_memory=True, num_workers=8
    )
    val_loader = DataLoader(
        val, batch_size=config.batch_size, pin_memory=True, num_workers=8
    )

    architecure = globals()[config.architecture]
    model = architecure(config.fake_quantization_const).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = globals()[config.optimizer](
        model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay
    )

    return model, train_loader, test_loader, val_loader, criterion, optimizer

In [None]:
def get_data(dataset_name):
    if dataset_name == "MNIST":
        training_data = datasets.MNIST(
            root="../../data",
            train=True,
            download=True,
            transform=tt.ToTensor(),
        )

        # Download test data from open datasets.
        test_data = datasets.MNIST(
            root="../../data",
            train=False,
            download=True,
            transform=tt.ToTensor(),
        )

        training_data, val_data = torch.utils.data.random_split(
            training_data, [55000, 5000]
        )

        return training_data, test_data, val_data

    if dataset_name == "CIFAR10":
        stats = ((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
        train_tfms = tt.Compose(
            [
                tt.RandomCrop(32, padding=4, padding_mode="reflect"),
                tt.RandomHorizontalFlip(),
                tt.ToTensor(),
                tt.Normalize(*stats, inplace=True),
            ]
        )
        test_tfms = tt.Compose([tt.ToTensor(), tt.Normalize(*stats)])

        training_data = datasets.CIFAR10(
            root="../../data",
            train=True,
            download=True,
            transform=train_tfms,
        )

        # Download test data from open datasets.
        test_data = datasets.CIFAR10(
            root="../../data",
            train=False,
            download=True,
            transform=test_tfms,
        )

        training_data, val_data = torch.utils.data.random_split(
            training_data, [45000, 5000]
        )

        return training_data, test_data, val_data

In [None]:
class FakeQuantization(nn.Module):
    def __init__(self, fake_quantization_const):
        super(FakeQuantization, self).__init__()
        self.fake_quantization_const = fake_quantization_const

    def forward(self, x):
        return (
            torch.round(x / self.fake_quantization_const) * self.fake_quantization_const
        )

    def backward(self, x):
        return x

In [None]:
class MODEL_A(nn.Module):
    def __init__(self, fake_quantization_const=0):
        super(MODEL_A, self).__init__()
        self.layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 10),
        )

    def forward(self, x):
        return self.layers(x)


class MODEL_B_POOL_REPL(nn.Module):
    def __init__(self, fake_quantization_const=0):
        super(MODEL_B_POOL_REPL, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(1, 5, kernel_size=5, stride=1),
            nn.ReLU(),
            nn.Conv2d(5, 5, kernel_size=3, stride=3),
            nn.ReLU(),
            nn.Conv2d(5, 10, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Conv2d(10, 10, kernel_size=3, stride=3),
            nn.Dropout(0.25),
            nn.Flatten(),
            nn.Linear(2 * 2 * 10, 100),  # upscaling?!
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(100, 10),
        )

    def forward(self, x):
        return self.layers(x)


class MODEL_C(nn.Module):
    def __init__(self, fake_quantization_const=0):
        super(MODEL_C, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(1, 5, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(13 * 13 * 5, 100),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(100, 10),
        )

    def forward(self, x):
        return self.layers(x)


class MODEL_D_POOL_REPL(nn.Module):
    def __init__(self, fake_quantization_const=0):
        super(MODEL_D_POOL_REPL, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=6, stride=2),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=6, stride=2),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Flatten(),
            nn.Linear(256, 100),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(100, 10),
        )

    def forward(self, x):
        return self.layers(x)


class MODEL_E_30(nn.Module):
    def __init__(self, fake_quantization_const=0):
        super(MODEL_E_30, self).__init__()
        self.layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 30),
            nn.Tanh(),
            nn.Dropout(0.2),
            nn.Linear(30, 10),
        )

    def forward(self, x):
        return self.layers(x)


class MODEL_E_100(nn.Module):
    def __init__(self, fake_quantization_const=0):
        super(MODEL_E_100, self).__init__()
        self.layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 100),
            nn.Tanh(),
            nn.Dropout(0.2),
            nn.Linear(100, 10),
        )

    def forward(self, x):
        return self.layers(x)


class MODEL_F_MINIONN_POOL_REPL(nn.Module):
    def __init__(self, fake_quantization_const=0):
        super(MODEL_F_MINIONN_POOL_REPL, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3),  # 1
            nn.ReLU(),  # 2
            nn.Conv2d(64, 64, kernel_size=3),  # 3
            nn.ReLU(),  # 4
            nn.Conv2d(64, 64, kernel_size=2, stride=2),  # replaces maxpool
            nn.Dropout(0.25),
            nn.Conv2d(64, 64, kernel_size=3),  # 6
            nn.ReLU(),  # 7
            nn.Conv2d(64, 64, kernel_size=3),  # 8
            nn.ReLU(),  # 9
            nn.Conv2d(64, 64, kernel_size=2, stride=2),  # replaces maxpool
            nn.Dropout(0.25),
            nn.Conv2d(64, 64, kernel_size=3),  # 11
            nn.ReLU(),  # 12
            nn.Conv2d(64, 64, kernel_size=1),  # 13
            nn.ReLU(),  # 14
            nn.Conv2d(64, 16, kernel_size=1),  # 15
            nn.ReLU(),  # 16
            nn.Flatten(),
            nn.Dropout(0.5),
            nn.Linear(16 * 3 * 3, 10),  # 17
        )

    def forward(self, x):
        return self.layers(x)


class MODEL_F_GNNP_POOL_REPL(nn.Module):
    def __init__(self, fake_quantization_const=0):
        super(MODEL_F_GNNP_POOL_REPL, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=2, stride=2),  # replaces maxpool
            nn.Dropout(0.25),
            nn.Conv2d(32, 64, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=2, stride=2),  # replaces maxpool
            nn.Dropout(0.25),
            nn.Conv2d(64, 128, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(128, 10),
        )

    def forward(self, x):
        return self.layers(x)

In [None]:
def quantize_model(model, fake_quantization_const):
    for p in model.parameters():
        p.data = torch.round(p.data / fake_quantization_const) * fake_quantization_const

In [None]:
def train(model, train_loader, test_loader, val_loader, criterion, optimizer, config):
    wandb.watch(model, criterion, log="all", log_freq=10)

    model.train()

    size = len(train_loader.dataset)

    best_model = copy.deepcopy(model)
    best_model_val_loss, best_model_val_acc = test(best_model, val_loader, criterion)
    example_cnt = 0

    for epoch in tqdm(range(config.epochs)):
        if epoch > 75:
            for param_group in optimizer.param_groups:
                param_group["lr"] = 0.0005
        if epoch > 100:
            for param_group in optimizer.param_groups:
                param_group["lr"] = 0.0003

        train_loss, correct = 0, 0
        for _, (X, y) in enumerate(train_loader):
            X, y = X.to(device), y.to(device)

            example_cnt += len(X)

            # Quantize parameter during training
            if config.fake_quantization_const != 0 and config.fake_quantize_weights:
                quantize_model(model, config.fake_quantization_const)

            # Compute prediction error
            pred = model(X)
            loss = criterion(pred, y)
            train_loss = loss.item()

            # correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            _, pred_label = torch.max(pred.data, 1)
            correct += (pred_label == y).sum().item()

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Report training_loss
            wandb.log({"train": {"loss": train_loss}}, step=example_cnt)

        # Report train accuracy and validation metrcis
        train_acc = correct / size
        if config.fake_quantization_const != 0 and config.fake_quantize_weights:
            quantize_model(model, config.fake_quantization_const)

        val_loss, val_acc = test(model, val_loader, criterion)
        wandb.log(
            {
                "train": {"acc": train_acc},
                "val": {"loss": val_loss, "acc": val_acc},
                "epoch": epoch,
            },
            step=example_cnt,
        )

        # Save best model with respect to validation loss
        if val_loss < best_model_val_loss:
            best_model = copy.deepcopy(model)
            best_model_val_loss = val_loss
            best_model_val_acc = val_acc
            best_model_train_loss = train_loss
            best_model_train_acc = train_acc
            wandb.run.summary["best_model_val_loss"] = best_model_val_loss
            wandb.run.summary["best_model_val_acc"] = best_model_val_acc
            wandb.run.summary["best_model_val_loss"] = best_model_train_loss
            wandb.run.summary["best_model_val_acc"] = best_model_train_acc
            wandb.run.summary["best_epoch"] = epoch

    # Get test metrics
    test_loss, test_acc = test(best_model, test_loader, criterion)
    wandb.run.summary["best_model_test_loss"] = test_loss
    wandb.run.summary["best_model_test_acc"] = test_acc
    # Serialize model
    serialize_model(best_model, train_loader)

    return best_model

In [None]:
def test(model, test_loader, criterion):
    model.eval()
    test_loss, correct = 0, 0

    size = len(test_loader.dataset)
    num_batches = len(test_loader)

    with torch.no_grad():
        for X, y in test_loader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += criterion(pred, y).item()

            _, pred_label = torch.max(pred.data, 1)
            correct += (pred_label == y).sum().item()

    test_loss /= num_batches
    test_accuracy = correct / size
    return test_loss, test_accuracy

In [None]:
def serialize_model(model, loader):
    dir_path = "../trained_models/"
    file_name = model.__class__.__name__ + ".onnx"
    model_path = dir_path + file_name

    images, labels = next(iter(loader))
    images = images.to(device)

    input_names = ["actual_input_1"]
    output_names = ["output1"]
    torch.onnx.export(
        model,
        images,
        file_name,
        verbose=True,
        input_names=input_names,
        output_names=output_names,
        operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH,
    )
    # Check that the model is well formed
    # onnx_model = onnx.load(file_name)
    # onnx.checker.check_model(onnx_model)
    wandb.save(file_name)
    os.popen(f"cp {file_name} {model_path}")

In [None]:
for config in configs:
    best_model = model_pipeline(config)