In [1]:
import torch
import torchvision
from torch import nn
from tqdm import tqdm
import torch.nn.init as init
from torch.optim import AdamW
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torch.optim.lr_scheduler import ExponentialLR, MultiStepLR

In [2]:
class SkipConnection(torch.nn.Module):
    def __init__(self, f_m, f_s=None):
        """
        Description
        """
        super().__init__()
        self.f_m = f_m
        self.f_s = f_s
        self.relu = nn.ReLU()
        
    def forward(self, X):
        """
        Description
        """
        if self.f_s is not None:
            return self.relu(self.f_s(X) + self.f_m(X))
        else:
            return self.relu(X + self.f_m(X))
        
class AverageMeter(object):
        """
        Description
        """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

class Accuracy(object):
        """
        Description
        """
    def __init__(self, reduction="sum"):
        if reduction not in ["mean", "sum"]:
            raise AttributeError('The reduction can be either sum or mean')
            
        self.reduction = reduction
        
    @torch.no_grad()
    def __call__(self, x ,y):
        if self.reduction == "sum":
            return (x.argmax(1) == y).float().sum().item()
        else:
            return (x.argmax(1) == y).float().mean().item()

### Helper functions from PyTorch

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def evaluate(model, criterion, data_loader, neval_batches):
    model.eval()
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    cnt = 0
    with torch.no_grad():
        for image, target in data_loader:
            output = model(image)
            loss = criterion(output, target)
            cnt += 1
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            print('.', end = '')
            top1.update(acc1[0], image.size(0))
            top5.update(acc5[0], image.size(0))
            if cnt >= neval_batches:
                 return top1, top5

    return top1, top5

def load_model(model_file):
    model = MobileNetV2()
    state_dict = torch.load(model_file)
    model.load_state_dict(state_dict)
    model.to('cpu')
    return model

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

### Reproducability

In [3]:
import numpy as np
import random 

torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
_ = g.manual_seed(0)

### Configuration

In [4]:
cfg = {
    "device": torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
    "checkpoint_path":"./chkp/model_checkpoint_64_check.pt"
}

cfg_CIFAR = {
    "root":"./data",
    "download":False
}

cfg_dataloader_train = {
    "batch_size":64,
    "shuffle":True,
    "num_workers":2,
    "pin_memory":True,
    "worker_init_fn":seed_worker,
    "generator":g,
}

cfg_dataloader_test = {
    "batch_size":1024,
    "shuffle":False,
    "num_workers":2,
    "pin_memory":True,
}

cfg_train = {
    "n_epoches":200,
}

### Data

In [5]:
trainset = CIFAR10(transform=transforms.Compose([
                        transforms.RandomHorizontalFlip(),
                        transforms.RandomCrop(32, 4),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])]
                    ), **cfg_CIFAR, train=True)
trainloader = DataLoader(trainset, **cfg_dataloader_train)

testset = CIFAR10(transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])]
                    ), **cfg_CIFAR, train=False)
testloader = DataLoader(testset, **cfg_dataloader_test)

### Model

In [6]:
ResNet20 = nn.Sequential(
    ### Initial Layer
    nn.Conv2d(3, 16, 3, padding=1, bias=False),
    nn.BatchNorm2d(16),
    nn.ReLU(),
    
    ### 16x16 Block of 3 Connections
    SkipConnection(
        nn.Sequential(
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
        )
    ),
    SkipConnection(
        nn.Sequential(
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
        )
    ),
    SkipConnection(
        nn.Sequential(
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
        )
    ),
    
    ### Downsampling
    SkipConnection(
        nn.Sequential(
            nn.Conv2d(16, 32, 3, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1, bias=False),
            nn.BatchNorm2d(32),
        ),
        nn.Sequential(
            nn.Conv2d(16, 32, 3, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(32),
        ),
    ),
    
    ### 32x32 Block of 2 Connections
    SkipConnection(
        nn.Sequential(
            nn.Conv2d(32, 32, 3, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1, bias=False),
            nn.BatchNorm2d(32),
        )
    ),
    SkipConnection(
        nn.Sequential(
            nn.Conv2d(32, 32, 3, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1, bias=False),
            nn.BatchNorm2d(32),
        )
    ),
    
    ### Downsampling
    SkipConnection(
        nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
        ),
        nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(64),
        ),
    ),
    
    ### 64x64 Block of 2 Connections
    SkipConnection(
        nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
        )
    ),
    SkipConnection(
        nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
        )
    ),
    
    ### Flattening
    nn.AvgPool2d(8),
    nn.Flatten(start_dim=1, end_dim=-1),
    
    ### Head Layer
    nn.Linear(64, 10)
).to(cfg["device"])

optimResNet20 = AdamW(ResNet20.parameters(), lr=1e-2)
schedResNet20 = ExponentialLR(optimResNet20, gamma=0.1)
schedule = [100, 150]
CELoss = nn.CrossEntropyLoss(reduction="sum")
Acc = Accuracy(reduction="sum")

### Training

In [7]:
best_metric_acc = None
best_metric_CE = None

train_meter_CE = []
train_meter_acc = []

test_meter_CE = []
test_meter_acc = []

for e in (pbar := tqdm(range(cfg_train["n_epoches"]))):
    err_CE = 0
    err_acc = 0
    n_elem = 0
    
    ResNet20.train()
    for X_batch, y_batch in trainloader:
        optimResNet20.zero_grad()
        X_batch = X_batch.to(cfg["device"])
        y_batch = y_batch.to(cfg["device"])
        
        logits = ResNet20(X_batch)
        output = CELoss(logits, y_batch)
        accuracy = Acc(logits, y_batch)
        
        output.backward()
        optimResNet20.step()
        
        with torch.no_grad():
            batch_shape = X_batch.shape[0]
            n_elem += batch_shape
            err_CE += output.item()
            err_acc += accuracy

    train_meter_CE.append(err_CE/n_elem)
    train_meter_acc.append(err_acc/n_elem)

    ResNet20.eval()
    with torch.no_grad():
        err_CE = 0
        err_acc = 0
        n_elem = 0

        for X_batch, y_batch in testloader:
            X_batch = X_batch.to(cfg["device"])
            y_batch = y_batch.to(cfg["device"])

            logits = ResNet20(X_batch)
            output = CELoss(logits, y_batch)
            accuracy = Acc(logits, y_batch)
            
            batch_shape = X_batch.shape[0]
            n_elem += batch_shape
            err_CE += output.item()
            err_acc += accuracy

        test_meter_CE.append(err_CE/n_elem)
        test_meter_acc.append(err_acc/n_elem)

    if best_metric_acc is None:
        best_metric_CE = test_meter_CE[-1]
        best_metric_acc = test_meter_acc[-1]
        
        torch.save({
            'model_state_dict': ResNet20.state_dict(),
            'optimizer_state_dict': optimResNet20.state_dict(),
            'CE': best_metric_CE,
            "Acc":best_metric_acc,
            }, cfg["checkpoint_path"])
        
    elif best_metric_acc < (LOSS := test_meter_acc[-1]):
        best_metric_CE = test_meter_CE[-1]
        best_metric_acc = LOSS
        
        torch.save({
            'model_state_dict': ResNet20.state_dict(),
            'optimizer_state_dict': optimResNet20.state_dict(),
            'CE': best_metric_CE,
            "Acc":best_metric_acc,
            }, cfg["checkpoint_path"])
        
    if (e+1) in schedule:
        schedResNet20.step()
        
    pbar.set_description("Train: CE {:.3f} Acc. {:.3f}| Test: CE {:.3f} Acc. {:.3f} | LR: {}".format(
        train_meter_CE[-1], train_meter_acc[-1], best_metric_CE, best_metric_acc, schedResNet20.get_last_lr()[0]
    ))

Train: CE 0.014 Acc. 0.996| Test: CE 0.346 Acc. 0.928 | LR: 0.0001: 100%|██████████| 200/200 [54:48<00:00, 16.44s/it]
