In [2]:
from mads_datasets import DatasetFactoryProvider, DatasetType
from mltrainer.preprocessors import PaddedPreprocessor

from mltrainer import Trainer
from torch import optim

from mltrainer import TrainerSettings, ReportTypes
from mltrainer.metrics import Accuracy
import torch

import sys 
import os
sys.path.append(os.path.abspath('../networks'))
sys.path.append(os.path.abspath('../dev'))

from RNN import ModelConfig

preprocessor = PaddedPreprocessor()

gesturesdatasetfactory = DatasetFactoryProvider.create_factory(DatasetType.GESTURES)
streamers = gesturesdatasetfactory.create_datastreamer(batchsize=32, preprocessor=preprocessor)
train = streamers["train"]
valid = streamers["valid"]

trainstreamer = train.stream()
validstreamer = valid.stream()

[32m2025-06-04 18:33:43.423[0m | [1mINFO    [0m | [36mmads_datasets.base[0m:[36mdownload_data[0m:[36m121[0m - [1mFolder already exists at /home/azureuser/.cache/mads_datasets/gestures[0m
100%|[38;2;30;71;6m██████████[0m| 2600/2600 [00:00<00:00, 3304.16it/s]
100%|[38;2;30;71;6m██████████[0m| 651/651 [00:00<00:00, 3126.43it/s]


In [3]:
accuracy = Accuracy()
loss_fn = torch.nn.CrossEntropyLoss()

config = ModelConfig(
    input_size=3, # vast
    hidden_size=128,
    num_layers=2,
    output_size=20, # vast
    dropout=0.2,
)

In [4]:
from pathlib import Path

settings = TrainerSettings(
    epochs=10, 
    metrics=[accuracy],
    logdir=Path("gestures"),
    train_steps=len(train),
    valid_steps=len(valid),
    reporttypes=[ReportTypes.TOML, ReportTypes.TENSORBOARD],
    scheduler_kwargs={"factor": 0.5, "patience": 5},
    earlystop_kwargs = {
        "save": False, # save every best model, and restore the best one
        "verbose": True,
        "patience": 5, # number of epochs with no improvement after which training will be stopped
    }
)
settings

epochs: 10
metrics: [Accuracy]
logdir: gestures
train_steps: 81
valid_steps: 20
reporttypes: [<ReportTypes.TOML: 'TOML'>, <ReportTypes.TENSORBOARD: 'TENSORBOARD'>]
optimizer_kwargs: {'lr': 0.001, 'weight_decay': 1e-05}
scheduler_kwargs: {'factor': 0.5, 'patience': 5}
earlystop_kwargs: {'save': False, 'verbose': True, 'patience': 5}

In [5]:
from RNN import RecurrentNeuralNetwork

model = RecurrentNeuralNetwork(
    config=config,
)

trainer = Trainer(
    model=model,
    settings=settings,
    loss_fn=loss_fn,
    optimizer=optim.Adam,
    traindataloader=trainstreamer,
    validdataloader=validstreamer,
    scheduler=optim.lr_scheduler.ReduceLROnPlateau,
)

trainer.loop()

[32m2025-06-04 18:33:57.093[0m | [1mINFO    [0m | [36mmltrainer.trainer[0m:[36mdir_add_timestamp[0m:[36m24[0m - [1mLogging to gestures/20250604-183357[0m
[32m2025-06-04 18:33:58.185[0m | [1mINFO    [0m | [36mmltrainer.trainer[0m:[36m__init__[0m:[36m68[0m - [1mFound earlystop_kwargs in settings.Set to None if you dont want earlystopping.[0m
100%|[38;2;30;71;6m██████████[0m| 81/81 [00:01<00:00, 79.67it/s]
[32m2025-06-04 18:33:59.337[0m | [1mINFO    [0m | [36mmltrainer.trainer[0m:[36mreport[0m:[36m209[0m - [1mEpoch 0 train 2.7087 test 2.4861 metric ['0.1109'][0m
100%|[38;2;30;71;6m██████████[0m| 81/81 [00:01<00:00, 79.60it/s]
[32m2025-06-04 18:34:00.441[0m | [1mINFO    [0m | [36mmltrainer.trainer[0m:[36mreport[0m:[36m209[0m - [1mEpoch 1 train 2.4734 test 2.7625 metric ['0.1062'][0m
[32m2025-06-04 18:34:00.442[0m | [1mINFO    [0m | [36mmltrainer.trainer[0m:[36m__call__[0m:[36m252[0m - [1mbest loss: 2.4861, current loss 2.7625.C

In [6]:
from RNN import GRUWithAttention

model = GRUWithAttention(
    config=config,
)

trainer = Trainer(
    model=model,
    settings=settings,
    loss_fn=loss_fn,
    optimizer=optim.Adam,
    traindataloader=trainstreamer,
    validdataloader=validstreamer,
    scheduler=optim.lr_scheduler.ReduceLROnPlateau,
)

trainer.loop()

[32m2025-06-04 18:34:11.003[0m | [1mINFO    [0m | [36mmltrainer.trainer[0m:[36mdir_add_timestamp[0m:[36m24[0m - [1mLogging to gestures/20250604-183411[0m
[32m2025-06-04 18:34:11.004[0m | [1mINFO    [0m | [36mmltrainer.trainer[0m:[36m__init__[0m:[36m68[0m - [1mFound earlystop_kwargs in settings.Set to None if you dont want earlystopping.[0m
100%|[38;2;30;71;6m██████████[0m| 81/81 [00:03<00:00, 26.94it/s]
[32m2025-06-04 18:34:14.184[0m | [1mINFO    [0m | [36mmltrainer.trainer[0m:[36mreport[0m:[36m209[0m - [1mEpoch 0 train 2.1419 test 1.4894 metric ['0.5516'][0m
100%|[38;2;30;71;6m██████████[0m| 81/81 [00:02<00:00, 27.64it/s]
[32m2025-06-04 18:34:17.306[0m | [1mINFO    [0m | [36mmltrainer.trainer[0m:[36mreport[0m:[36m209[0m - [1mEpoch 1 train 0.9624 test 0.4935 metric ['0.8578'][0m
100%|[38;2;30;71;6m██████████[0m| 81/81 [00:03<00:00, 26.06it/s]
[32m2025-06-04 18:34:20.597[0m | [1mINFO    [0m | [36mmltrainer.trainer[0m:[36mrepor

In [8]:
from RNN import RecurrentNeuralNetworkWithAttention

model = RecurrentNeuralNetworkWithAttention(
    config=config,
)

trainer = Trainer(
    model=model,
    settings=settings,
    loss_fn=loss_fn,
    optimizer=optim.Adam,
    traindataloader=trainstreamer,
    validdataloader=validstreamer,
    scheduler=optim.lr_scheduler.ReduceLROnPlateau,
)

trainer.loop()

[32m2025-06-04 18:36:09.466[0m | [1mINFO    [0m | [36mmltrainer.trainer[0m:[36mdir_add_timestamp[0m:[36m24[0m - [1mLogging to gestures/20250604-183609[0m
[32m2025-06-04 18:36:09.467[0m | [1mINFO    [0m | [36mmltrainer.trainer[0m:[36m__init__[0m:[36m68[0m - [1mFound earlystop_kwargs in settings.Set to None if you dont want earlystopping.[0m
100%|[38;2;30;71;6m██████████[0m| 81/81 [00:01<00:00, 65.91it/s]
[32m2025-06-04 18:36:10.791[0m | [1mINFO    [0m | [36mmltrainer.trainer[0m:[36mreport[0m:[36m209[0m - [1mEpoch 0 train 2.0716 test 1.4498 metric ['0.5078'][0m
100%|[38;2;30;71;6m██████████[0m| 81/81 [00:01<00:00, 68.60it/s]
[32m2025-06-04 18:36:12.062[0m | [1mINFO    [0m | [36mmltrainer.trainer[0m:[36mreport[0m:[36m209[0m - [1mEpoch 1 train 1.2472 test 1.0210 metric ['0.6609'][0m
100%|[38;2;30;71;6m██████████[0m| 81/81 [00:01<00:00, 66.57it/s]
[32m2025-06-04 18:36:13.365[0m | [1mINFO    [0m | [36mmltrainer.trainer[0m:[36mrepor