In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import copy
import wandb
import math

import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import argparse
import os

import quant_lib.ResNet18 as ResNet18


In [2]:
class QuantizeLayer(nn.Module):
    def __init__(self, layer):
        super(QuantizeLayer, self).__init__()
        self.layer = layer

    def forward(self, x):
        if isinstance(self.layer, nn.Conv2d):
            weight_q = torch.sign(self.layer.weight)
            if self.layer.bias is not None:
                bias_q = torch.sign(self.layer.bias)
                return nn.functional.conv2d(
                    x,
                    weight_q,
                    bias_q,
                    self.layer.stride,
                    self.layer.padding,
                    self.layer.dilation,
                    self.layer.groups,
                )
            else:
                return nn.functional.conv2d(
                    x,
                    weight_q,
                    None,
                    self.layer.stride,
                    self.layer.padding,
                    self.layer.dilation,
                    self.layer.groups,
                )
        elif isinstance(self.layer, nn.Linear):
            weight_q = torch.sign(self.layer.weight)
            if self.layer.bias is not None:
                bias_q = torch.sign(self.layer.bias)
                return nn.functional.linear(x, weight_q, bias_q)
            else:
                return nn.functional.linear(x, weight_q, None)
        else:
            return self.layer(x)

In [3]:
def ResNet(bit_width=1):
    return ResNet18.ResNet(ResNet18.ResidualBlock, bit_width=bit_width)

In [4]:
# Check GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set hyperparameter
bit_width=1
EPOCH = 150
ANNEAL_EPOCH = 90
pre_epoch = 0
BATCH_SIZE = 100
LR = 0.06

# ASkewSGD
DECAY_CONST=0.88
alpha = 0.2
# ProxQuant
reg_lambda = 4e-4

# prepare dataset and preprocessing
transform_train = transforms.Compose(
    [
        transforms.Resize(40),
        torchvision.transforms.RandomResizedCrop(
            32, scale=(0.64, 1.0), ratio=(1.0, 1.0)
        ),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.RandomErasing(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

trainset = torchvision.datasets.CIFAR10(
    root="../data", train=True, download=True, transform=transform_train
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=3
)

testset = torchvision.datasets.CIFAR10(
    root="../data", train=False, download=True, transform=transform_test
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=3
)

# Labels in CIFAR10
classes = (
    "plane",
    "car",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
)

# Define ResNet18
base_net = ResNet(bit_width=bit_width).to(device)

# Define loss funtion
criterion = nn.CrossEntropyLoss()


Files already downloaded and verified
Files already downloaded and verified


In [5]:
# FILE = 'resnet18_original.pt'

# checkpoint = torch.load(FILE)
# net.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# torch.save({'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, FILE)

In [6]:
def init(project_name, opt_name, batch_size, architecture, dataset_name, lr, alpha=None, reg_lambda=None):
    wandb.init(
        # set the wandb project where this run will be logged
        project=project_name,
        name=opt_name,
        # track hyperparameters and run metadata
        config={
            "batch_size": batch_size,
            "architecture": architecture,
            "dataset": dataset_name,
            "lr": lr,
            "alpha": alpha, 
            "reg_lambda": reg_lambda,
            "bit_width": base_net.bit_width
        },
    )
    net = copy.deepcopy(base_net)
    net.to(device)
    weights = [p for name, p in net.named_parameters() if 'bias' not in name]
    bias = [p for name, p in net.named_parameters() if 'bias' in name]
    parameters = [{"params": weights, "tag": "weights"}, {"params": bias, "tag": "bias"}]
    optimizer = optim.SGD(parameters, lr=lr)
    return net, optimizer

In [None]:
# SGD
net, optimizer = init(project_name="CIFAR10_binary", opt_name="ProxQuant", batch_size=BATCH_SIZE, architecture="ResNet-18", dataset_name="CIFAR10", lr=LR)

lr_decay_epochs = [20, 40, 60, 95]

best_acc=0

for decay_epoch in lr_decay_epochs:
    if pre_epoch > decay_epoch:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

# Train
for epoch in range(pre_epoch, EPOCH):
    print("\nEpoch: %d" % (epoch + 1))
    net.train()
    sum_loss = 0.0
    correct = 0.0
    total = 0.0

    if epoch in lr_decay_epochs:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

    for i, data in enumerate(trainloader, 0):
        # prepare dataset
        length = len(trainloader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # forward & backward
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        with torch.no_grad():
            for name, param in net.named_parameters():
                if not name.endswith(".bias"):
                    param.data = torch.clamp(param.data, -1.0, 1.0)
        sum_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()
        print(
            "[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% "
            % (
                epoch + 1,
                (i + 1 + (epoch) * length),
                sum_loss / (i + 1),
                100.0 * correct / total,
            )
        )
        
    print("Waiting Test...")
    with torch.no_grad():
        # train_loss, train_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=False)
        # qtrain_loss, qtrain_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=True)
        test_loss, test_acc=net.evaluate(testloader, criterion, device, eval=True, qt=False)
        qtest_loss, qtest_acc=net.evaluate(testloader, criterion, device, eval=True, qt=True)
        wandb.log(
            {
                "test_loss": test_loss,
                "quantized_test_loss": qtest_loss,
                "test_accuracy": test_acc,
                "quantized_test_accuracy": qtest_acc
            }
        )
        print(
            "Test Loss: %.03f | Test Acc: %.3f%% "
            % (
                test_loss,
                test_acc,
            )
        )
        print(
            "Quantized Test Loss: %.03f | Quantized Test Acc: %.3f%% "
            % (
                qtest_loss,
                qtest_acc,
            )
        )
        FILE = "MNIST_SGD.pt"
        if qtest_acc>best_acc:
            torch.save(
                {
                    "model_state_dict": net.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                FILE,
            )
            best_acc=qtest_acc
print("Best Test Accuracy after Quantization: %.3f%%" % best_acc)
wandb.finish()
# if epoch + 1 == ANNEAL_EPOCH:
#     FILE = 'resnet18_original_new.pt'
#     torch.save({'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, FILE)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: hokfong (hokfong-the-chinese-university-of-hong-kong). Use `wandb login --relogin` to force relogin



Epoch: 1
[epoch:1, iter:1] Loss: 2.465 | Acc: 10.000% 
[epoch:1, iter:2] Loss: 2.770 | Acc: 9.000% 
[epoch:1, iter:3] Loss: 2.947 | Acc: 10.000% 
[epoch:1, iter:4] Loss: 3.042 | Acc: 11.750% 
[epoch:1, iter:5] Loss: 3.075 | Acc: 11.400% 
[epoch:1, iter:6] Loss: 3.150 | Acc: 10.667% 
[epoch:1, iter:7] Loss: 3.112 | Acc: 11.143% 
[epoch:1, iter:8] Loss: 3.043 | Acc: 11.750% 
[epoch:1, iter:9] Loss: 3.023 | Acc: 11.778% 
[epoch:1, iter:10] Loss: 2.954 | Acc: 12.400% 
[epoch:1, iter:11] Loss: 2.910 | Acc: 12.727% 
[epoch:1, iter:12] Loss: 2.841 | Acc: 13.833% 
[epoch:1, iter:13] Loss: 2.806 | Acc: 14.231% 
[epoch:1, iter:14] Loss: 2.758 | Acc: 15.071% 
[epoch:1, iter:15] Loss: 2.715 | Acc: 15.200% 
[epoch:1, iter:16] Loss: 2.683 | Acc: 15.438% 
[epoch:1, iter:17] Loss: 2.657 | Acc: 15.529% 
[epoch:1, iter:18] Loss: 2.624 | Acc: 16.111% 
[epoch:1, iter:19] Loss: 2.600 | Acc: 16.316% 
[epoch:1, iter:20] Loss: 2.579 | Acc: 16.250% 
[epoch:1, iter:21] Loss: 2.568 | Acc: 16.238% 
[epoch:1, ite

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
quantized_test_accuracy,▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▂▁▃▄▃▁▅▁▄▁▄▄▅▃▆▄▃▂▁▂▆██
quantized_test_loss,█▇▅▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_accuracy,▁▄▅▆▆▇▇▇▇▇▇██▇▇█████████████████████████
test_loss,█▄▄▂▃▁▁▃▂▂▂▁▁▂▂▂▁▁▁▂▂▁▁▁▂▂▁▁▁▁▂▂▁▁▁▁▁▁▁▁

0,1
quantized_test_accuracy,13.37
quantized_test_loss,4.565953337513891e+26
test_accuracy,88.49
test_loss,0.48455


In [9]:
# ProxQuant
net, optimizer = init(project_name="CIFAR10_binary", opt_name="SGD", batch_size=BATCH_SIZE, architecture="ResNet-18", dataset_name="CIFAR10", lr=LR, reg_lambda=reg_lambda)
optimizer = optim.Adam(net.parameters(), lr=LR, betas=(0.9, 0.999))

best_acc=0

lr_decay_epochs = [20, 40, 60, 95]
lr=LR
for decay_epoch in lr_decay_epochs:
    if pre_epoch > decay_epoch:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5
            lr*=0.5

# Train
it=-1
reg_lambda = 0.0001
total_it=(EPOCH-ANNEAL_EPOCH)*len(trainloader)
for epoch in range(0, EPOCH):
    print("\nEpoch: %d" % (epoch + 1))
    sum_loss = 0.0
    correct = 0.0
    total = 0.0
    if epoch in lr_decay_epochs:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5
            lr*=0.5

    for i, data in enumerate(trainloader, 0):
        # prepare dataset
        if epoch < ANNEAL_EPOCH:
            epsilon=0
        else:
            it+=1
            epsilon = reg_lambda*it/total_it
        net.train()
        length = len(trainloader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # forward & backward
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        sum_loss += loss.item()
        optimizer.step()
        with torch.no_grad():
            for name, param in net.named_parameters():
                if not name.endswith(".bias"):
                    # Prox Step
                    param.data=(param.data+(epsilon*lr)*torch.sign(param.data))/(1+epsilon*lr)
        optimizer.zero_grad()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()
        print(
            "[Epoch:%d, Iter:%d] Loss: %.03f | Acc: %.3f%% "
            % (
                epoch + 1,
                (i + 1 + (epoch) * length),
                sum_loss / (i + 1),
                100.0 * correct / total,
            )
        )
    print("Waiting Test...")
    with torch.no_grad():
        # train_loss, train_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=False)
        # qtrain_loss, qtrain_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=True)
        test_loss, test_acc=net.evaluate(testloader, criterion, device, eval=True, qt=False)
        qtest_loss, qtest_acc=net.evaluate(testloader, criterion, device, eval=True, qt=True)
        wandb.log(
            {
                "test_loss": test_loss,
                "quantized_test_loss": qtest_loss,
                "test_accuracy": test_acc,
                "quantized_test_accuracy": qtest_acc
            }
        )
        print(
            "Test Loss: %.03f | Test Acc: %.3f%% "
            % (
                test_loss,
                test_acc,
            )
        )
        print(
            "Quantized Test Loss: %.03f | Quantized Test Acc: %.3f%% "
            % (
                qtest_loss,
                qtest_acc,
            )
        )
        FILE = "MNIST_ProxQuant.pt"
        if qtest_acc>best_acc:
            torch.save(
                {
                    "model_state_dict": net.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                FILE,
            )
            best_acc=qtest_acc
print("Best Test Accuracy after Quantization: %.3f%%" % best_acc)
wandb.finish()


Epoch: 1
[Epoch:1, Iter:1] Loss: 2.407 | Acc: 11.000% 
[Epoch:1, Iter:2] Loss: 8.962 | Acc: 9.500% 
[Epoch:1, Iter:3] Loss: 12.123 | Acc: 9.667% 
[Epoch:1, Iter:4] Loss: 11.780 | Acc: 10.250% 
[Epoch:1, Iter:5] Loss: 11.644 | Acc: 9.200% 
[Epoch:1, Iter:6] Loss: 11.065 | Acc: 9.167% 
[Epoch:1, Iter:7] Loss: 10.017 | Acc: 10.000% 
[Epoch:1, Iter:8] Loss: 9.686 | Acc: 10.750% 
[Epoch:1, Iter:9] Loss: 9.120 | Acc: 10.222% 
[Epoch:1, Iter:10] Loss: 8.590 | Acc: 10.900% 
[Epoch:1, Iter:11] Loss: 8.594 | Acc: 10.364% 
[Epoch:1, Iter:12] Loss: 8.848 | Acc: 10.500% 
[Epoch:1, Iter:13] Loss: 8.481 | Acc: 11.154% 
[Epoch:1, Iter:14] Loss: 8.418 | Acc: 12.000% 
[Epoch:1, Iter:15] Loss: 8.033 | Acc: 11.867% 
[Epoch:1, Iter:16] Loss: 8.214 | Acc: 11.875% 
[Epoch:1, Iter:17] Loss: 8.279 | Acc: 11.412% 
[Epoch:1, Iter:18] Loss: 8.210 | Acc: 11.611% 
[Epoch:1, Iter:19] Loss: 7.954 | Acc: 12.000% 
[Epoch:1, Iter:20] Loss: 7.820 | Acc: 11.950% 
[Epoch:1, Iter:21] Loss: 7.641 | Acc: 12.000% 
[Epoch:1, I

KeyboardInterrupt: 

In [None]:
# ASkewSGD
net, optimizer = init(project_name="CIFAR10_binary", opt_name="ASkewSGD", batch_size=BATCH_SIZE, architecture="ResNet-18", dataset_name="CIFAR10", lr=LR)

best_acc=0

lr_decay_epochs = [20, 40, 60, 95]

for decay_epoch in lr_decay_epochs:
    if pre_epoch > decay_epoch:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

# Train
for epoch in range(pre_epoch, EPOCH):
    print("\nEpoch: %d" % (epoch + 1))
    net.train()
    sum_loss = 0.0
    correct = 0.0
    total = 0.0

    if epoch in lr_decay_epochs:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

    for i, data in enumerate(trainloader, 0):
        if epoch < ANNEAL_EPOCH:
            epsilon = 1
        else:
            epsilon =  DECAY_CONST ** ((epoch - ANNEAL_EPOCH)+(i/len(trainloader)))
        # prepare dataset
        length = len(trainloader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # forward & backward
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        for param_group in optimizer.param_groups:
            if param_group["tag"] == "weights":
                for idx, p in enumerate(param_group["params"]):
                    constr = epsilon - (p.data**2 - 1) ** 2
                    Kx = -4 * (p.data**2 - 1) * p.data
                    direct_grad = torch.logical_or(
                        torch.logical_or(constr >= 0, Kx == 0),
                        torch.logical_and(
                            constr < 0, (-Kx * p.grad.data) >= -alpha * constr
                        ),
                    )
                    p.grad.data[direct_grad] = p.grad.data[direct_grad]
                    p.grad.data[~direct_grad] = (torch.clip(
                        alpha * constr / Kx,
                        -1,
                        1,
                    ))[~direct_grad]
        optimizer.step()
        optimizer.zero_grad()
        sum_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()
        print(
            "[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% "
            % (
                epoch + 1,
                (i + 1 + (epoch) * length),
                sum_loss / (i + 1),
                100.0 * correct / total,
            )
        )
        
    print("Waiting Test...")
    with torch.no_grad():
        # train_loss, train_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=False)
        # qtrain_loss, qtrain_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=True)
        test_loss, test_acc=net.evaluate(testloader, criterion, device, eval=True, qt=False)
        qtest_loss, qtest_acc=net.evaluate(testloader, criterion, device, eval=True, qt=True)
        wandb.log(
            {
                "test_loss": test_loss,
                "quantized_test_loss": qtest_loss,
                "test_accuracy": test_acc,
                "quantized_test_accuracy": qtest_acc
            }
        )
        print(
            "Test Loss: %.03f | Test Acc: %.3f%% "
            % (
                test_loss,
                test_acc,
            )
        )
        print(
            "Quantized Test Loss: %.03f | Quantized Test Acc: %.3f%% "
            % (
                qtest_loss,
                qtest_acc,
            )
        )
        FILE = "CIFAR10_ASkewSGD.pt"
        if qtest_acc>best_acc:
            torch.save(
                {
                    "model_state_dict": net.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                FILE,
            )
            best_acc=qtest_acc
print("Best Test Accuracy after Quantization: %.3f%%" % best_acc)
wandb.finish()



Epoch: 1
[epoch:1, iter:1] Loss: 2.383 | Acc: 8.000% 
[epoch:1, iter:2] Loss: 2.445 | Acc: 11.000% 
[epoch:1, iter:3] Loss: 2.427 | Acc: 13.333% 
[epoch:1, iter:4] Loss: 2.466 | Acc: 13.500% 
[epoch:1, iter:5] Loss: 2.486 | Acc: 13.800% 
[epoch:1, iter:6] Loss: 2.553 | Acc: 14.667% 
[epoch:1, iter:7] Loss: 2.627 | Acc: 14.286% 
[epoch:1, iter:8] Loss: 2.649 | Acc: 15.000% 
[epoch:1, iter:9] Loss: 2.685 | Acc: 14.556% 
[epoch:1, iter:10] Loss: 2.714 | Acc: 14.700% 
[epoch:1, iter:11] Loss: 2.720 | Acc: 14.727% 
[epoch:1, iter:12] Loss: 2.714 | Acc: 14.750% 
[epoch:1, iter:13] Loss: 2.718 | Acc: 15.077% 
[epoch:1, iter:14] Loss: 2.714 | Acc: 15.357% 
[epoch:1, iter:15] Loss: 2.685 | Acc: 15.133% 
[epoch:1, iter:16] Loss: 2.675 | Acc: 15.750% 
[epoch:1, iter:17] Loss: 2.642 | Acc: 16.059% 
[epoch:1, iter:18] Loss: 2.637 | Acc: 16.111% 
[epoch:1, iter:19] Loss: 2.620 | Acc: 16.421% 
[epoch:1, iter:20] Loss: 2.610 | Acc: 16.500% 
[epoch:1, iter:21] Loss: 2.591 | Acc: 16.905% 
[epoch:1, ite

In [None]:
# Deterministic BinaryConnect
net, optimizer = init(project_name="CIFAR10_binary", opt_name="Deterministic BinaryConnect", batch_size=BATCH_SIZE, architecture="ResNet-18", dataset_name="CIFAR10", lr=LR*4)

best_acc=0
model_copy = copy.deepcopy(net)


lr_decay_epochs = [60, 80, 100, 115, 130]

for decay_epoch in lr_decay_epochs:
    if pre_epoch > decay_epoch:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

# Train
for epoch in range(pre_epoch, EPOCH):
    print("\nEpoch: %d" % (epoch + 1))
    net.train()
    sum_loss = 0.0
    correct = 0.0
    total = 0.0

    if epoch in lr_decay_epochs:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

    for i, data in enumerate(trainloader, 0):
        # prepare dataset
        length = len(trainloader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # forward & backward
        with torch.no_grad():
            outputs = model_copy(inputs)
            loss2 = criterion(outputs, labels)
            for net_name, net_param in net.named_parameters():
                if not net_name.endswith(".bias"):
                    net_param.data=net.quantize(net_param.data, net.bit_width)

        outputs2 = net(inputs)
        loss = criterion(outputs2, labels)
        loss.backward()
        optimizer.step()
        for (net_name, net_param), (model_copy_name, model_copy_param) in zip(
            net.named_parameters(), model_copy.named_parameters()
        ):
            if not net_name.endswith(".bias"):
                delta = net_param.data - model_copy.quantize(model_copy_param.data, model_copy.bit_width)
                net_param.data = torch.clamp(model_copy_param.data + delta, -1, 1)
            model_copy_param.data = torch.clamp(model_copy_param.data + delta, -1, 1)
        sum_loss += loss.item()
        optimizer.zero_grad()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()
        print(
            "[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% "
            % (
                epoch + 1,
                (i + 1 + (epoch) * length),
                sum_loss / (i + 1),
                100.0 * correct / total,
            )
        )
        
    print("Waiting Test...")
    with torch.no_grad():
        # train_loss, train_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=False)
        # qtrain_loss, qtrain_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=True)
        test_loss, test_acc=net.evaluate(testloader, criterion, device, eval=True, qt=False)
        qtest_loss, qtest_acc=net.evaluate(testloader, criterion, device, eval=True, qt=True)
        wandb.log(
            {
                "test_loss": test_loss,
                "quantized_test_loss": qtest_loss,
                "test_accuracy": test_acc,
                "quantized_test_accuracy": qtest_acc
            }
        )
        print(
            "Test Loss: %.03f | Test Acc: %.3f%% "
            % (
                test_loss,
                test_acc,
            )
        )
        print(
            "Quantized Test Loss: %.03f | Quantized Test Acc: %.3f%% "
            % (
                qtest_loss,
                qtest_acc,
            )
        )
        FILE = "CIFAR10_Deterministic_BinaryConnect.pt"
        if qtest_acc>best_acc:
            torch.save(
                {
                    "model_state_dict": net.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                FILE,
            )
            best_acc=qtest_acc
print("Best Test Accuracy after Quantization: %.3f%%" % best_acc)
wandb.finish()



In [None]:
# FILE = "resnet18_qt_new.pt"

# torch.save(
#     {
#         "model_state_dict": net.state_dict(),
#         "optimizer_state_dict": optimizer.state_dict(),
#     },
#     FILE,
# )

In [None]:
# FILE = "resnet18_original_new.pt"
# checkpoint = torch.load(FILE)
# net.load_state_dict(checkpoint["model_state_dict"])
# with torch.no_grad():
#     correct = 0
#     total = 0
#     for data in testloader:
#         net.eval()
#         images, labels = data
#         images, labels = images.to(device), labels.to(device)
#         outputs = net(images)
#         _, predicted = torch.max(outputs.data, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum()
#     print("Test's ac is: %.3f%%" % (100 * correct / total))

# FILE = "resnet18_qt_new.pt"
# checkpoint = torch.load(FILE)
# net.load_state_dict(checkpoint["model_state_dict"])
# model_copy = copy.deepcopy(net)
# with torch.no_grad():
#     for name, param in model_copy.named_parameters():
#         param.data = torch.sign(param.data)
#     correct = 0
#     total = 0
#     for data in testloader:
#         model_copy.eval()
#         images, labels = data
#         images, labels = images.to(device), labels.to(device)
#         outputs = model_copy(images)
#         _, predicted = torch.max(outputs.data, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum()
#     print("Test's ac is: %.3f%%" % (100 * correct / total))