In [1]:
# ============================================================
# 1. Setup (run once)
# ============================================================
import torch
from torch.utils.data import Dataset, DataLoader
import sys, os
sys.path.append(r"C:\Users\Hagai.LAPTOP-QAG9263N\Desktop\Thesis\repositories\ImagiNav")

from modules.autoencoder import AutoEncoder
from training.autoencoder_trainer import AutoEncoderTrainer


# ============================================================
# 2. Dummy dataset
# ============================================================
class DummyDataset(Dataset):
    def __init__(self, length=32, shape=(3, 64, 64)):
        self.length = length
        self.shape = shape

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # return tuple to preserve batch dimension
        return (torch.rand(self.shape),)  # random data for testing


train_loader = DataLoader(DummyDataset(length=16), batch_size=4)
val_loader = DataLoader(DummyDataset(length=8), batch_size=4)


# ============================================================
# 3. Instantiate components
# ============================================================
device = "cuda" if torch.cuda.is_available() else "cpu"

autoencoder = AutoEncoder.from_shape(
    in_channels=3,
    out_channels=3,
    base_channels=16,
    latent_channels=3,
    image_size=64,
    latent_base=16,
    norm="batch",
    act="relu"
).to(device)

loss_fn = torch.nn.MSELoss()

trainer = AutoEncoderTrainer(
    autoencoder=autoencoder,
    loss_fn=loss_fn,
    epochs=5,
    log_interval=1,       # log every step
    sample_interval=2,    # save reconstructions
    eval_interval=1,      # validate every epoch
    output_dir="ae_test_outputs",
    ckpt_dir="ae_test_outputs/checkpoints",
)


# ============================================================
# 4. Run one short training cycle
# ============================================================
trainer.fit(train_loader, val_loader)


# ============================================================
# 5. Inspect results
# ============================================================
print("\nAutoencoder training complete.")
print("Artifacts saved in:", os.path.abspath(trainer.output_dir))
print("Metric log entries:", len(getattr(trainer, 'metric_log', [])))

if hasattr(trainer, "metric_log") and trainer.metric_log:
    print("Example metric entry:", trainer.metric_log[0])


Training autoencoder for 5 epochs on cpu
[Epoch 1] Step 1 | Loss: 0.093619
[Epoch 1] Step 2 | Loss: 0.092660
[Sample] Saved: ae_test_outputs\samples\sample_step_2.png
[Epoch 1] Step 3 | Loss: 0.092034
[Epoch 1] Step 4 | Loss: 0.092135
[Sample] Saved: ae_test_outputs\samples\sample_step_4.png
[Epoch 1] Average training loss: 0.092612
[Epoch 1] Validation loss: 0.083098
[Checkpoint] Saved: ae_test_outputs/checkpoints\ae_epoch_1.pt
[Plot] Updated: ae_test_outputs\train_val_curve.png
[Metrics] Updated: ae_test_outputs\metrics.json
[Epoch 2] Step 5 | Loss: 0.091514
[Epoch 2] Step 6 | Loss: 0.090474
[Sample] Saved: ae_test_outputs\samples\sample_step_6.png
[Epoch 2] Step 7 | Loss: 0.089356
[Epoch 2] Step 8 | Loss: 0.089698
[Sample] Saved: ae_test_outputs\samples\sample_step_8.png
[Epoch 2] Average training loss: 0.090261
[Epoch 2] Validation loss: 0.083367
[Checkpoint] Saved: ae_test_outputs/checkpoints\ae_epoch_2.pt
[Plot] Updated: ae_test_outputs\train_val_curve.png
[Metrics] Updated: ae_t