In [1]:
import sqlite3
import io
import pickle
import time
import os
import torch
import torch.nn as nn
import torch.optim as optim
import mlflow
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets, transforms
from mlflow.models import infer_signature
from mlflow import MlflowClient

# ==========================================
# 1. DATABASE SETUP (One-time setup)
# ==========================================
def create_sqlite_mnist(db_path="mnist.db"):
    """
    Downloads MNIST, converts images to bytes, and stores them in SQLite.
    """
    if os.path.exists(db_path):
        print(f"Database {db_path} already exists. Skipping creation.")
        return

    print("Creating SQLite database from MNIST dataset...")
    # Download standard MNIST
    transform = transforms.ToTensor()
    mnist_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
    
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    # Create table: ID, Label, ImageBlob
    cursor.execute('CREATE TABLE IF NOT EXISTS train_data (id INTEGER PRIMARY KEY, label INTEGER, image BLOB)')
    
    # Batch insert for speed
    data_to_insert = []
    for idx, (img_tensor, label) in enumerate(mnist_data):
        # Convert tensor to numpy, then pickle to bytes
        img_np = img_tensor.numpy() 
        img_bytes = pickle.dumps(img_np)
        data_to_insert.append((idx, label, img_bytes))
        
        if idx % 1000 == 0:
            print(f"Processing image {idx}/{len(mnist_data)}...")

    cursor.executemany('INSERT INTO train_data VALUES (?, ?, ?)', data_to_insert)
    conn.commit()
    conn.close()
    print("Database creation complete.")

# ==========================================
# 2. CUSTOM DATASET CLASS
# ==========================================
class SqliteDataset(Dataset):
    """
    Custom PyTorch Dataset that reads from SQLite.
    """
    def __init__(self, db_path, table_name="train_data", transform=None):
        self.db_path = db_path
        self.table_name = table_name
        self.transform = transform
        
        # Connect to get total length
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}")
        self.length = cursor.fetchone()[0]
        conn.close()

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # Open a new connection per thread (SQLite requirement for concurrency)
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        # Fetch specific row (id is 0-indexed in our insertion logic)
        cursor.execute(f"SELECT label, image FROM {self.table_name} WHERE id=?", (idx,))
        label, img_bytes = cursor.fetchone()
        conn.close()
        
        # Deserialize
        img_np = pickle.loads(img_bytes)
        img_tensor = torch.from_numpy(img_np)
        
        return img_tensor, label

# ==========================================
# 3. MODEL DEFINITION
# ==========================================
class MnistModel(nn.Module):
    def __init__(self, hidden_size=128, dropout_rate=0.2):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        return self.linear_relu_stack(x)

# ==========================================
# 4. TRAINING FUNCTION WITH MLFLOW
# ==========================================
def train_and_log():
    # --- Configuration ---
    DB_PATH = "mnist.db"
    EXPERIMENT_NAME = "MNIST_SQLite_Experiment"
    REGISTERED_MODEL_NAME = "MNIST_Classifier_SQLite"
    
    # Hyperparameters to log
    params = {
        "epochs": 3,
        "batch_size": 64,
        "learning_rate": 0.001,
        "hidden_size": 128,
        "dropout": 0.2,
        "checkpoint_interval": 1  # Save checkpoint every N epochs
    }

    # --- Setup ---
    # 1. Prepare DB
    create_sqlite_mnist(DB_PATH)
    
    # 2. Setup DataLoaders
    full_dataset = SqliteDataset(DB_PATH)
    # Split for validaiton
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_set, val_set = random_split(full_dataset, [train_size, val_size])
    
    train_loader = DataLoader(train_set, batch_size=params["batch_size"], shuffle=True)
    val_loader = DataLoader(val_set, batch_size=params["batch_size"], shuffle=False)

    # 3. Setup MLflow
    mlflow.set_experiment(EXPERIMENT_NAME)
    
    # Enable System Metrics (CPU/GPU/RAM usage)
    mlflow.enable_system_metrics_logging()

    # --- Start Run ---
    with mlflow.start_run() as run:
        print(f"Starting Run ID: {run.info.run_id}")
        
        # Log Hyperparameters
        mlflow.log_params(params)
        
        # Initialize Model & Optimizer
        model = MnistModel(hidden_size=params["hidden_size"], dropout_rate=params["dropout"])
        loss_fn = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=params["learning_rate"])

        # Infer Signature (Input/Output Schema)
        # Grab a dummy batch to define input shape
        dummy_input, _ = next(iter(train_loader))
        dummy_output = model(dummy_input)
        signature = infer_signature(dummy_input.numpy(), dummy_output.detach().numpy())

        # --- 1. Log the Dataset ---
        # We grab a single batch to act as a "schema definition" and profile
        # We point the 'source' to our local sqlite file
        sample_data, sample_targets = next(iter(train_loader))
        
        # Convert to numpy for MLflow interpretation
        dataset_source_path = os.path.abspath(DB_PATH)
        
        # specific MLflow data object
        dataset_info = mlflow.data.from_numpy(
            features=sample_data.numpy(), 
            targets=sample_targets.numpy(), 
            name="mnist_sqlite_train", 
            source=dataset_source_path  # Links the run to this specific file
        )
        
        # Log it to the run
        print("Logging dataset info to MLflow...")
        mlflow.log_input(dataset_info, context="training")

        # --- Training Loop ---
        for epoch in range(params["epochs"]):
            model.train()
            running_loss = 0.0
            correct = 0
            total = 0

            for batch_idx, (data, target) in enumerate(train_loader):
                optimizer.zero_grad()
                output = model(data)
                loss = loss_fn(output, target)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                _, predicted = output.max(1)
                total += target.size(0)
                correct += predicted.eq(target).sum().item()

            # Calculate Epoch Metrics
            train_loss = running_loss / len(train_loader)
            train_acc = correct / total

            # Validation Step
            model.eval()
            val_correct = 0
            val_total = 0
            with torch.no_grad():
                for data, target in val_loader:
                    output = model(data)
                    _, predicted = output.max(1)
                    val_total += target.size(0)
                    val_correct += predicted.eq(target).sum().item()
            val_acc = val_correct / val_total

            print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Acc={val_acc:.4f}")

            # Log Metrics
            mlflow.log_metrics({
                "train_loss": train_loss,
                "train_accuracy": train_acc,
                "val_accuracy": val_acc
            }, step=epoch)

            # Checkpoint Tracking (Log model state every N epochs)
            if (epoch + 1) % params["checkpoint_interval"] == 0:
                print(f"Saving checkpoint for epoch {epoch+1}...")
                mlflow.pytorch.log_model(
                    pytorch_model=model,
                    artifact_path=f"checkpoints/epoch_{epoch+1}",
                    signature=signature,
                    input_example=dummy_input.numpy()  # Saves a sample input file
                )

        # --- Final Model Logging & Registration ---
        print("Logging final model...")
        model_info = mlflow.pytorch.log_model(
            pytorch_model=model,
            artifact_path="model",
            signature=signature,
            registered_model_name=REGISTERED_MODEL_NAME  # Automatically creates/updates version in Registry
        )
        
        return model_info, params

# ==========================================
# 5. LOADING BEST MODEL FOR TESTING
# ==========================================
def load_and_test_best_model():
    EXPERIMENT_NAME = "MNIST_SQLite_Experiment"
    
    print("\n--- Finding Best Model ---")
    
    # 1. Search for the best run in this experiment
    current_experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME)
    if not current_experiment:
        print("Experiment not found!")
        return

    # Search runs, order by 'val_accuracy' descending, take top 1
    best_run = mlflow.search_runs(
        experiment_ids=[current_experiment.experiment_id],
        order_by=["metrics.val_accuracy DESC"],
        max_results=1
    ).iloc[0]

    run_id = best_run.run_id
    best_acc = best_run["metrics.val_accuracy"]
    print(f"Best Run ID: {run_id} with Val Accuracy: {best_acc}")

    # 2. Load the model from that run
    # Format: runs:/<run_id>/<artifact_path>
    model_uri = f"runs:/{run_id}/model"
    print(f"Loading model from: {model_uri}")
    
    loaded_model = mlflow.pytorch.load_model(model_uri)
    
    # 3. Test Prediction
    print("Running inference on random noise...")
    loaded_model.eval()
    dummy_input = torch.randn(1, 1, 28, 28) # Single random image
    with torch.no_grad():
        prediction = loaded_model(dummy_input)
        predicted_class = prediction.argmax().item()
    
    print(f"Model prediction (class index): {predicted_class}")
    print("Success! Workflow complete.")

if __name__ == "__main__":
    # Run the pipeline
    train_and_log()
    load_and_test_best_model()

Creating SQLite database from MNIST dataset...


100%|██████████| 9.91M/9.91M [00:14<00:00, 675kB/s] 
100%|██████████| 28.9k/28.9k [00:00<00:00, 49.3kB/s]
100%|██████████| 1.65M/1.65M [00:03<00:00, 519kB/s] 
100%|██████████| 4.54k/4.54k [00:00<00:00, 4.24MB/s]


Processing image 0/60000...
Processing image 1000/60000...
Processing image 2000/60000...
Processing image 3000/60000...
Processing image 4000/60000...
Processing image 5000/60000...
Processing image 6000/60000...
Processing image 7000/60000...
Processing image 8000/60000...
Processing image 9000/60000...
Processing image 10000/60000...
Processing image 11000/60000...
Processing image 12000/60000...
Processing image 13000/60000...
Processing image 14000/60000...
Processing image 15000/60000...
Processing image 16000/60000...
Processing image 17000/60000...
Processing image 18000/60000...
Processing image 19000/60000...
Processing image 20000/60000...
Processing image 21000/60000...
Processing image 22000/60000...
Processing image 23000/60000...
Processing image 24000/60000...
Processing image 25000/60000...
Processing image 26000/60000...
Processing image 27000/60000...
Processing image 28000/60000...
Processing image 29000/60000...
Processing image 30000/60000...
Processing image 3100

2026/01/02 15:18:41 INFO mlflow.store.db.utils: Creating initial MLflow database tables...
2026/01/02 15:18:41 INFO mlflow.store.db.utils: Updating database tables
2026/01/02 15:18:41 INFO alembic.runtime.migration: Context impl SQLiteImpl.
2026/01/02 15:18:41 INFO alembic.runtime.migration: Will assume non-transactional DDL.
2026/01/02 15:18:41 INFO alembic.runtime.migration: Running upgrade  -> 451aebb31d03, add metric step
2026/01/02 15:18:41 INFO alembic.runtime.migration: Running upgrade 451aebb31d03 -> 90e64c465722, migrate user column to tags
2026/01/02 15:18:41 INFO alembic.runtime.migration: Running upgrade 90e64c465722 -> 181f10493468, allow nulls for metric values
2026/01/02 15:18:42 INFO alembic.runtime.migration: Running upgrade 181f10493468 -> df50e92ffc5e, Add Experiment Tags Table
2026/01/02 15:18:42 INFO alembic.runtime.migration: Running upgrade df50e92ffc5e -> 7ac759974ad8, Update run tags with larger limit
2026/01/02 15:18:42 INFO alembic.runtime.migration: Running 

Starting Run ID: 426cfd29a96a4eee9e6e93aec2af8bde


  return _dataset_source_registry.resolve(
  return _dataset_source_registry.resolve(


Logging dataset info to MLflow...


2026/01/02 15:19:18 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2026/01/02 15:19:18 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!


Epoch 1: Train Loss=0.4049, Val Acc=0.9344
Saving checkpoint for epoch 1...


MlflowException: Invalid model name ('checkpoints/epoch_1') provided. Model name must be a non-empty string and cannot contain the following characters: ('/', ':', '.', '%', '"', "'")

In [None]:
CLI: mlflow server --port 5000
HTTP: http://localhost:5000