In [69]:
import torch
from torch.utils import data
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import sklearn.datasets
import matplotlib.pyplot as plt
import numpy as np
import wandb
import quant_lib.MNIST_MLP as mlp
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import copy
%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [70]:
def MLPMNISTNet(bit_width=1):
    return mlp.MLP(input_class=784, num_classes=10, bit_width=bit_width)

In [71]:
def load_array(features, labels, batch_size, is_train=True):
    dataset = data.TensorDataset(features, labels)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)

In [None]:
# Check GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Set hyperparameter
bit_width = 1
EPOCH = 50
ANNEAL_EPOCH = 20
pre_epoch = 0
BATCH_SIZE = 200
LR = 0.06
lr_decay_epochs = [7, 15, 25, 30]
# ASkewSGD
DECAY_CONST = 0.88
alpha = 0.004
# ProxQuant
reg_lambda = 4e-6


# Generate training and testing dataset
transform_train = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]
)

transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]
)

trainset = torchvision.datasets.MNIST(
    root="../data", train=True, download=True, transform=transform_train
)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)

testset = torchvision.datasets.MNIST(
    root="../data", train=False, download=True, transform=transform_test
)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)
# Define base net
base_net = MLPMNISTNet(bit_width=bit_width).to(device)
# Define loss
criterion = nn.CrossEntropyLoss()

In [73]:
def init(project_name, opt_name, batch_size, architecture, dataset_name, lr, alpha=None, reg_lambda=None):
    wandb.init(
        # set the wandb project where this run will be logged
        project=project_name,
        name=opt_name,
        # track hyperparameters and run metadata
        config={
            "batch_size": batch_size,
            "architecture": architecture,
            "dataset": dataset_name,
            "lr": lr,
            "alpha": alpha, 
            "reg_lambda": reg_lambda,
            "bit_width": base_net.bit_width
        },
    )
    net = copy.deepcopy(base_net)
    net.to(device)
    weights = [p for name, p in net.named_parameters()]
    # bias = [p for name, p in net.named_parameters() if 'bias' in name]
    parameters = [{"params": weights, "tag": "weights"}, 
                #   {"params": bias, "tag": "bias"}
                ]
    optimizer = optim.SGD(parameters, lr=lr)
    return net, optimizer

In [None]:
# SGD
net, optimizer = init(project_name="MNIST_binary", opt_name="SGD", batch_size=BATCH_SIZE, architecture="MLP", dataset_name="MNIST", lr=LR)

best_acc=0

for decay_epoch in lr_decay_epochs:
    if pre_epoch > decay_epoch:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

# Train
for epoch in range(pre_epoch, EPOCH):
    print("\nEpoch: %d" % (epoch + 1))
    net.train()
    sum_loss = 0.0
    correct = 0.0
    total = 0.0

    if epoch in lr_decay_epochs:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

    for i, data in enumerate(trainloader, 0):
        # prepare dataset
        length = len(trainloader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # forward & backward
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        sum_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()
        print(
            "[Epoch:%d, Iter:%d] Loss: %.03f | Acc: %.3f%% "
            % (
                epoch + 1,
                (i + 1 + (epoch) * length),
                sum_loss / (i + 1),
                100.0 * correct / total,
            )
        )
    print("Waiting Test...")
    with torch.no_grad():
        train_loss, train_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=False)
        qtrain_loss, qtrain_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=True)
        test_loss, test_acc=net.evaluate(testloader, criterion, device, eval=True, qt=False)
        qtest_loss, qtest_acc=net.evaluate(testloader, criterion, device, eval=True, qt=True)
        wandb.log(
            {
                "test_loss": test_loss,
                "quantized_test_loss": qtest_loss,
                "accuracy": train_acc,
                "quantized_accuracy": qtrain_acc,
                "test_accuracy": test_acc,
                "quantized_test_accuracy": qtest_acc
            }
        )
        print(
            "Train Loss: %.03f | Train Acc: %.3f%% | Test Loss: %.03f | Test Acc: %.3f%% "
            % (
                train_loss,
                train_acc,
                test_loss,
                test_acc,
            )
        )
        print(
            "Quantized Train Loss: %.03f | Quantized Train Acc: %.3f%% | Quantized Test Loss: %.03f | Quantized Test Acc: %.3f%% "
            % (
                qtrain_loss,
                qtrain_acc,
                qtest_loss,
                qtest_acc,
            )
        )
        FILE = "MNIST_SGD.pt"
        if qtest_acc>best_acc:
            torch.save(
                {
                    "model_state_dict": net.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                FILE,
            )
            best_acc=qtest_acc
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…


Epoch: 1
[Epoch:1, Iter:1] Loss: 2.310 | Acc: 8.000% 
[Epoch:1, Iter:2] Loss: 2.298 | Acc: 9.750% 
[Epoch:1, Iter:3] Loss: 2.285 | Acc: 14.000% 
[Epoch:1, Iter:4] Loss: 2.270 | Acc: 20.000% 
[Epoch:1, Iter:5] Loss: 2.257 | Acc: 22.800% 
[Epoch:1, Iter:6] Loss: 2.244 | Acc: 26.000% 
[Epoch:1, Iter:7] Loss: 2.231 | Acc: 29.286% 
[Epoch:1, Iter:8] Loss: 2.217 | Acc: 32.562% 
[Epoch:1, Iter:9] Loss: 2.205 | Acc: 34.556% 
[Epoch:1, Iter:10] Loss: 2.191 | Acc: 36.500% 
[Epoch:1, Iter:11] Loss: 2.176 | Acc: 38.727% 
[Epoch:1, Iter:12] Loss: 2.162 | Acc: 40.292% 
[Epoch:1, Iter:13] Loss: 2.144 | Acc: 42.385% 
[Epoch:1, Iter:14] Loss: 2.130 | Acc: 43.536% 
[Epoch:1, Iter:15] Loss: 2.112 | Acc: 44.733% 
[Epoch:1, Iter:16] Loss: 2.095 | Acc: 45.531% 
[Epoch:1, Iter:17] Loss: 2.077 | Acc: 46.853% 
[Epoch:1, Iter:18] Loss: 2.059 | Acc: 47.861% 
[Epoch:1, Iter:19] Loss: 2.040 | Acc: 48.658% 
[Epoch:1, Iter:20] Loss: 2.023 | Acc: 49.675% 
[Epoch:1, Iter:21] Loss: 2.003 | Acc: 50.810% 
[Epoch:1, Iter

KeyboardInterrupt: 

In [None]:
# ProxQuant
net, optimizer = init(project_name="MNIST_binary", opt_name="ProxQuant", batch_size=BATCH_SIZE, architecture="MLP", dataset_name="MNIST", lr=LR, reg_lambda=reg_lambda)

best_acc=0

lr=LR
for decay_epoch in lr_decay_epochs:
    if pre_epoch > decay_epoch:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5
            lr*=0.5

# Train
it=-1
total_it=(EPOCH-ANNEAL_EPOCH)*len(trainloader)
for epoch in range(0, EPOCH):
    print("\nEpoch: %d" % (epoch + 1))
    sum_loss = 0.0
    correct = 0.0
    total = 0.0
    if epoch in lr_decay_epochs:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5
            lr*=0.5

    for i, data in enumerate(trainloader, 0):
        # prepare dataset
        if epoch < ANNEAL_EPOCH:
            epsilon=0
        else:
            it+=1
            epsilon = reg_lambda*it/total_it
        net.train()
        length = len(trainloader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # forward & backward
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        sum_loss += loss.item()
        optimizer.step()
        with torch.no_grad():
            for name, param in net.named_parameters():
                # if not name.endswith(".bias"):
                    # Prox Step
                param.data=(param.data+epsilon*torch.sign(param.data))/(1+epsilon)
        optimizer.zero_grad()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()
        print(
            "[Epoch:%d, Iter:%d] Loss: %.03f | Acc: %.3f%% "
            % (
                epoch + 1,
                (i + 1 + (epoch) * length),
                sum_loss / (i + 1),
                100.0 * correct / total,
            )
        )
        
    print("Waiting Test...")
    with torch.no_grad():
        train_loss, train_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=False)
        qtrain_loss, qtrain_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=True)
        test_loss, test_acc=net.evaluate(testloader, criterion, device, eval=True, qt=False)
        qtest_loss, qtest_acc=net.evaluate(testloader, criterion, device, eval=True, qt=True)
        wandb.log(
            {
                "test_loss": test_loss,
                "quantized_test_loss": qtest_loss,
                "accuracy": train_acc,
                "quantized_accuracy": qtrain_acc,
                "test_accuracy": test_acc,
                "quantized_test_accuracy": qtest_acc
            }
        )
        print(
            "Train Loss: %.03f | Train Acc: %.3f%% | Test Loss: %.03f | Test Acc: %.3f%% "
            % (
                train_loss,
                train_acc,
                test_loss,
                test_acc,
            )
        )
        print(
            "Quantized Train Loss: %.03f | Quantized Train Acc: %.3f%% | Quantized Test Loss: %.03f | Quantized Test Acc: %.3f%% "
            % (
                qtrain_loss,
                qtrain_acc,
                qtest_loss,
                qtest_acc,
            )
        )
        FILE = "MNIST_ProxQuant_W1A32.pt"
        if qtest_acc>best_acc:
            torch.save(
                {
                    "model_state_dict": net.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                FILE,
            )
            best_acc=qtest_acc
wandb.finish()


Epoch: 1
[Epoch:1, Iter:1] Loss: 2.321 | Acc: 12.000% 
[Epoch:1, Iter:2] Loss: 2.302 | Acc: 17.500% 
[Epoch:1, Iter:3] Loss: 2.289 | Acc: 22.167% 
[Epoch:1, Iter:4] Loss: 2.277 | Acc: 24.000% 
[Epoch:1, Iter:5] Loss: 2.267 | Acc: 25.500% 
[Epoch:1, Iter:6] Loss: 2.256 | Acc: 28.083% 
[Epoch:1, Iter:7] Loss: 2.242 | Acc: 31.786% 
[Epoch:1, Iter:8] Loss: 2.230 | Acc: 34.312% 
[Epoch:1, Iter:9] Loss: 2.215 | Acc: 37.389% 
[Epoch:1, Iter:10] Loss: 2.204 | Acc: 39.400% 
[Epoch:1, Iter:11] Loss: 2.188 | Acc: 41.545% 
[Epoch:1, Iter:12] Loss: 2.173 | Acc: 43.417% 
[Epoch:1, Iter:13] Loss: 2.157 | Acc: 44.308% 
[Epoch:1, Iter:14] Loss: 2.140 | Acc: 45.679% 
[Epoch:1, Iter:15] Loss: 2.124 | Acc: 46.567% 
[Epoch:1, Iter:16] Loss: 2.111 | Acc: 46.906% 
[Epoch:1, Iter:17] Loss: 2.095 | Acc: 47.882% 
[Epoch:1, Iter:18] Loss: 2.078 | Acc: 48.972% 
[Epoch:1, Iter:19] Loss: 2.058 | Acc: 50.158% 
[Epoch:1, Iter:20] Loss: 2.035 | Acc: 51.150% 
[Epoch:1, Iter:21] Loss: 2.015 | Acc: 52.190% 
[Epoch:1, It

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁▃▅▅▆▇▇▇▇▇▇████████████████████████████▇
quantized_accuracy,▁▄▅▆▇▇▇█▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆
quantized_test_accuracy,▁▃▅▆▇▇▇█▇▇▆▆▆▇▆▇▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅
quantized_test_loss,█▅▄▂▂▁▁▁▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
test_accuracy,▁▄▅▆▇▇█▇████████████████████▇▇▇▇█▇▇▇▇▇▇▇
test_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▃▃▃▄▅▆█

0,1
accuracy,98.86167
quantized_accuracy,94.51
quantized_test_accuracy,93.77
quantized_test_loss,899.62406
test_accuracy,96.89
test_loss,10.66552


In [None]:
# ASkewSGD
net, optimizer = init(project_name="MNIST_binary", opt_name="ASkewSGD", batch_size=BATCH_SIZE, architecture="MLP", dataset_name="MNIST", lr=LR, alpha=alpha)

best_acc=0

for decay_epoch in lr_decay_epochs:
    if pre_epoch > decay_epoch:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

# Train
for epoch in range(0, EPOCH):
    print("\nEpoch: %d" % (epoch + 1))
    net.train()
    sum_loss = 0.0
    correct = 0.0
    total = 0.0

    if epoch in lr_decay_epochs:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

    if epoch < ANNEAL_EPOCH:
        epsilon = 1
    else:
        epsilon =  DECAY_CONST ** ((epoch - ANNEAL_EPOCH)+(i/len(trainloader)))
    
    for i, data in enumerate(trainloader, 0):
        # prepare dataset
        length = len(trainloader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # forward & backward
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        for param_group in optimizer.param_groups:
            if param_group["tag"] == "weights":
                for idx, p in enumerate(param_group["params"]):
                    constr = epsilon - (p.data**2 - 1) ** 2
                    Kx = -4 * (p.data**2 - 1) * p.data
                    direct_grad = torch.logical_or(
                        torch.logical_or(constr >= 0, Kx == 0),
                        torch.logical_and(
                            constr < 0, (-Kx * p.grad.data) >= -alpha * constr
                        ),
                    )
                    p.grad.data[direct_grad] = p.grad.data[direct_grad]
                    p.grad.data[~direct_grad] = torch.clip(
                        alpha * constr / Kx,
                        -2 / param_group["lr"],
                        2 / param_group["lr"],
                    )[~direct_grad]
        
        optimizer.step()
        optimizer.zero_grad()
        with torch.no_grad():
            for name, param in net.named_parameters():
                torch.clamp_(param.data, -1.2, 1.2)
        sum_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()
        print(
            "[Epoch:%d, Iter:%d] Loss: %.03f | Acc: %.3f%% "
            % (
                epoch + 1,
                (i + 1 + (epoch) * length),
                sum_loss / (i + 1),
                100.0 * correct / total,
            )
        )
        
    print("Waiting Test...")
    with torch.no_grad():
        train_loss, train_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=False)
        qtrain_loss, qtrain_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=True)
        test_loss, test_acc=net.evaluate(testloader, criterion, device, eval=True, qt=False)
        qtest_loss, qtest_acc=net.evaluate(testloader, criterion, device, eval=True, qt=True)
        wandb.log(
            {
                "test_loss": test_loss,
                "quantized_test_loss": qtest_loss,
                "accuracy": train_acc,
                "quantized_accuracy": qtrain_acc,
                "test_accuracy": test_acc,
                "quantized_test_accuracy": qtest_acc
            }
        )
        print(
            "Train Loss: %.03f | Train Acc: %.3f%% | Test Loss: %.03f | Test Acc: %.3f%% "
            % (
                train_loss,
                train_acc,
                test_loss,
                test_acc,
            )
        )
        print(
            "Quantized Train Loss: %.03f | Quantized Train Acc: %.3f%% | Quantized Test Loss: %.03f | Quantized Test Acc: %.3f%% "
            % (
                qtrain_loss,
                qtrain_acc,
                qtest_loss,
                qtest_acc,
            )
        )
        FILE = "MNIST_ASkewSGD_W1A32.pt"
        if qtest_acc>best_acc:
            torch.save(
                {
                    "model_state_dict": net.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                FILE,
            )
            best_acc=qtest_acc
wandb.finish()


Epoch: 1
[Epoch:1, Iter:1] Loss: 2.313 | Acc: 17.000% 
[Epoch:1, Iter:2] Loss: 2.298 | Acc: 17.250% 
[Epoch:1, Iter:3] Loss: 2.287 | Acc: 22.500% 
[Epoch:1, Iter:4] Loss: 2.274 | Acc: 26.250% 
[Epoch:1, Iter:5] Loss: 2.262 | Acc: 29.700% 
[Epoch:1, Iter:6] Loss: 2.250 | Acc: 32.667% 
[Epoch:1, Iter:7] Loss: 2.236 | Acc: 35.286% 
[Epoch:1, Iter:8] Loss: 2.220 | Acc: 38.000% 
[Epoch:1, Iter:9] Loss: 2.207 | Acc: 39.444% 
[Epoch:1, Iter:10] Loss: 2.195 | Acc: 40.600% 
[Epoch:1, Iter:11] Loss: 2.183 | Acc: 41.955% 
[Epoch:1, Iter:12] Loss: 2.168 | Acc: 44.042% 
[Epoch:1, Iter:13] Loss: 2.152 | Acc: 45.269% 
[Epoch:1, Iter:14] Loss: 2.136 | Acc: 46.214% 
[Epoch:1, Iter:15] Loss: 2.119 | Acc: 47.400% 
[Epoch:1, Iter:16] Loss: 2.101 | Acc: 48.312% 
[Epoch:1, Iter:17] Loss: 2.080 | Acc: 49.588% 
[Epoch:1, Iter:18] Loss: 2.058 | Acc: 50.556% 
[Epoch:1, Iter:19] Loss: 2.039 | Acc: 51.000% 
[Epoch:1, Iter:20] Loss: 2.017 | Acc: 51.975% 
[Epoch:1, Iter:21] Loss: 1.994 | Acc: 52.738% 
[Epoch:1, It

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁▄▅▅▇▇▇▇▇███████▇▇▇▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
quantized_accuracy,▆▇▇█████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
quantized_test_accuracy,▆▇▇█████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
quantized_test_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██
test_accuracy,▁▄▅▆▇▇▇█████████▇▇▇▆▆▇▆▆▇▇▆▇▇▇▇▇▇▇▇▇▆▇▇▇
test_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▃▃▄▄▄▄▅▅▅▅▆▆▇▆▇▆▇▇▇█

0,1
accuracy,98.08667
quantized_accuracy,78.11667
quantized_test_accuracy,77.79
quantized_test_loss,20995.81511
test_accuracy,96.84
test_loss,240.43269


In [None]:
# Deterministic BinaryConnect
net, optimizer = init(project_name="MNIST_binary", opt_name="Deterministic BinaryConnect", batch_size=BATCH_SIZE, architecture="MLP", dataset_name="MNIST", lr=LR/300)

best_acc=0

model_copy = copy.deepcopy(net)

for decay_epoch in lr_decay_epochs:
    if pre_epoch > decay_epoch:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

# Train
for epoch in range(pre_epoch, EPOCH):
    print("\nEpoch: %d" % (epoch + 1))
    net.train()
    model_copy.train()
    sum_loss = 0.0
    correct = 0.0
    total = 0.0

    if epoch in lr_decay_epochs:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

    for i, data in enumerate(trainloader, 0):
        # prepare dataset
        length = len(trainloader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # forward & backward
        with torch.no_grad():
            outputs = model_copy(inputs)
            loss2 = criterion(outputs, labels)
            for net_name, net_param in net.named_parameters():
                # if not net_name.endswith(".bias"):
                net_param.data=net.quantize(net_param.data, net.bit_width)

        outputs2 = net(inputs)
        loss = criterion(outputs2, labels)
        loss.backward()
        optimizer.step()
        for (net_name, net_param), (model_copy_name, model_copy_param) in zip(
            net.named_parameters(), model_copy.named_parameters()
        ):
            # if not net_name.endswith(".bias"):
            delta = net_param.data - model_copy.quantize(model_copy_param.data, model_copy.bit_width)
            net_param.data = torch.clamp(model_copy_param.data + delta, -1, 1)
            model_copy_param.data = torch.clamp(model_copy_param.data + delta, -1, 1)
            # else:
            # model_copy_param.data = net_param.data
        sum_loss += loss.item()
        optimizer.zero_grad()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()
        print(
            "[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% "
            % (
                epoch + 1,
                (i + 1 + (epoch) * length),
                sum_loss / (i + 1),
                100.0 * correct / total,
            )
        )
    print("Waiting Test...")
    with torch.no_grad():
        train_loss, train_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=False)
        qtrain_loss, qtrain_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=True)
        test_loss, test_acc=net.evaluate(testloader, criterion, device, eval=True, qt=False)
        qtest_loss, qtest_acc=net.evaluate(testloader, criterion, device, eval=True, qt=True)
        wandb.log(
            {
                "test_loss": test_loss,
                "quantized_test_loss": qtest_loss,
                "accuracy": train_acc,
                "quantized_accuracy": qtrain_acc,
                "test_accuracy": test_acc,
                "quantized_test_accuracy": qtest_acc
            }
        )
        print(
            "Train Loss: %.03f | Train Acc: %.3f%% | Test Loss: %.03f | Test Acc: %.3f%% "
            % (
                train_loss,
                train_acc,
                test_loss,
                test_acc,
            )
        )
        print(
            "Quantized Train Loss: %.03f | Quantized Train Acc: %.3f%% | Quantized Test Loss: %.03f | Quantized Test Acc: %.3f%% "
            % (
                qtrain_loss,
                qtrain_acc,
                qtest_loss,
                qtest_acc,
            )
        )
        FILE = "MNIST_ASkewSGD_W1A32.pt"
        if qtest_acc>best_acc:
            torch.save(
                {
                    "model_state_dict": net.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                FILE,
            )
            best_acc=qtest_acc
wandb.finish()


Epoch: 1
[epoch:1, iter:1] Loss: 11323.986 | Acc: 15.000% 
[epoch:1, iter:2] Loss: 18845.202 | Acc: 14.250% 
[epoch:1, iter:3] Loss: 20819.773 | Acc: 17.167% 
[epoch:1, iter:4] Loss: 19717.407 | Acc: 21.000% 
[epoch:1, iter:5] Loss: 17563.655 | Acc: 22.400% 
[epoch:1, iter:6] Loss: 15302.045 | Acc: 22.833% 
[epoch:1, iter:7] Loss: 13361.488 | Acc: 26.500% 
[epoch:1, iter:8] Loss: 11849.768 | Acc: 28.812% 
[epoch:1, iter:9] Loss: 10681.779 | Acc: 31.389% 
[epoch:1, iter:10] Loss: 9698.418 | Acc: 32.750% 
[epoch:1, iter:11] Loss: 8910.463 | Acc: 33.773% 
[epoch:1, iter:12] Loss: 8246.743 | Acc: 34.750% 
[epoch:1, iter:13] Loss: 7690.311 | Acc: 35.308% 
[epoch:1, iter:14] Loss: 7194.838 | Acc: 36.393% 
[epoch:1, iter:15] Loss: 6773.544 | Acc: 36.867% 
[epoch:1, iter:16] Loss: 6396.477 | Acc: 37.156% 
[epoch:1, iter:17] Loss: 6059.356 | Acc: 37.824% 
[epoch:1, iter:18] Loss: 5758.129 | Acc: 38.472% 
[epoch:1, iter:19] Loss: 5483.001 | Acc: 38.711% 
[epoch:1, iter:20] Loss: 5251.104 | Acc:

KeyboardInterrupt: 