In [2]:
import os
from functools import partial
from tqdm.notebook import tqdm

from jax import random

import dataset
import train
from models import MFLinear, MFNN, MFNNSI, MFNNResidual, LinearModel, Embedding

In [3]:
ds = dataset.Dataset("../data/polybench/20-40.npz")
kwargs = {"scale": ds.rms, "samples": ds.matrix.shape}

pca = ds.opcodes_pca

models = {
    # "rank1": (MFLinear, {"dim": 1, **kwargs}),
    # "rank8": (MFLinear, {"dim": 8, **kwargs}),
    # "rank32": (MFLinear, {"dim": 32, **kwargs}),
    "nn8": (MFNN, {"dim": (8, 8), **kwargs}),
    "nn32": (MFNN, {"dim": (32, 32), **kwargs}),
    "opcodes8": (MFNNSI, {"side_info": ds.opcodes, "dim": (0, 8), **kwargs}),
    "opcodes32": (MFNNSI, {"side_info": ds.opcodes, "dim": (0, 32), **kwargs}),
    "pca8": (MFNNSI, {"side_info": pca[:, :8], "dim": (0, 8), **kwargs}),
    "pca32": (MFNNSI, {"side_info": pca[:, :32], "dim": (0, 32), **kwargs}),
    "both8": (MFNNSI, {"side_info": pca[:, :8], "dim": (8, 8), **kwargs}),
    "both32": (MFNNSI, {"side_info": pca[:, :32], "dim": (32, 32), **kwargs}),
    "linearopcodes": (LinearModel, {"features": ds.opcodes, **kwargs}),
    "linearpca": (LinearModel, {"features": pca[:, :32], **kwargs}),
    "embedding": (Embedding, {"side_info": pca[:, :32], **kwargs})
}

sparsity = [0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95]

In [4]:
key = random.PRNGKey(42)

for model, (constr, kwargs) in models.items():

    model_dir = os.path.join("results", model)
    os.makedirs(model_dir, exist_ok=True)

    for s in sparsity:
        trainer = train.ReplicateTrainer(
            ds, partial(constr, **kwargs), optimizer=constr.optimizer,
            epochs=constr.epochs, epoch_size=constr.epoch_size,
            batch=64, tqdm=partial(tqdm, desc="{} : {}".format(model, s)))

        res = trainer.train_replicates(key, replicates=200, p=s)
        trainer.save_results(res, os.path.join(model_dir, "{}.npz".format(s)))

nn8 : 0.05:   0%|          | 0/100 [00:00<?, ?it/s]