In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import copy
import wandb
import math

import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import argparse
import os

import quant_lib.ResNet18 as ResNet18

In [None]:
def ResNet(bit_width=1):
    return ResNet18.ResNet(ResNet18.ResidualBlock, bit_width=bit_width)

In [None]:
# Check GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set hyperparameter
bit_width = 2
EPOCH = 150
ANNEAL_EPOCH_AS = 90
ANNEAL_EPOCH_PQ = 50
pre_epoch = 0
BATCH_SIZE = 100
LR = 0.06

# ASkewSGD
DECAY_CONST = 0.88
alpha = 0.2
# ProxQuant
reg_lambda = 0.001

# prepare dataset and preprocessing
transform_train = transforms.Compose(
    [
        transforms.Resize(40),
        torchvision.transforms.RandomResizedCrop(
            32, scale=(0.64, 1.0), ratio=(1.0, 1.0)
        ),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.RandomErasing(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

trainset = torchvision.datasets.CIFAR10(
    root="../data", train=True, download=True, transform=transform_train
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=3
)

testset = torchvision.datasets.CIFAR10(
    root="../data", train=False, download=True, transform=transform_test
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=3
)

# Labels in CIFAR10
classes = (
    "plane",
    "car",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
)

# Define ResNet18
base_net = ResNet(bit_width=bit_width).to(device)

# Define loss funtion
criterion = nn.CrossEntropyLoss()

In [None]:
def init(
    project_name,
    opt_name,
    batch_size,
    architecture,
    dataset_name,
    lr,
    alpha=None,
    reg_lambda=None,
):
    wandb.init(
        # set the wandb project where this run will be logged
        project=project_name,
        name=opt_name,
        # track hyperparameters and run metadata
        config={
            "batch_size": batch_size,
            "architecture": architecture,
            "dataset": dataset_name,
            "lr": lr,
            "alpha": alpha,
            "reg_lambda": reg_lambda,
            "bit_width": base_net.bit_width,
        },
    )
    net = copy.deepcopy(base_net)
    net.to(device)
    weights = [p for name, p in net.named_parameters() if 'fc' not in name and 'left.1' not in name and 'left.4' not in name and "shortcut.1" not in name]
    bias = [p for name, p in net.named_parameters() if 'fc' in name or 'left.1' in name or 'left.4' in name or "shortcut.1" in name]
    parameters = [
        {"params": weights, "tag": "weights"},
        {"params": bias, "tag": "bias"},
    ]
    optimizer = optim.Adam(parameters, lr=lr, betas=(0.9, 0.999))
    return net, optimizer

In [None]:
# SGD
net, optimizer = init(
    project_name="CIFAR10_multi",
    opt_name="SGD",
    batch_size=BATCH_SIZE,
    architecture="ResNet-18",
    dataset_name="CIFAR10",
    lr=LR,
)

lr_decay_epochs = [20, 40]

best_acc = 0

for decay_epoch in lr_decay_epochs:
    if pre_epoch > decay_epoch:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

# Train
for epoch in range(pre_epoch, EPOCH):
    print("\nEpoch: %d" % (epoch + 1))
    net.train()
    sum_loss = 0.0
    correct = 0.0
    total = 0.0

    if epoch in lr_decay_epochs:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

    for i, data in enumerate(trainloader, 0):
        # prepare dataset
        length = len(trainloader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # forward & backward
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        sum_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()
        print(
            "[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% "
            % (
                epoch + 1,
                (i + 1 + (epoch) * length),
                sum_loss / (i + 1),
                100.0 * correct / total,
            )
        )

    print("Waiting Test...")
    with torch.no_grad():
        # train_loss, train_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=False)
        # qtrain_loss, qtrain_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=True)
        test_loss, test_acc = net.evaluate(
            testloader, criterion, device, eval=True, qt=False
        )
        wandb.log(
            {
                "test_loss": test_loss,
                "test_accuracy": test_acc,
            }
        )
        print(
            "Test Loss: %.03f | Test Acc: %.3f%% "
            % (
                test_loss,
                test_acc,
            )
        )
        FILE = "CIFAR10_SGD.pt"
        if test_acc > best_acc:
            torch.save(
                {
                    "model_state_dict": net.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                FILE,
            )
            best_acc = test_acc
        
print("Best Test Accuracy: %.3f%%" % best_acc)
wandb.finish()

In [None]:
# ProxQuant
net, optimizer = init(
    project_name="CIFAR10_multi",
    opt_name="ProxQuant",
    batch_size=BATCH_SIZE,
    architecture="ResNet-18",
    dataset_name="CIFAR10",
    lr=LR,
    reg_lambda=reg_lambda,
)

lr_decay_epochs = [20, 40]
best_acc = 0
lr = LR
for decay_epoch in lr_decay_epochs:
    if pre_epoch > decay_epoch:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5
            lr *= 0.5

# Train
it = -1
total_it = (EPOCH - ANNEAL_EPOCH_PQ) * len(trainloader)

for epoch in range(pre_epoch, EPOCH):
    print("\nEpoch: %d" % (epoch + 1))
    sum_loss = 0.0
    correct = 0.0
    total = 0.0
    if epoch in lr_decay_epochs:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5
            lr *= 0.5

    for i, data in enumerate(trainloader, 0):
        if epoch < ANNEAL_EPOCH_PQ:
            epsilon = 0.000001
        else:
            it += 1
            epsilon = 0.000001 + reg_lambda * it / total_it
        # prepare dataset
        net.train()
        length = len(trainloader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # forward & backward
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        sum_loss += loss.item()
        optimizer.step()
        with torch.no_grad():
            for name, param in net.named_parameters():
                if (
                    "fc" not in name
                    and "left.1" not in name
                    and "left.4" not in name
                    and "shortcut.1" not in name
                ):
                    # Prox Step
                    if len(param.shape)==4:
                        rang = torch.arange(-2**(net.bit_width-1), 2**(net.bit_width-1)).to(device)
                        _ , indices = torch.sort(torch.abs(torch.unsqueeze(param.data, len(param.data.size())).repeat(1, 1, 1, 1, len(rang))-rang))
                        a = rang[indices][:, :, :, :, 0] 
                        param.data=(param.data+epsilon*a)/(1+epsilon)
                        param.data=torch.clamp(param.data, -2**(net.bit_width-1)-0.5, 2**(net.bit_width-1)+0.5)
                    else:
                        rang = torch.arange(-2**(net.bit_width-1), 2**(net.bit_width-1)).to(device)
                        _ , indices = torch.sort(torch.abs(torch.unsqueeze(param.data, len(param.data.size())).repeat(1, len(rang))-rang))
                        a = rang[indices][:, 0] 
                        param.data=(param.data+epsilon*a)/(1+epsilon)
                        param.data=torch.clamp(param.data, -2**(net.bit_width-1)-0.5, 2**(net.bit_width-1)+0.5)
        optimizer.zero_grad()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()
        print(
            "[Epoch:%d, Iter:%d] Loss: %.03f | Acc: %.3f%% "
            % (
                epoch + 1,
                (i + 1 + (epoch) * length),
                sum_loss / (i + 1),
                100.0 * correct / total,
            )
        )
    print("Waiting Test...")
    with torch.no_grad():
        # train_loss, train_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=False)
        # qtrain_loss, qtrain_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=True)
        test_loss, test_acc = net.evaluate(
            testloader, criterion, device, eval=True, qt=False
        )
        qtest_loss, qtest_acc = net.evaluate(
            testloader, criterion, device, eval=True, qt=True
        )
        wandb.log(
            {
                "test_loss": test_loss,
                "quantized_test_loss": qtest_loss,
                "test_accuracy": test_acc,
                "quantized_test_accuracy": qtest_acc,
            }
        )
        print(
            "Test Loss: %.03f | Test Acc: %.3f%% "
            % (
                test_loss,
                test_acc,
            )
        )
        print(
            "Quantized Test Loss: %.03f | Quantized Test Acc: %.3f%% "
            % (
                qtest_loss,
                qtest_acc,
            )
        )
        FILE = "CIFAR10_ProxQuant.pt"
        if qtest_acc > best_acc:
            torch.save(
                {
                    "model_state_dict": net.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                FILE,
            )
            best_acc = qtest_acc
print("Best Test Accuracy after Quantization: %.3f%%" % best_acc)
wandb.finish()

In [None]:
# ASkewSGD
net, optimizer = init(
    project_name="CIFAR10_multi",
    opt_name="ASkewSGD",
    batch_size=BATCH_SIZE,
    architecture="ResNet-18",
    dataset_name="CIFAR10",
    lr=LR,
    alpha=alpha,
)

best_acc = 0
lr_decay_epochs = [20, 40]

for decay_epoch in lr_decay_epochs:
    if pre_epoch > decay_epoch:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

# Train
for epoch in range(pre_epoch, EPOCH):
    print("\nEpoch: %d" % (epoch + 1))
    net.train()
    sum_loss = 0.0
    correct = 0.0
    total = 0.0

    if epoch < ANNEAL_EPOCH_AS:
        epsilon = 1
    else:
        epsilon = DECAY_CONST ** (epoch - ANNEAL_EPOCH_AS)

    if epoch in lr_decay_epochs:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5
    for i, data in enumerate(trainloader, 0):

        # prepare dataset
        length = len(trainloader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # forward & backward
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        for param_group in optimizer.param_groups:
            if param_group["tag"] == "weights":
                for idx, p in enumerate(param_group["params"]):
                    clip = (12.1*torch.sqrt(torch.mean(p.data**2))) - (12.2*torch.mean(p.data.abs()))
                    scale = 2*clip / (2 ** (net.bit_width - 1) + 2 ** (net.bit_width - 1) - 1)
                    p.data.div_(scale)
                    p.data.clamp_(-2**(net.bit_width-1), 2**(net.bit_width-1)-1)
                    rang = torch.arange(-2**(net.bit_width-1), 2**(net.bit_width-1)).to(device)
                    if len(p.data.shape) == 4:
                        _ , indices = torch.sort(torch.abs(torch.unsqueeze(p.data, len(p.data.size())).repeat(1, 1, 1, 1, len(rang))-rang))
                        a = rang[indices][:, :, :, :, 0]
                        b = rang[indices][:, :, :, :, 1]
                    else:
                        _ , indices = torch.sort(torch.abs(torch.unsqueeze(p.data, len(p.data.size())).repeat(1, len(rang))-rang))
                        a = rang[indices][:, 0]
                        b = rang[indices][:, 1]
                    constr = epsilon-((p.data-a)**2)*((p.data-b)**2)
                    Kx = scale * alpha * (epsilon-(p.data-a)**2*(p.data-b)**2) / (2 * (p.data-a)*(p.data-b) * (0.000001+(p.data-b)+(p.data-a)))
                    direct_grad = torch.logical_or(torch.logical_or((p.data-a)*(p.data-b)==0, constr >= 0), (-p.grad.data)*Kx > Kx**2)
                    Kx.clamp_(-scale/(4*param_group['lr']), scale/(4*param_group['lr']))
                    p.grad.data[direct_grad] = p.grad.data[direct_grad]
                    p.grad.data[~direct_grad] = -Kx[~direct_grad]
                    p.data.mul_(scale)
        optimizer.step()
        optimizer.zero_grad()
        with torch.no_grad():
            for name, param in net.named_parameters():
                if (
                    "fc" not in name
                    and "left.1" not in name
                    and "left.4" not in name
                    and "shortcut.1" not in name
                ):
                    torch.clamp_(param.data, -1.2, 1.2)
        sum_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()
        print(
            "[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% "
            % (
                epoch + 1,
                (i + 1 + (epoch) * length),
                sum_loss / (i + 1),
                100.0 * correct / total,
            )
        )

    print("Waiting Test...")
    with torch.no_grad():
        # train_loss, train_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=False)
        # qtrain_loss, qtrain_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=True)
        test_loss, test_acc = net.evaluate(
            testloader, criterion, device, eval=True, qt=False
        )
        qtest_loss, qtest_acc = net.evaluate(
            testloader, criterion, device, eval=True, qt=True
        )
        wandb.log(
            {
                "test_loss": test_loss,
                "quantized_test_loss": qtest_loss,
                "test_accuracy": test_acc,
                "quantized_test_accuracy": qtest_acc,
            }
        )
        print(
            "Test Loss: %.03f | Test Acc: %.3f%% "
            % (
                test_loss,
                test_acc,
            )
        )
        print(
            "Quantized Test Loss: %.03f | Quantized Test Acc: %.3f%% "
            % (
                qtest_loss,
                qtest_acc,
            )
        )
        FILE = "CIFAR10_ASkewSGD.pt"
        if qtest_acc > best_acc:
            torch.save(
                {
                    "model_state_dict": net.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                FILE,
            )
            best_acc = qtest_acc
print("Best Test Accuracy after Quantization: %.3f%%" % best_acc)
wandb.finish()

In [None]:
# Deterministic BinaryConnect
net, optimizer = init(
    project_name="CIFAR10_multi",
    opt_name="Deterministic BinaryConnect",
    batch_size=BATCH_SIZE,
    architecture="ResNet-18",
    dataset_name="CIFAR10",
    lr=LR,
)

best_acc = 0
model_copy = copy.deepcopy(net)


lr_decay_epochs = [20, 40]

for decay_epoch in lr_decay_epochs:
    if pre_epoch > decay_epoch:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

# Train
for epoch in range(pre_epoch, EPOCH):
    print("\nEpoch: %d" % (epoch + 1))
    net.train()
    sum_loss = 0.0
    correct = 0.0
    total = 0.0

    if epoch in lr_decay_epochs:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

    for i, data in enumerate(trainloader, 0):
        # prepare dataset
        length = len(trainloader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # forward & backward
        with torch.no_grad():
            outputs = model_copy(inputs)
            loss2 = criterion(outputs, labels)
            for net_name, net_param in net.named_parameters():
                if (
                    "fc" not in net_name
                    and "left.1" not in net_name
                    and "left.4" not in net_name
                    and "shortcut.1" not in net_name
                ):
                    net_param.data = net.quantize(net_param.data, net.bit_width)

        outputs2 = net(inputs)
        loss = criterion(outputs2, labels)
        loss.backward()
        optimizer.step()
        for (net_name, net_param), (model_copy_name, model_copy_param) in zip(
            net.named_parameters(), model_copy.named_parameters()
        ):
            if (
                "fc" not in net_name
                and "left.1" not in net_name
                and "left.4" not in net_name
                and "shortcut.1" not in net_name
            ):
                delta = net_param.data - model_copy.quantize(model_copy_param.data, model_copy.bit_width)
                if net.bit_width == 1:
                    net_param.data = torch.clamp(model_copy_param.data + delta, -1, 1)
                else:
                    net_param.data = torch.clamp(model_copy_param.data + delta, -(2 ** (model_copy.bit_width - 1))-0.5, (2 ** (model_copy.bit_width - 1))-0.5)
                if net.bit_width == 1:
                    model_copy_param.data = torch.clamp(model_copy_param.data + delta, -1, 1)
                else:
                    model_copy_param.data = torch.clamp(model_copy_param.data + delta, -(2 ** (model_copy.bit_width - 1))-0.5, (2 ** (model_copy.bit_width - 1))-0.5)
            else:
                model_copy_param.data = net_param.data        
        sum_loss += loss.item()
        optimizer.zero_grad()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()
        print(
            "[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% "
            % (
                epoch + 1,
                (i + 1 + (epoch) * length),
                sum_loss / (i + 1),
                100.0 * correct / total,
            )
        )

    print("Waiting Test...")
    with torch.no_grad():
        # train_loss, train_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=False)
        # qtrain_loss, qtrain_acc = net.evaluate(trainloader, criterion, device, eval=False, qt=True)
        test_loss, test_acc = net.evaluate(
            testloader, criterion, device, eval=True, qt=False
        )
        qtest_loss, qtest_acc = net.evaluate(
            testloader, criterion, device, eval=True, qt=True
        )
        wandb.log(
            {
                "test_loss": test_loss,
                "quantized_test_loss": qtest_loss,
                "test_accuracy": test_acc,
                "quantized_test_accuracy": qtest_acc,
            }
        )
        print(
            "Test Loss: %.03f | Test Acc: %.3f%% "
            % (
                test_loss,
                test_acc,
            ) 
        )
        print(
            "Quantized Test Loss: %.03f | Quantized Test Acc: %.3f%% "
            % (
                qtest_loss,
                qtest_acc,
            )
        )
        FILE = "CIFAR10_Deterministic_BinaryConnect.pt"
        if qtest_acc > best_acc:
            torch.save(
                {
                    "model_state_dict": net.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                FILE,
            )
            best_acc = qtest_acc
print("Best Test Accuracy after Quantization: %.3f%%" % best_acc)
wandb.finish()

In [None]:
# FILE = 'resnet18_original.pt'

# checkpoint = torch.load(FILE)
# net.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# torch.save({'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, FILE)