In [1]:
import os
from functools import partial
from tqdm.notebook import tqdm

from jax import random
import optax

import dataset
import train
from models import MFLinear, MFNN, MFNNSI

In [2]:
ds = dataset.Dataset("../data/polybench/20-40.npz")
kwargs = {"scale": ds.rms, "samples": ds.matrix.shape}

pca = ds.opcodes_pca

models = {
    "mf_linear": partial(MFLinear, dim=1, **kwargs),
    "mf_nn_8": partial(MFNN, dim=(8, 8), **kwargs),
    "mf_nn_32": partial(MFNN, dim=(32, 32), **kwargs),
    "mf_opcodes_8": partial(MFNNSI, ds.opcodes, dim=(0, 8), **kwargs),
    "mf_opcodes_32": partial(MFNNSI, ds.opcodes, dim=(0, 32), **kwargs),
    "mf_pca_8": partial(MFNNSI, pca[:, :8], dim=(0, 8), **kwargs),
    "mf_pca_32": partial(MFNNSI, pca[:, :32], dim=(0, 32), **kwargs),    
    "mf_both_8": partial(MFNNSI, pca[:, :8], dim=(8, 8), **kwargs),
    "mf_both_32": partial(MFNNSI, pca[:, :32], dim=(32, 32), **kwargs),
}

sparsity = [0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95]

In [4]:
optimizer = optax.adam(0.001)
key = random.PRNGKey(42)

for model, constructor in models.items():

    model_dir = os.path.join("results", model)
    os.makedirs(model_dir, exist_ok=True)

    for s in sparsity:
        trainer = train.ReplicateTrainer(
            ds, constructor, optimizer=optimizer,
            epochs=100, epoch_size=100, batch=64,
            tqdm=partial(tqdm, desc="{} : {}".format(model, s)))

        res = trainer.train_replicates(key, replicates=100, p=s)
        trainer.save_results(res, os.path.join(model_dir, "{}.npz".format(s)))

mf_linear : 0.05:   0%|          | 0/100 [00:00<?, ?it/s]