# The Regular usage of collie

In [1]:
import sys
import os
sys.path.append("../..")

import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset

from collie.core import (
    Transformer,
    Trainer,
    Evaluator,
    Pusher,
    TrainerPayload,
    TransformerPayload,
    TunerPayload,
    EvaluatorPayload,
    PusherPayload,
    Orchestrator
)

from collie import Event

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
num_samples = 1000
input_dim = 20   
num_classes = 4

## Transformer

In [3]:
class MLPTransformer(Transformer):
    def __init__(self) -> None:
        super().__init__()

    def handle(self, event) -> Event:

        X = torch.randn(num_samples, input_dim)
        y = torch.randint(0, num_classes, (num_samples,))

        X_data = pd.DataFrame(X.numpy(), columns=[f"feature_{i}" for i in range(input_dim)])
        y_data = pd.DataFrame(y.numpy(), columns=["label"])

        train_data = pd.concat([X_data, y_data], axis=1)

        return Event(
            payload=TransformerPayload(
                train_data=train_data,
                validation_data=None,
                test_data=None
            )
        )

# Tuner

In [4]:
class MLPTuner(Tuner):
    def __init__(self) -> None:
        super().__init__()

    def handle(self, event: Event) -> Event:
        # Find the best hyperparameters (dummy example)
        hyperparameters = {
            "learning_rate": 0.001,
            "batch_size": 32,
        }
        # Need to pass train, validation, test data to the next stage
        return Event(
            payload=TunerPayload(
                hyperparameters=hyperparameters,
                train_data=event.payload.train_data,
                validation_data=event.payload.validation_data,
                test_data=event.payload.test_data
            )
        )

## Trainer

In [9]:
class SimpleClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, num_classes)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        return self.fc2(x)


class MLPTrainer(Trainer):
    def __init__(self):
        super().__init__()
        self.model = SimpleClassifier()
        self.criterion = nn.CrossEntropyLoss()
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.optimizer = None
        self.scheduler = None

    def handle(self, event):

        learning_rate = event.payload.hyperparameters.get("learning_rate")
        batch_size = event.payload.hyperparameters.get("batch_size")
        print(f"learning_rate: {learning_rate}, batch_size: {batch_size}")
        self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=10, gamma=0.1)
        
        train_data = event.payload.train_data

        X = torch.tensor(train_data.drop("label", axis=1).values, dtype=torch.float32)
        y = torch.tensor(train_data["label"].values, dtype=torch.long) 

        dataset = TensorDataset(X, y)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        epochs = 10
        for epoch in range(1, epochs + 1):
            self.model.train()
            total_loss = 0.0
            for xb, yb in dataloader:
                xb, yb = xb.to(self.device), yb.to(self.device)
                self.optimizer.zero_grad()
                logits = self.model(xb)
                loss = self.criterion(logits, yb)
                loss.backward()
                self.optimizer.step()
                total_loss += loss.item()

            self.log_metric("learning rate", self.scheduler.get_last_lr()[0], step=epoch)
            self.log_metric("loss", round(total_loss/len(dataloader), 3), step=epoch)
            
        return Event(
            payload=TrainerPayload(
                model=self.model,
                train_loss=total_loss/len(dataloader),
                val_loss=None
            )
        )

## Evaluator

In [10]:
class MLPEvaluator(Evaluator):
    def __init__(
        self,
        registered_model_name="MLPClassifier",
        model_uri="" 
    ) -> None:
        super().__init__(
            registered_model_name=registered_model_name, 
            model_uri=model_uri
        )

    def handle(self, event):
        model = event.payload.model
        train_loss = event.payload.train_loss
        #mock the production metrics
        production_metric = 10

        return Event(
            payload=EvaluatorPayload(
                metrics={"Experiment": train_loss, "Production": production_metric},
                greater_is_better=False
            )
        )

## Pusher

In [11]:
class MLPPusher(Pusher):
    def __init__(
        self,
        registered_model_name="MLPClassifier"
    ) -> None:
        super().__init__(registered_model_name=registered_model_name)

    def handle(self, event):
        return Event(
            payload=PusherPayload(
                model_uri="mlp_model_uri",
            )
        )

## Main

In [12]:
orchestrator = Orchestrator(
    tracking_uri="http://localhost:5001",
    components=[
        MLPTransformer(),
        MLPTuner(),
        MLPTrainer(),
        MLPEvaluator(),
        MLPPusher()
    ],
    mlflow_tags={"Example": "MLP"},
    experiment_name="MLP2",
)
orchestrator.run()

2025/10/06 21:07:24 INFO mlflow.system_metrics.system_metrics_monitor: Skip logging GPU metrics. Set logger level to DEBUG for more details.
2025/10/06 21:07:24 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.
  return _dataset_source_registry.resolve(
  return _dataset_source_registry.resolve(
2025/10/06 21:07:25 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2025/10/06 21:07:25 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!
2025/10/06 21:07:25 INFO mlflow.system_metrics.system_metrics_monitor: Skip logging GPU metrics. Set logger level to DEBUG for more details.
2025/10/06 21:07:25 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.
2025/10/06 21:07:25 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2025/10/06 21:07:25 INFO mlflow.system_metrics.system_metrics_monitor: Successfu

🏃 View run Transformer at: http://localhost:5001/#/experiments/1/runs/ec2e397c62a641a79838467c47baf065
🧪 View experiment at: http://localhost:5001/#/experiments/1
🏃 View run Tuner at: http://localhost:5001/#/experiments/1/runs/c09df186464d4cdfb9ec00d1f4118260
🧪 View experiment at: http://localhost:5001/#/experiments/1
learning_rate: 0.001, batch_size: 32


2025/10/06 21:07:33 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2025/10/06 21:07:33 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!
2025/10/06 21:07:33 INFO mlflow.system_metrics.system_metrics_monitor: Skip logging GPU metrics. Set logger level to DEBUG for more details.
2025/10/06 21:07:33 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.
Registered model 'MLPClassifier' already exists. Creating a new version of this model...


🏃 View run Trainer at: http://localhost:5001/#/experiments/1/runs/b3722103ac2a4208bdac2681cd35f050
🧪 View experiment at: http://localhost:5001/#/experiments/1
Model URI: runs:/b3722103ac2a4208bdac2681cd35f050/model


2025/10/06 21:07:33 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: MLPClassifier, version 2
Created version '2' of model 'MLPClassifier'.
  latest_versions = self.mlflow_client.get_latest_versions(
  self.mlflow_client.transition_model_version_stage(
2025/10/06 21:07:33 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2025/10/06 21:07:33 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!


🏃 View run Evaluator at: http://localhost:5001/#/experiments/1/runs/8b41e983ba034ec7a0aa69215a3c706f
🧪 View experiment at: http://localhost:5001/#/experiments/1
🏃 View run Pusher at: http://localhost:5001/#/experiments/1/runs/f0a4772ccebf4b2c9d960b303a739129
🧪 View experiment at: http://localhost:5001/#/experiments/1
🏃 View run Orchestrator at: http://localhost:5001/#/experiments/1/runs/dfaa1c3f46224f8b99a6883dc1895ed6
🧪 View experiment at: http://localhost:5001/#/experiments/1


In [None]:
# TODO:
# TEST pusher tuner