# Illustration of Forward-Backward-Forward Algorithm
## Min-Max-Problem with box constraints (linear classifier model)

The FBF algorithm was originally formulated for monotone inclusions and finds application in variational inequality problems (VIPs).
VIPs also cover the class of zero-sum games (Min-Max-Problem with a mutual objective in two variables):

$\min\limits_{x \in H} \max\limits_{y \in G} F(x, y)$

This type of problem perfectly fits the Wasserstein-GAN formulation with weight clipping (see https://arxiv.org/abs/1701.07875), where $x$ and $y$ is a parametrisation of the generator and discriminator network, respectively, and the constraint set $G$ is a d-dimensional cube.

The algorithm applied to this specific setting looks as follows:

$u_k = P_H \left[ x_k - \alpha \nabla_x F(x_k, y_k)\right]$

$v_k = P_G \left[ y_k + \alpha \nabla_y F(x_k, y_k)\right]$

$x_{k+1} = u_k - \alpha \nabla_x F(u_k, v_k) + \alpha \nabla_x F(x_k, y_k)$

$y_{k+1} = v_k + \alpha \nabla_y F(u_k, v_k) - \alpha \nabla_y F(x_k, y_k)$

We have proved convergence of the FBF method if $F(x, y)$ is differentiable, and convex in $x$ and concave in $y$, and the constraint sets $H$ and $G$ are nonempty, closed and convex. This is a well-established result.

In absence of a constraint set (and thus a projection) we get the so-called "extra-gradient method" (for application in GANs see https://arxiv.org/abs/1802.10551).

The implementation of one step (e.g., to get from $x_{k}$ to $x_{k+1}$) is split into two phases:

1. "extrapolation":
    1. compute update (either via SGD or Adam)
    2. do descent step
    3. store update (e.g., $- \alpha \nabla_x F(x_k, y_k)$)

2. "step":
    1. compute update (either via SGD or Adam)
    2. do descent step and subtract stored update
    
Note: The projection (in case of a d-dimensional cube this means "weight clipping") is directly done in the executable training file, e.g., "train_fbfadam.py", and is not included in the optimiser class.

The purpose of this notebook is to illustrate the implementation of the FBF method, in particular to show the two key methods of the FBF optimiser class. In this case this is done for "FBFSGD" ("FBFAdam" works in a similar fashion).
This is done for a linear classifier model (one fully connected layer without activation), showed only for one component as the algorithm does the same in both components apart from the opposite sign of the objective function.

In [1]:
import torch
import torch.nn as nn

In [2]:
# define toy instance of a neural network
class LinClas(nn.Module):
    def __init__(self):
        super(LinClas, self).__init__()
        self.fc = nn.Linear(5, 1)

    def forward(self, x):
        x = self.fc(x)
        return x

In [3]:
# define toy example of a loss function
def loss(x):
    return x*x

In [4]:
# print function for optimiser (weights, gradient and copy of update)
def print_opt():
    for group in opt.param_groups:
        for p in group['params']:
            print(f"Weights\n{p}")
            print(f"Gradient\n{p.grad}\n")
    print(f"Updates_Copy:\n{opt.updates_copy}\n")

In [5]:
# radius of d-dimensional cube
clip = 0.5

# input of neural network (whole batch)
inp = torch.Tensor([-0.1, 0., 0.1, 0.2, 0.3])

First we instantiate a fully connected (1-layer) neural network and have a look at the initial weights.

In [6]:
A = LinClas()
A.state_dict()

OrderedDict([('fc.weight',
              tensor([[ 0.1992, -0.4454, -0.2189,  0.4182, -0.0485]])),
             ('fc.bias', tensor([-0.4138]))])

Now we import the FBF optimiser class "FBFSGD" and set up an instance with a certain stepsize (= "lr"). To easily keep track of what happens nothing fancy (e.g., "Momentum" or "Nesterov") is specified.
To check that all the parameters of the network are tracked, make use of `print_opt()`.

In [7]:
from optim import FBFSGD
opt = FBFSGD(A.parameters(), lr = 0.1)
print_opt()

Weights
Parameter containing:
tensor([[ 0.1992, -0.4454, -0.2189,  0.4182, -0.0485]], requires_grad=True)
Gradient
None

Weights
Parameter containing:
tensor([-0.4138], requires_grad=True)
Gradient
None

Updates_Copy:
[]



### Computation of gradient

Compute output of network with respect to input.

In [8]:
outp = A(inp)
print(f"Output: {outp}")
lc_loss = loss(outp)
print(f"Loss: {lc_loss}")

Output: tensor([-0.3865], grad_fn=<AddBackward0>)
Loss: tensor([0.1494], grad_fn=<MulBackward0>)


 Clear old gradients that where possibly stored.

In [9]:
opt.zero_grad()
print_opt()

Weights
Parameter containing:
tensor([[ 0.1992, -0.4454, -0.2189,  0.4182, -0.0485]], requires_grad=True)
Gradient
None

Weights
Parameter containing:
tensor([-0.4138], requires_grad=True)
Gradient
None

Updates_Copy:
[]



Backpropagate the loss through the network to get gradients with respect to each weight.

In [10]:
lc_loss.backward()
print_opt()

Weights
Parameter containing:
tensor([[ 0.1992, -0.4454, -0.2189,  0.4182, -0.0485]], requires_grad=True)
Gradient
tensor([[ 0.0773,  0.0000, -0.0773, -0.1546, -0.2319]])

Weights
Parameter containing:
tensor([-0.4138], requires_grad=True)
Gradient
tensor([-0.7731])

Updates_Copy:
[]



### Extrapolation

In [11]:
opt.extrapolation()
print_opt()

Weights
Parameter containing:
tensor([[ 0.1914, -0.4454, -0.2112,  0.4337, -0.0253]], requires_grad=True)
Gradient
tensor([[ 0.0773,  0.0000, -0.0773, -0.1546, -0.2319]])

Weights
Parameter containing:
tensor([-0.3365], requires_grad=True)
Gradient
tensor([-0.7731])

Updates_Copy:
[tensor([[-0.0077, -0.0000,  0.0077,  0.0155,  0.0232]]), tensor([0.0773])]



### Projection

In [12]:
for p in A.parameters():
    p.data.clamp_(-clip, clip)
print_opt()

Weights
Parameter containing:
tensor([[ 0.1914, -0.4454, -0.2112,  0.4337, -0.0253]], requires_grad=True)
Gradient
tensor([[ 0.0773,  0.0000, -0.0773, -0.1546, -0.2319]])

Weights
Parameter containing:
tensor([-0.3365], requires_grad=True)
Gradient
tensor([-0.7731])

Updates_Copy:
[tensor([[-0.0077, -0.0000,  0.0077,  0.0155,  0.0232]]), tensor([0.0773])]



### Computation of gradient

Compute output of network with respect to input.

In [13]:
outp = A(inp)
print(f"Output: {outp}")
lc_loss = loss(outp)
print(f"Loss: {lc_loss}")

Output: tensor([-0.2976], grad_fn=<AddBackward0>)
Loss: tensor([0.0886], grad_fn=<MulBackward0>)


 Clear old gradients that where possibly stored.

In [14]:
opt.zero_grad()
print_opt()

Weights
Parameter containing:
tensor([[ 0.1914, -0.4454, -0.2112,  0.4337, -0.0253]], requires_grad=True)
Gradient
tensor([[0., 0., 0., 0., 0.]])

Weights
Parameter containing:
tensor([-0.3365], requires_grad=True)
Gradient
tensor([0.])

Updates_Copy:
[tensor([[-0.0077, -0.0000,  0.0077,  0.0155,  0.0232]]), tensor([0.0773])]



Backpropagate the loss through the network to get gradients with respect to each weight.

In [15]:
lc_loss.backward()
print_opt()

Weights
Parameter containing:
tensor([[ 0.1914, -0.4454, -0.2112,  0.4337, -0.0253]], requires_grad=True)
Gradient
tensor([[ 0.0595,  0.0000, -0.0595, -0.1191, -0.1786]])

Weights
Parameter containing:
tensor([-0.3365], requires_grad=True)
Gradient
tensor([-0.5953])

Updates_Copy:
[tensor([[-0.0077, -0.0000,  0.0077,  0.0155,  0.0232]]), tensor([0.0773])]



### Step

In [16]:
opt.step()
print_opt()

Weights
Parameter containing:
tensor([[ 0.1932, -0.4454, -0.2129,  0.4301, -0.0307]], requires_grad=True)
Gradient
tensor([[ 0.0595,  0.0000, -0.0595, -0.1191, -0.1786]])

Weights
Parameter containing:
tensor([-0.3543], requires_grad=True)
Gradient
tensor([-0.5953])

Updates_Copy:
[]

