# PcGrad Implementation using Python
Pc-Grad is an optimized gradient descent algorithm which is mainly used for multitask-training.

In [46]:
import logging

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pdb
import numpy as np
import copy
import random

from fontTools.svgLib.path import shapes
from torch.optim.adamw import adamw

In [25]:
class PCGrad():
    def __init__(self, optimizer, reduction='mean'):
        self.optimizer = optimizer
        return

Use a seed for replication and assign random values for input and expected results for testing. Further specifiy a fixed
tensor assuming it's a prediction.

In [5]:
torch.manual_seed(4)
input, expectations = torch.randn(2, 3), torch.randn(2, 4)
predictions = torch.tensor([[-1.6053,  0.2325,  2.2399],
                           [ 0.8473,  1.2006, -0.4016]])
print(input,predictions)


tensor([[-1.6053,  0.2325,  2.2399],
        [ 0.8473,  1.2006, -0.4016]]) tensor([[-1.6053,  0.2325,  2.2399],
        [ 0.8473,  1.2006, -0.4016]])


Compute the losses of the current predictions from the model. First use the derivation of the MSE functions
and then apply the loss functions on the input and predictions done by the model, this is
our starting point for PCGrad.

In [6]:
loss1_fn, loss2_fn = nn.L1Loss(), nn.MSELoss()
print(loss1_fn(input, predictions), loss2_fn(input, predictions))

tensor(2.8007e-05) tensor(1.0211e-09)


Create a random tensor which is considered as our parameters. Then initialize an optimizer.

In [29]:
parameter = torch.randn(1, 3)
optimizer = torch.optim.Adam(params=[parameter])

Now we compute the projected conflicted gradient or the gradients for the give loss objectives.

In [30]:
def pc_backward(objectives):
    grads, shapes, has_grads = pc_package_grad(objectives)
    pc_grad = pc_project_conflicting(grads, has_grads)
    pc_Grad = pc_unflatten_grad(pc_grad, shapes[0])
    self.pc_set_grad(pc_grad)
    return

Specify parameters from optimizer and package them into an array of gradient values, shapes, and present bits.

In [31]:
def pc_package_grad(objectives):
    grads, shapes, has_grads = [], [], []
    for objective in objectives:
        optimizer.zero_grad(set_to_none=True)
        objective.backward(retain_grad=True)
        grad, shape, has_grad = pc_retrieve_grad()
        grads.append(flatten_grads(grads, shapes))
        has_grads.append(flatten_grads(has_grad, shape))
        shapes.append(shape)
    return grads, shapes, has_grads

The function will get the gradient of the parameters form the given model/network with specific objective.
It will differentiate between gradients that actually present in the given parameter tensor or not.


In [12]:
def pc_retrieve_grad():
    grad, shape, has_grad = [], [], []
    for group in optimizer.param_groups:
        for param in group['params']:
            if param.grad is None:
                shape.append(param.shape)
                grad.append(torch.zeros_like(param).to(param.device))
                has_grad.append(torch.zeros_like(param).to(param.device))
                continue
            shape.append(param.grad.shape)
            grad.append(param.grad.clone())
            has_grad.append(torch.ones_like(param).to(param.device))
    return grad, shape, has_grad


Flattening is a technique that is used to convert multi-dimensional arrays into 1-D array. In this case concat all the gradients to a 1-Dimensional Array

In [34]:
def flatten_grads(grads, shapes):
    flatten_grad = torch.cat([g.flatten() for g in grads])
    return flatten_grad

ex_grads = torch.randn(4,4)
print("grads: ", ex_grads, " shapes: ", ex_grads.shape)
flatted_ex_grads = flatten_grads(ex_grads, ex_grads.shape)
print("flatted grads: ", flatted_ex_grads)


grads:  tensor([[-0.6889, -0.8556, -0.0143, -0.4549],
        [ 1.5472, -1.5344, -0.5646, -0.2765],
        [-1.0467, -1.1993,  0.2494, -1.2676],
        [ 0.7497, -0.1904,  1.0569,  1.0551]])  shapes:  torch.Size([4, 4])
flatted grads:  tensor([-0.6889, -0.8556, -0.0143, -0.4549,  1.5472, -1.5344, -0.5646, -0.2765,
        -1.0467, -1.1993,  0.2494, -1.2676,  0.7497, -0.1904,  1.0569,  1.0551])


This method will compare the projecting direction of the different gradients computed for the multiple tasks. If they do not project in the same direction
they will be corrected and merged in the end.

In [56]:
def pc_project_conflicting(grads, has_grads, reduction="sum", shapes=None):
    # Boolean mask: coordinates present ALL tasks
    shared = torch.stack(has_grads, dim=0).all(dim=0)
    # Work on a deepcopy so we don't mutate the caller's grads
    pc_grad = [g.clone() for g in grads]
    num_task = len(has_grads)
    # For each task i, project away components that conflict with other
    # tasks' gradients
    for i in range (num_task):
        order = list(range(num_task))
        print("order ", order)
        random.shuffle(order)
        for j in order:
            if i == j:
                continue # Edge case handling
            gi, gj =  pc_grad[i], grads[j]
            print("gi, gj: ", gi, gj)
            dot = torch.dot(gi, gj) # provides the direction
            print("dot product: ", dot)
            if dot < 0: # gradients are pointing in different directions
                denorm = gj.norm().pow(2) # simple algebra solves for y and denormalize and makes positive
                print("denorm: ", denorm)
                if denorm > 0:
                    # Correct projection, subtract component of gi along gj
                    pc_grad[i] = gi - (dot/denorm) * gj
                    gi = pc_grad[i] # saving also for gi for next iteration
                    print("gi: ", gi)
    stacked = torch.stack(pc_grad, dim=0) # concat the corrected gradient
    if reduction == "sum":
        merged = stacked.sum(dim=0)
    elif reduction == "mean":
        merged = stacked.mean(dim=0)
    else:
        raise Exception("reduction must be 'sum' or 'mean'")
    return merged


Expending the gradient tensor in to the desired shape.

In [None]:
def pc_unflatten_grad(grads, shapes):
    unflatten_grad, idx = [], 0
    for shape in shapes:
        length = np.prod(shape)
        unflatten_grad.append(grads[idx:idx+length].view(shape).clone())
        idx += length
    return unflatten_grad




Sets the corrected gradients back again to the desired parameters from the orginal
optimizer.

In [57]:
def pc_set_grad(self, grads):
    idx = 0
    for group in self.optimizer.para_groups:
        for p in group['params']:
            p.grad = grads[idx]
            idx += 1
    return
