In [None]:
!git init
!git remote add origin https://github.com/amin-rami/computer-vision-playground.git
!git fetch
!git checkout main
!git pull

In [None]:
import torch
import models
import training
from utils.data import FastDataset
from functools import partial
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader

# Data

In [None]:
num_classes = 10
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def one_hot(label, n_classes):
    y = torch.zeros((n_classes,))
    y[label] = 1
    return y

train_transform = transforms.Compose(
    [
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

train_post_transform = transforms.Compose(
    [
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    ]
)

test_transform = transforms.Compose(
    [
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

target_transform = transforms.Compose(
    [
        partial(one_hot, n_classes=num_classes),
    ]
)

train_data = datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
)
test_data = datasets.CIFAR10(
    root="data",
    train=False,
    download=True,
)

train_data = FastDataset(train_data, transform=train_transform, post_transform=train_post_transform, target_transform=target_transform)
test_data = FastDataset(test_data, transform=test_transform, target_transform=target_transform)

In [None]:
train_dataloader = DataLoader(train_data, batch_size=1024, shuffle=True, num_workers=12)
test_dataloader = DataLoader(test_data, batch_size=2048, num_workers=12)

# Models

In [None]:
resnet, vggnet = models.from_config_file("configs/model.json")
resnet = resnet.to(device)
vggnet = vggnet.to(device)

# ResNet

In [None]:
resnet_optimizer = torch.optim.Adam(resnet.parameters())
resnet_lr_scheduler = torch.optim.lr_scheduler.StepLR(resnet_optimizer, 30, 0.5)
resnet_loop = training.TrainLoop(
    model=resnet,
    optimizer=resnet_optimizer,
    loss_fn=torch.nn.CrossEntropyLoss(),
    train_dataloader=train_dataloader,
    epoches=120,
    device=device,
    lr_scheduler=resnet_lr_scheduler,
    test_every=5,
    val_dataloader=test_dataloader,
    save_every=5,
    root="/results/resnet"
)

In [None]:
resnet_loop.train()

# VGGNet

In [None]:
vggnet_optimizer = torch.optim.Adam(vggnet.parameters())
vggnet_lr_scheduler = torch.optim.lr_scheduler.StepLR(vggnet_optimizer, 30, 0.5)
vggnet_loop = training.TrainLoop(
    model=vggnet,
    optimizer=vggnet_optimizer,
    loss_fn=torch.nn.CrossEntropyLoss(),
    train_dataloader=train_dataloader,
    epoches=120,
    device=device,
    lr_scheduler=vggnet_lr_scheduler,
    test_every=5,
    val_dataloader=test_dataloader,
    save_every=10,
    root="/results/vggnet"
)


In [None]:
vggnet_loop.train()