### Encoder(P2P) --> Diffusion(DDPM) --> Decoder(P2P)

In [1]:
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
from datasets import load_dataset
import matplotlib.pyplot as plt
import wandb
from diffusers import DDPMScheduler
import numpy as np
from tqdm import tqdm
from torchinfo import summary
import torch.cuda.amp as amp


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = {
    "project_name": "Pix2Pix_Diffusion_Pipeline",
    "device": "cuda:1" if torch.cuda.is_available() else "cpu",
    "batch_size": 512,
    "timesteps": 1000,
    "embedding_dim": 512,
    "time_emb_dim": 256,
    "learning_rate": 2e-4,
    "num_epochs": 50,
    "save_checkpoint_interval": 5,
    "diffusion_loss_weight": 0.65,
    "latent_loss_weight": 0.1,
    "sample_interval": 10,
    "val_interval": 10,
    "use_mixed_precision": True,
    "logging": {
        "use_wandb": True,
        "sample_dir": "./outputs/pipeline_samples",
        "checkpoint_dir": "./outputs/pipeline_checkpoints",
        "plot_dir": "./outputs/pipeline_plots"
    }
}

for dir_path in [config["logging"]["sample_dir"], config["logging"]["checkpoint_dir"], config["logging"]["plot_dir"]]:
    try:
        os.makedirs(dir_path, exist_ok=True)
        print(f"📁 Created directory: {dir_path}")
    except Exception as e:
        print(f"⚠️ Failed to create directory {dir_path}: {e}")


📁 Created directory: ./outputs/pipeline_samples
📁 Created directory: ./outputs/pipeline_checkpoints
📁 Created directory: ./outputs/pipeline_plots


In [3]:
from diffusers import DDPMScheduler
import math

class DiffusionAutoencoder(nn.Module):
    """Diffusion Autoencoder combining an encoder, U-Net, and decoder for image reconstruction."""
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.device = torch.device(config["device"])
        self.embedding_dim = config["embedding_dim"]
        self.time_emb_dim = config["time_emb_dim"]
        self.timesteps = config["timesteps"]
        
        self.encoder = self.Encoder(in_channels=3, features=64, embedding_dim=self.embedding_dim)
        self.unet = self.ImprovedMLP_UNet(embedding_dim=self.embedding_dim, time_dim=self.time_emb_dim)
        self.decoder = self.Decoder(out_channels=3, features=64, embedding_dim=self.embedding_dim)
        self.scheduler = self.DiffusionScheduler(timesteps=self.timesteps)
        
        self.to(self.device)

    class SinusoidalPositionEmbeddings(nn.Module):
        """Generates sinusoidal position embeddings for timesteps."""
        def __init__(self, dim):
            super().__init__()
            self.dim = dim

        def forward(self, time):
            device = time.device
            half_dim = self.dim // 2
            embeddings = math.log(10000) / (half_dim - 1)
            embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
            embeddings = time[:, None] * embeddings[None, :]
            embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
            if self.dim % 2 == 1:
                embeddings = F.pad(embeddings, (0, 1, 0, 0))
            return embeddings

    class Encoder(nn.Module):
        """Encodes input images into a latent representation."""
        def __init__(self, in_channels, features, embedding_dim):
            super().__init__()
            self.initial = nn.Sequential(
                nn.Conv2d(in_channels, features, kernel_size=4, stride=2, padding=1, bias=False),
                nn.LeakyReLU(0.2, inplace=True)
            )
            self.down1 = self._block(features, features * 2)
            self.down2 = self._block(features * 2, features * 4)
            self.down3 = self._block(features * 4, features * 8)
            self.down4 = self._block(features * 8, features * 8)
            self.final = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten(),
                nn.Linear(features * 8, embedding_dim)
            )

        def _block(self, in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU(0.2, inplace=True)
            )

        def forward(self, x):
            d1 = self.initial(x)
            d2 = self.down1(d1)
            d3 = self.down2(d2)
            d4 = self.down3(d3)
            d5 = self.down4(d4)
            embedding = self.final(d5)
            return embedding, [d1, d2, d3, d4, d5]

    class ResidualBlock(nn.Module):
        """Residual block for U-Net with timestep embeddings."""
        def __init__(self, in_channels, out_channels, time_emb_dim, dropout=0.1):
            super().__init__()
            self.time_mlp = nn.Sequential(
                nn.Linear(time_emb_dim, out_channels),
                nn.GELU()
            )
            self.block1 = nn.Sequential(
                nn.Linear(in_channels, out_channels),
                nn.GELU(),
                nn.Dropout(dropout)
            )
            self.block2 = nn.Sequential(
                nn.Linear(out_channels, out_channels),
                nn.GELU(),
                nn.Dropout(dropout)
            )
            self.residual_conv = nn.Linear(in_channels, out_channels) if in_channels != out_channels else nn.Identity()
            self.layer_norm = nn.LayerNorm(out_channels)

        def forward(self, x, t):
            h = self.block1(x)
            time_emb = self.time_mlp(t)
            h = h + time_emb
            h = self.block2(h)
            return self.layer_norm(h + self.residual_conv(x))

    class ImprovedMLP_UNet(nn.Module):
        """U-Net model for denoising latent representations with timestep conditioning."""
        def __init__(self, embedding_dim, hidden_dim=1024, time_dim=256, dropout=0.1):
            super().__init__()
            self.time_mlp = nn.Sequential(
                DiffusionAutoencoder.SinusoidalPositionEmbeddings(time_dim),
                nn.Linear(time_dim, time_dim * 2),
                nn.GELU(),
                nn.Linear(time_dim * 2, time_dim),
            )
            self.down1 = DiffusionAutoencoder.ResidualBlock(embedding_dim, hidden_dim, time_dim, dropout)
            self.down2 = DiffusionAutoencoder.ResidualBlock(hidden_dim, hidden_dim, time_dim, dropout)
            self.mid = DiffusionAutoencoder.ResidualBlock(hidden_dim, hidden_dim, time_dim, dropout)
            self.up1 = DiffusionAutoencoder.ResidualBlock(hidden_dim * 2, hidden_dim, time_dim, dropout)
            self.up2 = DiffusionAutoencoder.ResidualBlock(hidden_dim * 2, embedding_dim, time_dim, dropout)
            self.final = nn.Sequential(
                nn.Linear(embedding_dim, embedding_dim),
                nn.Tanh()
            )

        def forward(self, x, t):
            time_emb = self.time_mlp(t)
            down1 = self.down1(x, time_emb)
            down2 = self.down2(down1, time_emb)
            mid = self.mid(down2, time_emb)
            up1 = self.up1(torch.cat([mid, down2], dim=1), time_emb)
            up2 = self.up2(torch.cat([up1, down1], dim=1), time_emb)
            return self.final(up2)

    class Decoder(nn.Module):
        """Decodes latent representations back to images."""
        def __init__(self, out_channels, features, embedding_dim):
            super().__init__()
            self.project = nn.Sequential(
                nn.Linear(embedding_dim, 512 * 2 * 2),
                nn.ReLU(inplace=True)
            )
            self.up1 = self._block(512, 512)
            self.up2 = self._block(512 + 512, 256)
            self.up3 = self._block(256 + 256, 128)
            self.up4 = self._block(128 + 128, 64)
            self.final = nn.Sequential(
                nn.ConvTranspose2d(64 + 64, out_channels, kernel_size=4, stride=2, padding=1),
                nn.Tanh()
            )

        def _block(self, in_channels, out_channels):
            return nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )

        def forward(self, x, encoder_features=None):
            x = self.project(x)
            x = x.view(-1, 512, 2, 2)
            x1 = self.up1(x)
            if encoder_features:
                x2 = self.up2(torch.cat([x1, encoder_features[3]], dim=1))
                x3 = self.up3(torch.cat([x2, encoder_features[2]], dim=1))
                x4 = self.up4(torch.cat([x3, encoder_features[1]], dim=1))
                output = self.final(torch.cat([x4, encoder_features[0]], dim=1))
            else:
                output = self.final(x1)
            return output

    class DiffusionScheduler:
        """Manages noise scheduling for diffusion process using DDPMScheduler."""
        def __init__(self, timesteps):
            self.scheduler = DDPMScheduler(num_train_timesteps=timesteps)

        def add_noise(self, x, t):
            noise = torch.randn_like(x)
            noisy_x = self.scheduler.add_noise(x, noise, t)
            return noisy_x, noise

    def forward(self, img, t):
        """Forward pass through encoder, U-Net, and decoder."""
        latent, encoder_features = self.encoder(img)
        noisy_latent, noise = self.scheduler.add_noise(latent, t)
        denoised_latent = self.unet(noisy_latent, t)
        reconstructed_img = self.decoder(denoised_latent, encoder_features)
        return reconstructed_img, noise, latent, denoised_latent



# purifier = DiffusionAutoencoder(config).to(config['device'])

In [4]:

def sample_images(model, loader, epoch, config, prefix="train", num_samples=4):
    """Generate and save sample images, logging to WandB."""
    model.eval()
    sample_dir = config["logging"]["sample_dir"]
    wandb_images = []
    with torch.no_grad():
        for batch in loader:
            images = batch['image'][:num_samples].to(model.device)
            t = torch.randint(0, model.timesteps, (num_samples,), device=model.device).long()
            output, _, _, _ = model(images, t)
            output = output.cpu().numpy().transpose(0, 2, 3, 1)
            images = images.cpu().numpy().transpose(0, 2, 3, 1)

            output = (output * 0.5 + 0.5).clip(0, 1)
            images = (images * 0.5 + 0.5).clip(0, 1)

            for i in range(num_samples):
                fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(6, 3))
                ax1.imshow(images[i])
                ax1.set_title("Input")
                ax1.axis("off")
                ax2.imshow(output[i])
                ax2.set_title("Reconstructed")
                ax2.axis("off")

                sample_path = os.path.join(sample_dir, f"{prefix}_epoch_{epoch}_sample_{i}.png")
                plt.savefig(sample_path, bbox_inches="tight")
                plt.close(fig)
                print(f"📸 Saved {prefix} sample: {sample_path}")

                # if config["logging"]["use_wandb"]:
                #     wandb_images.append(wandb.Image(sample_path, caption=f"{prefix.capitalize()} Epoch {epoch} Sample {i}"))

            # if config["logging"]["use_wandb"] and wandb_images:
            #     wandb.log({f"{prefix}_samples": wandb_images}, commit=False)
            break


In [10]:

def validate_model(model, val_loader, config, epoch):
    """Evaluate model on validation set and return loss metrics."""
    model.eval()
    total_loss = 0
    total_recon_loss = 0
    total_diffusion_loss = 0
    total_latent_loss = 0
    num_batches = 0

    with torch.no_grad():
        for batch in val_loader:
            images = batch['image'].to(model.device)
            t = torch.randint(0, model.timesteps, (images.size(0),), device=model.device).long()
            output, noise, latent, denoised_latent = model(images, t)

            recon_loss = F.mse_loss(output, batch['target'].to(model.device))
            diffusion_loss = F.mse_loss(denoised_latent, latent)
            latent_loss = F.mse_loss(denoised_latent, latent.detach())
            loss = recon_loss + config["diffusion_loss_weight"] * diffusion_loss + config["latent_loss_weight"] * latent_loss

            total_loss += loss.item()
            total_recon_loss += recon_loss.item()
            total_diffusion_loss += diffusion_loss.item()
            total_latent_loss += latent_loss.item()
            num_batches += 1

    avg_loss = total_loss / num_batches
    avg_recon_loss = total_recon_loss / num_batches
    avg_diffusion_loss = total_diffusion_loss / num_batches
    avg_latent_loss = total_latent_loss / num_batches

    metrics = {
        "val_loss": avg_loss,
        "val_recon_loss": avg_recon_loss,
        "val_diffusion_loss": avg_diffusion_loss,
        "val_latent_loss": avg_latent_loss
    }
    return metrics



In [11]:
def train_model(model, train_loader, val_loader, config, optimizer=None):
    """Train the diffusion autoencoder with mixed precision and validation."""
    if optimizer is None:
        optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])

    scaler = amp.GradScaler(enabled=config["use_mixed_precision"])
    
    # if config["logging"]["use_wandb"]:
        # try:
        #     wandb.init(project=config["project_name"], config=config)
        #     print("✅ WandB initialized")
        # except Exception as e:
        #     print(f"⚠️ WandB initialization failed: {e}")
        #     config["logging"]["use_wandb"] = False

    for epoch in tqdm(range(config["num_epochs"]), desc="Epochs", colour="green"):
        model.train()
        total_loss = 0
        total_recon_loss = 0
        total_diffusion_loss = 0
        total_latent_loss = 0
        num_batches = 0

        for batch in tqdm(train_loader, desc="Batches", leave=False):
            images = batch['image'].to(model.device)
            t = torch.randint(0, model.timesteps, (images.size(0),), device=model.device).long()

            optimizer.zero_grad()
            with amp.autocast(enabled=config["use_mixed_precision"]):
                output, noise, latent, denoised_latent = model(images, t)
                recon_loss = F.mse_loss(output, batch['target'].to(model.device))
                diffusion_loss = F.mse_loss(denoised_latent, latent)
                latent_loss = F.mse_loss(denoised_latent, latent.detach())
                loss = recon_loss + config["diffusion_loss_weight"] * diffusion_loss + config["latent_loss_weight"] * latent_loss

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            total_recon_loss += recon_loss.item()
            total_diffusion_loss += diffusion_loss.item()
            total_latent_loss += latent_loss.item()
            num_batches += 1

            # if config["logging"]["use_wandb"]:
            #     wandb.log({
            #         "train_loss": loss.item(),
            #         "train_recon_loss": recon_loss.item(),
            #         "train_diffusion_loss": diffusion_loss.item(),
            #         "train_latent_loss": latent_loss.item()
            #     }, commit=False)

        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch + 1}/{config['num_epochs']}, "
              f"Train Loss: {avg_loss:.4f}, "
              f"Recon: {total_recon_loss / num_batches:.4f}, "
              f"Diffusion: {total_diffusion_loss / num_batches:.4f}, "
              f"Latent: {total_latent_loss / num_batches:.4f}")

        if (epoch + 1) % config["val_interval"] == 0:
            val_metrics = validate_model(model, val_loader, config, epoch + 1)
            print(f"Validation - Loss: {val_metrics['val_loss']:.4f}, "
                  f"Recon: {val_metrics['val_recon_loss']:.4f}, "
                  f"Diffusion: {val_metrics['val_diffusion_loss']:.4f}, "
                  f"Latent: {val_metrics['val_latent_loss']:.4f}")
            # if config["logging"]["use_wandb"]:
            #     wandb.log(val_metrics, commit=False)

        if (epoch + 1) % config["sample_interval"] == 0:
            sample_images(model, train_loader, epoch + 1, config, prefix="train")
            sample_images(model, val_loader, epoch + 1, config, prefix="val")

        if (epoch + 1) % config["save_checkpoint_interval"] == 0:
            checkpoint_path = os.path.join(config["logging"]["checkpoint_dir"], f"model_epoch_{epoch + 1}.pth")
            torch.save(model.state_dict(), checkpoint_path)
            print(f"💾 Saved checkpoint: {checkpoint_path}")

        # if config["logging"]["use_wandb"]:
        #     wandb.log({"epoch": epoch + 1, "avg_train_loss": avg_loss}, commit=True)

    final_checkpoint_path = os.path.join(config["logging"]["checkpoint_dir"], "model_final.pth")
    torch.save(model.state_dict(), final_checkpoint_path)
    print(f"💾 Saved final checkpoint: {final_checkpoint_path}")
    print("✅ Training completed!")


In [22]:
class ImagenetDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item['image']
        if self.transform:
            image = self.transform(image)
        noised_image = image + torch.randn_like(image) * torch.rand(1).item()
        return {'image': noised_image, 'target': image}

transform = transforms.Compose([
    transforms.Lambda(lambda img: img.convert("RGB")),
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2302, 0.2265, 0.2262)),
])

dataset = load_dataset("zh-plus/tiny-imagenet", split="train")
dataset = ImagenetDataset(dataset, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False, num_workers=4)
# from torch.utils.data import Dataset
# class ImageNetDataset(Dataset):
#     def __init__(self, root_dir, transform=None):
#         self.root_dir = root_dir
#         self.transform = transform
#         self.image_paths = [os.path.join(root, file) for root, _, files in os.walk(root_dir) for file in files if file.lower().endswith(('.png', '.jpg', '.jpeg'))]

#     def __len__(self):
#         return len(self.image_paths)

#     def __getitem__(self, idx):
#         img_path = self.image_paths[idx]
#         image = Image.open(img_path).convert("RGB")
#         if self.transform:
#             image = self.transform(image)

#         noised_image = image+torch.randn_like(image)
#         return {'image':noised_image, 'target':image}


# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2302, 0.2265, 0.2262)),
# ])

# imagenet_dir = "./val"
# dataset = ImageNetDataset(root_dir=imagenet_dir, transform=transform)
# train_size = int(0.8 * len(dataset))
# val_size = len(dataset) - train_size
# train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True, num_workers=4)
# val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False, num_workers=4)


In [23]:
model = DiffusionAutoencoder(config)

In [24]:
from PIL import Image
train_model(model, train_loader, val_loader, config)

  scaler = amp.GradScaler(enabled=config["use_mixed_precision"])
  with amp.autocast(enabled=config["use_mixed_precision"]):
Epochs:   2%|[32m▏         [0m| 1/50 [00:09<07:59,  9.78s/it]

Epoch 1/50, Train Loss: 0.4991, Recon: 0.4912, Diffusion: 0.0104, Latent: 0.0104


Epochs:   4%|[32m▍         [0m| 2/50 [00:18<07:07,  8.90s/it]

Epoch 2/50, Train Loss: 0.3342, Recon: 0.3339, Diffusion: 0.0004, Latent: 0.0004


Epochs:   6%|[32m▌         [0m| 3/50 [00:26<06:40,  8.53s/it]

Epoch 3/50, Train Loss: 0.3180, Recon: 0.3178, Diffusion: 0.0002, Latent: 0.0002


Epochs:   8%|[32m▊         [0m| 4/50 [00:34<06:26,  8.41s/it]

Epoch 4/50, Train Loss: 0.3098, Recon: 0.3096, Diffusion: 0.0002, Latent: 0.0002




Epoch 5/50, Train Loss: 0.3050, Recon: 0.3049, Diffusion: 0.0001, Latent: 0.0001


Epochs:  10%|[32m█         [0m| 5/50 [00:44<06:39,  8.87s/it]

💾 Saved checkpoint: ./outputs/pipeline_checkpoints/model_epoch_5.pth


Epochs:  12%|[32m█▏        [0m| 6/50 [00:52<06:19,  8.63s/it]

Epoch 6/50, Train Loss: 0.3021, Recon: 0.3021, Diffusion: 0.0001, Latent: 0.0001


Epochs:  14%|[32m█▍        [0m| 7/50 [01:00<06:00,  8.39s/it]

Epoch 7/50, Train Loss: 0.2995, Recon: 0.2995, Diffusion: 0.0001, Latent: 0.0001


Epochs:  16%|[32m█▌        [0m| 8/50 [01:08<05:47,  8.28s/it]

Epoch 8/50, Train Loss: 0.2974, Recon: 0.2974, Diffusion: 0.0001, Latent: 0.0001


Epochs:  18%|[32m█▊        [0m| 9/50 [01:17<05:54,  8.64s/it]

Epoch 9/50, Train Loss: 0.2960, Recon: 0.2960, Diffusion: 0.0001, Latent: 0.0001




Epoch 10/50, Train Loss: 0.2946, Recon: 0.2945, Diffusion: 0.0001, Latent: 0.0001
Validation - Loss: 0.2946, Recon: 0.2946, Diffusion: 0.0000, Latent: 0.0000
📸 Saved train sample: ./outputs/pipeline_samples/train_epoch_10_sample_0.png
📸 Saved train sample: ./outputs/pipeline_samples/train_epoch_10_sample_1.png
📸 Saved train sample: ./outputs/pipeline_samples/train_epoch_10_sample_2.png
📸 Saved train sample: ./outputs/pipeline_samples/train_epoch_10_sample_3.png
📸 Saved val sample: ./outputs/pipeline_samples/val_epoch_10_sample_0.png
📸 Saved val sample: ./outputs/pipeline_samples/val_epoch_10_sample_1.png
📸 Saved val sample: ./outputs/pipeline_samples/val_epoch_10_sample_2.png
📸 Saved val sample: ./outputs/pipeline_samples/val_epoch_10_sample_3.png


Epochs:  20%|[32m██        [0m| 10/50 [01:29<06:31,  9.78s/it]

💾 Saved checkpoint: ./outputs/pipeline_checkpoints/model_epoch_10.pth


Epochs:  22%|[32m██▏       [0m| 11/50 [01:37<05:59,  9.23s/it]

Epoch 11/50, Train Loss: 0.2934, Recon: 0.2933, Diffusion: 0.0000, Latent: 0.0000


Epochs:  24%|[32m██▍       [0m| 12/50 [01:45<05:37,  8.88s/it]

Epoch 12/50, Train Loss: 0.2922, Recon: 0.2921, Diffusion: 0.0000, Latent: 0.0000


Epochs:  26%|[32m██▌       [0m| 13/50 [01:55<05:34,  9.03s/it]

Epoch 13/50, Train Loss: 0.2910, Recon: 0.2909, Diffusion: 0.0000, Latent: 0.0000


Epochs:  28%|[32m██▊       [0m| 14/50 [02:03<05:14,  8.73s/it]

Epoch 14/50, Train Loss: 0.2906, Recon: 0.2906, Diffusion: 0.0000, Latent: 0.0000




Epoch 15/50, Train Loss: 0.2896, Recon: 0.2896, Diffusion: 0.0000, Latent: 0.0000


Epochs:  30%|[32m███       [0m| 15/50 [02:11<05:02,  8.64s/it]

💾 Saved checkpoint: ./outputs/pipeline_checkpoints/model_epoch_15.pth


Epochs:  32%|[32m███▏      [0m| 16/50 [02:20<04:49,  8.50s/it]

Epoch 16/50, Train Loss: 0.2888, Recon: 0.2887, Diffusion: 0.0000, Latent: 0.0000


Epochs:  34%|[32m███▍      [0m| 17/50 [02:29<04:51,  8.84s/it]

Epoch 17/50, Train Loss: 0.2881, Recon: 0.2881, Diffusion: 0.0000, Latent: 0.0000


Epochs:  36%|[32m███▌      [0m| 18/50 [02:37<04:35,  8.62s/it]

Epoch 18/50, Train Loss: 0.2875, Recon: 0.2875, Diffusion: 0.0000, Latent: 0.0000


Epochs:  38%|[32m███▊      [0m| 19/50 [02:45<04:22,  8.47s/it]

Epoch 19/50, Train Loss: 0.2872, Recon: 0.2871, Diffusion: 0.0000, Latent: 0.0000




Epoch 20/50, Train Loss: 0.2863, Recon: 0.2863, Diffusion: 0.0000, Latent: 0.0000
Validation - Loss: 0.2865, Recon: 0.2865, Diffusion: 0.0000, Latent: 0.0000
📸 Saved train sample: ./outputs/pipeline_samples/train_epoch_20_sample_0.png
📸 Saved train sample: ./outputs/pipeline_samples/train_epoch_20_sample_1.png
📸 Saved train sample: ./outputs/pipeline_samples/train_epoch_20_sample_2.png
📸 Saved train sample: ./outputs/pipeline_samples/train_epoch_20_sample_3.png
📸 Saved val sample: ./outputs/pipeline_samples/val_epoch_20_sample_0.png
📸 Saved val sample: ./outputs/pipeline_samples/val_epoch_20_sample_1.png
📸 Saved val sample: ./outputs/pipeline_samples/val_epoch_20_sample_2.png
📸 Saved val sample: ./outputs/pipeline_samples/val_epoch_20_sample_3.png


Epochs:  40%|[32m████      [0m| 20/50 [02:59<05:01, 10.04s/it]

💾 Saved checkpoint: ./outputs/pipeline_checkpoints/model_epoch_20.pth


Epochs:  42%|[32m████▏     [0m| 21/50 [03:07<04:33,  9.44s/it]

Epoch 21/50, Train Loss: 0.2860, Recon: 0.2860, Diffusion: 0.0000, Latent: 0.0000


Epochs:  44%|[32m████▍     [0m| 22/50 [03:15<04:12,  9.03s/it]

Epoch 22/50, Train Loss: 0.2852, Recon: 0.2852, Diffusion: 0.0000, Latent: 0.0000


Epochs:  46%|[32m████▌     [0m| 23/50 [03:23<03:56,  8.76s/it]

Epoch 23/50, Train Loss: 0.2849, Recon: 0.2849, Diffusion: 0.0000, Latent: 0.0000


Epochs:  48%|[32m████▊     [0m| 24/50 [03:33<03:52,  8.92s/it]

Epoch 24/50, Train Loss: 0.2846, Recon: 0.2846, Diffusion: 0.0000, Latent: 0.0000




Epoch 25/50, Train Loss: 0.2843, Recon: 0.2842, Diffusion: 0.0000, Latent: 0.0000


Epochs:  50%|[32m█████     [0m| 25/50 [03:41<03:38,  8.76s/it]

💾 Saved checkpoint: ./outputs/pipeline_checkpoints/model_epoch_25.pth


Epochs:  52%|[32m█████▏    [0m| 26/50 [03:49<03:26,  8.58s/it]

Epoch 26/50, Train Loss: 0.2838, Recon: 0.2838, Diffusion: 0.0000, Latent: 0.0000


Epochs:  54%|[32m█████▍    [0m| 27/50 [03:57<03:13,  8.41s/it]

Epoch 27/50, Train Loss: 0.2837, Recon: 0.2837, Diffusion: 0.0000, Latent: 0.0000


Epochs:  56%|[32m█████▌    [0m| 28/50 [04:07<03:12,  8.76s/it]

Epoch 28/50, Train Loss: 0.2830, Recon: 0.2830, Diffusion: 0.0000, Latent: 0.0000


Epochs:  58%|[32m█████▊    [0m| 29/50 [04:15<03:01,  8.62s/it]

Epoch 29/50, Train Loss: 0.2830, Recon: 0.2830, Diffusion: 0.0000, Latent: 0.0000




Epoch 30/50, Train Loss: 0.2827, Recon: 0.2827, Diffusion: 0.0000, Latent: 0.0000
Validation - Loss: 0.2832, Recon: 0.2832, Diffusion: 0.0000, Latent: 0.0000
📸 Saved train sample: ./outputs/pipeline_samples/train_epoch_30_sample_0.png
📸 Saved train sample: ./outputs/pipeline_samples/train_epoch_30_sample_1.png
📸 Saved train sample: ./outputs/pipeline_samples/train_epoch_30_sample_2.png
📸 Saved train sample: ./outputs/pipeline_samples/train_epoch_30_sample_3.png
📸 Saved val sample: ./outputs/pipeline_samples/val_epoch_30_sample_0.png
📸 Saved val sample: ./outputs/pipeline_samples/val_epoch_30_sample_1.png
📸 Saved val sample: ./outputs/pipeline_samples/val_epoch_30_sample_2.png
📸 Saved val sample: ./outputs/pipeline_samples/val_epoch_30_sample_3.png


Epochs:  60%|[32m██████    [0m| 30/50 [04:28<03:15,  9.79s/it]

💾 Saved checkpoint: ./outputs/pipeline_checkpoints/model_epoch_30.pth


Epochs:  62%|[32m██████▏   [0m| 31/50 [04:37<03:05,  9.76s/it]

Epoch 31/50, Train Loss: 0.2822, Recon: 0.2822, Diffusion: 0.0000, Latent: 0.0000


Epochs:  64%|[32m██████▍   [0m| 32/50 [04:46<02:47,  9.32s/it]

Epoch 32/50, Train Loss: 0.2820, Recon: 0.2820, Diffusion: 0.0000, Latent: 0.0000


Epochs:  66%|[32m██████▌   [0m| 33/50 [04:54<02:33,  9.06s/it]

Epoch 33/50, Train Loss: 0.2818, Recon: 0.2818, Diffusion: 0.0000, Latent: 0.0000


Epochs:  68%|[32m██████▊   [0m| 34/50 [05:02<02:20,  8.77s/it]

Epoch 34/50, Train Loss: 0.2814, Recon: 0.2814, Diffusion: 0.0000, Latent: 0.0000




Epoch 35/50, Train Loss: 0.2811, Recon: 0.2811, Diffusion: 0.0000, Latent: 0.0000


Epochs:  70%|[32m███████   [0m| 35/50 [05:12<02:16,  9.08s/it]

💾 Saved checkpoint: ./outputs/pipeline_checkpoints/model_epoch_35.pth


Epochs:  72%|[32m███████▏  [0m| 36/50 [05:20<02:03,  8.80s/it]

Epoch 36/50, Train Loss: 0.2812, Recon: 0.2812, Diffusion: 0.0000, Latent: 0.0000


Epochs:  74%|[32m███████▍  [0m| 37/50 [05:28<01:51,  8.57s/it]

Epoch 37/50, Train Loss: 0.2808, Recon: 0.2808, Diffusion: 0.0000, Latent: 0.0000


Epochs:  76%|[32m███████▌  [0m| 38/50 [05:36<01:41,  8.43s/it]

Epoch 38/50, Train Loss: 0.2806, Recon: 0.2806, Diffusion: 0.0000, Latent: 0.0000


Epochs:  78%|[32m███████▊  [0m| 39/50 [05:46<01:35,  8.71s/it]

Epoch 39/50, Train Loss: 0.2807, Recon: 0.2807, Diffusion: 0.0000, Latent: 0.0000




Epoch 40/50, Train Loss: 0.2803, Recon: 0.2803, Diffusion: 0.0000, Latent: 0.0000
Validation - Loss: 0.2807, Recon: 0.2807, Diffusion: 0.0000, Latent: 0.0000
📸 Saved train sample: ./outputs/pipeline_samples/train_epoch_40_sample_0.png
📸 Saved train sample: ./outputs/pipeline_samples/train_epoch_40_sample_1.png
📸 Saved train sample: ./outputs/pipeline_samples/train_epoch_40_sample_2.png
📸 Saved train sample: ./outputs/pipeline_samples/train_epoch_40_sample_3.png
📸 Saved val sample: ./outputs/pipeline_samples/val_epoch_40_sample_0.png
📸 Saved val sample: ./outputs/pipeline_samples/val_epoch_40_sample_1.png
📸 Saved val sample: ./outputs/pipeline_samples/val_epoch_40_sample_2.png
📸 Saved val sample: ./outputs/pipeline_samples/val_epoch_40_sample_3.png


Epochs:  80%|[32m████████  [0m| 40/50 [05:58<01:37,  9.76s/it]

💾 Saved checkpoint: ./outputs/pipeline_checkpoints/model_epoch_40.pth


Epochs:  82%|[32m████████▏ [0m| 41/50 [06:06<01:23,  9.24s/it]

Epoch 41/50, Train Loss: 0.2801, Recon: 0.2801, Diffusion: 0.0000, Latent: 0.0000


Epochs:  84%|[32m████████▍ [0m| 42/50 [06:14<01:10,  8.87s/it]

Epoch 42/50, Train Loss: 0.2800, Recon: 0.2800, Diffusion: 0.0000, Latent: 0.0000


Epochs:  86%|[32m████████▌ [0m| 43/50 [06:23<01:03,  9.06s/it]

Epoch 43/50, Train Loss: 0.2798, Recon: 0.2798, Diffusion: 0.0000, Latent: 0.0000


Epochs:  88%|[32m████████▊ [0m| 44/50 [06:31<00:52,  8.75s/it]

Epoch 44/50, Train Loss: 0.2796, Recon: 0.2795, Diffusion: 0.0000, Latent: 0.0000




Epoch 45/50, Train Loss: 0.2794, Recon: 0.2794, Diffusion: 0.0000, Latent: 0.0000


Epochs:  90%|[32m█████████ [0m| 45/50 [06:40<00:43,  8.64s/it]

💾 Saved checkpoint: ./outputs/pipeline_checkpoints/model_epoch_45.pth


Epochs:  92%|[32m█████████▏[0m| 46/50 [06:48<00:33,  8.47s/it]

Epoch 46/50, Train Loss: 0.2793, Recon: 0.2793, Diffusion: 0.0000, Latent: 0.0000


Epochs:  94%|[32m█████████▍[0m| 47/50 [06:57<00:26,  8.73s/it]

Epoch 47/50, Train Loss: 0.2792, Recon: 0.2792, Diffusion: 0.0000, Latent: 0.0000


Epochs:  96%|[32m█████████▌[0m| 48/50 [07:05<00:17,  8.51s/it]

Epoch 48/50, Train Loss: 0.2790, Recon: 0.2790, Diffusion: 0.0000, Latent: 0.0000


Epochs:  98%|[32m█████████▊[0m| 49/50 [07:13<00:08,  8.38s/it]

Epoch 49/50, Train Loss: 0.2788, Recon: 0.2788, Diffusion: 0.0000, Latent: 0.0000




Epoch 50/50, Train Loss: 0.2787, Recon: 0.2787, Diffusion: 0.0000, Latent: 0.0000
Validation - Loss: 0.2790, Recon: 0.2790, Diffusion: 0.0000, Latent: 0.0000
📸 Saved train sample: ./outputs/pipeline_samples/train_epoch_50_sample_0.png
📸 Saved train sample: ./outputs/pipeline_samples/train_epoch_50_sample_1.png
📸 Saved train sample: ./outputs/pipeline_samples/train_epoch_50_sample_2.png
📸 Saved train sample: ./outputs/pipeline_samples/train_epoch_50_sample_3.png
📸 Saved val sample: ./outputs/pipeline_samples/val_epoch_50_sample_0.png
📸 Saved val sample: ./outputs/pipeline_samples/val_epoch_50_sample_1.png
📸 Saved val sample: ./outputs/pipeline_samples/val_epoch_50_sample_2.png
📸 Saved val sample: ./outputs/pipeline_samples/val_epoch_50_sample_3.png


Epochs: 100%|[32m██████████[0m| 50/50 [07:27<00:00,  8.94s/it]

💾 Saved checkpoint: ./outputs/pipeline_checkpoints/model_epoch_50.pth





💾 Saved final checkpoint: ./outputs/pipeline_checkpoints/model_final.pth
✅ Training completed!
