In [1]:
### Utilities
import os
import json
import random 
import argparse
import warnings
import numpy as np
from tqdm import tqdm
import copy

### Torch
import torch
from torch import nn
from torch.optim import AdamW
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torch.optim.lr_scheduler import MultiStepLR

### Custom
from src.model import model
from src.train import train, validate
from src.utils import Accuracy, AverageMeter

In [2]:
torch.manual_seed(42)

<torch._C.Generator at 0x7fd0eceb8bd0>

### Configuration

In [3]:
with open("src/cfg/cfg_64_200.json") as configurations:
    cfg, cfg_CIFAR, cfg_dataloader_train, cfg_dataloader_test, cfg_train = json.load(configurations).values()

### Data

In [4]:
trainset = CIFAR10(transform=transforms.Compose([
                        transforms.RandomHorizontalFlip(),
                        transforms.RandomCrop(32, 4),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])]
                    ), **cfg_CIFAR, train=True)
trainloader = DataLoader(trainset, **cfg_dataloader_train)

testset = CIFAR10(transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])]
                    ), **cfg_CIFAR, train=False)
testloader = DataLoader(testset, **cfg_dataloader_test)

### Model

In [5]:
is_cuda = torch.cuda.is_available()
from_checkpoint = os.path.exists(cfg["checkpoint_path"])
cfg["device"] = torch.device("cuda") if is_cuda \
            else torch.device("cpu")


ResNet20 = model().to(cfg["device"])
optimResNet20 = AdamW(ResNet20.parameters(), lr=1e-2)
schedResNet20 = MultiStepLR(optimResNet20, last_epoch=-1,
                            milestones=[100, 150], gamma=0.1)

if from_checkpoint:
    checkpoint = torch.load(cfg["checkpoint_path"])
    last_epoch = checkpoint["epoch"] + 1
    best_acc = checkpoint["best_acc"]
    ResNet20.load_state_dict(checkpoint["state_dict"])
    optimResNet20.load_state_dict(checkpoint["optimizer"])
    schedResNet20.load_state_dict(checkpoint["scheduler"])
    
CELoss = nn.CrossEntropyLoss(reduction="sum")
Acc = Accuracy(reduction="sum")

In [6]:
ce, acc = validate(testloader, ResNet20, CELoss, Acc, cfg["device"], verbose=True)
print("Cross Entropy: {:.3f}, Accuracy: {:.3f}".format(ce.avg, acc.avg))

100%|██████████| 10/10 [00:02<00:00,  3.98it/s]

Cross Entropy: 0.374, Accuracy: 0.926





In [13]:
ResNet20PTQ = copy.deepcopy(ResNet20.to(torch.device("cpu")))
ResNet20PTQ.qconfig = torch.quantization.get_default_qconfig('fbgemm')