In [1]:
import torch
import torchvision
from torch import nn
from tqdm import tqdm
import torch.nn.init as init
from torch.optim import AdamW
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torch.optim.lr_scheduler import ExponentialLR

In [10]:
class SkipConnection(torch.nn.Module):
    
    def __init__(self, f_m, f_s=None):
        """
        Description
        """
        super().__init__()
        self.f_m = f_m
        self.f_s = f_s
        self.relu = nn.ReLU()
        
    def forward(self, X):
        """
        Description
        """
        if self.f_s is not None:
            return self.relu(self.f_s(X) + self.f_m(X))
        else:
            return self.relu(X + self.f_m(X))
        
class AverageMeter(object):
    
    def __init__(self):
        """
        Description
        """
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

class Accuracy(object):
    
    def __init__(self, reduction="sum"):
        """
        Description
        """
        if reduction not in ["mean", "sum"]:
            raise AttributeError('The reduction can be either sum or mean')
            
        self.reduction = reduction
        
    @torch.no_grad()
    def __call__(self, x ,y):
        if self.reduction == "sum":
            return (x.argmax(1) == y).float().sum().item()
        else:
            return (x.argmax(1) == y).float().mean().item()

### Reproducability

In [11]:
import numpy as np
import random 

torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
_ = g.manual_seed(0)

### Configuration

In [15]:
cfg = {
    "device": torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
    "checkpoint_path":"./chkp/model_checkpoint_64.pt"
}

cfg_CIFAR = {
    "root":"./data",
    "download":False
}

cfg_dataloader_train = {
    "batch_size":64,
    "shuffle":True,
    "num_workers":2,
    "pin_memory":True,
    "worker_init_fn":seed_worker,
    "generator":g,
}

cfg_dataloader_test = {
    "batch_size":1024,
    "shuffle":False,
    "num_workers":2,
    "pin_memory":True,
}

cfg_train = {
    "n_epoches":200,
}

### Data

In [13]:
trainset = CIFAR10(transform=transforms.Compose([
                        transforms.RandomHorizontalFlip(),
                        transforms.RandomCrop(32, 4),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])]
                    ), **cfg_CIFAR, train=True)
trainloader = DataLoader(trainset, **cfg_dataloader_train)

testset = CIFAR10(transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])]
                    ), **cfg_CIFAR, train=False)
testloader = DataLoader(testset, **cfg_dataloader_test)

### Model

In [14]:
ResNet20 = nn.Sequential(
    ### Initial Layer
    nn.Conv2d(3, 16, 3, padding=1, bias=False),
    nn.BatchNorm2d(16),
    nn.ReLU(),
    
    ### 16x16 Block of 3 Connections
    SkipConnection(
        nn.Sequential(
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
        )
    ),
    SkipConnection(
        nn.Sequential(
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
        )
    ),
    SkipConnection(
        nn.Sequential(
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
        )
    ),
    
    ### Downsampling
    SkipConnection(
        nn.Sequential(
            nn.Conv2d(16, 32, 3, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1, bias=False),
            nn.BatchNorm2d(32),
        ),
        nn.Sequential(
            nn.Conv2d(16, 32, 3, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(32),
        ),
    ),
    
    ### 32x32 Block of 2 Connections
    SkipConnection(
        nn.Sequential(
            nn.Conv2d(32, 32, 3, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1, bias=False),
            nn.BatchNorm2d(32),
        )
    ),
    SkipConnection(
        nn.Sequential(
            nn.Conv2d(32, 32, 3, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1, bias=False),
            nn.BatchNorm2d(32),
        )
    ),
    
    ### Downsampling
    SkipConnection(
        nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
        ),
        nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(64),
        ),
    ),
    
    ### 64x64 Block of 2 Connections
    SkipConnection(
        nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
        )
    ),
    SkipConnection(
        nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
        )
    ),
    
    ### Flattening
    nn.AvgPool2d(8),
    nn.Flatten(start_dim=1, end_dim=-1),
    
    ### Head Layer
    nn.Linear(64, 10)
).to(cfg["device"])

optimResNet20 = AdamW(ResNet20.parameters(), lr=1e-2)
schedResNet20 = ExponentialLR(optimResNet20, gamma=0.1)
schedule = [100, 150]
CELoss = nn.CrossEntropyLoss(reduction="sum")
Acc = Accuracy(reduction="sum")

In [18]:
checkpoint = torch.load(cfg["checkpoint_path"])
_ = ResNet20.load_state_dict(checkpoint['model_state_dict'])

### Evaluation

In [20]:
ResNet20.eval()
with torch.no_grad():
    err_CE = 0
    err_acc = 0
    n_elem = 0

    for X_batch, y_batch in (pbar := tqdm(testloader)):
        X_batch = X_batch.to(cfg["device"])
        y_batch = y_batch.to(cfg["device"])

        logits = ResNet20(X_batch)
        output = CELoss(logits, y_batch)
        accuracy = Acc(logits, y_batch)

        batch_shape = X_batch.shape[0]
        n_elem += batch_shape
        err_CE += output.item()
        err_acc += accuracy
        pbar.set_description("CE {:.3f}, Acc {:.3f}".format(err_CE/n_elem, err_acc/n_elem))

CE 0.346, Acc 0.928: 100%|██████████| 10/10 [00:01<00:00,  8.89it/s]


### Builtin PTQ

In [21]:
# set quantization config for server (x86)
deploymentmyModel.qconfig = torch.quantization.get_default_config('fbgemm')

# insert observers
ResNet20PTQ = torch.quantization.prepare(ResNet20, inplace=False)
# Calibrate the model and collect statistics

# convert to quantized version
ResNet20PTQ = torch.quantization.convert(ResNet20PTQ, inplace=False)

AttributeError: module 'torch.quantization' has no attribute 'get_default_config'