In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import copy
import wandb
import math

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, stride=1):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Conv2d(
                inchannel,
                outchannel,
                kernel_size=3,
                stride=stride,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(outchannel, affine=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False
            ),
            nn.BatchNorm2d(outchannel, affine=False),
        )
        self.shortcut = nn.Sequential()
        if stride != 1 or inchannel != outchannel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    inchannel, outchannel, kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm2d(outchannel, affine=False),
            )

    def forward(self, x):
        out = self.left(x)
        out = out + self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, ResidualBlock, num_classes=10):
        super(ResNet, self).__init__()
        self.inchannel = 64
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64, affine=False),
            nn.ReLU(),
        )
        self.layer1 = self.make_layer(ResidualBlock, 64, 2, stride=1)
        self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)
        self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2)
        self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2)
        self.fc = nn.Linear(512, num_classes, bias=False)

    def make_layer(self, block, channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.inchannel, channels, stride))
            self.inchannel = channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


class QuantizeLayer(nn.Module):
    def __init__(self, layer):
        super(QuantizeLayer, self).__init__()
        self.layer = layer

    def forward(self, x):
        if isinstance(self.layer, nn.Conv2d):
            weight_q = torch.sign(self.layer.weight)
            if self.layer.bias is not None:
                bias_q = torch.sign(self.layer.bias)
                return nn.functional.conv2d(
                    x,
                    weight_q,
                    bias_q,
                    self.layer.stride,
                    self.layer.padding,
                    self.layer.dilation,
                    self.layer.groups,
                )
            else:
                return nn.functional.conv2d(
                    x,
                    weight_q,
                    None,
                    self.layer.stride,
                    self.layer.padding,
                    self.layer.dilation,
                    self.layer.groups,
                )
        elif isinstance(self.layer, nn.Linear):
            weight_q = torch.sign(self.layer.weight)
            if self.layer.bias is not None:
                bias_q = torch.sign(self.layer.bias)
                return nn.functional.linear(x, weight_q, bias_q)
            else:
                return nn.functional.linear(x, weight_q, None)
        else:
            return self.layer(x)

In [None]:
def ResNet18():
    return ResNet(ResidualBlock)

In [None]:
# Use the ResNet18 on CIFAR-10
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import argparse
import os

# Check GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set hyperparameter
EPOCH = 150
ANNEAL_EPOCH = 90
pre_epoch = 0
BATCH_SIZE = 100
LR = 0.06
alpha = 0.2

# prepare dataset and preprocessing
transform_train = transforms.Compose(
    [
        transforms.Resize(40),
        torchvision.transforms.RandomResizedCrop(
            32, scale=(0.64, 1.0), ratio=(1.0, 1.0)
        ),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.RandomErasing(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

trainset = torchvision.datasets.CIFAR10(
    root="../data", train=True, download=True, transform=transform_train
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=3
)

testset = torchvision.datasets.CIFAR10(
    root="../data", train=False, download=True, transform=transform_test
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=3
)

# Labels in CIFAR10
classes = (
    "plane",
    "car",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
)

# Define ResNet18
base_net = ResNet18().to(device)

# Define loss funtion
criterion = nn.CrossEntropyLoss()


In [None]:
# FILE = 'resnet18_original.pt'

# checkpoint = torch.load(FILE)
# net.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# torch.save({'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, FILE)

In [None]:
# Train
net = copy.deepcopy(base_net)
optimizer = optim.Adam(net.parameters(), lr=LR, betas=(0.9, 0.999))
wandb.init(
    # set the wandb project where this run will be logged
    project="exp3",
    name="ASkewSGD",
    # track hyperparameters and run metadata
    config={
        "batch_size": 100,
        "architecture": "ResNet-18",
        "dataset": "CIFAR10",
        "lr": LR,
        "alpha": alpha,
    },
)
lr_decay_epochs = [20, 40, 60, 95]
for decay_epoch in lr_decay_epochs:
    if pre_epoch > decay_epoch:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

for epoch in range(pre_epoch, EPOCH):
    print("\nEpoch: %d" % (epoch + 1))
    net.train()
    sum_loss = 0.0
    correct = 0.0
    total = 0.0
    correct2 = 0.0
    total2 = 0.0
    if epoch < ANNEAL_EPOCH:
        epsilon = 1
    else:
        epsilon = 0.88 ** (epoch - ANNEAL_EPOCH)

    if epoch in lr_decay_epochs:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

    for i, data in enumerate(trainloader, 0):
        # prepare dataset
        length = len(trainloader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # forward & backward
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        model_copy = copy.deepcopy(net)
        with torch.no_grad():
            for name, param in model_copy.named_parameters():
                if not name.endswith(".bias"):
                    param.data = torch.sign(param.data)
            outputs2 = model_copy(inputs)
            loss2 = criterion(outputs2, labels)

        for param_group in optimizer.param_groups:
            for idx, p in enumerate(param_group["params"]):
                constr = epsilon - (p.data**2 - 1) ** 2
                Kx = -4 * (p.data**2 - 1) * p.data
                direct_grad = torch.logical_or(
                    torch.logical_or(constr >= 0, Kx == 0),
                    torch.logical_and(
                        constr < 0, (-Kx * p.grad.data) >= -alpha * constr
                    ),
                )
                p.grad.data[direct_grad] = p.grad.data[direct_grad]
                p.grad.data[~direct_grad] = torch.clip(
                    alpha * constr / Kx,
                    -2 / param_group["lr"],
                    2 / param_group["lr"],
                )[~direct_grad]
        optimizer.step()
        optimizer.zero_grad()
        if epoch >= ANNEAL_EPOCH:
            with torch.no_grad():
                for name, param in net.named_parameters():
                    torch.clamp_(param.data, -1, 1)
        sum_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()
        _, predicted = torch.max(outputs2.data, 1)
        total2 += labels.size(0)
        correct2 += predicted.eq(labels.data).cpu().sum()
        print(
            "[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% "
            % (
                epoch + 1,
                (i + 1 + (epoch) * length),
                sum_loss / (i + 1),
                100.0 * correct / total,
            )
        )
        wandb.log(
            {
                "loss": loss,
                "quantized loss": loss2,
                "accuracy": (100 * correct / total),
                "quantized_accuracy": (100 * correct2 / total2),
            }
        )
    print("Waiting Test...")
    with torch.no_grad():
        correct = 0
        total = 0
        for data in testloader:
            net.eval()
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()
        print(
            "Test's accuracy (before quantization) is: %.3f%%" % (100 * correct / total)
        )
    model_copy = copy.deepcopy(net)
    with torch.no_grad():
        for name, param in model_copy.named_parameters():
            param.data = torch.sign(param.data)
        correct = 0
        total = 0
        for data in testloader:
            model_copy.eval()
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model_copy(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()
        print(
            "Test's accuracy (after quantization) is: %.3f%%" % (100 * correct / total)
        )
    # if epoch + 1 == ANNEAL_EPOCH:
    #     FILE = 'resnet18_original_new.pt'
    #     torch.save({'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, FILE)

In [None]:
# FILE = "resnet18_qt_new.pt"

# torch.save(
#     {
#         "model_state_dict": net.state_dict(),
#         "optimizer_state_dict": optimizer.state_dict(),
#     },
#     FILE,
# )

In [None]:
# FILE = "resnet18_original_new.pt"
# checkpoint = torch.load(FILE)
# net.load_state_dict(checkpoint["model_state_dict"])
# with torch.no_grad():
#     correct = 0
#     total = 0
#     for data in testloader:
#         net.eval()
#         images, labels = data
#         images, labels = images.to(device), labels.to(device)
#         outputs = net(images)
#         _, predicted = torch.max(outputs.data, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum()
#     print("Test's ac is: %.3f%%" % (100 * correct / total))

# FILE = "resnet18_qt_new.pt"
# checkpoint = torch.load(FILE)
# net.load_state_dict(checkpoint["model_state_dict"])
# model_copy = copy.deepcopy(net)
# with torch.no_grad():
#     for name, param in model_copy.named_parameters():
#         param.data = torch.sign(param.data)
#     correct = 0
#     total = 0
#     for data in testloader:
#         model_copy.eval()
#         images, labels = data
#         images, labels = images.to(device), labels.to(device)
#         outputs = model_copy(images)
#         _, predicted = torch.max(outputs.data, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum()
#     print("Test's ac is: %.3f%%" % (100 * correct / total))

In [None]:
wandb.finish()

In [None]:
# Train
net = copy.deepcopy(base_net)
model_copy = copy.deepcopy(net)

optimizer = optim.SGD(net.parameters(), lr=LR)
wandb.init(
    # set the wandb project where this run will be logged
    project="exp3",
    name="Deterministic BinaryConnect",
    # track hyperparameters and run metadata
    config={
        "batch_size": 100,
        "architecture": "ResNet-18",
        "dataset": "CIFAR10",
        "lr": LR,
        "alpha": 0,
    },
)
lr_decay_epochs = [20, 40, 60, 95]
for decay_epoch in lr_decay_epochs:
    if pre_epoch > decay_epoch:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

for epoch in range(pre_epoch, EPOCH):
    print("\nEpoch: %d" % (epoch + 1))
    net.train()
    sum_loss = 0.0
    correct = 0.0
    total = 0.0
    correct2 = 0.0
    total2 = 0.0

    if epoch in lr_decay_epochs:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

    for i, data in enumerate(trainloader, 0):
        # prepare dataset
        length = len(trainloader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # forward & backward
        with torch.no_grad():
            outputs = model_copy(inputs)
            loss2 = criterion(outputs, labels)
            for net_name, net_param in net.named_parameters():
                net_param.data = torch.sign(net_param.data)

        outputs2 = net(inputs)
        loss = criterion(outputs2, labels)
        loss.backward()
        optimizer.step()
        for (net_name, net_param), (model_copy_name, model_copy_param) in zip(
            net.named_parameters(), model_copy.named_parameters()
        ):
            delta = net_param.data - torch.sign(model_copy_param.data)
            net_param.data = torch.clamp(model_copy_param.data + delta, -1, 1)
            model_copy_param.data = torch.clamp(model_copy_param.data + delta, -1, 1)
        sum_loss += loss.item()
        optimizer.zero_grad()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()
        _, predicted = torch.max(outputs2.data, 1)
        total2 += labels.size(0)
        correct2 += predicted.eq(labels.data).cpu().sum()
        print(
            "[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% "
            % (
                epoch + 1,
                (i + 1 + (epoch) * length),
                sum_loss / (i + 1),
                100.0 * correct2 / total2,
            )
        )
        wandb.log(
            {
                "loss": loss,
                "quantized loss": loss2,
                "accuracy": (100 * correct / total),
                "quantized_accuracy": (100 * correct2 / total2),
            }
        )
    print("Waiting Test...")
    with torch.no_grad():
        correct = 0
        total = 0
        for data in testloader:
            net.eval()
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()
        print(
            "Test's accuracy (before quantization) is: %.3f%%" % (100 * correct / total)
        )
    model_copy = copy.deepcopy(net)
    with torch.no_grad():
        for name, param in model_copy.named_parameters():
            param.data = torch.sign(param.data)
        correct = 0
        total = 0
        for data in testloader:
            model_copy.eval()
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model_copy(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()
        print(
            "Test's accuracy (after quantization) is: %.3f%%" % (100 * correct / total)
        )

wandb.finish()

In [None]:
# Train
net = copy.deepcopy(base_net)
model_copy = copy.deepcopy(net)

optimizer = optim.Adam(net.parameters(), lr=LR, betas=(0.9, 0.999))
wandb.init(
    # set the wandb project where this run will be logged
    project="exp3",
    name="ProxQuant",
    # track hyperparameters and run metadata
    config={
        "batch_size": 100,
        "architecture": "ResNet-18",
        "dataset": "CIFAR10",
        "lr": LR,
        "alpha": 0,
    },
)
lr_decay_epochs = [20, 40, 60, 95]
for decay_epoch in lr_decay_epochs:
    if pre_epoch > decay_epoch:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

for epoch in range(pre_epoch, EPOCH):
    print("\nEpoch: %d" % (epoch + 1))
    net.train()
    sum_loss = 0.0
    correct = 0.0
    total = 0.0
    correct2 = 0.0
    total2 = 0.0

    if epoch in lr_decay_epochs:
        for param_group in optimizer.param_groups:
            param_group["lr"] *= 0.5

    for i, data in enumerate(trainloader, 0):
       # prepare dataset
        length = len(trainloader)
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # forward & backward
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        model_copy = copy.deepcopy(net)
        with torch.no_grad():
            for name, param in model_copy.named_parameters():
                if not name.endswith(".bias"):
                    param.data = torch.sign(param.data)
            outputs2 = model_copy(inputs)
            loss2 = criterion(outputs2, labels)
        
        optimizer.step()
        optimizer.zero_grad()
        with torch.no_grad():
            for name, param in model_copy.named_parameters():
                if not name.endswith(".bias"):
                    project = torch.logical_and(-1+(param.data+1)/(2*alpha/epsilon+1)>0, param.data>0)
                    param.data[project] = 1+(param.data[project]-1)/(2*alpha/epsilon+1)
                    param.data[~project] = -1+(param.data[~project]+1)/(2*alpha/epsilon+1)
        sum_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += predicted.eq(labels.data).cpu().sum()
        _, predicted = torch.max(outputs2.data, 1)
        total2 += labels.size(0)
        correct2 += predicted.eq(labels.data).cpu().sum()
        wandb.log(
            {
                "quantized loss": loss2,
                "loss": loss,
                "accuracy": (100 * correct / total),
                "quantized_accuracy": (100 * correct2 / total2),
            }
        )
        print(
            "[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% "
            % (
                epoch + 1,
                (i + 1 + (epoch) * length),
                sum_loss / (i + 1),
                100.0 * correct / total,
            )
        )
    print("Waiting Test...")
    with torch.no_grad():
        correct = 0
        total = 0
        for data in testloader:
            net.eval()
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()
        print(
            "Test's accuracy (before quantization) is: %.3f%%" % (100 * correct / total)
        )
    model_copy = copy.deepcopy(net)
    with torch.no_grad():
        for name, param in model_copy.named_parameters():
            param.data = torch.sign(param.data)
        correct = 0
        total = 0
        for data in testloader:
            model_copy.eval()
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model_copy(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()
        print(
            "Test's accuracy (after quantization) is: %.3f%%" % (100 * correct / total)
        )

wandb.finish()