In [None]:
import pathlib
import os
import random
import time

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torchvision
from torchvision.datasets import CIFAR100

from torch.utils.data import DataLoader

In [None]:


def _weights_init(m):
    """He initialization"""
    classname = m.__class__.__name__
    #print(classname)
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)


class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option='A'):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                     nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                     nn.BatchNorm2d(self.expansion * planes)
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Cutr32(nn.Module):
    def __init__(self, batch_size):
        super().__init__()
        block = BasicBlock
        num_classes = 100
        self.in_planes = 16

        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, 5, stride=1)
        self.layer2 = self._make_layer(block, 32, 5, stride=2)
        self.layer3 = self._make_layer(block, 64, 5, stride=2)
        self.linear = nn.Linear(64, num_classes)

        self.apply(_weights_init)
    
        

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out



def seed_everything(seed):
    '''
    Fixes the class-to-task assignments and most other sources of randomness, except CUDA training aspects.
    '''
    # Avoid all sorts of randomness for better replication
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True # An exemption for speed :P


def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2


def cutmix_data(x, y, alpha=1.0, cutmix_prob=0.5):
    assert(alpha > 0)
    # generate mixed sample
    lam = np.random.beta(alpha, alpha)

    batch_size = x.size()[0]
    index = torch.randperm(batch_size)

    if torch.cuda.is_available():
        index = index.cuda()

    y_a, y_b = y, y[index]
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]

    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
    return x, y_a, y_b, lam

In [None]:
cutmix_prob = 0.5
encoding_block = 2

In [None]:
EPOCHS = 512
DEVICE = 'cuda'
DATA_DIR = pathlib.Path('.')
BATCH_SIZE = 256
pretraining_classes = range(50)
seed_everything(0)

In [None]:
train_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])

test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5071, 0.4867, 0.4408),(0.2675, 0.2565, 0.2761))
])


trainset_full = CIFAR100(
    train=True,
    transform=train_transforms,
    root=DATA_DIR,
    download=True
)
train_subset_idxs = [i for i in range(len(trainset_full)) if trainset_full.targets[i] in pretraining_classes]
trainset = torch.utils.data.Subset(trainset_full, train_subset_idxs)

testset_full = CIFAR100(
    train=False,
    transform=test_transforms,
    root=DATA_DIR,
    download=True
)
test_subset_idxs = [i for i in range(len(testset_full)) if testset_full.targets[i] in pretraining_classes]
testset = torch.utils.data.Subset(testset_full, test_subset_idxs)

trainloader = DataLoader(
    dataset=trainset,
    shuffle=True,
    batch_size=BATCH_SIZE,
    num_workers=4,
)
testloader = DataLoader(
    dataset=testset,
    shuffle=False,
    batch_size=BATCH_SIZE,
    num_workers=4,
)

In [None]:
model = Cutr32(encoding_block).to(DEVICE)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=0.01,
    momentum=0.9,
    weight_decay=1e-4,
)


In [None]:
def test_model(model, dataloader, criterion, device):
    with torch.no_grad():
        running_loss = 0
        running_acc = 0
        for data, target in dataloader:
            data = data.to(device)
            target = target.to(device)
            output = model(data)
            loss = criterion(output, target)

            prob = torch.nn.functional.softmax(output, dim=1)
            pred = torch.argmax(prob, dim=1)
            corr = torch.eq(pred, target)
            acc = (corr*1.0).mean()

            running_acc += acc.item()
            running_loss += loss.item()

    return (
        running_acc/len(dataloader),
        running_loss/len(dataloader),
    )


Train the model


In [None]:

for data, target in trainloader:
    assert (target<50).all().item()
for data, target in testloader:
    assert (target<50).all().item()


In [None]:

print('starting testing')
model.eval()
tstart = time.time()
acc, loss = test_model(
    model=model,
    dataloader=testloader,
    criterion=criterion,
    device=DEVICE
)
testtimes = [time.time()-tstart]
testlosses = [loss]
testaccs = [acc]
trainlosses = [loss]
trainaccs = [0]
traintimes = [0]
checkpoints = []
best_acc = 0
print(f"epoch {0}\t testloss {testlosses[-1]:.4f} \f testacc {testaccs[-1]:.4f}\t time {testtimes[-1]:.0f} s")

# training
for epoch in range(EPOCHS):
    running_loss = 0
    running_acc = 0
    tstart = time.time()
    model.train()
    for data, target in trainloader:
        data = data.to(DEVICE)
        target = target.to(DEVICE)

        do_cutmix =  np.random.rand(1) < cutmix_prob
        if do_cutmix:
            data, labels_a, labels_b, lam = cutmix_data(x=data, y=target, alpha=1)
            

        # output = model(embedding)
        output = model(data)
        
        if do_cutmix:
            loss = lam * criterion(output, labels_a) + (1 - lam) * criterion(output, labels_b)
        else:
            loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        prob = torch.nn.functional.softmax(output, dim=1)
        pred = torch.argmax(prob, dim=1)
        corr = torch.eq(pred, target)
        acc = (corr*1.0).mean()
        running_acc += acc.item()
        running_loss += loss.item()

    traintimes.append(time.time() - tstart)
    trainlosses.append(running_loss / len(trainloader))
    trainaccs.append(running_acc / len(trainloader))

    print(f"epoch {epoch+1}\t trainloss {trainlosses[-1]:.4f} \f trainacc {trainaccs[-1]:.4f}\t time {traintimes[-1]:.0f} s")

    # testing
    model.eval()
    running_loss = 0
    running_acc = 0
    tstart = time.time()
    acc, loss = test_model(
        model=model,
        dataloader=testloader,
        criterion=criterion,
        device=DEVICE
    )
    testtimes.append(time.time()-tstart)
    testlosses.append(loss)
    testaccs.append(acc)

    if acc > best_acc:
        path = f"resnet32_cifar100_classes0to49_e{epoch+1}.pt"
        torch.save(model.state_dict(), path)
        checkpoints.append(epoch+1)
        best_acc = acc
        print(f"saved checkpoint: {path}")

    print(f"epoch {epoch+1}\t testloss {testlosses[-1]:.4f} \t testacc {testaccs[-1]:.4f}\t time {testtimes[-1]:.0f} s")


In [None]:
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(8, 8))

ax1.plot(list(range(EPOCHS+1)), trainlosses, label='train')
ax1.plot(list(range(EPOCHS+1)), testlosses, label='test')
ax1.legend()
ax1.set_title('Loss')
ax2.plot(list(range(EPOCHS+1)), trainaccs, label='train')
ax2.plot(list(range(EPOCHS+1)), testaccs, label='test')
ax2.legend()
ax2.set_title('Accuracy')
fig.savefig('lc.svg')