# Equibirium Model

This class of models follows the formula

$$\min_{\lambda=(\gamma, \theta)} = \sum_{i=1}^n E_i(w_i(\gamma), \theta)$$
$$w_i(\gamma) = \phi_i(w_i(\gamma),\gamma), \quad i = 1\dots n$$

The experiment in the paper studies
$$\phi_i(w_i, \gamma)=\tanh(Aw_i + B x_i + c)$$

In [1]:
import torch
import torch.nn as nn
from torch import autograd

The following implementation or reparameterization of matrix $A$ allows its spectral norm strictly less than 1, $||A||<1$

In [2]:
class HouseHolderMatrix(nn.Module):
    """Construct A in Householder transformation to make the contraction for the dynamic"""

    def __init__(self, n_dims, rank=3):
        super().__init__()
        self.n_dims = n_dims
        self.vectors = nn.ParameterList()
        for _ in range(rank):
            self.vectors.append(nn.Parameter(torch.randn(n_dims, 1)))
        self.register_buffer("eye", torch.eye(n_dims))
    
    def forward(self):
        householder_matrices = [self.householder(v) for v in self.vectors]
        if len(self.vectors) == 1:
            return householder_matrices[0]
        ret = householder_matrices[0]
        for matrix in householder_matrices[1:]:
            ret = ret @ matrix
        return ret

    
    def householder(self, v):
        return self.eye - 2.* v @ v.t() / torch.norm(v) ** 2


class MatrixA(nn.Module):

    def __init__(self, n_dims, rank=3):
        super().__init__()
        self.householder = HouseHolderMatrix(n_dims, rank)
        self.diag = nn.Parameter(torch.randn(n_dims, ))
    
    def forward(self, epsilon=0.8):
        D = torch.diag(torch.sigmoid(self.diag)* (1. - epsilon))
        P = self.householder()
        return (P @ D) @ P

### Dynamic model in lower level
In fact, this is just a fixed-point iteration

In [3]:
class DynamicModel(nn.Module):

    def __init__(self, n_data, input_dim, T=10):
        super().__init__()
        self.n_data = n_data
        self.input_dim = input_dim
        self.T = T

        self.w = nn.Parameter(torch.zeros(n_data, input_dim))
        
    
    def forward(self, x, dynamic_func):
        
        w = self.w
        for _ in range(self.T):
            w = dynamic_func(w, x)

        return w

class HyperModel(nn.Module):

    def __init__(self, n_data, input_dim, hidden_dim, n_classes=10):

        super().__init__()
        # hyperparamters
        self.A = MatrixA(hidden_dim, rank=3)
        self.B = nn.Linear(input_dim, hidden_dim) 
        self.classifier = nn.Linear(hidden_dim, n_classes)

        #parameters
        self.dynamic_model = DynamicModel(n_data=n_data, input_dim=hidden_dim)

        self.criterion = nn.CrossEntropyLoss()

    def dynamic_func(self, w, x):
        return torch.tanh(w @ self.A() + self.B(x))

    def forward(self, x, train=True):
        w = self.dynamic_model.forward(x, self.dynamic_func)
        if not train:
            logit = self.classifier(w)
        else:
            logit = None
        return w, logit
    
    def validation_loss(self, x_val, y_val):
        _, logit = self.forward(x_val, train=False)
        loss = self.criterion(logit, y_val)
        return loss
    
    def train_fixed_point_iteration(self, x_train, y_train):
        w, _ = self.forward(x_train)
        return w

    @property
    def hyper_parameters(self):
        return list(self.A.parameters()) + list(self.B.parameters()) + list(self.classifier.parameters())
    
    @property
    def parameters(self):
        return list(self.dynamic_model.parameters())

### Prepare data
Pick 5000 data point for train and validation

In [4]:
from torchvision import datasets, transforms
normalize = transforms.Compose([transforms.Normalize((0.1307,), (0.3081,))])
train_data = datasets.MNIST(root=".", train=True)

num_train = num_val = 5000
x_train, y_train = train_data.data[:num_train], train_data.targets[:num_train]
x_val, y_val = train_data.data[-num_val:], train_data.targets[-num_val:]
def transform_data(x):
    return torch.reshape(x / 255., (-1, 28*28))
x_train, x_val = transform_data(x_train), transform_data(x_val)

### Create models

In [5]:
hidden_dim=200
model = HyperModel(n_data=num_train, input_dim=28*28, hidden_dim=hidden_dim)
hyper_optimizer = torch.optim.SGD(model.hyper_parameters, lr=0.01)


In [6]:
def evaluate():
    
    with torch.no_grad():
        w, logit = model(x_val, train=False)
        loss = model.criterion(logit, y_val)
        pred = torch.argmax(logit, dim=1).float()
        correct = (y_val == pred).float().sum()
    return loss.item(), correct.item()/x_val.shape[0]
    

## Optimization
The ```BaseHyperOpt``` cannot use here. The following is a slight variant implementation

In [7]:
n_iter = 500
K = 20 
for iter in range(n_iter):
    # fixed point iteration
    w = model.train_fixed_point_iteration(x_train, y_train)

    # hypergradient step
    val_loss = model.validation_loss(x_val, y_val)
    dval_dparam = autograd.grad(
        val_loss, 
        model.parameters,
        retain_graph=True)
    
    def mvp(v):
        return autograd.grad(
            w,
            model.parameters,
            grad_outputs=v,
            retain_graph=True
        )
    
    v = dval_dparam
    for k in range(K):
        output = mvp(v)
        v = [o_ + e_ for o_, e_ in zip(output, dval_dparam)]
    
    indirect = autograd.grad(
        w,
        model.hyper_parameters,
        grad_outputs=v,
        allow_unused=True
    )

    direct = autograd.grad(
        val_loss,
        model.hyper_parameters,
        allow_unused=True
    )

    total_grad = []
    for d, i in zip(direct, indirect):
        if d is None and i is None:
            raise RuntimeError("Both of them should not be None")
        elif d is None and i is not None:
            total_grad.append(i)
        elif i is None and d is not None:
            total_grad.append(d)
        else:
            total_grad.append(d+i)
        
    hyper_optimizer.zero_grad()
    for p, g in zip(model.hyper_parameters, total_grad):
        p.grad = g
    hyper_optimizer.step()

    if iter % 20 == 0:
        loss, acc = evaluate()
        print(f"Iter {iter} \t Val loss {loss:.3f} \t Acc: {acc:.4f}")




Iter 0 	 Val loss 2.302 	 Acc: 0.0902
Iter 20 	 Val loss 2.194 	 Acc: 0.3302
Iter 40 	 Val loss 2.092 	 Acc: 0.5572
Iter 60 	 Val loss 1.992 	 Acc: 0.6836
Iter 80 	 Val loss 1.893 	 Acc: 0.7336
Iter 100 	 Val loss 1.795 	 Acc: 0.7614
Iter 120 	 Val loss 1.698 	 Acc: 0.7782
Iter 140 	 Val loss 1.604 	 Acc: 0.7916
Iter 160 	 Val loss 1.513 	 Acc: 0.8020
Iter 180 	 Val loss 1.426 	 Acc: 0.8088
Iter 200 	 Val loss 1.344 	 Acc: 0.8158
Iter 220 	 Val loss 1.268 	 Acc: 0.8206
Iter 240 	 Val loss 1.198 	 Acc: 0.8262
Iter 260 	 Val loss 1.133 	 Acc: 0.8298
Iter 280 	 Val loss 1.075 	 Acc: 0.8358
Iter 300 	 Val loss 1.022 	 Acc: 0.8404
Iter 320 	 Val loss 0.974 	 Acc: 0.8448
Iter 340 	 Val loss 0.930 	 Acc: 0.8472
Iter 360 	 Val loss 0.890 	 Acc: 0.8506
Iter 380 	 Val loss 0.854 	 Acc: 0.8548
Iter 400 	 Val loss 0.821 	 Acc: 0.8580
Iter 420 	 Val loss 0.791 	 Acc: 0.8616
Iter 440 	 Val loss 0.764 	 Acc: 0.8640
Iter 460 	 Val loss 0.739 	 Acc: 0.8680
Iter 480 	 Val loss 0.716 	 Acc: 0.8718
