In [5]:
import glob
import numpy as np 
import jax.numpy as jnp
import tensorflow as tf 
from architectures.azresnet import AZResnet, AZResnetConfig
from trainer import TrainerModule
from constants import POLICY_LABELS, BOARD_HEIGHT, BOARD_WIDTH, NUM_BUGHOUSE_CHANNELS

In [6]:
batch_size = 1024

trainer = TrainerModule(model_name="AZResNet", model_class=AZResnet, model_configs=AZResnetConfig(
    num_blocks=15,
    channels=256,
    policy_channels=4, 
    value_channels=8,
    num_policy_labels=len(POLICY_LABELS)
), optimizer_name='lion', optimizer_params={'learning_rate': 0.00001}, x=jnp.ones((batch_size, BOARD_HEIGHT, 2 * BOARD_WIDTH, NUM_BUGHOUSE_CHANNELS)))
trainer.init_optimizer()


data = np.load("../data/fics_training_data/checkpoint0.npz")
with tf.device('/CPU:0'):
    val_loader = tf.data.Dataset.from_tensor_slices((data["board_planes"][:2**12], data["move_planes"][:2**12], data["value_planes"][:2**12]))
    val_loader = val_loader.shuffle(buffer_size=2**16).batch(batch_size)

for path in glob.glob("../data/fics_training_data/*"):
    data = np.load(path)

    with tf.device('/CPU:0'):
        train_loader = tf.data.Dataset.from_tensor_slices((data["board_planes"], data["move_planes"], data["value_planes"]))
        train_loader = train_loader.shuffle(buffer_size=2**16).batch(batch_size)

    trainer.train_model(train_loader, batch_size) 
    trainer.save_checkpoint()

    policy_acc, value_acc =  trainer.eval_model(val_loader, batch_size)
    print(policy_acc, value_acc)



TypeError: init_optimizer() takes 1 positional argument but 3 were given