In [None]:
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torchvision as tv
from torch import optim

In [None]:
class AdaHessian(optim.Optimizer):
    def __init__(self,params,wd = 0, mc_iters=1, lr=.001, betas=(.9,.999),eps = 1e-8, control_variate=False):
        super(AdaHessian, self).__init__(params, defaults={'lr':lr})
        self.state = dict()
        self.lr = lr
        self.betas = betas
        self.control_variate = control_variate
        self.eps = eps
        self.mc_iters = mc_iters
        self.n_steps = 0
        self.wd = wd
        for group in self.param_groups:
            for p in group['params']:
                self.state[p] = dict(mom=th.zeros_like(p.data),hess_mom=th.zeros_like(p.data))
                p.hess = th.zeros_like(p.data)
    def zero_hessian(self):
        for group in self.param_groups:
            for p in group['params']:
                p.hess = th.zeros_like(p.data)

    def set_hessian(self):
        vals = []
        params = []
        if self.control_variate:
            for group in self.param_groups:
                for p in group['params']:
                    params.append(p)
                    vals.append(p.grad - self.state[p]["hess_mom"].detach() * p * self.betas[1])
        else:
            for group in self.param_groups:
                for p in group['params']:
                    params.append(p)
                    vals.append(p.grad)


        for iter in range(self.mc_iters):
          with th.no_grad():
            z_values = [th.randn_like(p.data) for p in params]
            hz_values = th.autograd.grad(vals, params,z_values,retain_graph = (iter != self.mc_iters - 1))
            for p, z, hz in zip(params, z_values, hz_values):
                p.hess += hz * z / self.mc_iters

    def step(self,closure = None):
        loss = None
        if closure is not None:
          loss = closure()
        self.n_steps += 1
        self.zero_hessian()
        self.set_hessian()
        beta0 = self.betas[0]
        beta1 = self.betas[1]
        bias_correction_0 = 1-beta0**self.n_steps
        bias_correction_1 = 1-beta1**self.n_steps
        with th.no_grad():
          for group in self.param_groups:
              step_size = group['lr']*bias_correction_0
              for p in group['params']:
                  if self.wd != 0:
                      p.mul_(1-self.wd*self.lr)

                  mom = self.state[p]['mom']
                  mom.mul_(beta0).add_(p.grad,alpha=1-beta0)
                  hess_mom = self.state[p]['hess_mom']
                  hess_mom.mul_(beta1).add_(p.hess,alpha=1-beta1)
                  denominator = th.abs(hess_mom/bias_correction_1).pow(1/2).add_(self.eps)
                  p.addcdiv_(mom,denominator, value=-step_size)

        return loss





In [None]:
def compute_accuracy(y_hat, y):
    return th.sum(th.where(th.argmax(y_hat,-1) == y,1, 0 ))/y_hat.size(0)


In [None]:

def train_one_epoch(model, train_dataloader, optimizer, epoch):
    criterion = nn.CrossEntropyLoss()
    average_loss = 0
    average_accuracy = 0
    for idx, batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        x,y = batch
        x = x.cuda()
        y = y.cuda()
        y_hat = model(x)
        loss = criterion(y_hat,y)
        loss.backward(create_graph=True)
        accuracy = compute_accuracy(y_hat,y)
        optimizer.step()
        if idx % 100 == 0:
            print(f"epoch: {epoch} idx: {idx} \t loss = {loss} , accuracy = {accuracy*100}%")
        average_loss += loss
        average_accuracy += accuracy
    return average_loss / len(train_dataloader) , average_accuracy / len(train_dataloader)


In [None]:

def evaluate(model, val_dataloader ):
    criterion = nn.CrossEntropyLoss()
    average_loss = 0
    average_accuracy = 0
    for idx, batch in enumerate(val_dataloader):
        x,y = batch
        x = x.cuda()
        y = y.cuda()
        y_hat = model(x)
        loss = criterion(y_hat,y)
        accuracy = compute_accuracy(y_hat, y)
        average_loss += loss
        average_accuracy += accuracy
    return average_loss / len(val_dataloader) , average_accuracy / len(val_dataloader)


In [None]:
def experiment():
    epochs = 20
    bsz = 64
    lr = .001
    model = tv.models.resnet18()
    model.cuda()
    optimizer = AdaHessian(model.parameters(),lr=lr,wd=.01)

    train_transform = tv.transforms.Compose([
    tv.transforms.RandomHorizontalFlip(),
    tv.transforms.RandomCrop(32, padding=4),
    tv.transforms.ToTensor(),
    tv.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

    val_transform = tv.transforms.Compose([
    tv.transforms.ToTensor(),
    tv.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
    trainset = tv.datasets.CIFAR10("\data", download=True,transform=train_transform)
    valset = tv.datasets.CIFAR10("\data", train=False, download=True,transform=val_transform)
    train_loader = th.utils.data.DataLoader(trainset, batch_size=bsz, shuffle=True)
    val_loader = th.utils.data.DataLoader(valset, batch_size=bsz, shuffle=False)


    for i in range(epochs):
        train_loss, train_accuracy = train_one_epoch(model,train_loader,optimizer,i)
        val_loss, val_accuracy = evaluate(model,val_loader)
        print(f"Epoch [{i}/{epochs}]: train_loss={train_loss}, val_loss={val_loss}, train_acc={train_accuracy*100}%, val_acc={val_accuracy*100}%\n")

In [None]:
experiment()