# GradVac Optimizer

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random

The following implements a multi-task test setup for neural network. Generally a backbone refers to some “base model” of the overall architecture, which can be changed for another one.

In [4]:
class MultiTaskScaffold(nn.Module):
    def __init__(self, in_dim=64, hid=64, k1=5, k2=3):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Linear(in_dim, hid), nn.ReLU(),
            nn.Linear(hid, hid), nn.ReLU(),
        )
        self.head1 = nn.Linear(hid, k1)
        self.head2 = nn.Linear(hid, k2)

    def forward(self, x):
        h = self.backbone(x)
        return self.head1(h), self.head2(h)

    @staticmethod
    def backbone_params(self, model):
        # only the shared backbone params
        return [p for p in model.backbone.parameters() if p.requires_grad]

    @staticmethod
    def zero_all_grads(self, model):
        for p in model.parameters():
            if p.grad is not None:
                p.grad.zero_()

def backbone_params(model):
    return [p for p in model.backbone.parameters() if p.requires_grad]

def zero_all_grads(model):
    for p in model.parameters():
        if p.grad is not None:
            p.grad.zero_()

def flat_grad(params):
    return torch.cat([p.grad.view(-1) for p in
                      params if p.grad is not None])


In [5]:
class GradVac:
    @staticmethod
    def collect_gradients(model, loss1, loss2):
        zero_all_grads(model); loss1.backward(retain_graph=True)
        g1 = flat_grad(backbone_params(model)).detach()

        zero_all_grads(model); loss2.backward(retain_graph=True)
        g2 = flat_grad(backbone_params(model)).detach()

        return g1, g2

    @staticmethod
    def norms(g1, g2):
        return g1.norm(), g2.norm()

    @staticmethod
    def cosine(g1, g2, eps=1e-12):
        n1, n2 = g1.norm(), g2.norm()
        return (g1 @ g2) / (n1 * n2 + eps)

    def _unit(v, eps=1e-12):
        n = v.norm() + eps
        return v / n, n
    def gradvac_pair(g1, g2, target=0.0, alpha=0.5):
    u1, n1 = _unit(g1)
    u2, n2 = _unit(g2)
    s = torch.clamp(torch.dot(u1, u2), -1.0, 1.0)
    d1 = (u2 - s * u1)
    d2 = (u1 - s * u2)

    g1_new = g1 + alpha * (target - s) * d1 * n1
    g2_new = g2 + alpha * (target - s) * d2 * n2
    return g1_new, g2_new