# Traning Code

### 📦 Imports

In [None]:
import os

from src.train.train import train
from src.utils.data import get_dataloaders
from src.utils.environment import get_device, set_seed

### 🛠️ Configuration

In [None]:
epochs = 10
batch_size = 128
learning_rate = 1e-3
seed = 1337
checkpoint = None  # e.g., "checkpoints/last.ckpt"
model_name = "unet"
method = "diffusion"  # or "flow"

### 🧪 Setup: Seed and Device

In [None]:
set_seed(seed)
device = get_device()
os.makedirs("checkpoints", exist_ok=True)

### 📥 Data Loading

In [None]:
train_loader, val_loader = get_dataloaders(batch_size=batch_size)

### 🧠 Model Training

In [None]:
# NOTE: Currently assumes 10 classes are hardcoded
num_classes = 10
model_kwargs = {"num_classes": num_classes}

trained_model = train(
    num_epochs=epochs,
    device=device,
    dataloader=train_loader,
    val_loader=val_loader,
    learning_rate=learning_rate,
    use_wandb=True,
    checkpoint_path=checkpoint,
    model_name=model_name,
    model_kwargs=model_kwargs,
    method=method,
)