In [28]:
from mads_datasets import DatasetFactoryProvider, DatasetType
from mltrainer.preprocessors import PaddedPreprocessor

import sys 
import os
sys.path.append(os.path.abspath('../models'))
sys.path.append(os.path.abspath('../dev'))

preprocessor = PaddedPreprocessor()

gesturesdatasetfactory = DatasetFactoryProvider.create_factory(DatasetType.GESTURES)
streamers = gesturesdatasetfactory.create_datastreamer(batchsize=32, preprocessor=preprocessor)
train = streamers["train"]
valid = streamers["valid"]

trainstreamer = train.stream()
validstreamer = valid.stream()

[32m2025-05-27 18:03:39.375[0m | [1mINFO    [0m | [36mmads_datasets.base[0m:[36mdownload_data[0m:[36m121[0m - [1mFolder already exists at /home/azureuser/.cache/mads_datasets/gestures[0m
100%|[38;2;30;71;6m██████████[0m| 2600/2600 [00:00<00:00, 2626.71it/s]
100%|[38;2;30;71;6m██████████[0m| 651/651 [00:00<00:00, 2791.73it/s]


In [29]:
from mltrainer import TrainerSettings, ReportTypes
from mltrainer.metrics import Accuracy
import torch

accuracy = Accuracy()
loss_fn = torch.nn.CrossEntropyLoss()

In [30]:
from pathlib import Path


settings = TrainerSettings(
    epochs=10, 
    metrics=[accuracy],
    logdir=Path("gestures"),
    train_steps=len(train),
    valid_steps=len(valid),
    reporttypes=[ReportTypes.TOML, ReportTypes.TENSORBOARD, ReportTypes.MLFLOW],
    scheduler_kwargs={"factor": 0.5, "patience": 5},
    earlystop_kwargs = {
        "save": False, # save every best model, and restore the best one
        "verbose": True,
        "patience": 5, # number of epochs with no improvement after which training will be stopped
        "delta": 0.0, # minimum change to be considered an improvement
    }
)
settings

epochs: 10
metrics: [Accuracy]
logdir: gestures
train_steps: 81
valid_steps: 20
reporttypes: [<ReportTypes.TOML: 'TOML'>, <ReportTypes.TENSORBOARD: 'TENSORBOARD'>, <ReportTypes.MLFLOW: 'MLFLOW'>]
optimizer_kwargs: {'lr': 0.001, 'weight_decay': 1e-05}
scheduler_kwargs: {'factor': 0.5, 'patience': 5}
earlystop_kwargs: {'save': False, 'verbose': True, 'patience': 5, 'delta': 0.0}

In [35]:
import timeit
import mlflow
from datetime import datetime

from mltrainer import Trainer
from torch import optim

from RNN import GRUmodel, ModelConfig

mlflow.set_tracking_uri("file:./mlruns")
mlflow.set_experiment("gestures")
modeldir = Path("gestures").resolve()
if not modeldir.exists():
    modeldir.mkdir(parents=True)

with mlflow.start_run():
    mlflow.set_tag("model", "GRU")
    mlflow.set_tag("dev", "melissa")
    config = ModelConfig(
        input_size=3, # vast
        hidden_size=64,
        num_layers=2,
        output_size=20, # vast
        dropout=0.4,
    )

    model = GRUmodel(
        config=config,
    )

    trainer = Trainer(
        model=model,
        settings=settings,
        loss_fn=loss_fn,
        optimizer=optim.Adam,
        traindataloader=trainstreamer,
        validdataloader=validstreamer,
        scheduler=optim.lr_scheduler.ReduceLROnPlateau,
    )
    timeit 
    trainer.loop()

    if not settings.earlystop_kwargs["save"]:
        tag = datetime.now().strftime("%Y%m%d-%H%M-")
        modelpath = modeldir / (tag + "model.pt")
        torch.save(model, modelpath)

[32m2025-05-27 18:06:55.875[0m | [1mINFO    [0m | [36mmltrainer.trainer[0m:[36mdir_add_timestamp[0m:[36m24[0m - [1mLogging to gestures/20250527-180655[0m
[32m2025-05-27 18:06:55.877[0m | [1mINFO    [0m | [36mmltrainer.trainer[0m:[36m__init__[0m:[36m68[0m - [1mFound earlystop_kwargs in settings.Set to None if you dont want earlystopping.[0m
  0%|[38;2;30;71;6m          [0m| 0/10 [00:00<?, ?it/s]

100%|[38;2;30;71;6m██████████[0m| 81/81 [00:03<00:00, 26.97it/s]
[32m2025-05-27 18:06:59.013[0m | [1mINFO    [0m | [36mmltrainer.trainer[0m:[36mreport[0m:[36m205[0m - [1mEpoch 0 train 2.6559 test 2.2622 metric ['0.1969'][0m
100%|[38;2;30;71;6m██████████[0m| 81/81 [00:02<00:00, 34.12it/s]
[32m2025-05-27 18:07:01.520[0m | [1mINFO    [0m | [36mmltrainer.trainer[0m:[36mreport[0m:[36m205[0m - [1mEpoch 2 train 2.1841 test 2.1000 metric ['0.2219'][0m
100%|[38;2;30;71;6m██████████[0m| 81/81 [00:02<00:00, 31.42it/s]
[32m2025-05-27 18:07:04.225[0m | [1mINFO    [0m | [36mmltrainer.trainer[0m:[36mreport[0m:[36m205[0m - [1mEpoch 4 train 2.0261 test 1.8505 metric ['0.3797'][0m
100%|[38;2;30;71;6m██████████[0m| 81/81 [00:02<00:00, 33.34it/s]
[32m2025-05-27 18:07:06.789[0m | [1mINFO    [0m | [36mmltrainer.trainer[0m:[36mreport[0m:[36m205[0m - [1mEpoch 6 train 1.6785 test 1.3366 metric ['0.5766'][0m
100%|[38;2;30;71;6m██████████[0m| 81/81 [00:02

KeyboardInterrupt: 