In [None]:
"Conjugate function for Mx = b solving"

import torch
import torch.autograd as autograd
import torch.nn as nn

def Hvp_vec(grad_vec, params, vec, retain_graph=False):
    if torch.isnan(grad_vec).any():
        print('grad vec nan')
        raise ValueError('grad Nan')
    if torch.isnan(vec).any():
        print('vec nan')
        raise ValueError('vec Nan')
    try:
        grad_grad = autograd.grad(grad_vec, params, grad_outputs=vec, retain_graph=retain_graph, allow_unused=True)
        hvp = torch.cat([g.contiguous().view(-1) for g in grad_grad])
        if torch.isnan(hvp).any():
            print('hvp nan')
            raise ValueError('hvp Nan')
    except:
        # print('filling zero for None')
        grad_grad = autograd.grad(grad_vec, params, grad_outputs=vec, retain_graph=retain_graph,
                                  allow_unused=True)
        grad_list = []
        for i, p in enumerate(params):
            if grad_grad[i] is None:
                grad_list.append(torch.zeros_like(p))
            else:
                grad_list.append(grad_grad[i].contiguous().view(-1))
        hvp = torch.cat(grad_list)
        if torch.isnan(hvp).any():
            raise ValueError('hvp Nan')
    return hvp

def zero_grad(params):
    for p in params:
        if p.grad is not None:
            p.grad.detach()
            p.grad.zero_()

def conjugate_gradient(grad_x, grad_y, x_params, y_params, b, x=None, nsteps=10, residual_tol=1e-18,
                       lr=1.0, device=torch.device('cpu')):
    '''
    :param grad_x:
    :param grad_y:
    :param x_params:
    :param y_params:
    :param b: vec
    :param nsteps: max number of steps
    :param residual_tol:
    :return: A ** -1 * b
    h_1 = D_yx * p
    h_2 = D_xy * D_yx * p
    A = I + lr ** 2 * D_xy * D_yx * p
    '''
    if x is None:
        x = torch.zeros(b.shape[0], device=device)
    r = b.clone().detach()
    p = r.clone().detach()
    rdotr = torch.dot(r, r)
    residual_tol = residual_tol * rdotr
    
    i = 0
    for i in range(nsteps):
        # To compute Avp
        h_1 = Hvp_vec(grad_vec=grad_x, params=y_params, vec=p, retain_graph=True,allow_unused=True)
        h_2 = Hvp_vec(grad_vec=grad_y, params=x_params, vec=h_1, retain_graph=True, allow_unused=True)
        Avp_ = p + lr * lr * h_2

        alpha = rdotr / torch.dot(p, Avp_)
        x.data.add_(alpha * p)
        r.data.add_(- alpha * Avp_)
        new_rdotr = torch.dot(r, r)
        beta = new_rdotr / rdotr
        p = r + beta * p
        rdotr = new_rdotr
        if rdotr < residual_tol:
            break
    return x, i + 1

In [None]:
import time

class BCGD(object):
    def __init__(self, max_params, min_params, lr=1e-3, weight_decay=0, device=torch.device('cpu'),
                 solve_x=False, collect_info=True):
        self.max_params = max_params
        self.min_params = min_params
        self.lr = lr
        self.weight_decay = weight_decay
        self.device = device
        self.solve_x = solve_x
        self.collect_info = collect_info

        self.old_x = None
        self.old_y = None

    def zero_grad(self):
        zero_grad(self.max_params)
        zero_grad(self.min_params)

    def getinfo(self):
        if self.collect_info:
            return self.norm_gx, self.norm_gy, self.norm_px, self.norm_py, self.norm_cgx, self.norm_cgy, \
                   self.timer, self.iter_num
        else:
            raise ValueError(
                'No update information stored. Set collect_info=True before call this method')

    def step(self, loss):
        grad_x = autograd.grad(loss, self.max_params, create_graph=True, retain_graph=True, allow_unused=True)
        # print(grad_x)
        grad_x_vec = torch.cat([g.contiguous().view(-1) for g in grad_x])
        grad_y = autograd.grad(loss, self.min_params, create_graph=True, retain_graph=True, allow_unused=True)
        grad_y_vec = torch.cat([g.contiguous().view(-1) for g in grad_y])

        hvp_x_vec = Hvp_vec(grad_y_vec, self.max_params, grad_y_vec,
                            retain_graph=True)  # h_xy * d_y
        hvp_y_vec = Hvp_vec(grad_x_vec, self.min_params, grad_x_vec,
                            retain_graph=True)  # h_yx * d_x

        p_x = torch.add(grad_x_vec, - self.lr * hvp_x_vec)
        p_y = torch.add(grad_y_vec, self.lr * hvp_y_vec)
        if self.collect_info:
            self.norm_px = torch.norm(p_x, p=2)
            self.norm_py = torch.norm(p_y, p=2)
            self.timer = time.time()


        # # solve both x and y

        # # solve_x
        # cg_y, self.iter_num = conjugate_gradient(grad_x=grad_y_vec, grad_y=grad_x_vec,
        #                                              x_params=self.min_params,
        #                                              y_params=self.max_params, b=p_y, x=self.old_y,
        #                                              nsteps=p_y.shape[0] // 10000,
        #                                              lr=self.lr, device=self.device)
        # hcg = Hvp_vec(grad_y_vec, self.max_params, cg_y)
        # cg_x = torch.add(grad_x_vec, - self.lr * hcg)
        # self.old_x = cg_x


        # # solve_y
        # cg_x, self.iter_num = conjugate_gradient(grad_x=grad_x_vec, grad_y=grad_y_vec,
        #                                              x_params=self.max_params,
        #                                              y_params=self.min_params, b=p_x, x=self.old_x,
        #                                              nsteps=p_x.shape[0] // 10000,
        #                                              lr=self.lr, device=self.device)
        # hcg = Hvp_vec(grad_x_vec, self.min_params, cg_x)
        # cg_y = torch.add(grad_y_vec, self.lr * hcg)
        # self.old_y = cg_y


        if self.solve_x:
            cg_y, self.iter_num = conjugate_gradient(grad_x=grad_y_vec, grad_y=grad_x_vec,
                                                     x_params=self.min_params,
                                                     y_params=self.max_params, b=p_y, x=self.old_y,
                                                     nsteps=p_y.shape[0] // 10000,
                                                     lr=self.lr, device=self.device)
            hcg = Hvp_vec(grad_y_vec, self.max_params, cg_y)
            cg_x = torch.add(grad_x_vec, - self.lr * hcg)
            self.old_x = cg_x
        else:
            cg_x, self.iter_num = conjugate_gradient(grad_x=grad_x_vec, grad_y=grad_y_vec,
                                                     x_params=self.max_params,
                                                     y_params=self.min_params, b=p_x, x=self.old_x,
                                                     nsteps=p_x.shape[0] // 10000,
                                                     lr=self.lr, device=self.device)
            hcg = Hvp_vec(grad_x_vec, self.min_params, cg_x)
            cg_y = torch.add(grad_y_vec, self.lr * hcg)
            self.old_y = cg_y

        if self.collect_info:
            self.timer = time.time() - self.timer

        index = 0
        for p in self.max_params:
            if self.weight_decay != 0:
                p.data.add_(- self.weight_decay * p)
            p.data.add_(self.lr * cg_x[index: index + p.numel()].reshape(p.shape))
            index += p.numel()
        if index != cg_x.numel():
            raise ValueError('CG size mismatch')
        index = 0
        for p in self.min_params:
            if self.weight_decay != 0:
                p.data.add_(- self.weight_decay * p)
            p.data.add_(- self.lr * cg_y[index: index + p.numel()].reshape(p.shape))
            index += p.numel()
        if index != cg_y.numel():
            raise ValueError('CG size mismatch')

        if self.collect_info:
            self.norm_gx = torch.norm(grad_x_vec, p=2)
            self.norm_gy = torch.norm(grad_y_vec, p=2)
            self.norm_cgx = torch.norm(cg_x, p=2)
            self.norm_cgy = torch.norm(cg_y, p=2)
        self.solve_x = False if self.solve_x else True


In [None]:
x = torch.tensor([0.5], requires_grad=True) # 1D vector max_param
y = torch.tensor([0.5], requires_grad=True) # 1D vector min_param

# Hyperparameter
eta = 0.2  #learning rate
gamma = 1.0 # consesus rate
alpha = 1.0

# Bilinear Game 

## Test case1
f = alpha * torch.dot(x , y)  # opt function
g = -1*alpha * torch.dot(x , y) 


# ## Test case 2

# f = alpha * (torch.dot(x , x) - torch.dot(y,y))
# g=  -1*alpha * (torch.dot(x , x) - torch.dot(y,y))


## Test Case 3


import matplotlib.pyplot as plt

solver = BCGD(x, y)
epoch = 10

steps = [0.5]

for e in range(epoch):
    solver.step(loss= f)

    steps.append(solver.old_y)

plt.plot(steps)

Most code source taken from :

```
@misc{schfer2019implicit,
    title={Implicit competitive regularization in GANs},
    author={Florian Schäfer and Hongkai Zheng and Anima Anandkumar},
    year={2019},
    eprint={1910.05852},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}
```