In [1]:
!pip install neural-semigroups

Collecting neural-semigroups
  Downloading neural_semigroups-0.5.4-py3-none-any.whl (25 kB)
Installing collected packages: neural-semigroups
Successfully installed neural-semigroups-0.5.4
You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m


In [2]:
from neural_semigroups import Magma
from neural_semigroups.utils import hide_cells, partial_table_to_cube
import torch

cardinality = 6


def transform(x):
    if torch.randn((1, )).cpu().item() > 0.5:
        new_y = Magma(x[0]).random_isomorphism()
    else:
        new_y = Magma(x[0]).random_isomorphism().T
    new_x = partial_table_to_cube(
        hide_cells(
            new_y,
            cardinality * cardinality - cardinality
        )
    )
    return new_x,  partial_table_to_cube(new_y)

In [3]:
from neural_semigroups.smallsemi_dataset import Smallsemi

data = Smallsemi(
    root="/kaggle/input/smallsemi",
    cardinality=cardinality,
    transform=transform
)

In [4]:
from torch.utils.data.dataset import random_split
from torch.utils.data import DataLoader

data_size = len(data)
print(data_size)
train_size = 1024
data_loaders = tuple(
    DataLoader(data_split, batch_size=32)
    for data_split
    in random_split(data, [train_size, train_size, data_size - 2 * train_size])
)

15973


In [5]:
from neural_semigroups.associator_loss import AssociatorLoss
from torch import Tensor

def loss(prediction: Tensor, target: Tensor) -> Tensor:
    return AssociatorLoss()(prediction)

In [6]:
from neural_semigroups import MagmaDAE

dae = MagmaDAE(
    cardinality=cardinality,
    hidden_dims=2 * [cardinality ** 3]
)

In [7]:
%load_ext tensorboard

In [8]:
%tensorboard --logdir runs

In [9]:
!rm -rf runs

In [10]:
from neural_semigroups.training_helpers import learning_pipeline
from ignite.metrics.loss import Loss
from neural_semigroups.training_helpers import associative_ratio, guessed_ratio

params = {"learning_rate": 0.001, "epochs": 1000}
metrics = {
    "loss": Loss(loss),
    "associative_ratio": Loss(associative_ratio),
    "guessed_ratio": Loss(guessed_ratio)
}
learning_pipeline(params, dae, loss, metrics, data_loaders)

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))