In [1]:
import torch
from torch.utils import data
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import sklearn.datasets
import time
import matplotlib.pyplot as plt
import numpy as np
import wandb
import copy
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
class Timer:
    def __init__(self):
        self.times = []
        self.start()

    def start(self):
        self.tik = time.time()

    def stop(self):
        self.times.append(time.time() - self.tik)
        return self.times[-1]

    def avg(self):
        return sum(self.times) / len(self.times)

    def sum(self):
        return sum(self.times)

In [3]:
class MLP(nn.Module):
    def __init__(self, num_classes=10):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(784, 1024, bias=False)
        self.fc2 = nn.Linear(1024, 1024, bias=False)
        self.fc3 = nn.Linear(1024, 1024, bias=False)
        self.fc4 = nn.Linear(1024, num_classes, bias=False)

    def forward(self, x):
        out = x.view(x.size(0), -1)
        out = self.fc1(out)
        out = F.relu(out)
        out = self.fc2(out)
        out = F.relu(out)
        out = self.fc3(out)
        out = F.relu(out)
        out = self.fc4(out)
        out = F.log_softmax(out, dim=1)
        return out


class QuantizeLayer(nn.Module):
    def __init__(self, layer):
        super(QuantizeLayer, self).__init__()
        self.layer = layer

    def forward(self, x):
        if isinstance(self.layer, nn.Conv2d):
            weight_q = torch.sign(self.layer.weight)
            if self.layer.bias is not None:
                bias_q = torch.sign(self.layer.bias)
                return nn.functional.conv2d(
                    x,
                    weight_q,
                    bias_q,
                    self.layer.stride,
                    self.layer.padding,
                    self.layer.dilation,
                    self.layer.groups,
                )
            else:
                return nn.functional.conv2d(
                    x,
                    weight_q,
                    None,
                    self.layer.stride,
                    self.layer.padding,
                    self.layer.dilation,
                    self.layer.groups,
                )
        elif isinstance(self.layer, nn.Linear):
            weight_q = torch.sign(self.layer.weight)
            if self.layer.bias is not None:
                bias_q = torch.sign(self.layer.bias)
                return nn.functional.linear(x, weight_q, bias_q)
            else:
                return nn.functional.linear(x, weight_q, None)
        else:
            return self.layer(x)

In [4]:
def MLPMNISTNet():
    return MLP(num_classes=10)

In [5]:
def load_array(features, labels, batch_size, is_train=True):
    dataset = data.TensorDataset(features, labels)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)

In [6]:
def set_size(width, fraction=1):
    """Set figure dimensions to avoid scaling in LaTeX.

    Parameters
    ----------
    width: float
            Document textwidth or columnwidth in pts
    fraction: float, optional
            Fraction of the width which you wish the figure to occupy

    Returns
    -------
    fig_dim: tuple
            Dimensions of figure in inches
    """
    # Width of figure (in pts)
    fig_width_pt = width * fraction

    # Convert from pt to inches
    inches_per_pt = 1 / 72.27

    # Golden ratio to set aesthetic figure height
    # https://disq.us/p/2940ij3
    golden_ratio = (5**0.5 - 1) / 2

    # Figure width in inches
    fig_width_in = fig_width_pt * inches_per_pt
    # Figure height in inches
    fig_height_in = fig_width_in * golden_ratio

    fig_dim = (fig_width_in, fig_height_in)

    return fig_dim

In [7]:
def init(project_name, opt_name, batch_size, architecture, dataset_name, lr, alpha):
    wandb.init(
        # set the wandb project where this run will be logged
        project=project_name,
        name=opt_name,
        # track hyperparameters and run metadata
        config={
            "batch_size": batch_size,
            "architecture": architecture,
            "dataset": dataset_name,
            "lr": lr,
        },
    )

In [8]:
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import argparse
import os

# Check GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  
# Set hyperparameter
EPOCH = 50
ANNEAL_EPOCH = 10
pre_epoch = 0
BATCH_SIZE = 200
LR = 0.001
alpha = 0.005

# Generate training and testing dataset
# prepare dataset and preprocessing
transform_train = transforms.Compose(
    [
        transforms.ToTensor(),
    ]
)

transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
    ]
)

trainset = torchvision.datasets.MNIST(
    root="../data", train=True, download=True, transform=transform_train
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=BATCH_SIZE, shuffle=True
)

testset = torchvision.datasets.MNIST(
    root="../data", train=False, download=True, transform=transform_test
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False
)
# Define base net
base_net = MLPMNISTNet().to(device)

In [21]:
# SGD

init(project_name="exp4", opt_name="SGD", batch_size=200, architecture="MLP", dataset_name="MNIST", lr=LR, alpha=0)

# Define neural network
net = copy.deepcopy(base_net)

# Define loss funtion & optimizer
criterion = nn.NLLLoss()
weights = [p for name, p in net.named_parameters() if 'bias' not in name]
bias = [p for name, p in net.named_parameters() if 'bias' in name]
parameters = [{"params": weights, "tag": "weights"}, {"params": bias, "tag": "bias"}]
optimizer = optim.Adam(parameters, lr=LR, betas=(0.9, 0.999))
lr_decay_epochs = [5, 10, 15, 20, 25]
for decay_epoch in lr_decay_epochs:
    if pre_epoch > decay_epoch:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

# Train
for epoch in range(pre_epoch, EPOCH):
    print("\nEpoch: %d" % (epoch + 1))
    net.train()
    sum_loss = 0.0
    correct = 0.0
    total = 0.0
    correct2 = 0.0
    total2 = 0.0

    if epoch in lr_decay_epochs:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

    for i, data in enumerate(trainloader, 0):
        # prepare dataset
        length = len(trainloader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # forward & backward
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        sum_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()
        print(
            "[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% "
            % (
                epoch + 1,
                (i + 1 + (epoch) * length),
                sum_loss / (i + 1),
                100.0 * correct / total,
            )
        )

        model_copy = copy.deepcopy(net)
        with torch.no_grad():
            for name, param in model_copy.named_parameters():
                if not name.endswith(".bias"):
                    param.data = torch.sign(param.data)
        outputs = model_copy(inputs)
        loss2 = criterion(outputs, labels)
        _, predicted = torch.max(outputs.data, 1)
        total2 += labels.size(0)
        correct2 += predicted.eq(labels.data).cpu().sum()
        wandb.log(
            {
                "quantized loss": loss2,
                "loss": loss,
                "accuracy": (100 * correct / total),
                "quantized_accuracy": (100 * correct2 / total2),
            }
        )
        
    print("Waiting Test...")
    with torch.no_grad():
        correct = 0
        total = 0
        for data in testloader:
            net.eval()
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()
        print(
            "Test's accuracy (before quantization) is: %.3f%%" % (100 * correct / total)
        )
    model_copy = copy.deepcopy(net)
    with torch.no_grad():
        for name, param in model_copy.named_parameters():
            if not name.endswith(".bias"):
                param.data = torch.sign(param.data)
        correct = 0
        total = 0
        for data in testloader:
            model_copy.eval()
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model_copy(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()
        print(
            "Test's accuracy (after quantization) is: %.3f%%" % (100 * correct / total)
        )
wandb.finish()

0,1
accuracy,▁▃▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇███████████████████
loss,█▅▄▄▄▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▂▂▂▂▂▂▁▁▂▂▂▂▂▁▁▁▁
quantized loss,▇█▄▅▅▃▃▃▂▂▃▂▃▂▃▂▂▃▂▂▂▂▂▂▂▃▃▂▂▂▂▂▃▄▂▂▂▁▁▂
quantized_accuracy,▁▃▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇███████████████████

0,1
accuracy,97.37674
loss,0.07268
quantized loss,73148.63281
quantized_accuracy,95.42882



Epoch: 1
[epoch:1, iter:1] Loss: 2.302 | Acc: 10.500% 
[epoch:1, iter:2] Loss: 2.277 | Acc: 9.750% 
[epoch:1, iter:3] Loss: 2.235 | Acc: 21.667% 
[epoch:1, iter:4] Loss: 2.162 | Acc: 27.750% 
[epoch:1, iter:5] Loss: 2.054 | Acc: 33.800% 
[epoch:1, iter:6] Loss: 1.939 | Acc: 39.833% 
[epoch:1, iter:7] Loss: 1.836 | Acc: 42.286% 
[epoch:1, iter:8] Loss: 1.735 | Acc: 45.812% 
[epoch:1, iter:9] Loss: 1.647 | Acc: 48.556% 
[epoch:1, iter:10] Loss: 1.568 | Acc: 50.650% 
[epoch:1, iter:11] Loss: 1.497 | Acc: 53.000% 
[epoch:1, iter:12] Loss: 1.433 | Acc: 54.750% 
[epoch:1, iter:13] Loss: 1.368 | Acc: 57.038% 
[epoch:1, iter:14] Loss: 1.325 | Acc: 58.500% 
[epoch:1, iter:15] Loss: 1.274 | Acc: 60.000% 
[epoch:1, iter:16] Loss: 1.231 | Acc: 61.156% 
[epoch:1, iter:17] Loss: 1.197 | Acc: 62.412% 
[epoch:1, iter:18] Loss: 1.168 | Acc: 63.472% 
[epoch:1, iter:19] Loss: 1.141 | Acc: 64.342% 
[epoch:1, iter:20] Loss: 1.107 | Acc: 65.525% 
[epoch:1, iter:21] Loss: 1.075 | Acc: 66.548% 
[epoch:1, ite

0,1
accuracy,▁▆▇▇████████████████████████████████████
loss,█▃▁▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
quantized loss,█▃▁▃▃▃▂▄▂▃▃▅▇▄▃▃▅▂▄▄▄▇▃▄▄▄▄▃▆▄▆▇▃▆▂▅▄▅▅▄
quantized_accuracy,▁██▇▆█▇▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▇▇▇▇▆▆▆▇▇▇▆▇▇▇▇▇▇

0,1
accuracy,100.0
loss,0.0
quantized loss,137386.5625
quantized_accuracy,94.845


In [12]:
# ASkewSGD
init(project_name="exp4", opt_name="ASkewSGD", batch_size=200, architecture="MLP", dataset_name="MNIST", lr=LR/10, alpha=alpha)

# Define neural network
net = copy.deepcopy(base_net)
# Define loss funtion & optimizer
criterion = nn.NLLLoss()
weights = [p for name, p in net.named_parameters() if 'bias' not in name]
bias = [p for name, p in net.named_parameters() if 'bias' in name]
parameters = [{"params": weights, "tag": "weights"}, {"params": bias, "tag": "bias"}]
optimizer = optim.Adam(parameters, lr=LR/10, betas=(0.9, 0.999))
lr_decay_epochs = [5, 10, 15, 20, 25, 35]
for decay_epoch in lr_decay_epochs:
    if pre_epoch > decay_epoch:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

# Train
for epoch in range(0, EPOCH):
    print("\nEpoch: %d" % (epoch + 1))
    net.train()
    sum_loss = 0.0
    correct = 0.0
    total = 0.0
    correct2 = 0.0
    total2 = 0.0
    if epoch < 25:
        epsilon = 1
    else:
        epsilon = 0.9 ** (epoch - 25)

    if epoch in lr_decay_epochs:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

    for i, data in enumerate(trainloader, 0):
        # prepare dataset
        length = len(trainloader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # forward & backward
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        model_copy = copy.deepcopy(net)
        with torch.no_grad():
            for name, param in model_copy.named_parameters():
                if not name.endswith(".bias"):
                    param.data = torch.sign(param.data)
            outputs2 = model_copy(inputs)
            loss2 = criterion(outputs2, labels)
        
        for param_group in optimizer.param_groups:
            if param_group["tag"] == "weights":
                for idx, p in enumerate(param_group["params"]):
                    constr = epsilon - (p.data**2 - 1) ** 2
                    Kx = -4 * (p.data**2 - 1) * p.data
                    direct_grad = torch.logical_or(
                        torch.logical_or(constr >= 0, Kx == 0),
                        torch.logical_and(
                            constr < 0, (-Kx * p.grad.data) >= -alpha * constr
                        ),
                    )
                    p.grad.data[direct_grad] = p.grad.data[direct_grad]
                    p.grad.data[~direct_grad] = (torch.clip(
                        alpha * constr / Kx,
                        -1,
                        1,
                    ))[~direct_grad]
        
        optimizer.step()
        optimizer.zero_grad()
        sum_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()
        _, predicted = torch.max(outputs2.data, 1)
        total2 += labels.size(0)
        correct2 += predicted.eq(labels.data).cpu().sum()
        wandb.log(
            {
                "quantized loss": loss2,
                "loss": loss,
                "accuracy": (100 * correct / total),
                "quantized_accuracy": (100 * correct2 / total2),
            }
        )
        print(
            "[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% "
            % (
                epoch + 1,
                (i + 1 + (epoch) * length),
                sum_loss / (i + 1),
                100.0 * correct / total,
            )
        )
    print("Waiting Test...")
    with torch.no_grad():
        correct = 0
        total = 0
        for data in testloader:
            net.eval()
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()
        print(
            "Test's accuracy (before quantization) is: %.3f%%" % (100 * correct / total)
        )
    model_copy = copy.deepcopy(net)
    with torch.no_grad():
        for name, param in model_copy.named_parameters():
            if not name.endswith(".bias"):
                param.data = torch.sign(param.data)
        correct = 0
        total = 0
        for data in testloader:
            model_copy.eval()
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model_copy(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()
        print(
            "Test's accuracy (after quantization) is: %.3f%%" % (100 * correct / total)
        )
wandb.finish()


Epoch: 1
[epoch:1, iter:1] Loss: 2.301 | Acc: 12.500% 
[epoch:1, iter:2] Loss: 2.298 | Acc: 19.000% 
[epoch:1, iter:3] Loss: 2.293 | Acc: 26.500% 
[epoch:1, iter:4] Loss: 2.290 | Acc: 29.125% 
[epoch:1, iter:5] Loss: 2.286 | Acc: 31.800% 
[epoch:1, iter:6] Loss: 2.282 | Acc: 34.167% 
[epoch:1, iter:7] Loss: 2.278 | Acc: 35.429% 
[epoch:1, iter:8] Loss: 2.274 | Acc: 38.000% 
[epoch:1, iter:9] Loss: 2.268 | Acc: 40.278% 
[epoch:1, iter:10] Loss: 2.264 | Acc: 42.000% 
[epoch:1, iter:11] Loss: 2.258 | Acc: 44.227% 
[epoch:1, iter:12] Loss: 2.252 | Acc: 46.000% 
[epoch:1, iter:13] Loss: 2.246 | Acc: 47.192% 
[epoch:1, iter:14] Loss: 2.239 | Acc: 48.393% 
[epoch:1, iter:15] Loss: 2.232 | Acc: 49.233% 
[epoch:1, iter:16] Loss: 2.224 | Acc: 50.156% 
[epoch:1, iter:17] Loss: 2.216 | Acc: 50.882% 
[epoch:1, iter:18] Loss: 2.208 | Acc: 51.667% 
[epoch:1, iter:19] Loss: 2.200 | Acc: 52.368% 
[epoch:1, iter:20] Loss: 2.191 | Acc: 53.025% 
[epoch:1, iter:21] Loss: 2.180 | Acc: 53.643% 
[epoch:1, it

0,1
accuracy,▁▆▇▇▇███████████████████████████████████
loss,█▅▅▄▂▂▃▂▁▁▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▂▁▁▂▁▁▁▂▂▁▁▁▁▁
quantized loss,█▆█▆▃▃▄▃▄▃▅▂▅▁▅▂▅▁▄▃▃▄▂▂▁▂▃▃▂▅▂▄▃▁▄▄▂▃▄▃
quantized_accuracy,▁▆▇█████████████████████████████████████

0,1
accuracy,99.575
loss,1e-05
quantized loss,35855.30859
quantized_accuracy,96.09834


In [10]:
# Deterministic BinaryConnect
init(project_name="exp4", opt_name="Deterministic BinaryConnect", batch_size=200, architecture="MLP", dataset_name="MNIST", lr=LR/100, alpha=0)

# Define neural network
net = copy.deepcopy(base_net)
model_copy = copy.deepcopy(net)

# Define loss funtion & optimizer
criterion = nn.NLLLoss()
weights = [p for name, p in net.named_parameters() if "bias" not in name]
bias = [p for name, p in net.named_parameters() if "bias" in name]
parameters = [{"params": weights, "tag": "weights"}, {"params": bias, "tag": "bias"}]
optimizer = optim.SGD(parameters, lr=LR/100)
lr_decay_epochs = [30, 60]
for decay_epoch in lr_decay_epochs:
    if pre_epoch > decay_epoch:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

# Train
for epoch in range(pre_epoch, EPOCH):
    print("\nEpoch: %d" % (epoch + 1))
    net.train()
    model_copy.train()
    sum_loss = 0.0
    correct = 0.0
    total = 0.0
    correct2 = 0.0
    total2 = 0.0
    if epoch in lr_decay_epochs:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

    for i, data in enumerate(trainloader, 0):
        # prepare dataset
        length = len(trainloader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # forward & backward
        with torch.no_grad():
            outputs = model_copy(inputs)
            loss2 = criterion(outputs, labels)
            for net_name, net_param in net.named_parameters():
                if not net_name.endswith(".bias"):
                    net_param.data = torch.sign(net_param.data)

        outputs2 = net(inputs)
        loss = criterion(outputs2, labels)
        loss.backward()
        optimizer.step()
        for (net_name, net_param), (model_copy_name, model_copy_param) in zip(
            net.named_parameters(), model_copy.named_parameters()
        ):
            if not net_name.endswith(".bias"):
                delta = net_param.data - torch.sign(model_copy_param.data)
                net_param.data = torch.clamp(model_copy_param.data + delta, -1, 1)
            model_copy_param.data = torch.clamp(model_copy_param.data + delta, -1, 1)
        sum_loss += loss.item()
        optimizer.zero_grad()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()
        _, predicted = torch.max(outputs2.data, 1)
        total2 += labels.size(0)
        correct2 += predicted.eq(labels.data).cpu().sum()
        print(
            "[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% "
            % (
                epoch + 1,
                (i + 1 + (epoch) * length),
                sum_loss / (i + 1),
                100.0 * correct2 / total2,
            )
        )
        wandb.log(
            {
                "quantized loss": loss2,
                "loss": loss,
                "accuracy": (100 * correct / total),
                "quantized_accuracy": (100 * correct2 / total2),
            }
        )
    print("Waiting Test...")
    with torch.no_grad():
        correct = 0
        total = 0
        for data in testloader:
            net.eval()
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()
        print(
            "Test's accuracy (before quantization) is: %.3f%%" % (100 * correct / total)
        )
    net_copy = copy.deepcopy(net)
    with torch.no_grad():
        for name, param in net_copy.named_parameters():
            if not name.endswith(".bias"):
                param.data = torch.sign(param.data)
        correct = 0
        total = 0
        for data in testloader:
            net_copy.eval()
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net_copy(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()
        print(
            "Test's accuracy (after quantization) is: %.3f%%" % (100 * correct / total)
        )
wandb.finish()


Epoch: 1
[epoch:1, iter:1] Loss: 176453.453 | Acc: 8.000% 
[epoch:1, iter:2] Loss: 367691.945 | Acc: 9.250% 
[epoch:1, iter:3] Loss: 445886.672 | Acc: 14.833% 
[epoch:1, iter:4] Loss: 431857.348 | Acc: 14.875% 
[epoch:1, iter:5] Loss: 385032.631 | Acc: 17.200% 
[epoch:1, iter:6] Loss: 337245.411 | Acc: 19.250% 
[epoch:1, iter:7] Loss: 297296.792 | Acc: 22.714% 
[epoch:1, iter:8] Loss: 264714.652 | Acc: 26.500% 
[epoch:1, iter:9] Loss: 238561.191 | Acc: 29.722% 
[epoch:1, iter:10] Loss: 216284.119 | Acc: 33.200% 
[epoch:1, iter:11] Loss: 197671.639 | Acc: 36.227% 
[epoch:1, iter:12] Loss: 182356.678 | Acc: 38.750% 
[epoch:1, iter:13] Loss: 169356.297 | Acc: 40.962% 
[epoch:1, iter:14] Loss: 158187.429 | Acc: 42.714% 
[epoch:1, iter:15] Loss: 148102.978 | Acc: 44.800% 
[epoch:1, iter:16] Loss: 139412.595 | Acc: 46.594% 
[epoch:1, iter:17] Loss: 131708.076 | Acc: 48.000% 
[epoch:1, iter:18] Loss: 124823.063 | Acc: 49.528% 
[epoch:1, iter:19] Loss: 118660.743 | Acc: 50.737% 
[epoch:1, ite

0,1
accuracy,▁▁▁▂▄▅▅▆▅▅▅▅▇▆▆▆▆▇▇█▇▇▇▇▇▇▆▇▇▇▇▇▇▇▇▇▇█▇▇
loss,█▅▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
quantized loss,▁▅▅██▇█▅▇▇▇█▇▆▆▅▆▆▅▄▆▆▅▅▅▆▅▄▄▆▆▆▆▅▅▅▅▅▄▅
quantized_accuracy,▁▅▆▆▇▇▇▇▇▇▇█▇███████████████████████████

0,1
accuracy,50.86167
loss,3.80773
quantized loss,2.29695
quantized_accuracy,99.59167


In [12]:
# ProxQuant
init(project_name="exp4", opt_name="ProxQuant", batch_size=200, architecture="MLP", dataset_name="MNIST", lr=LR, alpha=alpha)

# Define neural network
net = copy.deepcopy(base_net)
# Define loss funtion & optimizer
criterion = nn.NLLLoss()
weights = [p for name, p in net.named_parameters() if 'bias' not in name]
bias = [p for name, p in net.named_parameters() if 'bias' in name]
parameters = [{"params": weights, "tag": "weights"}, {"params": bias, "tag": "bias"}]
optimizer = optim.Adam(parameters, lr=LR, betas=(0.9, 0.999))
lr_decay_epochs = [5, 10, 15, 25, 35]
for decay_epoch in lr_decay_epochs:
    if pre_epoch > decay_epoch:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

# Train
for epoch in range(0, EPOCH):
    print("\nEpoch: %d" % (epoch + 1))
    net.train()
    sum_loss = 0.0
    correct = 0.0
    total = 0.0
    correct2 = 0.0
    total2 = 0.0

    if epoch < 15:
        epsilon = 1
    else:
        epsilon = (0.88)**(epoch-15)

    if epoch in lr_decay_epochs:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

    for i, data in enumerate(trainloader, 0):
        # prepare dataset
        length = len(trainloader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # forward & backward
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        model_copy = copy.deepcopy(net)
        with torch.no_grad():
            for name, param in model_copy.named_parameters():
                if not name.endswith(".bias"):
                    param.data = torch.sign(param.data)
            outputs2 = model_copy(inputs)
            loss2 = criterion(outputs2, labels)
        
        optimizer.step()
        optimizer.zero_grad()
        with torch.no_grad():
            for name, param in model_copy.named_parameters():
                if not name.endswith(".bias"):
                    project = torch.logical_and(-1+(param.data+1)/(2*alpha/epsilon+1)>0, param.data>0)
                    param.data[project] = 1+(param.data[project]-1)/(2*alpha/epsilon+1)
                    param.data[~project] = -1+(param.data[~project]+1)/(2*alpha/epsilon+1)
        sum_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()
        _, predicted = torch.max(outputs2.data, 1)
        total2 += labels.size(0)
        correct2 += predicted.eq(labels.data).cpu().sum()
        wandb.log(
            {
                "quantized loss": loss2,
                "loss": loss,
                "accuracy": (100 * correct / total),
                "quantized_accuracy": (100 * correct2 / total2),
            }
        )
        print(
            "[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% "
            % (
                epoch + 1,
                (i + 1 + (epoch) * length),
                sum_loss / (i + 1),
                100.0 * correct / total,
            )
        )
    print("Waiting Test...")
    with torch.no_grad():
        correct = 0
        total = 0
        for data in testloader:
            net.eval()
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()
        print(
            "Test's accuracy (before quantization) is: %.3f%%" % (100 * correct / total)
        )
    model_copy = copy.deepcopy(net)
    with torch.no_grad():
        for name, param in model_copy.named_parameters():
            if not name.endswith(".bias"):
                param.data = torch.sign(param.data)
        correct = 0
        total = 0
        for data in testloader:
            model_copy.eval()
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model_copy(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()
        print(
            "Test's accuracy (after quantization) is: %.3f%%" % (100 * correct / total)
        )
wandb.finish()


Epoch: 1
[epoch:1, iter:1] Loss: 2.302 | Acc: 13.000% 
[epoch:1, iter:2] Loss: 2.274 | Acc: 24.000% 
[epoch:1, iter:3] Loss: 2.224 | Acc: 29.833% 
[epoch:1, iter:4] Loss: 2.157 | Acc: 33.875% 
[epoch:1, iter:5] Loss: 2.053 | Acc: 40.300% 
[epoch:1, iter:6] Loss: 1.933 | Acc: 44.250% 
[epoch:1, iter:7] Loss: 1.810 | Acc: 48.214% 
[epoch:1, iter:8] Loss: 1.681 | Acc: 51.500% 
[epoch:1, iter:9] Loss: 1.582 | Acc: 53.667% 
[epoch:1, iter:10] Loss: 1.493 | Acc: 55.550% 
[epoch:1, iter:11] Loss: 1.426 | Acc: 57.636% 
[epoch:1, iter:12] Loss: 1.380 | Acc: 58.917% 
[epoch:1, iter:13] Loss: 1.323 | Acc: 60.692% 
[epoch:1, iter:14] Loss: 1.282 | Acc: 61.679% 
[epoch:1, iter:15] Loss: 1.242 | Acc: 62.667% 
[epoch:1, iter:16] Loss: 1.197 | Acc: 64.000% 
[epoch:1, iter:17] Loss: 1.164 | Acc: 65.029% 
[epoch:1, iter:18] Loss: 1.135 | Acc: 65.667% 
[epoch:1, iter:19] Loss: 1.097 | Acc: 66.789% 
[epoch:1, iter:20] Loss: 1.066 | Acc: 67.650% 
[epoch:1, iter:21] Loss: 1.037 | Acc: 68.524% 
[epoch:1, it

0,1
accuracy,▁▆▇▇████████████████████████████████████
loss,█▂▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
quantized loss,█▁▄▁▆▃▂▄▇▄▃▃▅▂▃▃█▅▃▄▃▅▇▄▂▅▅▅▁▇▄▅▃▇▄▅▄▂▄▄
quantized_accuracy,▁█▇▇▆▇▇▇▅▇▇▆▆▆▆▇▇▇▆▇▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆

0,1
accuracy,100.0
loss,0.0
quantized loss,126143.0625
quantized_accuracy,93.24333
