In [1]:
import torch
from torch.autograd import forward_ad
from torch.nn.utils import parameters_to_vector


def log_softmax(x):
    return (x.exp()/x.exp().sum()).log()


def softmax(x):
  return x.exp()/x.exp().sum()


def criterion(y, y_hat):
    pred = log_softmax(y)
    target = torch.nn.functional.one_hot(y_hat, num_classes=10)
    return -torch.mul(target, pred).mean()

def l2_loss(y, y_hat):
    return torch.abs(y-y_hat).mean()


In [6]:
# FORWARD AD FUNCTIONS 

def preparation(model):
    pointer = 0
    indices = []
    for p in model.parameters():
        p.requires_grad_(False)  # turn off grad for the original parameter node
        num_param = p.numel()
        index = torch.tensor(range(pointer, num_param + pointer))
        indices.append(index)
        pointer += num_param
    return indices


def forward_optim(model, criteria, x, y, indices, eta=2e-4):
    """
    Function performing forward pass + SGD
    """
    def f(params: torch.Tensor):
        for p, index in zip(model.parameters(), indices):
            p.mul_(0)
            p.add_(params.index_select(0, index).view_as(p.data))
        out = model(x)
        loss = criteria(out, y)
        return loss

    v = [torch.normal(mean=0, std=1, size=p.shape)
         for p in model.parameters()]

    primal = parameters_to_vector(model.parameters())
    tangent = parameters_to_vector(v)

    with torch.no_grad():
        with forward_ad.dual_level():
            dual = forward_ad.make_dual(primal, tangent)
            rst = f(dual)
            my_loss, jvp = forward_ad.unpack_dual(rst)

        # here SGD starts
        i = 0
        for name, param in model.named_parameters():
            if 'weight' in name:
                param.add_(-eta*jvp*tangent[i])
                i += 1

    return my_loss

A small test on linear regression

In [5]:
torch.manual_seed(111)

class LinearRegressor(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.regressor = torch.nn.Linear(
            in_features=3, out_features=1, bias=False)

    def forward(self, X):
        return self.regressor(X)

# data creation for lin regression ------
x1 = torch.arange(0, 10).float().unsqueeze(-1)
x2 = torch.arange(0, 10).float().unsqueeze(-1)
x3 = torch.arange(0, 10).float().unsqueeze(-1)
X = torch.cat((x1, x2, x3), dim=1)
eps = torch.normal(0, .03, (10, 3))
X[:, (0, 1, 2)] += (eps)
y = torch.arange(0, 10).float().unsqueeze(-1)
# ------------------------------------------


model = LinearRegressor()
indeces = preparation(model)  # calling the preparation ALWAYS CALL AFTER definition of model

# training
for epoch in range(400):
    my_loss = forward_optim(model, l2_loss, X, y, indeces, eta=2e-4)
    if epoch % 20 == 19:
        print(f"Epoch [{epoch+1}] /  Loss : [{my_loss}]")


Epoch [20] /  Loss : [3.374274492263794]
Epoch [40] /  Loss : [3.0811939239501953]
Epoch [60] /  Loss : [2.7915120124816895]
Epoch [80] /  Loss : [2.6944334506988525]
Epoch [100] /  Loss : [2.492323398590088]
Epoch [120] /  Loss : [2.2233099937438965]
Epoch [140] /  Loss : [1.8233766555786133]
Epoch [160] /  Loss : [1.5074561834335327]
Epoch [180] /  Loss : [1.3212858438491821]
Epoch [200] /  Loss : [1.1421191692352295]
Epoch [220] /  Loss : [0.8348051905632019]
Epoch [240] /  Loss : [0.6050833463668823]
Epoch [260] /  Loss : [0.30616238713264465]
Epoch [280] /  Loss : [0.11028136312961578]
Epoch [300] /  Loss : [0.02170298434793949]
Epoch [320] /  Loss : [0.021589595824480057]
Epoch [340] /  Loss : [0.02150806598365307]
Epoch [360] /  Loss : [0.021565936505794525]
Epoch [380] /  Loss : [0.02149207890033722]
Epoch [400] /  Loss : [0.021580006927251816]
