In [None]:
import pathlib
import torch
import time
from torch.utils.data import DataLoader, Dataset
from torchvision.models import resnet34
import torchvision
import PIL.Image
import pandas as pd
import random
import os
import numpy as np

In [None]:
cutmix_prob = 0.5

In [None]:

# copy the function to make a standalone notebook
def seed_everything(seed):
    '''
    Fixes the class-to-task assignments and most other sources of randomness, except CUDA training aspects.
    '''
    # Avoid all sorts of randomness for better replication
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True # An exemption for speed :P


def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2


def cutmix_data(x, y, alpha=1.0, cutmix_prob=0.5):
    assert(alpha > 0)
    # generate mixed sample
    lam = np.random.beta(alpha, alpha)

    batch_size = x.size()[0]
    index = torch.randperm(batch_size)

    if torch.cuda.is_available():
        index = index.cuda()

    y_a, y_b = y, y[index]
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]

    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
    return x, y_a, y_b, lam

In [None]:
from torchvision.datasets import CIFAR100
from torch.utils.data import DataLoader

In [None]:
EPOCHS = 128
DEVICE = 'cuda'
DATA_DIR = pathlib.Path('.')
BATCH_SIZE = 256
pretraining_classes = range(50)
seed_everything(0)

In [None]:
train_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])

test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5071, 0.4867, 0.4408),(0.2675, 0.2565, 0.2761))
])


trainset_full = CIFAR100(
    train=True,
    transform=train_transforms,
    root=DATA_DIR,
    download=True
)
train_subset_idxs = [i for i in range(len(trainset_full)) if trainset_full.targets[i] in pretraining_classes]
trainset = torch.utils.data.Subset(trainset_full, train_subset_idxs)

testset_full = CIFAR100(
    train=False,
    transform=test_transforms,
    root=DATA_DIR,
    download=True
)
test_subset_idxs = [i for i in range(len(testset_full)) if testset_full.targets[i] in pretraining_classes]
testset = torch.utils.data.Subset(testset_full, test_subset_idxs)

trainloader = DataLoader(
    dataset=trainset,
    shuffle=True,
    batch_size=BATCH_SIZE,
    num_workers=4,
)
testloader = DataLoader(
    dataset=testset,
    shuffle=False,
    batch_size=BATCH_SIZE,
    num_workers=4,
)

In [None]:
model = resnet34().to(DEVICE)
model.fc = torch.nn.Linear(in_features=512, out_features=100, device=DEVICE)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=0.01,
    momentum=0.9,
    weight_decay=1e-4,
)


In [None]:
def test_model(model, dataloader, criterion, device):
    with torch.no_grad():
        running_loss = 0
        running_acc = 0
        for data, target in dataloader:
            data = data.to(device)
            target = target.to(device)
            output = model(data)
            loss = criterion(output, target)

            prob = torch.nn.functional.softmax(output, dim=1)
            pred = torch.argmax(prob, dim=1)
            corr = torch.eq(pred, target)
            acc = (corr*1.0).mean()

            running_acc += acc.item()
            running_loss += loss.item()

    return (
        running_acc/len(dataloader),
        running_loss/len(dataloader),
    )

Train the model

In [None]:
for data, target in trainloader:
    assert (target<50).all().item()
for data, target in testloader:
    assert (target<50).all().item()

In [None]:

print('starting testing')
model.eval()
tstart = time.time()
acc, loss = test_model(
    model=model,
    dataloader=testloader,
    criterion=criterion,
    device=DEVICE
)
testtimes = [time.time()-tstart]
testlosses = [loss]
testaccs = [acc]
trainlosses = [loss]
trainaccs = [0]
traintimes = [0]
checkpoints = []
best_acc = 0
print(f"epoch {0}\t testloss {testlosses[-1]:.4f} \f testacc {testaccs[-1]:.4f}\t time {testtimes[-1]:.0f} s")

# training
for epoch in range(EPOCHS):
    running_loss = 0
    running_acc = 0
    tstart = time.time()
    model.train()
    for data, target in trainloader:
        data = data.to(DEVICE)
        target = target.to(DEVICE)

        do_cutmix =  np.random.rand(1) < cutmix_prob
        if do_cutmix:
            data, labels_a, labels_b, lam = cutmix_data(x=data, y=target, alpha=1)
            

        # output = model(embedding)
        output = model(data)
        
        if do_cutmix:
            loss = lam * criterion(output, labels_a) + (1 - lam) * criterion(output, labels_b)
        else:
            loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        prob = torch.nn.functional.softmax(output, dim=1)
        pred = torch.argmax(prob, dim=1)
        corr = torch.eq(pred, target)
        acc = (corr*1.0).mean()
        running_acc += acc.item()
        running_loss += loss.item()

    traintimes.append(time.time() - tstart)
    trainlosses.append(running_loss / len(trainloader))
    trainaccs.append(running_acc / len(trainloader))

    print(f"epoch {epoch+1}\t trainloss {trainlosses[-1]:.4f} \f trainacc {trainaccs[-1]:.4f}\t time {traintimes[-1]:.0f} s")

    # testing
    model.eval()
    running_loss = 0
    running_acc = 0
    tstart = time.time()
    acc, loss = test_model(
        model=model,
        dataloader=testloader,
        criterion=criterion,
        device=DEVICE
    )
    testtimes.append(time.time()-tstart)
    testlosses.append(loss)
    testaccs.append(acc)

    if acc > best_acc:
        path = f"resnet34_cifar100_classes0to49_e{epoch+1}.pt"
        torch.save(model.state_dict(), path)
        checkpoints.append(epoch+1)
        best_acc = acc
        print(f"saved checkpoint: {path}")

    print(f"epoch {epoch+1}\t testloss {testlosses[-1]:.4f} \t testacc {testaccs[-1]:.4f}\t time {testtimes[-1]:.0f} s")

In [None]:
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(8, 8))

ax1.plot(list(range(EPOCHS+1)), trainlosses, label='train')
ax1.plot(list(range(EPOCHS+1)), testlosses, label='test')
ax1.legend()
ax1.set_title('Loss')
ax2.plot(list(range(EPOCHS+1)), trainaccs, label='train')
ax2.plot(list(range(EPOCHS+1)), testaccs, label='test')
ax2.legend()
ax2.set_title('Accuracy')

In [None]:
fig.savefig('lc.svg')