# Equibirium Model

This class of models follows the formula

$$\min_{\lambda=(\gamma, \theta)} = \sum_{i=1}^n E_i(w_i(\gamma), \theta)$$
$$w_i(\gamma) = \phi_i(w_i(\gamma),\gamma), \quad i = 1\dots n$$

The experiment in the paper studies
$$\phi_i(w_i, \gamma)=\tanh(Aw_i + B x_i + c)$$

In [1]:
import torch
import torch.nn as nn


In [46]:
class HouseHolderMatrix(nn.Module):
    """Construct A in Householder transformation to make the contraction for the dynamic"""

    def __init__(self, n_dims, rank=3):
        super().__init__()
        self.n_dims = n_dims
        self.vectors = nn.ParameterList()
        for _ in range(rank):
            self.vectors.append(nn.Parameter(torch.randn(n_dims, 1)))
        self.register_buffer("eye", torch.eye(n_dims))
    
    def forward(self):
        householder_matrices = [self.householder(v) for v in self.vectors]
        if len(self.vectors) == 1:
            return householder_matrices[0]
        ret = householder_matrices[0]
        for matrix in householder_matrices[1:]:
            ret = ret @ matrix
        return ret

    
    def householder(self, v):
        return self.eye - 2.* v @ v.t() / torch.norm(v) ** 2




In [44]:
class MatrixA(nn.Module):

    def __init__(self, n_dims, rank=3):
        super().__init__()
        self.householder = HouseHolderMatrix(n_dims, rank)
        self.diag = nn.Parameter(torch.randn(n_dims, ))
    
    def forward(self):
        raise NotImplementedError

tensor([[ 0.6030, -0.0686, -0.7948],
        [ 0.7575, -0.2630,  0.5975],
        [-0.2500, -0.9623, -0.1066]], grad_fn=<MmBackward>)

In [45]:
A()@A().t()

tensor([[ 1.0000e+00, -3.7253e-08,  2.9802e-08],
        [-3.7253e-08,  1.0000e+00, -5.9605e-08],
        [ 2.9802e-08, -5.9605e-08,  1.0000e+00]], grad_fn=<MmBackward>)

In [None]:
class Core(nn.Module):

    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.A = nn.Linear(hidden_dim, hidden_dim)
        self.B = nn.Linear(input_dim, hidden_dim)

    def forward(self, x, w):
        return torch.tanh(self.A(w) + self.B(x))
    
