# Imports

In [1]:
import torch
import torch.autograd as autograd

# Testing Stuff

In [2]:
x = torch.ones(2, 2, requires_grad=True)
print(x)
y = x + 2
print(y)
z = y * y * 3
out = z.mean()

print(z)
print(out)


print(autograd.grad(out, x))


tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
tensor([[3., 3.],
        [3., 3.]], grad_fn=<AddBackward0>)
tensor([[27., 27.],
        [27., 27.]], grad_fn=<MulBackward0>)
tensor(27., grad_fn=<MeanBackward0>)
(tensor([[4.5000, 4.5000],
        [4.5000, 4.5000]]),)


# Function to Avp

In [19]:
def avp(
    loss_list,
    param_list,
    vector,
    lr=1,
    retain_graph=False,
    transpose=False
):
    """
    :param vector_list: list of vectors for each player
    :param loss_list: list of objective functions for each player
    :param param_list: list of parameter vectors for each player
    :param lr: learning rate
    """
    # TODO(jjma): add error handling and assertions
    # TODO(jjma): add transpose case
    
    prod_list = [torch.zeros(param.shape[0]) for param in param_list]
    
    for i, row_param in enumerate(param_list):
        for j, (col_param, vector_elem) in enumerate(zip(param_list, vector)):
            if i == j:
                prod_list[i] += vector_elem
                continue

            loss = loss_list[i] if not transpose else loss_list[j]
                
            grad_param = autograd.grad(loss, col_param, 
                                       create_graph=retain_graph,
                                       retain_graph=retain_graph,
                                       allow_unused=True)
            grad_param_vec = torch.cat([g.contiguous().view(-1) for g in grad_param])
            grad_vec_prod = torch.dot(grad_param_vec, vector_elem)
            
            if torch.isnan(grad_param_vec).any():
                raise ValueError('grad_param_vec nan')

            hvp = autograd.grad(grad_vec_prod, row_param, 
                                retain_graph=retain_graph, 
                                allow_unused=True)
            hvp_vec = torch.cat([g.contiguous().view(-1) for g in hvp])
            
            prod_list[i] += (lr * hvp_vec)
           
    Avp = torch.stack(prod_list)
    if torch.isnan(Avp).any():
        raise ValueError('Avp nan')
    
    return Avp
    


# Testing Meta-Matrix Product

## Two Player Case

In [22]:
x_param = torch.tensor([1.0, 1.0], requires_grad=True)
y_param = torch.tensor([-1.0, -1.0], requires_grad=True)

# x_loss = torch.dot(x_param, x_param) * torch.dot(y_param, y_param)
# y_loss =  - torch.dot(x_param, x_param) * torch.dot(y_param, y_param)

x_loss = torch.sum(torch.pow(x_param, 3)) * torch.sum(torch.pow(y_param, 3))
y_loss = - torch.sum(torch.pow(x_param, 3)) * torch.sum(torch.pow(y_param, 3))

b = torch.tensor([[1.0, 1.0], [1.0, 1.0]])

result1 = avp([x_loss, y_loss], [x_param, y_param], b, 1, transpose=False, retain_graph=True)
result2 = avp([x_loss, y_loss], [x_param, y_param], b, 1, transpose=True, retain_graph=True)

print(result1)
print(result2)


tensor([[ 19.,  19.],
        [-17., -17.]])
tensor([[-17., -17.],
        [ 19.,  19.]])


#  Three Player Case


In [24]:
x_param = torch.tensor([1.0, 1.0], requires_grad=True)
y_param = torch.tensor([1.0, 1.0], requires_grad=True)
z_param = torch.tensor([1.0, 1.0], requires_grad=True)
x
# x_loss = torch.dot(x_param, x_param) * torch.dot(y_param, y_param)
# y_loss =  - torch.dot(x_param, x_param) * torch.dot(y_param, y_param)

x_loss = torch.sum(torch.pow(x_param, 3)) * torch.sum(torch.pow(y_param, 3)) * torch.sum(torch.pow(z_param, 3))
y_loss = torch.sum(torch.pow(x_param, 3)) * torch.sum(torch.pow(y_param, 3)) * torch.sum(torch.pow(z_param, 3))
z_loss = torch.sum(torch.pow(x_param, 3)) * torch.sum(torch.pow(y_param, 3)) * torch.sum(torch.pow(z_param, 3))

b = torch.tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]])

result1 = avp([x_loss, y_loss, z_loss], [x_param, y_param, z_param], b, 1, transpose=False, retain_graph=True)
result2 = avp([x_loss, y_loss, z_loss], [x_param, y_param, z_param], b, 1, transpose=True, retain_graph=True)

print(result1)
print(result2)


tensor([[73., 73.],
        [73., 73.],
        [73., 73.]])
tensor([[73., 73.],
        [73., 73.],
        [73., 73.]])
