In [None]:
import torch
from torch.utils import data
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import sklearn.datasets
import matplotlib.pyplot as plt
import numpy as np
import wandb
import quant_lib.MNIST_MLP as mlp
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import copy
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
def MLPMNISTNet(bit_width=1):
    return mlp.MLP(input_class=784, num_classes=10, bit_width=bit_width)

In [None]:
def load_array(features, labels, batch_size, is_train=True):
    dataset = data.TensorDataset(features, labels)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)

In [None]:
# Check GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Set hyperparameter
bit_width=2
EPOCH = 50
ANNEAL_EPOCH = 20
pre_epoch = 0
BATCH_SIZE = 50
LR = 0.06
lr_decay_epochs = [7, 15, 30]
# ASkewSGD
DECAY_CONST=0.88
alpha = 0.02
# ProxQuant
reg_lambda = 4e-5


# Generate training and testing dataset
transform_train = transforms.Compose(
    [
        transforms.ToTensor(),
    ]
)

transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
    ]
)

trainset = torchvision.datasets.MNIST(
    root="../data", train=True, download=True, transform=transform_train
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=BATCH_SIZE, shuffle=True
)

testset = torchvision.datasets.MNIST(
    root="../data", train=False, download=True, transform=transform_test
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False
)
# Define base net
base_net = MLPMNISTNet(bit_width=bit_width).to(device)
# Define loss 
criterion = nn.CrossEntropyLoss()

In [None]:
def init(project_name, opt_name, batch_size, architecture, dataset_name, lr, alpha=None, reg_lambda=None):
    wandb.init(
        # set the wandb project where this run will be logged
        project=project_name,
        name=opt_name,
        # track hyperparameters and run metadata
        config={
            "batch_size": batch_size,
            "architecture": architecture,
            "dataset": dataset_name,
            "lr": lr,
            "alpha": alpha, 
            "reg_lambda": reg_lambda,
            "bit_width": base_net.bit_width
        },
    )
    net = copy.deepcopy(base_net)
    net.to(device)
    weights = [p for name, p in net.named_parameters()]
    # bias = [p for name, p in net.named_parameters() if 'bias' in name]
    parameters = [{"params": weights, "tag": "weights"}, 
                #   {"params": bias, "tag": "bias"}
                  ]
    optimizer = optim.SGD(parameters, lr=lr)
    return net, optimizer

In [None]:
# SGD
net, optimizer = init(project_name="MNIST_multibit", opt_name="SGD", batch_size=BATCH_SIZE, architecture="MLP", dataset_name="MNIST", lr=LR/10)

best_acc=0

for decay_epoch in lr_decay_epochs:
    if pre_epoch > decay_epoch:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

# Train
for epoch in range(pre_epoch, EPOCH):
    print("\nEpoch: %d" % (epoch + 1))
    net.train()
    sum_loss = 0.0
    correct = 0.0
    total = 0.0

    if epoch in lr_decay_epochs:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

    for i, data in enumerate(trainloader, 0):
        # prepare dataset
        length = len(trainloader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # forward & backward
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        sum_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()
        print(
            "[Epoch:%d, Iter:%d] Loss: %.03f | Acc: %.3f%% "
            % (
                epoch + 1,
                (i + 1 + (epoch) * length),
                sum_loss / (i + 1),
                100.0 * correct / total,
            )
        )
    print("Waiting Test...")
    with torch.no_grad():
        train_loss, train_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=False)
        qtrain_loss, qtrain_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=True)
        test_loss, test_acc=net.evaluate(testloader, criterion, device, eval=True, qt=False)
        qtest_loss, qtest_acc=net.evaluate(testloader, criterion, device, eval=True, qt=True)
        wandb.log(
            {
                "test_loss": test_loss,
                "quantized_test_loss": qtest_loss,
                "accuracy": train_acc,
                "quantized_accuracy": qtrain_acc,
                "test_accuracy": test_acc,
                "quantized_test_accuracy": qtest_acc
            }
        )
        print(
            "Train Loss: %.03f | Train Acc: %.3f%% | Test Loss: %.03f | Test Acc: %.3f%% "
            % (
                train_loss,
                train_acc,
                test_loss,
                test_acc,
            )
        )
        print(
            "Quantized Train Loss: %.03f | Quantized Train Acc: %.3f%% | Quantized Test Loss: %.03f | Quantized Test Acc: %.3f%% "
            % (
                qtrain_loss,
                qtrain_acc,
                qtest_loss,
                qtest_acc,
            )
        )
        FILE = "MNIST_SGD.pt"
        if qtest_acc>best_acc:
            torch.save(
                {
                    "model_state_dict": net.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                FILE,
            )
            best_acc=qtest_acc
print("Best Test Accuracy after Quantization: %.3f%%" % best_acc)
wandb.finish()

In [None]:
# ProxQuant
net, optimizer = init(project_name="MNIST_multibit", opt_name="ProxQuant", batch_size=BATCH_SIZE, architecture="MLP", dataset_name="MNIST", lr=LR/8, reg_lambda=reg_lambda)

lr=LR/8
for decay_epoch in lr_decay_epochs:
    if pre_epoch > decay_epoch:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5
            lr*=0.5

# Train
it=-1
total_it=EPOCH*len(trainloader)
best_acc=0
for epoch in range(0, EPOCH):
    print("\nEpoch: %d" % (epoch + 1))
    sum_loss = 0.0
    correct = 0.0
    total = 0.0
    if epoch in lr_decay_epochs:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5
            lr*=0.5

    for i, data in enumerate(trainloader, 0):
        # prepare dataset
        it+=1
        epsilon = reg_lambda*it/total_it
        net.train()
        length = len(trainloader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # forward & backward
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        sum_loss += loss.item()
        optimizer.step()
        with torch.no_grad():
            for name, param in net.named_parameters():
                # if not name.endswith(".bias"):
                # Prox Step
                rang = torch.arange(-2**(net.bit_width-1), 2**(net.bit_width-1)).to(device)
                if len(param.data.shape)==2:
                    _ , indices = torch.sort(torch.abs(torch.unsqueeze(param.data, len(param.data.size())).repeat(1, 1, len(rang))-rang))
                    a = rang[indices][:, :, 0] 
                else:
                    _ , indices = torch.sort(torch.abs(torch.unsqueeze(param.data, len(param.data.size())).repeat(1, len(rang))-rang))
                    a = rang[indices][:, 0] 
                param.data=(param.data+(epsilon*lr)*a)/(1+epsilon*lr)
                param.data=torch.clamp(param.data, -2**(net.bit_width-1)-0.5, 2**(net.bit_width-1)+0.5)
        optimizer.zero_grad()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()
        print(
            "[Epoch:%d, Iter:%d] Loss: %.03f | Acc: %.3f%% "
            % (
                epoch + 1,
                (i + 1 + (epoch) * length),
                sum_loss / (i + 1),
                100.0 * correct / total,
            )
        )
        
    print("Waiting Test...")
    with torch.no_grad():
        train_loss, train_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=False)
        qtrain_loss, qtrain_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=True)
        test_loss, test_acc=net.evaluate(testloader, criterion, device, eval=True, qt=False)
        qtest_loss, qtest_acc=net.evaluate(testloader, criterion, device, eval=True, qt=True)
        wandb.log(
            {
                "test_loss": test_loss,
                "quantized_test_loss": qtest_loss,
                "accuracy": train_acc,
                "quantized_accuracy": qtrain_acc,
                "test_accuracy": test_acc,
                "quantized_test_accuracy": qtest_acc
            }
        )
        print(
            "Train Loss: %.03f | Train Acc: %.3f%% | Test Loss: %.03f | Test Acc: %.3f%% "
            % (
                train_loss,
                train_acc,
                test_loss,
                test_acc,
            )
        )
        print(
            "Quantized Train Loss: %.03f | Quantized Train Acc: %.3f%% | Quantized Test Loss: %.03f | Quantized Test Acc: %.3f%% "
            % (
                qtrain_loss,
                qtrain_acc,
                qtest_loss,
                qtest_acc,
            )
        )
        FILE = "MNIST_ProxQuant_W2A32.pt"
        if qtest_acc>best_acc:
            torch.save(
                {
                    "model_state_dict": net.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                FILE,
            )
            best_acc=qtest_acc
print("Best Test Accuracy after Quantization: %.3f%%" % best_acc)
wandb.finish()

In [None]:
# ASkewSGD
net, optimizer = init(project_name="MNIST_multibit", opt_name="ASkewSGD", batch_size=BATCH_SIZE, architecture="MLP", dataset_name="MNIST", lr=LR, alpha=alpha)

for decay_epoch in lr_decay_epochs:
    if pre_epoch > decay_epoch:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

best_acc=0

# Train
for epoch in range(0, EPOCH):
    print("\nEpoch: %d" % (epoch + 1))
    net.train()
    sum_loss = 0.0
    correct = 0.0
    total = 0.0

    if epoch in lr_decay_epochs:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

    for i, data in enumerate(trainloader, 0):
        if epoch < ANNEAL_EPOCH:
            epsilon = 1
        else:
            epsilon =  DECAY_CONST ** ((epoch - ANNEAL_EPOCH)+(i/len(trainloader)))
        # prepare dataset
        length = len(trainloader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # forward & backward
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        for param_group in optimizer.param_groups:
            if param_group["tag"] == "weights":
                for idx, p in enumerate(param_group["params"]):
                    rang = torch.arange(-2**(net.bit_width-1), 2**(net.bit_width-1)).to(device)
                    _ , indices = torch.sort(torch.abs(torch.unsqueeze(p.data, len(p.data.size())).repeat(1, 1, len(rang))-rang))
                    if len(p.data.shape)==2:
                        _ , indices = torch.sort(torch.abs(torch.unsqueeze(p.data, len(p.data.size())).repeat(1, 1, len(rang))-rang))
                        a = rang[indices][:, :, 0]
                        b = rang[indices][:, :, 1]
                    else:
                        _ , indices = torch.sort(torch.abs(torch.unsqueeze(p.data, len(p.data.size())).repeat(1, len(rang))-rang))
                        a = rang[indices][:, 0]
                        b = rang[indices][:, 1]
                    constr = epsilon-((p.data-a)**2)*((p.data-b)**2)
                    Kx = alpha * (epsilon-(p.data-a)**2*(p.data-b)**2) / (2 * (p.data-a)*(p.data-b) * (0.000001+(p.data-b)+(p.data-a)))
                    direct_grad = torch.logical_or(torch.logical_or((p.data-a)*(p.data-b)==0, constr >= 0), (-p.grad.data)*Kx > Kx**2)
                    Kx.clamp_(-1/(4*param_group['lr']), 1/(4*param_group['lr']))
                    p.grad.data[direct_grad] = p.grad.data[direct_grad]
                    p.grad.data[~direct_grad] = -Kx[~direct_grad]
                    p.data.clamp_(-2**(net.bit_width-1), 2**(net.bit_width-1)-1)
        optimizer.step()
        optimizer.zero_grad()
        sum_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()
        print(
            "[Epoch:%d, Iter:%d] Loss: %.03f | Acc: %.3f%% "
            % (
                epoch + 1,
                (i + 1 + (epoch) * length),
                sum_loss / (i + 1),
                100.0 * correct / total,
            )
        )
        
    print("Waiting Test...")
    with torch.no_grad():
        train_loss, train_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=False)
        qtrain_loss, qtrain_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=True)
        test_loss, test_acc=net.evaluate(testloader, criterion, device, eval=True, qt=False)
        qtest_loss, qtest_acc=net.evaluate(testloader, criterion, device, eval=True, qt=True)
        wandb.log(
            {
                "test_loss": test_loss,
                "quantized_test_loss": qtest_loss,
                "accuracy": train_acc,
                "quantized_accuracy": qtrain_acc,
                "test_accuracy": test_acc,
                "quantized_test_accuracy": qtest_acc
            }
        )
        print(
            "Train Loss: %.03f | Train Acc: %.3f%% | Test Loss: %.03f | Test Acc: %.3f%% "
            % (
                train_loss,
                train_acc,
                test_loss,
                test_acc,
            )
        )
        print(
            "Quantized Train Loss: %.03f | Quantized Train Acc: %.3f%% | Quantized Test Loss: %.03f | Quantized Test Acc: %.3f%% "
            % (
                qtrain_loss,
                qtrain_acc,
                qtest_loss,
                qtest_acc,
            )
        )
        FILE = "MNIST_ASkewSGD_W2A32.pt"
        if qtest_acc>best_acc:
            torch.save(
                {
                    "model_state_dict": net.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                FILE,
            )
            best_acc=qtest_acc
print("Best Test Accuracy after Quantization: %.3f%%" % best_acc)
wandb.finish()

In [None]:
# Deterministic BinaryConnect (MultiBitConnect)
net, optimizer = init(project_name="MNIST_multibit", opt_name="Deterministic MultiBitConnect", batch_size=BATCH_SIZE, architecture="MLP", dataset_name="MNIST", lr=LR/300, alpha=0)

model_copy = copy.deepcopy(net)


for decay_epoch in lr_decay_epochs:
    if pre_epoch > decay_epoch:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

best_acc=0

# Train
for epoch in range(pre_epoch, EPOCH):
    print("\nEpoch: %d" % (epoch + 1))
    net.train()
    model_copy.train()
    sum_loss = 0.0
    correct = 0.0
    total = 0.0

    if epoch in lr_decay_epochs:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

    for i, data in enumerate(trainloader, 0):
        # prepare dataset
        length = len(trainloader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # forward & backward
        with torch.no_grad():
            outputs = model_copy(inputs)
            loss2 = criterion(outputs, labels)
            for net_name, net_param in net.named_parameters():
                # if not net_name.endswith(".bias"):
                net_param.data=net.quantize(net_param.data, net.bit_width)

        outputs2 = net(inputs)
        loss = criterion(outputs2, labels)
        loss.backward()
        optimizer.step()
        for (net_name, net_param), (model_copy_name, model_copy_param) in zip(
            net.named_parameters(), model_copy.named_parameters()
        ):
            # if not net_name.endswith(".bias"):
            delta = net_param.data - model_copy.quantize(model_copy_param.data, model_copy.bit_width)
            if net.bit_width == 1:
                net_param.data = torch.clamp(model_copy_param.data + delta, -1, 1)
            else:
                net_param.data = torch.clamp(model_copy_param.data + delta, -(2 ** (model_copy.bit_width - 1))-0.5, (2 ** (model_copy.bit_width - 1))-0.5)
            if net.bit_width == 1:
                model_copy_param.data = torch.clamp(model_copy_param.data + delta, -1, 1)
            else:
                model_copy_param.data = torch.clamp(model_copy_param.data + delta, -(2 ** (model_copy.bit_width - 1))-0.5, (2 ** (model_copy.bit_width - 1))-0.5)
        sum_loss += loss.item()
        optimizer.zero_grad()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()
        print(
            "[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% "
            % (
                epoch + 1,
                (i + 1 + (epoch) * length),
                sum_loss / (i + 1),
                100.0 * correct / total,
            )
        )
    print("Waiting Test...")
    with torch.no_grad():
        train_loss, train_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=False)
        qtrain_loss, qtrain_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=True)
        test_loss, test_acc=net.evaluate(testloader, criterion, device, eval=True, qt=False)
        qtest_loss, qtest_acc=net.evaluate(testloader, criterion, device, eval=True, qt=True)
        wandb.log(
            {
                "test_loss": test_loss,
                "quantized_test_loss": qtest_loss,
                "accuracy": train_acc,
                "quantized_accuracy": qtrain_acc,
                "test_accuracy": test_acc,
                "quantized_test_accuracy": qtest_acc
            }
        )
        print(
            "Train Loss: %.03f | Train Acc: %.3f%% | Test Loss: %.03f | Test Acc: %.3f%% "
            % (
                train_loss,
                train_acc,
                test_loss,
                test_acc,
            )
        )
        print(
            "Quantized Train Loss: %.03f | Quantized Train Acc: %.3f%% | Quantized Test Loss: %.03f | Quantized Test Acc: %.3f%% "
            % (
                qtrain_loss,
                qtrain_acc,
                qtest_loss,
                qtest_acc,
            )
        )
        FILE = "MNIST_MultiBitConnect_W2A32.pt"
        if qtest_acc>best_acc:
            torch.save(
                {
                    "model_state_dict": net.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                FILE,
            )
            best_acc=qtest_acc
print("Best Test Accuracy after Quantization: %.3f%%" % best_acc)
wandb.finish()