# Imports

In [4]:
import torch
import torch.autograd as autograd

# Testing Stuff

In [None]:
x = torch.ones(2, 2, requires_grad=True)
print(x)
y = x + 2
print(y)
z = y * y * 3
out = z.mean()

print(z)
print(out)


print(autograd.grad(out, x))


# Hessian Vector Product Trick (from Hongkai)

In [5]:
def Hvp_vec(grad_vec, params, vec, retain_graph=False):
    '''
    Parameters:
        - grad_vec: Tensor of which the Hessian vector product will be computed
        - params: list of params, w.r.t which the Hessian will be computed
        - vec: The "vector" in Hessian vector product
    return: Hessian vector product
    '''
    if torch.isnan(grad_vec).any():
        raise ValueError('Gradvec nan')
    if torch.isnan(vec).any():
        raise ValueError('vector nan')
        # zero padding for None
    grad_grad = autograd.grad(grad_vec, params, grad_outputs=vec, retain_graph=retain_graph,
                              allow_unused=True)
    grad_list = []
    for i, p in enumerate(params):
        if grad_grad[i] is None:
            grad_list.append(torch.zeros_like(p).view(-1))
        else:
            grad_list.append(grad_grad[i].contiguous().view(-1))
    hvp = torch.cat(grad_list)
    if torch.isnan(hvp).any():
        raise ValueError('hvp Nan')
    return hvp


# Function to Avp

In [23]:
def avp(
    vector_list,
    loss_list,
    param_list,
    lr,
    transpose=False
):
    """
    :param vector_list: list of vectors for each player
    :param loss_list: list of objective functions for each player
    :param param_list: list of parameter vectors for each player
    :param lr: learning rate
    """
    # TODO(jjma): add error handling and assertions
    # TODO(jjma): add transpose case
    
    prod_list = [torch.zeros(param.shape[0]) for param in param_list]
    
    for i, (loss, row_param) in enumerate(zip(loss_list, param_list)):
        for j, (col_param, vector_elem) in enumerate(zip(param_list, vector_list)):
            if i == j:
                prod_list[i]._add(col_param)
                continue

            grad_param = autograd.grad(loss, col_param, create_graph=True, retain_graph=True)
            grad_param_vec = torch.cat([g.continuous().view(-1) for g in grad_param])

            if torch.isnan(grad_param_vec).any():
                raise ValueError('grad_param_vec nan')

            hessian_vec_prod = Hvp_vec(grad_param_vec, row_param, vector_elem)

            prod_list[i]._add(lr * hessian_vec_prod)
                
    Avp = torch.cat(prod_list)
    if torch.isnan(Avp).any():
        raise ValueError('Avp nan')
    
    return Avp
    


# Testing Meta-Matrix Product