In [1]:
import sys
sys.path.append('../')
from tqdm import tqdm

import jraph
import jax.numpy as jnp
import jax.random as jr
import jax

import equinox as eqx
import equiformer.graphs as graphs
import equiformer.layers as layers
import equiformer.examples.tetris as tetris

from jax.lax import gather

from torch.utils.data import Dataset, DataLoader
import optax

#import lovely_jax as lj
#lj.monkey_patch()

In [6]:
import lovely_jax as lj
#lj.monkey_patch()

In [2]:
dataset = tetris.TetrisDataset()
loader = DataLoader(dataset, batch_size=8, shuffle=False, collate_fn=tetris.labelled_graph_batcher)
unbatched_loader = DataLoader(dataset, batch_size=None, shuffle=False)

In [3]:
@eqx.filter_value_and_grad
def compute_loss(model, g, label):
    pred_label = model(g)
    return jnp.sum(optax.softmax_cross_entropy_with_integer_labels(pred_label, label))

In [4]:
@eqx.filter_jit
def make_step(model, g, label, optim, opt_state):
    loss, grads = compute_loss(model, g, label)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

In [5]:
model = tetris.ShapeClassifier(jr.PRNGKey(1))

optim = optax.adam(5e-3)
opt_state = optim.init(eqx.filter(model, eqx.is_array))

EPOCHS = 1500
progress_bar = tqdm(range(EPOCHS))
for epoch in progress_bar:
    epoch_loss = 0.0
    for (g, label) in loader:
        #label = jnp.array([label]) # necessary for unbatched loader
        loss, model, opt_state = make_step(model, g, label, optim, opt_state)
        epoch_loss += loss
    progress_bar.set_description(f"Loss: {epoch_loss:.4f}")

Loss: 0.0224: 100%|██████████| 1500/1500 [00:07<00:00, 195.50it/s]


In [6]:
def single_label(model, g) -> int:
    probs = jax.nn.softmax(model(g))
    return int(jnp.argmax(probs)), probs

# Rotated version of dataset
rotated_dataset = tetris.TetrisDataset(rotate_seed = 4)

for ((g_rot, y_rot), (g, y)) in zip(dataset, rotated_dataset):
    y_pred_rot, probs_rot = single_label(model, g_rot)
    y_pred, probs = single_label(model, g)

    print(y_pred_rot, "vs", y_pred, "vs truth which is", y_rot, "=", y, end="; ")
    print("max(Δp) =", jnp.max(jnp.abs(probs - probs_rot)))

0 vs 0 vs truth which is 0 = 0; max(Δp) = 1.0244548e-07
1 vs 1 vs truth which is 1 = 1; max(Δp) = 1.2805685e-07
2 vs 2 vs truth which is 2 = 2; max(Δp) = 5.296897e-09
3 vs 3 vs truth which is 3 = 3; max(Δp) = 4.48199e-09
4 vs 4 vs truth which is 4 = 4; max(Δp) = 1.4144462e-08
5 vs 5 vs truth which is 5 = 5; max(Δp) = 9.575674e-16
6 vs 6 vs truth which is 6 = 6; max(Δp) = 3.632158e-08
7 vs 7 vs truth which is 7 = 7; max(Δp) = 2.2118911e-08
