# grad-corr and dir-sharpness

Objective
- compute gradient correlation and
- compute directional sharpness
- as presented in https://arxiv.org/pdf/2306.00204.pdf

$$f(x_{t+1}) = f(x_t) + \nabla f(x_t) ^T (x_{t+1} - x_t) + \frac{1}{2} (x_{t+1}-x_t) \nabla^2 f(x_t) (x_{t+1}-x_t) + O( \eta ^3)$$ 

# load utils to go ahead

In [1]:
from zeroptim.mlp import MLP
from zeroptim.data import loader
from zeroptim.utils import parse_yaml_config
import torch

In [2]:
dataloader = loader('mnist-digits')
m = MLP(**parse_yaml_config('mlp.yaml'))
opt = torch.optim.SGD(m.parameters(), lr=0.1)
crit = torch.nn.CrossEntropyLoss()
m.norm()

tensor(28.3087, grad_fn=<AddBackward0>)

In [3]:
m

MLP(
  (act_func): ReLU()
  (out_func): Softmax(dim=1)
  (layers): Sequential(
    (0): Sequential(
      (0): Linear(in_features=784, out_features=128, bias=True)
      (1): ReLU()
    )
    (1): Sequential(
      (0): Linear(in_features=128, out_features=64, bias=True)
      (1): ReLU()
    )
    (2): Linear(in_features=64, out_features=10, bias=True)
  )
)

# perform a single opt step

In [4]:
# fix (inputs, targets) for a batch
inputs, targets = next(iter(dataloader))
inputs = inputs.flatten(start_dim=1)

# store a clone of model m at time t
m_ = m.clone()

# optimization step on original model m
opt.zero_grad()
loss = crit(m(inputs), targets)
loss.backward()
opt.step()

# compute the parameters update step
update_step = [p.detach() - p_.detach() for p, p_ in zip(m.parameters(), m_.parameters())]

# verify there was a change in model performances
print(crit(m_(inputs), targets))
print(crit(m(inputs), targets))

tensor(2.3130, grad_fn=<NllLossBackward0>)
tensor(2.3015, grad_fn=<NllLossBackward0>)


In [5]:
print(len(update_step)) # one update-step entry per weights&biases
print(update_step[0].shape) # verify that shape of layer1 is correct

6
torch.Size([128, 784])


# try to compute grad-corr using `jvp`
- the thing is that `jvp` was intended to compute gradients wrt. inputs
- not wrt. the model parameters; so we need to hack it a little bit

## failed tentative
- didnt understand properly how jvp computed Jacobian at first

In [6]:
# say we want to compute jvp for first layer only
p_ = next(m_.parameters()).detach()
p = next(m.parameters()).detach()
ps = p - p_

In [7]:
torch.autograd.functional.jvp(
    func = lambda p: crit(m_(inputs), targets), # previous timestep model
    inputs = p_, # previous timestep parameters (dummy input to trick jvp)
    v = ps, # parameters update step, the direction where we stepped
)

(tensor(2.3130), tensor(0.))

In [8]:
# the jvp result is null. this make sense as 
# jvp perturbs func() with small steps in inputs in direction of v
# and because our func is a constant, then jvp is null

## second tentative
- bit hacky and pretty ugly but should work in theory
- don't understand why it does not lead to some jvp value

In [9]:
# verify both things point to the same weight matrix
print(m.state_dict()['layers.0.0.weight'].shape)
print(next(m.parameters()).shape)

torch.Size([128, 784])
torch.Size([128, 784])


In [10]:
def func(params: torch.Tensor, wnb_id='layers.0.0.weight'):
    model = m_.clone() # clone original model (timestep-1)
    sd = model.state_dict() # get params dict of that model
    sd[wnb_id] = params # hack and update a specific layer
    model.load_state_dict(sd) # load updated hacked state dict
    return crit(model(inputs), targets) # loss wrt. new params

In [11]:
torch.autograd.functional.jvp(
    func = func, # hacky function to compute jvp for a layer at least
    inputs = next(m_.parameters()).detach(), # 'layers.0.0.weight'
    v = ps, # parameters update step (direction)
)

(tensor(2.3130), tensor(0.))

In [12]:
# really confused about why it's not working either
# jvp should be able to perturb func() with params in direction of v
# and return the jvp of func wrt. params in direction of v

In [13]:
# anyway this is nothing efficient...

# just do it manually

In [14]:
# fix (inputs, targets) for a batch
inputs, targets = next(iter(dataloader))
inputs = inputs.flatten(start_dim=1)

# store a clone of model m at time t
# could also only store p_ = [p.detach() for p in m.parameters()]
m_ = m.clone()

# optimization step on original model m
opt.zero_grad()
loss = crit(m(inputs), targets)
loss.backward()
opt.step()

# parameter update step or direction
vs = [
    p.detach() - p_.detach() 
    for p, p_ in zip(m.parameters(), m_.parameters())
]

# compute jvp manually;
# sum the product of grad_i * update_step_i for each parameter i
# WARNING: need to make sure p.grad still exist at this point
gcorr = sum(
    (p.grad * v).sum() 
    for p, v in zip(m.parameters(), vs)
)

# print final results
print("grad-corr:", gcorr)

# verify there was a change in model performances
print(crit(m_(inputs), targets))
print(crit(m(inputs), targets)) 

grad-corr: tensor(-0.0190)
tensor(2.2829, grad_fn=<NllLossBackward0>)
tensor(2.2639, grad_fn=<NllLossBackward0>)


# same for directional sharpness

In [15]:
# fix (inputs, targets) for a batch
inputs, targets = next(iter(dataloader))
inputs = inputs.flatten(start_dim=1)

# store previous parameters
p_ = [p.detach().clone() for p in m.parameters()]

# compute first order gradients
opt.zero_grad()
loss = crit(m(inputs), targets)
loss.backward()
opt.step()

# store gradients and update-directions
gs = [p.grad for p in m.parameters()]
vs = [p2 - p1 for p2, p1 in zip(m.parameters(), p_)]

# compute gradient correlations (grad * Dw)
gcorr = sum((g * v).sum()for g, v in zip(gs, vs))

# compute directional sharpness (Dw * H * Dw)
gs = [g.requires_grad_(True) for g in gs]
vs = [v.requires_grad_(True) for v in vs]
p_ = [p.requires_grad_(True) for p in p_]
# vHv = 

# print final results
print("grad-corr:", gcorr)
# print("dir-sharp:", vHv)

grad-corr: tensor(-0.0133, grad_fn=<AddBackward0>)


In [16]:
# get the same problem when using torch.autograd.functional.vhp

In [17]:
def func(*params):
    # re-assing parameters to model
    with torch.no_grad():
        for p, new_p in zip(m.parameters(), params):
            p.copy_(new_p.data)
    return crit(m(inputs), targets)

# compute the vector-Hessian-product
_, vhp = torch.autograd.functional.vhp(func, tuple(m.parameters()), tuple(v for v in vs))

# flatten vhp and compute the dot product with flattened v
v_flat = torch.cat([v.flatten() for v in vs])
vhp_flat = torch.cat([vhpi.flatten() for vhpi in vhp])
vHv = torch.dot(vhp_flat, v_flat)

# print the result
print("dir-sharp:", vHv)

dir-sharp: tensor(0., grad_fn=<DotBackward0>)


In [18]:
# try to compute manually

In [19]:
torch.autograd.grad(gs, p_, vs, allow_unused=True) # hum

(None, None, None, None, None, None)