# Imports

In [1]:
import os, sys
import logging
import argparse
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from datetime import datetime
import seaborn as sns
from scipy import stats

#@title Imports external sources
import os
import io
import PIL.Image, PIL.ImageDraw, PIL.ImageFont
import base64
import zipfile
import json
import requests
import matplotlib.pylab as pl
import numpy as np
import glob
import requests
import random as pyrandom
from concurrent import futures
from functools import partial
from scipy.ndimage import rotate
from IPython.display import Image, HTML, clear_output
from tqdm import tqdm_notebook, tnrange
import time
from typing import Any, MutableMapping, NamedTuple, Tuple, Optional

# !pip install --quiet --upgrade jax
# !pip install --quiet --upgrade jaxlib 
import jax
from jax import grad, jit, vmap
import jax.numpy as jnp

# !pip install --quiet -U dm-haiku
# !pip install --quiet -U optax
import haiku as hk
import optax
import math
# !pip install --quiet -U ml_collections
from ml_collections import config_dict
import matplotlib.pylab as pl
import matplotlib.colors as mcolors
colors = pl.colormaps['Dark2'] 

# !pip install ipython-autotime

markers = ['o', 'v', '^', '*',  's', 'D']
colors = ['blue', 'red', 'green', 'purple']
# sns.set_theme(style="darkgrid")
sns.set_theme(style="ticks")

fontsize=30

game = 'Chicken Game' #@param ["Matching Pennies", "IPD", "Ultimatum", "Chicken Game", "Tandem", "Balduzzi", "Hamiltonian"]

# Import Internal Sources

In [2]:
from src.transformer import Transformer
from src.data import create_reg_data_classic_token, create_weights
from src.config import config
from src.train import *

# Hyperparameters

In [3]:
GAMMA = 0.96 #@param {type:"number"}
NUM_RUNS =  10#@param {type:"number"}
NUM_EPOCHS =  500#@param {type:"number"}
ALPHA =  5.0#@param {type:"number"}
BETA =   5.0#@param {type:"number"}
SMOOTHING = 0.99

if game == 'IPD':
  INPUT_DIM = 10
  STD = 0.1
else:
  INPUT_DIM = 2
  STD =  1.0

if game in ['Tandem', 'Balduzzi', 'Hamiltonian']:
  BATCH_SIZE =  8
  NUM_INNERLOOP_SHORT =  120000
  NUM_INNERLOOP_LONG =  120000
  NUM_NODES =  8
  interval = 1
  LR_SCHEDULER = 0.8
  LR = 1e-1
else:
  BATCH_SIZE =  64
  NUM_INNERLOOP_SHORT =  80000
  NUM_INNERLOOP_LONG =  80000
  NUM_NODES =  16
  interval = 7
  LR_SCHEDULER = 1.0
  LR=0.001

OUTPUT_DIM=INPUT_DIM//2

hyper_params = {
    "gamma": GAMMA,
    "num_runs": NUM_RUNS,
    "num_epochs": NUM_EPOCHS,
    "alpha": ALPHA,
    "std": STD,
    "batch_size": BATCH_SIZE,
    "num_innerloop_short": NUM_INNERLOOP_SHORT,
    "num_innerloop_long": NUM_INNERLOOP_LONG,
    "num_nodes": NUM_NODES,
    "beta": BETA,
    "interval": interval,
    "input_dim": INPUT_DIM,
    "output_dim": OUTPUT_DIM,
    "lr_scheduler": LR_SCHEDULER,
    "lr": LR,
    "smoothing": SMOOTHING

}

# Game Defintions

In [4]:
def tandem():
    dims = [1, 1]

    def Ls(th):
        x, y = th
        # Tandem loss (quadratic loss for moving forward + linear penalty for pedalling backwards)
        L_1 = (x + y)**2 - 2.0 * x
        L_2 = (x + y)**2 - 2.0 * y
        return [L_1, L_2]
    return dims, Ls


def tandem_cubed():
    dims = [1, 1]

    def Ls(th):
        x, y = th
        # Tandem loss (quadratic loss for moving forward + linear penalty for pedalling backwards)
        L_1 = (x + y)**4 - 2.0 * x
        L_2 = (x + y)**4 - 2.0 * y
        return [L_1, L_2]
    return dims, Ls


def ultimatum():
  dims = [1, 1]
  def Ls(th):
    x, y = th
    p_fair = jax.nn.sigmoid(x)
    p_accept = jax.nn.sigmoid(y)
    L_1 = -(5*p_fair + 8*(1-p_fair)*p_accept)
    L_2 = -(5*p_fair + 2*(1-p_fair)*p_accept)
    return [L_1, L_2]
  return dims, Ls


def balduzzi():
  dims = [1, 1]
  def Ls(th):
    x, y = th
    L_1 = 0.5*(x**2) + 10*x*y
    L_2 = 0.5*(y**2) - 10*x*y
    return [L_1, L_2]
  return dims, Ls


def hamiltonian_game():
  dims=[1, 1]
  def Ls(th):
    x, y = th
    L_1 = x*y
    L_2 = -x*y
    return [L_1, L_2]
  return dims, Ls


def matching_pennies():
  dims = [1, 1]
  payout_mat_1 = jnp.array([[1,-1],[-1,1]])
  payout_mat_2 = -payout_mat_1
  def Ls(th):
    p_1, p_2 = jax.nn.sigmoid(th[0]), jax.nn.sigmoid(th[1])
    x, y = jnp.concatenate([p_1, 1-p_1]), jnp.concatenate([p_2, 1-p_2])
    L_1 = jnp.matmul(jnp.matmul(x, payout_mat_1), y)
    L_2 = jnp.matmul(jnp.matmul(x, payout_mat_2), y)
    return [L_1, L_2]
  return dims, Ls


def matching_pennies_batch(batch_size=128):
  dims = [1, 1]
  payout_mat_1 = jnp.array([[1,-1],[-1,1]])
  payout_mat_2 = -payout_mat_1
  payout_mat_1 = payout_mat_1.reshape((1, 2, 2)).repeat(batch_size, 1, 1)
  payout_mat_2 = payout_mat_2.reshape((1, 2, 2)).repeat(batch_size, 1, 1)
  def Ls(th):
    p_1, p_2 = jax.nn.sigmoid(th[0]), jax.nn.sigmoid(th[1])
    x, y = jnp.concatenate([p_1, 1-p_1], dim=-1), jnp.concatenate([p_2, 1-p_2], dim=-1)
    L_1 = jnp.matmul(jnp.matmul(x.unsqueeze(1), payout_mat_1), y.unsqueeze(-1))
    L_2 = jnp.matmul(jnp.matmul(x.unsqueeze(1), payout_mat_2), y.unsqueeze(-1))
    return [L_1.squeeze(-1), L_2.squeeze(-1)]
  return dims, Ls


def chicken_game():
  dims = [1, 1]
  payout_mat_1 = jnp.array([[0, -1],[1, -100]])
  payout_mat_2 = jnp.array([[0, 1],[-1, -100]])
  def Ls(th):
    p_1, p_2 = jax.nn.sigmoid(th[0]), jax.nn.sigmoid(th[1])
    x, y = jnp.concatenate([p_1, 1-p_1]), jnp.concatenate([p_2, 1-p_2])
    L_1 = -jnp.matmul(jnp.matmul(x, payout_mat_1), y)
    L_2 = -jnp.matmul(jnp.matmul(x, payout_mat_2), y)
    return [L_1, L_2]
  return dims, Ls


def chicken_game_batch(batch_size=128):
  dims = [1, 1]
  payout_mat_1 = jnp.array([[0, -1],[1, -100]])
  payout_mat_2 = jnp.array([[0, 1],[-1, -100]])
  payout_mat_1 = payout_mat_1.reshape((1, 2, 2)).repeat(batch_size, 1, 1)
  payout_mat_2 = payout_mat_2.reshape((1, 2, 2)).repeat(batch_size, 1, 1)
  def Ls(th):
    p_1, p_2 = jax.nn.sigmoid(th[0]), jax.nn.sigmoid(th[1])
    x, y = jnp.concatenate([p_1, 1-p_1], dim=-1), jnp.concatenate([p_2, 1-p_2], dim=-1)
    L_1 = -jnp.matmul(jnp.matmul(x.unsqueeze(1), payout_mat_1), y.unsqueeze(-1))
    L_2 = -jnp.matmul(jnp.matmul(x.unsqueeze(1), payout_mat_2), y.unsqueeze(-1))
    return [L_1.squeeze(-1), L_2.squeeze(-1)]
  return dims, Ls
  

def ipd_batched(gamma=0.96):
  dims = [5, 5]
  payout_mat_1 = jnp.array([[-1,-3],[0,-2]])
  payout_mat_2 = payout_mat_1.T
  payout_mat_1 = payout_mat_1.reshape((1, 2, 2)).repeat(hyper_params['batch_size'], 1, 1)
  payout_mat_2 = payout_mat_2.reshape((1, 2, 2)).repeat(hyper_params['batch_size'], 1, 1)
  def Ls(th):
    p_1_0 = jax.nn.sigmoid(th[0][:, 0:1])
    p_2_0 = jax.nn.sigmoid(th[1][:, 0:1])
    p = jnp.concatenate([p_1_0*p_2_0, p_1_0*(1-p_2_0), (1-p_1_0)*p_2_0, (1-p_1_0)*(1-p_2_0)], dim=-1)
    p_1 = jnp.reshape(jax.nn.sigmoid(th[0][:, 1:5]), (hyper_params['batch_size'], 4, 1))
    p_2 = jnp.reshape(jax.nn.sigmoid(th[1][:, 1:5]), (hyper_params['batch_size'], 4, 1))
    P = jnp.concatenate([p_1*p_2, p_1*(1-p_2), (1-p_1)*p_2, (1-p_1)*(1-p_2)], dim=-1)
    x = jnp.eye(4).reshape((1, 4, 4))
    eyes = x.repeat(hyper_params['batch_size'], 1, 1)


    M = -jnp.matmul(p.unsqueeze(1), jnp.linalg.inv(jnp.eye(4)-gamma*P))
    L_1 = jnp.matmul(M, jnp.reshape(payout_mat_1, (hyper_params['batch_size'], 4, 1)))
    L_2 = jnp.matmul(M, jnp.reshape(payout_mat_2, (hyper_params['batch_size'], 4, 1)))
    return [L_1.squeeze(-1), L_2.squeeze(-1)]
  return dims, Ls


def ipd(gamma=0.96):
  dims = [5, 5]
  payout_mat_1 = jnp.array([[-1,-3],[0,-2]])
  payout_mat_2 = payout_mat_1.T
  def Ls(th):
    p_1_0 = jax.nn.sigmoid(th[0][0:1])
    p_2_0 = jax.nn.sigmoid(th[1][0:1])
    p = jnp.concatenate([p_1_0*p_2_0, p_1_0*(1-p_2_0), (1-p_1_0)*p_2_0, (1-p_1_0)*(1-p_2_0)])
    p_1 = jnp.reshape(jax.nn.sigmoid(th[0][1:5]), (4, 1))
    p_2 = jnp.reshape(jax.nn.sigmoid(th[1][1:5]), (4, 1))
    P = jnp.concatenate([p_1*p_2, p_1*(1-p_2), (1-p_1)*p_2, (1-p_1)*(1-p_2)], axis=1)
    M = -jnp.matmul(p, jnp.linalg.inv(jnp.eye(4)-gamma*P))
    L_1 = jnp.matmul(M, jnp.reshape(payout_mat_1, (4, 1)))
    L_2 = jnp.matmul(M, jnp.reshape(payout_mat_2, (4, 1)))
    return [L_1, L_2]
  return dims, Ls

# Networks

In [5]:
class NetRelu(hk.Module):
    def __init__(
            self,
            hidden_dim: int,
            output_dim: int,
            name: Optional[str] = None,
    ):
        super().__init__(name=name)
        
        self.linear1 = hk.Linear(hidden_dim,
                                w_init=hk.initializers.RandomNormal(mean=0, stddev=0.1),
                                b_init=hk.initializers.Constant(0),
                                )
        self.linear2 = hk.Linear(hidden_dim,
                                w_init=hk.initializers.RandomNormal(mean=0, stddev=0.1),
                                b_init=hk.initializers.Constant(0),
                                )
        self.linear3 = hk.Linear(output_dim,
                                w_init=hk.initializers.RandomNormal(mean=0, stddev=0.1),
                                b_init=hk.initializers.Constant(0),
                                )
        self.relu = jax.nn.ReLU()
    
    def __call__(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.linear3(x)
        return x
    
class NetTanh(hk.Module):
    def __init__(self,
            hidden_dim: int,
            output_dim: int,
            name: Optional[str] = None,
                 ):
        super().__init__(name=name)
        self.main_body = hk.nets.MLP(
            [hidden_dim, hidden_dim, hidden_dim, output_dim],
            w_init=hk.initializers.RandomNormal(mean=0, stddev=0.1),
            b_init=hk.initializers.Constant(0),
            activate_final=False,
            activation=jnp.tanh,
        )
    def __call__(self, x):
        return self.main_body(x)
    


In [6]:
def make_tanh_network():
    def forward_fn(x):
        return NetTanh(hyper_params['num_nodes'], hyper_params['output_dim'])(x)
    network = hk.without_apply_rng(hk.transform(forward_fn))
    return network

In [7]:
_, Ls = ipd()
# def Ls(th):
#     agent1, agent2 = th
#     L_1 = agent1*agent2
#     L_2 = -(agent1*agent2)
#     return [L_1, L_2]

def loss1_fn(theta1, theta2):
    L1, _ = Ls([theta1, theta2])
    # print(L1.shape)
    return L1[0]

def loss2_fn(theta1, theta2):
    _, L2 = Ls([theta1, theta2])
    # print(L2.shape)
    return L2[0]

def lola_loss1_fn(theta1, theta2):
    theta2_prime = theta2 - lr2 * grad(loss2_fn, argnums=1)(theta1, theta2)
    L1, _ = Ls([theta1, theta2_prime])
    # print(L1.shape)
    return L1[0]

def lola_loss2_fn(theta1, theta2):
    theta1_prime = theta1 - lr1 * grad(loss1_fn, argnums=0)(theta1, theta2)
    _, L2 = Ls([theta1_prime, theta2])
    # print(L1.shape)
    return L2[0]

def update(theta1, theta2, lr1=0.1, lr2=0.1):
    return theta1 - lr1 * grad(loss1_fn, argnums=0)(theta1, theta2), theta2 - lr2 * grad(loss2_fn, argnums=1)(theta1, theta2)


def update_lola(theta1, theta2, lr1=0.1, lr2=0.1):
    # theta1_prime = theta1 - lr1 * grad(loss1_fn, argnums=0)(theta1, theta2)
    # theta2_prime = theta2 - lr2 * grad(loss2_fn, argnums=1)(theta1, theta2)
    return theta1 - lr1 * grad(lola_loss1_fn, argnums=0)(theta1, theta2), theta2 - lr2 * grad(lola_loss2_fn, argnums=1)(theta1, theta2)

rng = jax.random.PRNGKey(400)
# network = make_tanh_network()
# theta1 = network.init(rng, jnp.zeros((hyper_params['input_dim'],)))
# theta2 = network.init(rng, jnp.zeros((hyper_params['input_dim'],)))
theta1 = jax.random.normal(rng, (5,))
rng, rng2 = jax.random.split(rng)
theta2 = jax.random.normal(rng2, (5,))
lr1 = lr2 = 1.0
for i in range(100):
    theta1, theta2 = update_lola(theta1, theta2, lr1=lr1, lr2=lr2)
    print("Loss 1: ", (1-hyper_params['gamma'])*loss1_fn(theta1, theta2))
    print("Loss 2: ", (1-hyper_params['gamma'])*loss2_fn(theta1, theta2))
# grad(loss1_fn, argnums=0)(theta1, theta2)

Loss 1:  1.7698076
Loss 2:  1.8566042
Loss 1:  1.5100355
Loss 2:  1.618503
Loss 1:  1.4860876
Loss 2:  1.584396
Loss 1:  1.4730017
Loss 2:  1.5651535
Loss 1:  1.4641087
Loss 2:  1.5523251
Loss 1:  1.4567251
Loss 2:  1.5422539
Loss 1:  1.4488493
Loss 2:  1.5324996
Loss 1:  1.4373871
Loss 2:  1.5198469
Loss 1:  1.4127336
Loss 2:  1.4949945
Loss 1:  1.3205707
Loss 2:  1.4029701
Loss 1:  1.1003684
Loss 2:  1.1621422
Loss 1:  1.0713732
Loss 2:  1.1465768
Loss 1:  1.0320722
Loss 2:  1.1208047
Loss 1:  1.0029134
Loss 2:  1.1014441
Loss 1:  0.9895268
Loss 2:  1.0942748
Loss 1:  0.982332
Loss 2:  1.0921891
Loss 1:  0.97750777
Loss 2:  1.0922587
Loss 1:  0.9737406
Loss 2:  1.0935657
Loss 1:  0.9704256
Loss 2:  1.0958728
Loss 1:  0.9671755
Loss 2:  1.0993227
Loss 1:  0.96358985
Loss 2:  1.1044354
Loss 1:  0.9590684
Loss 2:  1.1124297
Loss 1:  0.95236516
Loss 2:  1.126219
Loss 1:  0.94028753
Loss 2:  1.15402
Loss 1:  0.9133317
Loss 2:  1.2252698
Loss 1:  0.9006468
Loss 2:  1.4567435
Loss 1:  1.448

# Transformers Playing Games

In [8]:
import dataclasses
from typing import Optional

import haiku as hk
import jax
import jax.numpy as jnp
from src.attn import (MLP,
                      MultiHeadAttention,
                      TokenVocab,
                      create_pos_encoding,
                      LNorm,
                      layer_norm)


@dataclasses.dataclass
class Transformer(hk.Module):
  """A flexible Transformer implementation.
  """

  def __init__(
      self,
      num_heads: int = 2,
      widening_factor: int = 4,
      num_layers: int = 3,
      key_size: int = 5,
      embedding_size: int = 64,
      output_size: int = 1,
      in_context_length: int = 17,
      in_context_length_test: int = 17,
      test_points: int = 1,
      dropout_rate: float = 0,
      only_attention: bool = True,
      use_layer_norm: bool = True,
      use_pe: bool = True,
      pe_size: int = 6,
      concat_pe: bool = False,
      output_mapping: bool = False,
      input_mapping: bool = False,
      use_bias_p: bool = True,
      zero_embeddings: bool = False,
      deq: bool = True,
      init_scale: float = 0.02,
      use_softmax: bool = False,
      use_non_lin_mix: bool = False,
      first_layer_sm: bool = False,
      y_update: bool = False,
      input_mlp: bool = False,
      input_mlp_out_dim: int = 0,
      gd_mlp_config: bool = False,
      sum_norm: bool = False,
      dampening: float = 1.0,
      clip: float = 0.0,
      ana_copy: bool = False,
      flip: bool = False,
      vocab_size: int = 0,
      vocab_token_dim: int = 0,
      vocab_init: int = 0.01,
      return_logits: bool = False,
      include_query: bool = False,
      name: Optional[str] = None,
  ):


    """Initialises the module.

    Args:
      num_heads: Number of heads in the self-attention module.
      widening_factor: Blow up in the hidden layer of MLP.
      num_layers: Number of transformer layers, usually one due DEQ behaviour.
      key_size: Key and querie size.
      embedding_size: Embedding size.
      output_size: Output size.
      in_context_length: Sequence length.
      test_points: Number of test points.
      dropout_rate: Optional dropout layer with rate dropout_rate if not None.
      only_attention: Only the attention layer without the MLP.
      use_layer_norm: Use layer norm or not.
      use_pe: Use positional encoding. 
      pe_size: Positional encoding size.
      concat_pe: Concat pe.
      output_mapping: Use output mapping.
      input_mapping: Use input mapping.
      lin_proj_after_att: Linear projection after attention layer.
      use_bias_p: Use bias parameter in the linear operations in the network.
      zero_embeddings: Use zero embeddings.
      full_state_update: Update context tokens or only querry.
      deq: Use recurrent transformer.
      y_update: Update only output states e.g. as in gradient descent.
      input_mlp: Use MLP instead of linear embedding.
      input_mlp_out_dim: Output dim of input MLP.
      gd_mlp_config: Gradient descent special MLP config.
      sum_norm: Use sum normalization from Schlag et. al 2012
      dampening: Dampen forward dynamics
      clip: Clip the activations to some value
      ana_copy: Return full prediction stack instead of last entry.
      include_query: Include query vector in computation.
      name : Optional name for this module.
    """

    super().__init__(name=name)
    self.num_heads = num_heads
    self.widening_factor = widening_factor
    self.num_layers = num_layers
    self.key_size = key_size
    self.dropout_rate = dropout_rate
    self.only_attention = only_attention
    self.use_layer_norm = use_layer_norm
    self.use_pe = use_pe
    self.pe_size = pe_size
    self.concat_pe = concat_pe
    self.output_mapping = output_mapping
    self.input_mapping = input_mapping
    self.use_bias_p = use_bias_p
    self.embedding_size = embedding_size
    self.output_size = output_size
    self.in_context_length = in_context_length
    self.in_context_length_test = in_context_length_test
    self.zero_embeddings = zero_embeddings
    self.init_scale = init_scale
    self.use_softmax = use_softmax
    self.use_non_lin_mix = use_non_lin_mix
    self.first_layer_sm = first_layer_sm
    self.deq = deq
    self.y_update = y_update
    self.input_mlp = input_mlp
    self.input_mlp_out_dim = input_mlp_out_dim
    self.gd_mlp_config = gd_mlp_config
    self.sum_norm = sum_norm
    self.dampening = dampening
    self.clip = clip
    self.ana_copy = ana_copy
    self.vocab_size = vocab_size
    self.vocab_token_dim = vocab_token_dim
    self.vocab_init = vocab_init
    self.return_logits = return_logits
    self.include_query = include_query

    if pe_size > 0:
      self.pos_encoding = create_pos_encoding(in_context_length, pe_size, flip)
      self.pos_encoding_test = create_pos_encoding(in_context_length_test,
                                                   pe_size, flip)
    else:
      self.pos_encoding = None

  def trans_block(self, h, nl):
    # First the attention block.

    if self.deq:
      h_norm = self.lnorm1(h) if self.use_layer_norm else h
      if not self.include_query:
        key = h_norm[:, :-1, :]
        value = h_norm[:, :-1, :]
      else:
        key = h_norm
        value = h_norm

      h_attn, att_map =self.attn_block(h_norm,key,value)

    h_attn = hk.dropout(hk.next_rng_key(), self.dropout_rate, h_attn)

    h = h + self.dampening*h_attn

    if self.clip > 0:
      h = jnp.clip(h, -self.clip, self.clip)

    return h, att_map

  def __call__(
      self,
      x: jnp.ndarray,
      is_training: bool,
      predict_test: bool
  ) -> jnp.ndarray:

    """Computes the transformer forward pass.

    Args:
      x: Inputs.
      is_training: Whether we're training or not.
      predict_test: Test or train prediction.
    Returns:
      Array of shape [B, T, H].
    """

    self.w_init = hk.initializers.VarianceScaling(self.init_scale)
    self.dropout_rate = self.dropout_rate if is_training else 0.

    embeddings = x

    h = embeddings

    if len(h.shape) == 2:
      _, model_size = h.shape
    elif len(h.shape) == 3:
      _, _, model_size = h.shape
    self.model_size = model_size
    if self.deq:
      self.attn_block = MultiHeadAttention(num_heads=self.num_heads,
                                           key_size=self.key_size,
                                           model_size=model_size,
                                           w_init=self.w_init,
                                           use_softmax=self.use_softmax,
                                           use_non_lin_mix=self.use_non_lin_mix,
                                           use_bias_p=self.use_bias_p,
                                           sum_normalization=self.sum_norm
                                           )

    st = h[:, -1, -1]*(-1.0) if not self.ana_copy else (h if self.include_query else h[:, :-1, :])
    stack_h = [] if not self.input_mlp else [st]
    stack_att = []
    for nl in range(self.num_layers):
      h, att_map = self.trans_block(h, nl)
      # intermediate readout of test prediction
      st = h[:, -1, -1]*(-1.0) if not self.ana_copy else (h if self.include_query else h[:, :-1, :])
      stack_h.append(st)
      stack_att.append(att_map)
    out = hk.Linear(self.output_size)(h) if self.output_mapping else h

    return(out, stack_h, stack_att)

In [9]:
_Metrics = MutableMapping[str, Any]

class TrainState(NamedTuple):
  """Container for the training state."""
  params: hk.Params
  opt_state: optax.OptState
  rng: jnp.DeviceArray
  step: jnp.DeviceArray

def forward(tokens: jnp.ndarray, is_training: bool, gd: bool):
    """Transformer forward."""
    if config.classic_token_const:
        in_context_length = config.dataset_size*2 + 1
    else:
        in_context_length = config.dataset_size + 1
        
    tr = Transformer(
        num_heads=config.num_heads,
        num_layers=config.num_layers,
        widening_factor=config.widening_factor,
        key_size=config.key_size,
        embedding_size=config.emb_size,
        only_attention=config.att_only_trans,
        in_context_length=in_context_length,
        output_size=config.output_size,
        dropout_rate=config.dropout_rate,
        use_pe=config.pos_enc,
        pe_size=config.pos_enc_size,
        concat_pe=config.concat_pos_enc,
        output_mapping=config.out_proj,
        input_mapping=config.in_proj,
        use_layer_norm=config.layer_norm,
        use_bias_p=config.use_bias,
        deq=config.deq,
        y_update=config.y_update,
        use_softmax=config.use_softmax,
        use_non_lin_mix=config.use_non_lin_mix,
        first_layer_sm=config.first_layer_sm,
        zero_embeddings=config.zero_pos_enc,
        init_scale=config.init_scale,
        input_mlp=config.input_mlp,
        input_mlp_out_dim=config.input_mlp_out_dim,
        sum_norm=config.sum_norm,
        dampening=config.dampening,
        clip=config.clip,
        ana_copy=config.ana_copy
        )


    return tr(tokens, is_training=is_training, predict_test=False)


def compute_loss(preds, targets):
    assert preds.shape == targets.shape
    return 0.5*jnp.sum((targets-preds)**2)/targets.shape[0]

@hk.transform
def loss_fn(data: jnp.ndarray, gd) -> jnp.ndarray:
    """Computes the MSE loss between targets and predictions."""
    preds, _, _ = forward(data[0], True, gd)
    targets = data[1][:, -1]
    preds = preds[:, -1, -1]*(-1.0)
    return compute_loss(preds, targets)

_, Ls = ipd()
# def Ls(th):
#     agent1, agent2 = th
#     L_1 = agent1*agent2
#     L_2 = -(agent1*agent2)
#     return [L_1, L_2]

def loss1_fn(theta1, theta2):
    theta1 = forward
    L1, _ = Ls([theta1, theta2])
    # print(L1.shape)
    return L1[0]

def loss2_fn(theta1, theta2):
    _, L2 = Ls([theta1, theta2])
    # print(L2.shape)
    return L2[0]

@partial(jax.jit, static_argnums=(2))
def update(state: TrainState, optimiser, gd=False)->Tuple[TrainState, _Metrics]:
    """Does an SGD step and returns training state as well as metrics."""
    rng, new_rng = jax.random.split(state.rng)
    jit_loss_apply = jit(loss_fn.apply, static_argnums=3)
    loss_and_grad_fn = jax.value_and_grad(jit_loss_apply)
    loss, gradients = loss_and_grad_fn(state.params, rng, gd)

    updates, new_opt_state = optimiser.update(gradients, state.opt_state,
                                                state.params)
    new_params = optax.apply_updates(state.params, updates)

    new_state = TrainState(
        params=new_params,
        opt_state=new_opt_state,
        rng=new_rng,
        step=state.step + 1,
    )

    metrics = {
        'step': state.step,
        'train_loss': loss,
    }
    return new_state, metrics


def init_model(rng, optimiser) -> TrainState:
    """Init haiku transform modules to create train and test state."""
    train_rng, test_rng = jax.random.split(rng, num=2)

    initial_params = loss_fn.init(rng, gd=False)
    initial_opt_state = optimiser.init(initial_params)

    return TrainState(
        params=initial_params,
        opt_state=initial_opt_state,
        rng=train_rng,
        step=np.array(0))


def init():
    """Init data creator, model, optimizer, etc."""
    rng = jax.random.PRNGKey(config.seed)
    rng, train_rng = jax.random.split(rng, 2)

    lr = config.lr
    if config.adam:
        optimiser = optax.chain(
            optax.clip_by_global_norm(config.grad_clip_value),
            optax.adamw(learning_rate=lr, b1=config.b1, b2=config.b2,
                        weight_decay=config.wd),
        )
    else:
        optimiser = optax.chain(
            optax.clip_by_global_norm(config.grad_clip_value),
            optax.sgd(learning_rate=lr,),
        )

    train_state, test_state = init_model(rng, optimiser)
    return optimiser, train_state, test_state, rng

In [10]:
optimiser1, train_state1, _, rng = init()
optimiser2, train_state2, _, rng = init()
for i in range(100):
    train_state1, metrics = update(train_state1, optimiser1)
    train_state2, metrics = update(train_state1, optimiser2)

TypeError: loss_fn() missing 1 required positional argument: 'data'