In this notebook, we truely implement the Forward-mode autodiff instead of only use the simulate version

In [4]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torchvision.datasets import FashionMNIST, MNIST, CIFAR10, SVHN
from torch.autograd import Variable
from torch import autograd
import torch.autograd.forward_ad as fwAD
import torchvision
from torchvision import transforms
import torchvision.utils as vision_utils
import matplotlib.pyplot as plt
import random
import os
import time
import math


DEVICE = torch.device('cuda')
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [242]:
def switch_to_device(dataset,device=None):
    final_X, final_Y = [], []
    for x, y in dataset:
        final_X.append(x)
        final_Y.append(y)
    X = torch.stack(final_X)
    Y = torch.tensor(final_Y)
    if device is not None:
        X = X.to(device)
        Y = Y.to(device)
    return torch.utils.data.TensorDataset(X, Y)


def get_mnist_dl(batch_size_train=1024, batch_size_eval=1024, device=torch.device('cuda')):
    transform = transforms.Compose([transforms.ToTensor()])
    
    data_train = MNIST('./datasets', train=True, download=True, transform=transform)
    data_train = switch_to_device(data_train, device=device)
    data_train, data_valid = torch.utils.data.random_split(data_train, [55000,5000])
    
    data_test = MNIST('./datasets', train=False, download=True, transform=transform)
    data_test = switch_to_device(data_test, device=device)
    
    train_dl = DataLoader(data_train, batch_size=batch_size_train, shuffle=True)
    valid_dl = DataLoader(data_valid, batch_size=batch_size_eval, shuffle=False)
    test_dl = DataLoader(data_test, batch_size=batch_size_eval, shuffle=False)
    
    return train_dl, valid_dl, test_dl

In [243]:
def print_stats(stats):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (6, 3), dpi = 110)
    ax1.grid()
    ax2.grid()

    ax1.set_title("ERM loss")
    ax2.set_title("Valid Acc")

    ax1.set_xlabel("iterations")
    ax2.set_xlabel("iterations")

    itrs = [x[0] for x in stats['train-loss']]
    loss = [x[1].cpu().detach().numpy() for x in stats['train-loss']]
    ax1.plot(itrs, loss)

    itrs = [x[0] for x in stats['valid-acc']]
    acc = [x[1] for x in stats['valid-acc']]
    ax2.plot(itrs, acc)


    ax1.set_ylim(0.0, np.max(loss))
    ax2.set_ylim(0.0, 1.05)


In [244]:
@torch.no_grad()
def get_acc(model, dl, num_dir, device = DEVICE):
    acc = []

    for X, y in dl:
        #one_hot_y = torch.zeros((X.shape[0], 10), device = DEVICE)
        #one_hot_y[[i for i in range(X.shape[0])], [k.item() for k in y]] = 1
        model.forward(X, y, num_dir)
        acc.append(torch.argmax(model.output, dim=1) == y)

    acc = torch.cat(acc)
    acc = torch.sum(acc)/len(acc)

    return acc.item()

In [245]:
#这里现在用的还是最简单的binarize的方法
def Binarize(x, quant_mode = 'det'):
    if quant_mode == 'det':
        return x.sign()
    else:
        return x.add_(1).div_(2).add_(torch.rand(x.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1)

In [246]:
class self_BinarizeLinear(nn.Linear):

    def __init__(self, *kargs, **kwargs):
        super(self_BinarizeLinear, self).__init__(*kargs, **kwargs)
    
    def forward(self, num_dir, input, da = None):
        '''
        input: batchsize * input_size
        da: None or num_dir * batchsize * input_size
        '''

        if input.size(1) != 784:
            input.data = Binarize(input.data)
        
        #if torch.is_tensor(da):
            #da = da.sign()
        
        self.vector_w = torch.randn((num_dir, self.weight.data.shape[0], self.weight.data.shape[1]), device = DEVICE).sign()
        self.weight.data = Binarize(self.weight.data)
        out = nn.functional.linear(input, self.weight)
        if torch.is_tensor(da): 
            new_da = torch.matmul(da, self.weight.T) + torch.matmul(input, self.vector_w.transpose(1,2))
        else:
            new_da = torch.matmul(input, self.vector_w.transpose(1, 2))

        if not self.bias is None:
            self.vector_b = torch.randn((num_dir, 1, self.bias.shape[0]), device = DEVICE).sign()
            self.bias.data = Binarize(self.bias.data)
            out += self.bias.view(1,-1).expand_as(out)
            new_da += self.vector_b
        
        return out, new_da
    
    def update(self, da, lr):
        gw = da*self.vector_w
        gw = torch.mean(gw, dim = 0)
        self.weight.data -= lr*gw
        self.weight.data = self.weight.data.sign()
        if not self.bias is None:
            gb = da*self.vector_b
            gb = torch.mean(gb, dim = 0)
            self.bias.data -= (lr*gb).view(-1)
            self.bias.data = self.bias.data.sign()


In [247]:
class self_Hardtanh(nn.Hardtanh):
    def __init__(self, *kargs, **kwargs):
        super(self_Hardtanh, self).__init__(*kargs, **kwargs)
    
    def forward(self, input, da):
        new_da = da.clone()
        new_da[input.expand_as(da)>1] = 0
        new_da[input.expand_as(da)<-1] = 0
        out = input.clamp(-1, 1)
        return out, new_da

In [248]:
class self_batchnorm1d():
    def __init__(self, dim):
        self.dim = dim
        self.BN = nn.BatchNorm1d(self.dim).to(DEVICE)
    
    def forward(self, input, da):
        num_dir = da.shape[0]
        self.vector_w = torch.randn((num_dir, self.BN.weight.data.shape[0]), device = DEVICE).sign()
        self.vector_b = torch.randn((num_dir, self.BN.bias.data.shape[0]), device = DEVICE).sign()
        new_da = torch.zeros(da.shape, device = DEVICE)
        for i in range(num_dir):
            with fwAD.dual_level():
                dual_input = fwAD.make_dual(input, da[i])
                dual_output = self.BN(dual_input)
                jvp = fwAD.unpack_dual(dual_output).tangent
                mid = (dual_output - self.BN.bias)/self.BN.weight
                new_da[i] = jvp + mid*self.vector_w[i] + self.vector_b[i]
        out = self.BN(input)
        return out, new_da
    
    def update(self, da, lr):
        gw = da.view(-1, 1)*self.vector_w
        gw = torch.mean(gw, dim = 0)
        self.BN.weight.data -= lr*gw
        gb = da.view(-1, 1)*self.vector_b
        gb = torch.mean(gb, dim = 0)
        self.BN.bias.data -= lr*gb

In [249]:
class self_Softmax_CrossEntropy():
    def __init__(self):
        self.CE = nn.CrossEntropyLoss()
    
    def forward(self, input, da, y):
        num_dir = da.shape[0]
        new_da = torch.zeros(num_dir, device = DEVICE)
        for i in range(num_dir):
            with fwAD.dual_level():
                dual_input = fwAD.make_dual(input, da[i])
                dual_output = self.CE(dual_input, y)
                jvp = fwAD.unpack_dual(dual_output).tangent
                new_da[i] = jvp
        
        loss = self.CE(input, y)
        return loss, new_da

In [250]:
'''
class self_Softmax_CrossEntropy():
    def forward(self, input, da, labels):
        exp_z = torch.exp(input)
        sum_exp_z = torch.sum(exp_z, dim=1).reshape(input.shape[0], 1)
        softmax_z = exp_z/sum_exp_z
        loss = torch.sum(-(labels*torch.log(softmax_z))) / input.shape[0]
        #new_da = torch.sum(da*(softmax_z - labels))/input.shape[0]
        new_da = torch.zeros([da.shape[0], 1, 1], dtype = float, device=DEVICE)
        for i in range(da.shape[0]):
            new_da[i] = torch.sum(da[i]*(softmax_z - labels))/input.shape[0]
        return loss, new_da
'''

'\nclass self_Softmax_CrossEntropy():\n    def forward(self, input, da, labels):\n        exp_z = torch.exp(input)\n        sum_exp_z = torch.sum(exp_z, dim=1).reshape(input.shape[0], 1)\n        softmax_z = exp_z/sum_exp_z\n        loss = torch.sum(-(labels*torch.log(softmax_z))) / input.shape[0]\n        #new_da = torch.sum(da*(softmax_z - labels))/input.shape[0]\n        new_da = torch.zeros([da.shape[0], 1, 1], dtype = float, device=DEVICE)\n        for i in range(da.shape[0]):\n            new_da[i] = torch.sum(da[i]*(softmax_z - labels))/input.shape[0]\n        return loss, new_da\n'

In [251]:
class BMLP_1():
    def __init__(self, device = DEVICE):
        self.device = device

        self.fc_1 = self_BinarizeLinear(28*28, 1024, bias=True, device = DEVICE)
        self.htan_1 = self_Hardtanh()
        self.bn_1 = self_batchnorm1d(1024)

        self.fc_2 = self_BinarizeLinear(1024, 1024, bias=True, device = DEVICE)
        self.htan_2 = self_Hardtanh()
        self.bn_2 = self_batchnorm1d(1024)

        self.fc_3 = self_BinarizeLinear(1024, 1024, bias=True, device = DEVICE)
        self.htan_3 = self_Hardtanh()
        self.bn_3 = self_batchnorm1d(1024)

        self.fc_4 = self_BinarizeLinear(1024, 10, bias=True, device = DEVICE)
        self.CrossEntropy = self_Softmax_CrossEntropy()
    
    def forward(self, input, labels, num_dir):
        x = torch.reshape(input, (input.shape[0], 28*28))
        da = None
        x, da = self.fc_1(num_dir, x)
        x, da = self.htan_1(x, da)
        x, da = self.bn_1.forward(x, da)

        x, da = self.fc_2(num_dir, x, da)
        x, da = self.htan_2(x, da)
        x, da = self.bn_2.forward(x, da)

        x, da = self.fc_3(num_dir, x, da)
        x, da = self.htan_3(x, da)
        x, da = self.bn_3.forward(x, da)

        x, da = self.fc_4(num_dir, x, da)
        self.output = x
        loss, da = self.CrossEntropy.forward(x, da, labels)
        self.loss = loss
        self.da = da.view(-1, 1, 1)
    
    def update(self, lr):
        self.fc_1.update(self.da, lr)
        self.fc_2.update(self.da, lr)
        self.fc_3.update(self.da, lr)
        self.fc_4.update(self.da, lr)
        self.bn_1.update(self.da, lr)
        self.bn_2.update(self.da, lr)
        self.bn_3.update(self.da, lr)

        

In [252]:
def run_experiment(model, train_dl, valid_dl, test_dl, max_epochs = 20, lr = 1e-2, num_dir = 1):
    itr = -1
    stats = {'train-loss': [], 'valid-acc': []}

    for epoch in range(max_epochs):
        for X, y in train_dl:
            itr += 1
            #one_hot_y = torch.zeros(X.shape[0], 10).to(DEVICE)
            #one_hot_y[[i for i in range(X.shape[0])], [k.item() for k in y]] = 1
            model.forward(X, y, num_dir)
            model.update(lr)
            stats['train-loss'].append((itr, model.loss.item()))

            if itr % 20 == 0:
                valid_acc = get_acc(model, valid_dl, num_dir)
                stats['valid-acc'].append((itr, valid_acc))
                s = f"{epoch}:{itr} [train] loss:{model.loss.item():.3f}, [valid] acc:{valid_acc:.3f}"
                print(s)
    test_acc = get_acc(model, test_dl, num_dir)
    print(f"[test] acc:{test_acc:.3f}")
    return stats

In [253]:
train_batch = 1024
valid_batch = 1024

model = BMLP_1()
max_epochs = 200
lr = 1e-1
num_dir = 20

In [254]:
train_dl, valid_dl, test_dl = get_mnist_dl(batch_size_train=train_batch, batch_size_eval=valid_batch, device=DEVICE)
stats = run_experiment(model, train_dl, valid_dl, test_dl, max_epochs, lr, num_dir)
print_stats(stats)

0:0 [train] loss:51.492, [valid] acc:0.073
0:20 [train] loss:52.000, [valid] acc:0.076
0:40 [train] loss:52.041, [valid] acc:0.077


KeyboardInterrupt: 

In [5]:
import torch
import torch.autograd.forward_ad as fwAD

primal_input = torch.tensor([[1.,2.,3.], [4.,5.,6.]])
primal_weight = torch.ones(3,4)
primal_bias = torch.tensor([1.,2.,3.,4.])
tangent_input = torch.tensor([[2., 7., 2.], [3., 5., 1.]])
tangent_weight = torch.ones(3, 4)
tangent_bias = torch.ones(4)

def fn(input, weight, bias):
    return torch.matmul(input, weight) + bias

# All forward AD computation must be performed in the context of
# a ``dual_level`` context. All dual tensors created in such a context
# will have their tangents destroyed upon exit. This is to ensure that
# if the output or intermediate results of this computation are reused
# in a future forward AD computation, their tangents (which are associated
# with this computation) won't be confused with tangents from the later
# computation.
with fwAD.dual_level():
    # To create a dual tensor we associate a tensor, which we call the
    # primal with another tensor of the same size, which we call the tangent.
    # If the layout of the tangent is different from that of the primal,
    # The values of the tangent are copied into a new tensor with the same
    # metadata as the primal. Otherwise, the tangent itself is used as-is.
    #
    # It is also important to note that the dual tensor created by
    # ``make_dual`` is a view of the primal.
    dual_input = fwAD.make_dual(primal_input, tangent_input)
    dual_weight = fwAD.make_dual(primal_weight, tangent_weight)
    dual_bias = fwAD.make_dual(primal_bias, tangent_bias)
    # Tensors that do not not have an associated tangent are automatically
    # considered to have a zero-filled tangent of the same shape.
    dual_output = fn(dual_input, dual_weight, dual_bias)

    # Unpacking the dual returns a namedtuple with ``primal`` and ``tangent``
    # as attributes
    jvp = fwAD.unpack_dual(dual_output).tangent

assert fwAD.unpack_dual(dual_output).tangent is None
print(jvp)

tensor([[18., 18., 18., 18.],
        [25., 25., 25., 25.]])


In [120]:
import torch.nn as nn

model = nn.BatchNorm1d(1024)
primal_input = torch.tensor([[1.,2.,3.], [4.,5.,6.]])
tangent_input = torch.tensor([[2., 7., 2.], [3., 5., 1.]])

params = {name: p for name, p in model.named_parameters()}

with fwAD.dual_level():
    for name, p in params.items():
        print(name)
        print(p.shape)
        '''
        if name == 'weight':
            delattr(model, name)
            p.data = torch.ones(4, 3)
            setattr(model, name, fwAD.make_dual(p, torch.ones(4, 3)))
            setattr(model, name, p)
        else:
            delattr(model, name)
            p.data = torch.tensor([1.,2.,3.,4.])
            setattr(model, name, fwAD.make_dual(p, torch.ones(4)))
            setattr(model, name, p)
    dual_input = fwAD.make_dual(primal_input, tangent_input)
    #print(fwAD.unpack_dual(model.weight).primal)
    #fwAD.unpack_dual(dual_input).tangent =torch.tensor([[2., 7., 2.], [3., 5., 2.]])
    #print(dual_input)
    out = model(primal_input)
    jvp = fwAD.unpack_dual(out).tangent
res = model(primal_input)
print(model.weight)
print(res)
print(jvp)
print(out)
params = {name: p for name, p in model.named_parameters()}
print(params)
'''

weight
torch.Size([1024])
bias
torch.Size([1024])


In [119]:
lin = nn.Linear(10, 2)
lin.weight.data[0][0] = 1.
x = torch.randn(1, 10)
output = lin(x)
output.mean().backward()