In [None]:
# ============================================================
# 1. Setup (run once)
# ============================================================
import torch
from torch.utils.data import Dataset, DataLoader
import sys, os
sys.path.append(r"C:\Users\Hagai.LAPTOP-QAG9263N\Desktop\Thesis\repositories\ImagiNav")
from modules.autoencoder import AutoEncoder
from modules.unet import UNet
from modules.diffusion import LatentDiffusion
from modules.scheduler import CosineScheduler
from training.diffusion_trainer import DiffusionTrainer



In [None]:

# ============================================================
# 2. Dummy dataset
# ============================================================
class DummyDataset(Dataset):
    def __init__(self, length=32, shape=(3, 64, 64)):
        self.length = length
        self.shape = shape

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # return tuple to preserve batch dimension
        return (torch.zeros(self.shape),)



train_loader = DataLoader(DummyDataset(length=16), batch_size=4)
val_loader = DataLoader(DummyDataset(length=8), batch_size=4)




In [None]:
# ============================================================
# 3. Instantiate components (aligned latent AE + UNet)
# ============================================================
device = "cuda" if torch.cuda.is_available() else "cpu"

# Define shared latent geometry for AE and UNet
latent_channels = 3
latent_base = 16
image_size = 64

# Autoencoder built via new from_shape API
autoencoder = AutoEncoder.from_shape(
    in_channels=3,
    out_channels=3,
    base_channels=16,
    latent_channels=latent_channels,
    image_size=image_size,
    latent_base=latent_base,
    norm="batch",
    act="relu"
).to(device)

# UNet configured to operate on the same latent space
unet = UNet(
    in_channels=latent_channels,
    out_channels=latent_channels,
    base_channels=16,
    depth=3
).to(device)

scheduler = CosineScheduler(num_steps=10)
latent_diffusion = LatentDiffusion(unet, scheduler, autoencoder)


In [None]:

trainer = DiffusionTrainer(
    unet=unet,
    autoencoder=autoencoder,
    scheduler=scheduler,
    epochs=10,
    log_interval=1,       # log every step
    sample_interval=2,    # create artifacts every 2 steps
    eval_interval=4,
    output_dir="test_outputs",
    ckpt_dir="test_outputs/checkpoints",
)

# ============================================================
# 4. Run one short training cycle
# ============================================================
trainer.fit(train_loader, val_loader)

# ============================================================
# 5. Inspect results
# ============================================================
print("\nTraining complete.")
print("Artifacts saved in:", os.path.abspath(trainer.output_dir))
print("Metric log entries:", len(getattr(trainer, 'metric_log', [])))

# Display one sample of recorded metrics (if available)
if hasattr(trainer, "metric_log") and trainer.metric_log:
    print("Example metrics:", trainer.metric_log[0])
