In [1]:
# ============================================================
# 1. Setup (run once)
# ============================================================
import torch
from torch.utils.data import Dataset, DataLoader
import sys, os
sys.path.append(r"C:\Users\Hagai.LAPTOP-QAG9263N\Desktop\Thesis\repositories\ImagiNav")
from modules.autoencoder import AutoEncoder
from modules.unet import UNet
from modules.diffusion import LatentDiffusion
from modules.scheduler import CosineScheduler
from training.diffusion_trainer import DiffusionTrainer



In [2]:

# ============================================================
# 2. Dummy dataset
# ============================================================
class DummyDataset(Dataset):
    def __init__(self, length=32, shape=(3, 64, 64)):
        self.length = length
        self.shape = shape

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # return tuple to preserve batch dimension
        return (torch.zeros(self.shape),)



train_loader = DataLoader(DummyDataset(length=16), batch_size=4)
val_loader = DataLoader(DummyDataset(length=8), batch_size=4)




In [3]:
# ============================================================
# 3. Instantiate components (aligned latent AE + UNet)
# ============================================================
device = "cuda" if torch.cuda.is_available() else "cpu"

# Define shared latent geometry for AE and UNet
latent_channels = 3
latent_base = 16
image_size = 64

# Autoencoder built via new from_shape API
autoencoder = AutoEncoder.from_shape(
    in_channels=3,
    out_channels=3,
    base_channels=16,
    latent_channels=latent_channels,
    image_size=image_size,
    latent_base=latent_base,
    norm="batch",
    act="relu"
).to(device)

# UNet configured to operate on the same latent space
unet = UNet(
    in_channels=latent_channels,
    out_channels=latent_channels,
    base_channels=16,
    depth=3
).to(device)

scheduler = CosineScheduler(num_steps=10)
latent_diffusion = LatentDiffusion(unet, scheduler, autoencoder)


In [4]:

trainer = DiffusionTrainer(
    unet=unet,
    autoencoder=autoencoder,
    scheduler=scheduler,
    epochs=10,
    log_interval=1,       # log every step
    sample_interval=2,    # create artifacts every 2 steps
    eval_interval=4,
    output_dir="test_outputs",
    ckpt_dir="test_outputs/checkpoints",
)

# ============================================================
# 4. Run one short training cycle
# ============================================================
trainer.fit(train_loader, val_loader)

# ============================================================
# 5. Inspect results
# ============================================================
print("\nTraining complete.")
print("Artifacts saved in:", os.path.abspath(trainer.output_dir))
print("Metric log entries:", len(getattr(trainer, 'metric_log', [])))

# Display one sample of recorded metrics (if available)
if hasattr(trainer, "metric_log") and trainer.metric_log:
    print("Example metrics:", trainer.metric_log[0])


Dataset diversity baseline: 0.0000


Epoch 1/10:   0%|          | 0/4 [00:00<?, ?it/s, loss=1.0782]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 145.72it/s]
Epoch 1/10:  25%|██▌       | 1/4 [00:00<00:00,  7.41it/s, loss=1.0776]

Decoding...
[1] loss=1.0782, snr=0.781, cos=-0.035, div=0.002
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 150.05it/s]


Decoding...
[2] loss=1.0776, snr=0.551, cos=-0.070, div=0.003
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 94.59it/s]


Decoding...


Epoch 1/10:  50%|█████     | 2/4 [00:00<00:01,  1.93it/s, loss=1.0689]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 128.32it/s]
Epoch 1/10:  75%|███████▌  | 3/4 [00:01<00:00,  2.89it/s, loss=0.9993]

Decoding...
[3] loss=1.0689, snr=0.632, cos=-0.042, div=0.002
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 110.89it/s]


Decoding...
[4] loss=0.9993, snr=0.135, cos=0.003, div=0.002
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 77.87it/s]


Decoding...


                                                                      

[Validation @ step 4] loss=1.0680
[Checkpoint] Saved: test_outputs/checkpoints\best_val.pt
[Config] Saved UNet config → test_outputs/checkpoints\unet_config.yaml
[Config] Saved AutoEncoder config → test_outputs/checkpoints\autoencoder_config.yaml
[Validation] New best model saved → test_outputs/checkpoints\best_val.pt
[Checkpoint] Saved: test_outputs/checkpoints/epoch_1.pt


Epoch 2/10:   0%|          | 0/4 [00:00<?, ?it/s, loss=1.0962]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 93.42it/s]
Epoch 2/10:  25%|██▌       | 1/4 [00:00<00:00,  5.31it/s, loss=1.0170]

Decoding...
[5] loss=1.0962, snr=0.615, cos=-0.064, div=0.002
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 93.63it/s]


Decoding...
[6] loss=1.0170, snr=0.364, cos=-0.030, div=0.003
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 72.00it/s]


Decoding...


Epoch 2/10:  50%|█████     | 2/4 [00:01<00:01,  1.58it/s, loss=1.0408]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 36.22it/s]
Epoch 2/10:  75%|███████▌  | 3/4 [00:01<00:00,  1.97it/s, loss=1.0056]

Decoding...
[7] loss=1.0408, snr=0.666, cos=-0.065, div=0.002
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 62.50it/s]


Decoding...
[8] loss=1.0056, snr=0.728, cos=-0.028, div=0.002
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 49.88it/s]


Decoding...


                                                                      

[Validation @ step 8] loss=1.0035
[Checkpoint] Saved: test_outputs/checkpoints\best_val.pt
[Validation] New best model saved → test_outputs/checkpoints\best_val.pt
[Checkpoint] Saved: test_outputs/checkpoints/epoch_2.pt


Epoch 3/10:   0%|          | 0/4 [00:00<?, ?it/s, loss=1.0796]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 55.24it/s]
Epoch 3/10:  25%|██▌       | 1/4 [00:00<00:01,  2.45it/s, loss=1.0414]

Decoding...
[9] loss=1.0796, snr=0.562, cos=-0.013, div=0.003
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 67.10it/s]


Decoding...
[10] loss=1.0414, snr=0.522, cos=-0.030, div=0.002
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 52.97it/s]


Decoding...


Epoch 3/10:  50%|█████     | 2/4 [00:01<00:01,  1.33it/s, loss=1.0671]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 102.79it/s]
Epoch 3/10:  75%|███████▌  | 3/4 [00:01<00:00,  2.06it/s, loss=1.0648]

Decoding...
[11] loss=1.0671, snr=0.744, cos=-0.034, div=0.002
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 38.22it/s]


Decoding...
[12] loss=1.0648, snr=0.688, cos=-0.048, div=0.002
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 58.26it/s]


Decoding...


                                                                      

[Validation @ step 12] loss=1.0083
[Checkpoint] Saved: test_outputs/checkpoints/epoch_3.pt


Epoch 4/10:   0%|          | 0/4 [00:00<?, ?it/s, loss=1.0190]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 74.07it/s]
Epoch 4/10:  25%|██▌       | 1/4 [00:00<00:00,  4.08it/s, loss=1.0190]

Decoding...
[13] loss=1.0190, snr=0.430, cos=-0.020, div=0.002


Epoch 4/10:  25%|██▌       | 1/4 [00:00<00:00,  4.08it/s, loss=1.0060]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 49.78it/s]


Decoding...
[14] loss=1.0060, snr=0.393, cos=0.035, div=0.003
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 27.42it/s]


Decoding...


Epoch 4/10:  50%|█████     | 2/4 [00:01<00:01,  1.06it/s, loss=0.9892]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 83.41it/s]
Epoch 4/10:  75%|███████▌  | 3/4 [00:01<00:00,  1.66it/s, loss=0.9892]

Decoding...
[15] loss=0.9892, snr=0.470, cos=0.016, div=0.003


Epoch 4/10:  75%|███████▌  | 3/4 [00:01<00:00,  1.66it/s, loss=0.9763]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 23.16it/s]


Decoding...
[16] loss=0.9763, snr=0.155, cos=0.020, div=0.002
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 53.44it/s]


Decoding...


                                                                      

[Validation @ step 16] loss=0.9898
[Checkpoint] Saved: test_outputs/checkpoints\best_val.pt
[Validation] New best model saved → test_outputs/checkpoints\best_val.pt
[Checkpoint] Saved: test_outputs/checkpoints/epoch_4.pt


Epoch 5/10:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.9963]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 41.90it/s]
Epoch 5/10:  25%|██▌       | 1/4 [00:00<00:00,  3.11it/s, loss=0.9783]

Decoding...
[17] loss=0.9963, snr=0.304, cos=0.027, div=0.003
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 75.58it/s]


Decoding...
[18] loss=0.9783, snr=0.751, cos=0.048, div=0.002
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 50.81it/s]


Decoding...


Epoch 5/10:  50%|█████     | 2/4 [00:01<00:01,  1.22it/s, loss=1.0075]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 77.24it/s]
Epoch 5/10:  75%|███████▌  | 3/4 [00:01<00:00,  1.83it/s, loss=1.0075]

Decoding...
[19] loss=1.0075, snr=0.804, cos=0.082, div=0.003


Epoch 5/10:  75%|███████▌  | 3/4 [00:01<00:00,  1.83it/s, loss=1.0316]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 31.03it/s]


Decoding...
[20] loss=1.0316, snr=0.664, cos=0.074, div=0.002
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 44.05it/s]


Decoding...


                                                                      

[Validation @ step 20] loss=1.0254
[Checkpoint] Saved: test_outputs/checkpoints/epoch_5.pt


Epoch 6/10:   0%|          | 0/4 [00:00<?, ?it/s, loss=1.0203]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 60.42it/s]
Epoch 6/10:  25%|██▌       | 1/4 [00:00<00:00,  3.63it/s, loss=1.0203]

Decoding...
[21] loss=1.0203, snr=0.770, cos=0.086, div=0.003


Epoch 6/10:  25%|██▌       | 1/4 [00:00<00:00,  3.63it/s, loss=0.9639]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 71.88it/s]


Decoding...
[22] loss=0.9639, snr=0.466, cos=0.079, div=0.002
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 61.78it/s]


Decoding...


Epoch 6/10:  50%|█████     | 2/4 [00:01<00:01,  1.39it/s, loss=0.9920]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 87.47it/s]
Epoch 6/10:  75%|███████▌  | 3/4 [00:01<00:00,  2.09it/s, loss=0.9920]

Decoding...
[23] loss=0.9920, snr=0.557, cos=0.122, div=0.003


Epoch 6/10:  75%|███████▌  | 3/4 [00:01<00:00,  2.09it/s, loss=0.9866]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 39.11it/s]


Decoding...
[24] loss=0.9866, snr=0.560, cos=0.083, div=0.002
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 64.25it/s]

Decoding...



                                                                      

[Validation @ step 24] loss=0.9819
[Checkpoint] Saved: test_outputs/checkpoints\best_val.pt
[Validation] New best model saved → test_outputs/checkpoints\best_val.pt
[Checkpoint] Saved: test_outputs/checkpoints/epoch_6.pt


Epoch 7/10:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.9730]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 86.33it/s]
Epoch 7/10:  25%|██▌       | 1/4 [00:00<00:00,  4.89it/s, loss=0.9726]

Decoding...
[25] loss=0.9730, snr=0.763, cos=0.153, div=0.002
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 84.43it/s]


Decoding...
[26] loss=0.9726, snr=0.401, cos=0.144, div=0.002
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 56.53it/s]


Decoding...


Epoch 7/10:  50%|█████     | 2/4 [00:01<00:01,  1.33it/s, loss=1.0138]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 76.39it/s]
Epoch 7/10:  75%|███████▌  | 3/4 [00:01<00:00,  1.94it/s, loss=1.0138]

Decoding...
[27] loss=1.0138, snr=0.423, cos=0.164, div=0.003


Epoch 7/10:  75%|███████▌  | 3/4 [00:01<00:00,  1.94it/s, loss=0.9606]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 27.86it/s]


Decoding...
[28] loss=0.9606, snr=0.645, cos=0.154, div=0.003
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 47.70it/s]


Decoding...


                                                                      

[Validation @ step 28] loss=0.9501
[Checkpoint] Saved: test_outputs/checkpoints\best_val.pt
[Validation] New best model saved → test_outputs/checkpoints\best_val.pt
[Checkpoint] Saved: test_outputs/checkpoints/epoch_7.pt


Epoch 8/10:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.9898]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 71.68it/s]
Epoch 8/10:  25%|██▌       | 1/4 [00:00<00:00,  4.36it/s, loss=0.9898]

Decoding...
[29] loss=0.9898, snr=0.398, cos=0.149, div=0.003


Epoch 8/10:  25%|██▌       | 1/4 [00:00<00:00,  4.36it/s, loss=0.9247]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 69.55it/s]


Decoding...
[30] loss=0.9247, snr=0.714, cos=0.217, div=0.002
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 67.35it/s]

Decoding...



Epoch 8/10:  50%|█████     | 2/4 [00:01<00:01,  1.44it/s, loss=0.9844]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 68.33it/s]
Epoch 8/10:  75%|███████▌  | 3/4 [00:01<00:00,  2.01it/s, loss=0.9844]

Decoding...
[31] loss=0.9844, snr=0.607, cos=0.232, div=0.002


Epoch 8/10:  75%|███████▌  | 3/4 [00:01<00:00,  2.01it/s, loss=0.9904]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 62.68it/s]


Decoding...
[32] loss=0.9904, snr=0.625, cos=0.215, div=0.003
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 59.52it/s]


Decoding...


                                                                      

[Validation @ step 32] loss=0.9520
[Checkpoint] Saved: test_outputs/checkpoints/epoch_8.pt


Epoch 9/10:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.9707]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 91.91it/s]
Epoch 9/10:  25%|██▌       | 1/4 [00:00<00:00,  5.33it/s, loss=0.9707]

Decoding...
[33] loss=0.9707, snr=0.667, cos=0.264, div=0.002


Epoch 9/10:  25%|██▌       | 1/4 [00:00<00:00,  5.33it/s, loss=0.9798]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 58.61it/s]


Decoding...
[34] loss=0.9798, snr=0.762, cos=0.309, div=0.003
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 55.33it/s]


Decoding...


Epoch 9/10:  50%|█████     | 2/4 [00:01<00:01,  1.37it/s, loss=0.9417]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 42.07it/s]
Epoch 9/10:  75%|███████▌  | 3/4 [00:01<00:00,  1.84it/s, loss=0.9218]

Decoding...
[35] loss=0.9417, snr=0.621, cos=0.316, div=0.002
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 84.82it/s]


Decoding...
[36] loss=0.9218, snr=0.910, cos=0.385, div=0.002
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 64.73it/s]


Decoding...


                                                                      

[Validation @ step 36] loss=0.9110
[Checkpoint] Saved: test_outputs/checkpoints\best_val.pt
[Validation] New best model saved → test_outputs/checkpoints\best_val.pt
[Checkpoint] Saved: test_outputs/checkpoints/epoch_9.pt


Epoch 10/10:   0%|          | 0/4 [00:00<?, ?it/s, loss=0.9395]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 95.44it/s]
Epoch 10/10:  25%|██▌       | 1/4 [00:00<00:00,  5.17it/s, loss=0.9692]

Decoding...
[37] loss=0.9395, snr=0.740, cos=0.354, div=0.002
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 72.19it/s]


Decoding...
[38] loss=0.9692, snr=0.514, cos=0.332, div=0.002
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 50.35it/s]


Decoding...


Epoch 10/10:  50%|█████     | 2/4 [00:01<00:01,  1.49it/s, loss=0.9366]

Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 80.74it/s]
Epoch 10/10:  75%|███████▌  | 3/4 [00:01<00:00,  2.16it/s, loss=0.9317]

Decoding...
[39] loss=0.9366, snr=0.600, cos=0.353, div=0.003
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 70.90it/s]


Decoding...
[40] loss=0.9317, snr=0.333, cos=0.306, div=0.002
Generating...


Diffusion sampling: 100%|██████████| 10/10 [00:00<00:00, 53.57it/s]


Decoding...


                                                                       

[Validation @ step 40] loss=0.9771
[Checkpoint] Saved: test_outputs/checkpoints/epoch_10.pt

Training complete.
Artifacts saved in: c:\Users\Hagai.LAPTOP-QAG9263N\Desktop\Thesis\repositories\ImagiNav\notebooks\test_outputs
Metric log entries: 40
Example metrics: {'step': 1, 'loss': 1.0782171487808228, 'snr': 0.7807475328445435, 'cosine': -0.0346669927239418, 'grad_norm': 1.142073154449463, 'diversity': 0.002451847540214658}


