In [1]:
!pip install neural-semigroups

Collecting neural-semigroups
  Downloading neural_semigroups-0.5.2-py3-none-any.whl (26 kB)
Installing collected packages: neural-semigroups
Successfully installed neural-semigroups-0.5.2
You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m


In [2]:
from neural_semigroups.smallsemi_dataset import Smallsemi
from neural_semigroups import Magma
from neural_semigroups.utils import corrupt_input

def transform(x):
    new_y = Magma(
        Magma(x[0]).random_isomorphism()
    ).probabilistic_cube
    new_x = corrupt_input(
        new_y.view(1, cardinality, cardinality, cardinality),
        dropout_rate=dropout_rate
    ).view(cardinality, cardinality, cardinality)
    return new_x, new_y

cardinality = 4
dropout_rate = 1 - 1 / cardinality
data = Smallsemi(
    root="/kaggle/input/smallsemi",
    cardinality=cardinality,
    transform=transform
)

In [3]:
from torch.utils.data.dataset import random_split
from torch.utils.data import DataLoader

data_size = len(data)
print(data_size)
test_size = len(data) // 3
data_loaders = tuple(
    DataLoader(data_split, batch_size=32)
    for data_split
    in random_split(data, [data_size - 2 * test_size, test_size, test_size])
)

126


In [4]:
from neural_semigroups.associator_loss import AssociatorLoss
from torch import Tensor

def loss(prediction: Tensor, target: Tensor) -> Tensor:
    return AssociatorLoss()(prediction)

In [5]:
from neural_semigroups import MagmaDAE

dae = MagmaDAE(
    cardinality=cardinality,
    hidden_dims=2 * [cardinality ** 3],
    dropout_rate=dropout_rate
)

In [6]:
!rm -rf runs

In [7]:
from neural_semigroups.training_helpers import learning_pipeline
from ignite.metrics.loss import Loss
from neural_semigroups.training_helpers import associative_ratio, guessed_ratio

params = {"learning_rate": 0.001, "epochs": 1000}
metrics = {
    "loss": Loss(loss),
    "associative_ratio": Loss(associative_ratio),
    "guessed_ratio": Loss(guessed_ratio)
}
learning_pipeline(params, dae, loss, metrics, data_loaders)

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

In [8]:
%load_ext tensorboard

In [9]:
%tensorboard --logdir runs