In [21]:
# Add project root to Python path
import sys
from pathlib import Path
import sys, os
sys.path.append(os.path.abspath("src"))
project_root = Path().absolute().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

In [22]:
# Clean up any active MLflow runs
import mlflow
if mlflow.active_run():
    mlflow.end_run()
print("MLflow cleanup complete")

MLflow cleanup complete


In [23]:
import hydra
from omegaconf import DictConfig
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
from src.codegen.blueprint import Blueprint
from src.codegen.renderer import render_blueprint
from src.codegen.validator import validate_blueprint_dict
import time
import os
import mlflow

class EarlyStop:
    def __init__(self, patience=3):
        self.patience = patience
        self.best = None
        self.counter = 0

    def step(self, val):
        if self.best is None or val > self.best:
            self.best = val
            self.counter = 0
            return False
        else:
            self.counter += 1
            return self.counter >= self.patience

@hydra.main(version_base=None, config_path=None, config_name=None)
def train(cfg: DictConfig):
    # cfg expected to contain:
    # cfg.blueprint (dict), cfg.train.batch_size, cfg.train.epochs, cfg.device, cfg.log_dir
    mlflow.start_run()
    bp = Blueprint.from_dict(cfg.blueprint)
    model = render_blueprint(bp)
    device = torch.device(cfg.device)
    model.to(device)

    # dataset loader (example: cifar10)
    transform = T.Compose([T.ToTensor(), T.Normalize((0.5,), (0.5,))])
    trainset = torchvision.datasets.CIFAR10(root=cfg.data.root, train=True, download=True, transform=transform)
    valset = torchvision.datasets.CIFAR10(root=cfg.data.root, train=False, download=True, transform=transform)
    train_loader = DataLoader(trainset, batch_size=cfg.train.batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(valset, batch_size=cfg.train.batch_size, shuffle=False, num_workers=2)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=cfg.train.lr, momentum=0.9)

    early = EarlyStop(patience=cfg.train.patience)
    best_val = 0.0
    for epoch in range(cfg.train.epochs):
        model.train()
        running = 0.0
        start = time.time()
        for i, (images, targets) in enumerate(train_loader):
            images, targets = images.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            running += loss.item()
            if i % 100 == 0:
                mlflow.log_metric("train_loss_batch", loss.item())

        # validation
        model.eval()
        correct = 0
        total = 0
        val_loss = 0.0
        with torch.no_grad():
            for images, targets in val_loader:
                images, targets = images.to(device), targets.to(device)
                outputs = model(images)
                val_loss += criterion(outputs, targets).item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        val_acc = correct / total
        mlflow.log_metric("val_acc", val_acc)
        mlflow.log_metric("val_loss", val_loss)

        print(f"Epoch {epoch}: val_acc={val_acc:.4f}, val_loss={val_loss:.4f}, time={(time.time()-start):.1f}s")
        if val_acc > best_val:
            best_val = val_acc
            # checkpoint
            os.makedirs(cfg.log_dir, exist_ok=True)
            torch.save(model.state_dict(), os.path.join(cfg.log_dir, "best.pth"))
            mlflow.log_artifact(os.path.join(cfg.log_dir, "best.pth"))

        if early.step(val_acc):
            print("Early stopping triggered")
            break
    mlflow.end_run()

if __name__ == "__main__":
    # example: python train.py --config-name=conf
    import sys
    # For simple invocation without Hydra config files, build cfg from dict:
    from omegaconf import OmegaConf
    example_cfg = {
        "blueprint": {
            "input_shape": [3, 32, 32],  # CIFAR-10 images are 3x32x32
            "num_classes": 10,
            "layers": []
        }, 
        "data": {"root": "./data"},
        "train": {"batch_size": 64, "epochs": 5, "lr": 0.01, "patience": 3},
        "device": "cuda" if torch.cuda.is_available() else "cpu",
        "log_dir": "./logs/exp1"
    }
    cfg = OmegaConf.create(example_cfg)
    train(cfg)


Epoch 0: val_acc=0.1974, val_loss=345.4610, time=26.3s
Epoch 1: val_acc=0.2078, val_loss=340.7155, time=25.0s
Epoch 1: val_acc=0.2078, val_loss=340.7155, time=25.0s
Epoch 2: val_acc=0.2124, val_loss=337.6188, time=27.8s
Epoch 2: val_acc=0.2124, val_loss=337.6188, time=27.8s
Epoch 3: val_acc=0.2230, val_loss=335.3273, time=29.6s
Epoch 3: val_acc=0.2230, val_loss=335.3273, time=29.6s
Epoch 4: val_acc=0.2189, val_loss=333.6264, time=29.1s
Epoch 4: val_acc=0.2189, val_loss=333.6264, time=29.1s
