In [None]:
import magno
from magno import deeptrack as dt

import tensorflow as tf
import tensorflow_addons as tfa

import numpy as np

tf.get_logger().setLevel('ERROR')

#### 1. Create node dataframe

In [None]:
dfs, labels = magno.NodeExtractor(mode="training")

#### 2. Generate graph representations

In [None]:
graph = magno.GraphExtractor(
    dfs, labels=labels, properties=["centroids"]
)

#### 3. Set up graph augmentation pipeline

In [None]:
# Total number of sets, i.e., protein configurations
max_set = dfs["set"].max()

# subset indices, i.e., realizations of the 
# same protein configuration
subsets = np.unique(dfs["subset"].values)

# Number of subgraphs in each batch, batch//2 graphs
# are passed to the teacher and batch//2 graphs are
# passed to the encoder.
batch = 8

feature = (
    dt.Value(graph)
    >> dt.Lambda(
        magno.GetSubSet,
        randset=lambda: np.random.randint(max_set + 1),
        randsubsets=lambda: np.random.choice(subsets, batch, replace=False),
    )
    >> dt.Lambda(
        magno.AugmentCentroids,
        rotate=lambda: np.random.rand(batch) * 2 * np.pi,
        translate=lambda: np.random.randn(batch, 2) * 0.05,
    )
    >> magno.Splitter()
)

#### 5. Define model

In [None]:
# Define encoder
encoder = dt.models.CTMAGIK(
    number_of_node_features=2, output_type="cls_rep"
)
encoder.summary()

In [None]:
# Define teacher
teacher = dt.models.CTMAGIK(
    number_of_node_features=2, output_type="cls_rep"
)

In [None]:
# define and compile MAGNO
model = dt.models.MAGNO(
    encoder, teacher, representation_size=128, center_momentum=0.99
)
# NOTE: 'learning_rate' and 'weight_decay' are controlled by MAGNO´s schedulers.
# The current values are set to the default values.
model.compile(
    optimizer=tfa.optimizers.AdamW(learning_rate=0.001, weight_decay=0.00001),
)


#### 6. Training the network

In [None]:
# Define generator
generator = magno.ContinuousGraphGenerator(
        feature,
        batch_function=lambda graph: graph[0],
        label_function=lambda graph: graph[1],
        min_data_size=1024,
        max_data_size=1025,
        batch_size=1,
    )

In [None]:
epochs = 100

# Define momentum schedule
MomentumSchedule = magno.CosineDecay(0.996, 1.0, epochs)
# Define learning rate schedule
LearningRateSchedule = magno.CosineDecay(1e-3, 1e-4, epochs)
# Define weight decay schedule
WeightDecaySchedule = magno.CosineDecay(0.04, 0.4, epochs)
# Define temperature schedule
TemperatureSchedule = magno.PiecewiseConstantDecay(
    0.04, 0.07, epochs, warmup_epochs=30
)

with generator:
    model.fit(
        generator,
        epochs=epochs,
        callbacks=[
            magno.MomentumScheduler(MomentumSchedule),
            magno.LearningRateScheduler(LearningRateSchedule),
            magno.WeightDecayScheduler(WeightDecaySchedule),
            magno.TemperatureScheduler(TemperatureSchedule),
        ],
    )
