<a href="https://colab.research.google.com/github/Avniiii2606/Image-To-Image-Diffusion/blob/main/Diffusion_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import kagglehub


path = kagglehub.dataset_download("ashwingupta3012/human-faces")

print("Path to dataset files:", path)

path1 = kagglehub.dataset_download("elibooklover/victorian400")

print("Path to dataset files:", path1)

Downloading from https://www.kaggle.com/api/v1/datasets/download/ashwingupta3012/human-faces?dataset_version_number=1...


100%|██████████| 1.82G/1.82G [00:22<00:00, 85.0MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/ashwingupta3012/human-faces/versions/1
Downloading from https://www.kaggle.com/api/v1/datasets/download/elibooklover/victorian400?dataset_version_number=5...


100%|██████████| 484M/484M [00:04<00:00, 121MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/elibooklover/victorian400/versions/5


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
os.environ["WANDB_DISABLED"] = "true"

In [None]:
import torch
torch.cuda.empty_cache()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
import numpy as np
from PIL import Image
import os
from tqdm import tqdm
import wandb  # for experiment tracking
from pathlib import Path
import math

class FaceTranslationDataset(Dataset):
    """Custom dataset for paired face images (painting/statue -> real)"""
    def __init__(self, source_dir, target_dir, image_size=256):
        # Use Path.rglob to search for images recursively in subfolders
        self.source_paths = sorted(Path(source_dir).rglob('*.jpg'))
        self.target_paths = sorted(Path(target_dir).rglob('*.jpg'))

        # Print the number of images found
        print(f"Found {len(self.source_paths)} source images")
        print(f"Found {len(self.target_paths)} target images")

        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def __len__(self):
        # Ensure both source and target have the same number of images
        return min(len(self.source_paths), len(self.target_paths))

    def __getitem__(self, idx):
        source_img = Image.open(self.source_paths[idx]).convert('RGB')
        target_img = Image.open(self.target_paths[idx]).convert('RGB')

        return {
            'source': self.transform(source_img),
            'target': self.transform(target_img)
        }

class CrossAttention(nn.Module):
    """Fixed cross attention module for conditioning"""
    def __init__(self, channels, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        self.channels = channels
        self.scale = (channels // num_heads) ** -0.5

        # Fix: Initialize GroupNorm with the correct number of channels
        self.norm_x = nn.GroupNorm(1, channels)  # Use separate GroupNorms
        self.norm_context = nn.GroupNorm(1, channels)  # For the context with 'channels' channels

        self.to_q = nn.Conv2d(channels, channels, 1)
        self.to_k = nn.Conv2d(channels, channels, 1)  # Change input channels to 'channels'
        self.to_v = nn.Conv2d(channels, channels, 1)  # Change input channels to 'channels'
        self.to_out = nn.Conv2d(channels, channels, 1)  # Initialize the context norm with the number of channels in the input tensor x

        # Define context_proj here
        self.context_proj = nn.Conv2d(in_channels=3, out_channels=channels, kernel_size=1) # Assuming context has 3 channels

    def forward(self, x, context):
        batch, c, h, w = x.shape

        # Project the context to the desired number of channels before normalization
        context = self.context_proj(context)


        # Normalize inputs using separate GroupNorms
        x = self.norm_x(x)
        context = self.norm_context(context)

        # Create Q, K, V
        q = self.to_q(x)
        k = self.to_k(context)
        v = self.to_v(context)

        # Reshape for attention
        q = q.view(batch, self.num_heads, c // self.num_heads, h * w)
        k = k.view(batch, self.num_heads, c // self.num_heads, h * w)
        v = v.view(batch, self.num_heads, c // self.num_heads, h * w)

        # Transpose for attention computation
        q, k, v = map(lambda t: t.transpose(-2, -1), (q, k, v))

        # Attention
        attention = torch.softmax(torch.matmul(q, k.transpose(-2, -1)) * self.scale, dim=-1)
        out = torch.matmul(attention, v)

        # Reshape back
        out = out.transpose(-2, -1).reshape(batch, c, h, w)
        return self.to_out(out) + x

class ConditionalUNet(nn.Module):
    """Updated U-Net with fixed cross-attention"""
    def __init__(self, in_channels=3, out_channels=3, time_emb_dim=256):
        super().__init__()

        # Time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim * 2),
            nn.GELU(),
            nn.Linear(time_emb_dim * 2, time_emb_dim)
        )

        # Initial convolution
        self.conv_in = nn.Conv2d(in_channels, 64, 3, padding=1)

        # Encoder
        self.down1 = nn.ModuleList([
            nn.Conv2d(64, 128, 3, padding=1),
            CrossAttention(128),
            nn.Conv2d(128, 128, 3, padding=1)
        ])

        self.down2 = nn.ModuleList([
            nn.Conv2d(128, 256, 3, padding=1),
            CrossAttention(256),
            nn.Conv2d(256, 256, 3, padding=1)
        ])

        self.down3 = nn.ModuleList([
            nn.Conv2d(256, 512, 3, padding=1),
            CrossAttention(512),
            nn.Conv2d(512, 512, 3, padding=1)
        ])

        # Middle
        self.mid = nn.ModuleList([
            nn.Conv2d(512, 512, 3, padding=1),
            CrossAttention(512),
            nn.Conv2d(512, 512, 3, padding=1)
        ])

        # Decoder
        self.up1 = nn.ModuleList([
            nn.ConvTranspose2d(1024, 256, 2, stride=2),
            CrossAttention(256),
            nn.Conv2d(256, 256, 3, padding=1)
        ])

        self.up2 = nn.ModuleList([
            nn.ConvTranspose2d(512, 128, 2, stride=2),
            CrossAttention(128),
            nn.Conv2d(128, 128, 3, padding=1)
        ])

        self.up3 = nn.ModuleList([
            nn.ConvTranspose2d(256, 64, 2, stride=2),
            CrossAttention(64),
            nn.Conv2d(64, 64, 3, padding=1)
        ])

        # Final convolution
        self.conv_out = nn.Conv2d(64, out_channels, 1)

        # Downsample and upsample operations
        self.downs = nn.ModuleList([
            nn.MaxPool2d(2),
            nn.MaxPool2d(2),
            nn.MaxPool2d(2)
        ])

    def forward(self, x, t, context):
        # Time embedding
        t = self.time_mlp(t)[:, :, None, None]  # Add spatial dimensions

        # Initial conv
        x = self.conv_in(x)

        # Cache residuals
        residuals = []

        # Encoder
        for i, down in enumerate([self.down1, self.down2, self.down3]):
            residuals.append(x)
            x = F.gelu(down[0](x))
            x = down[1](x, context)  # Cross attention
            x = F.gelu(down[2](x))
            x = self.downs[i](x)

        # Middle
        x = F.gelu(self.mid[0](x))
        x = self.mid[1](x, context)  # Cross attention
        x = F.gelu(self.mid[2](x))

        # Decoder
        for i, up in enumerate([self.up1, self.up2, self.up3]):
            x = up[0](torch.cat([x, residuals[-i-1]], dim=1))
            x = up[1](x, context)  # Cross attention
            x = F.gelu(up[2](x))

        return self.conv_out(x)

class SinusoidalPositionEmbeddings(nn.Module):
    """Time embedding module"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

def train(source_dir, target_dir, num_epochs=100, batch_size=4, device="cuda", gradient_accumulation_steps=1): # add gradient_accumulation_steps
    # Initialize dataset and dataloader
    dataset = FaceTranslationDataset(source_dir, target_dir, image_size=128) # Reduce image size to 128x128
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    # Initialize model and diffusion
    model = ConditionalUNet().to(device)
    diffusion = FaceTranslationDiffusion()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        for batch_idx, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")):
            source_images = batch['source'].to(device)
            target_images = batch['target'].to(device)

            # Sample random timesteps
            t = torch.randint(0, diffusion.noise_steps, (source_images.shape[0],)).to(device)

            # Add noise to target images
            noisy_target, noise = diffusion.noise_images(target_images, t)

            # Predict noise
            predicted_noise = model(noisy_target, t, source_images)

             # Calculate loss
            loss = F.mse_loss(noise, predicted_noise)

            # Gradient Accumulation
            loss = loss / gradient_accumulation_steps # scale loss
            loss.backward()

            if (batch_idx + 1) % gradient_accumulation_steps == 0: # update every n steps
                optimizer.step()
                optimizer.zero_grad()

            # Optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        # Print progress
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")

        # Save sample generations periodically
        if (epoch + 1) % 10 == 0:
            save_sample_generations(model, diffusion, dataloader, epoch, device)

def save_sample_generations(model, diffusion, dataloader, epoch, device):
    model.eval()
    with torch.no_grad():
        sample_source = next(iter(dataloader))['source'][:1].to(device)
        sample_result = diffusion.sample(model, sample_source)

        # Save the result
        save_image(
            torch.cat([sample_source, sample_result], dim=0),
            f"samples/epoch_{epoch+1}.png",
            normalize=True,
            nrow=2
        )

import torch
import torchvision
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import imageio
from IPython.display import display, HTML
import numpy as np
from torchvision.utils import make_grid
import io

class FaceTranslationVisualizer:
    def __init__(self, model, diffusion, device="cuda"):
        self.model = model
        self.diffusion = diffusion
        self.device = device

    def tensor_to_image(self, tensor):
        """Convert tensor to numpy image"""
        img = tensor.cpu().detach()
        img = (img.clamp(-1, 1) + 1) / 2
        img = img.permute(1, 2, 0).numpy()
        return (img * 255).astype(np.uint8)

    def create_comparison_grid(self, source_img, generated_img):
        """Create a grid with source and generated images"""
        # Ensure inputs are in the right format
        if isinstance(source_img, np.ndarray):
            source_img = torch.from_numpy(source_img).permute(2, 0, 1)
        if isinstance(generated_img, np.ndarray):
            generated_img = torch.from_numpy(generated_img).permute(2, 0, 1)

        # Create grid
        grid = make_grid(
            [source_img, generated_img],
            nrow=2,
            normalize=True,
            value_range=(-1, 1)
        )
        return grid

    def save_generation_process(self, source_image, save_path="generation_process.gif",
                              num_frames=50, fps=10):
        """Generate and save the transformation process as a GIF"""
        self.model.eval()
        frames = []

        # Prepare source image
        if not isinstance(source_image, torch.Tensor):
            transform = torchvision.transforms.ToTensor()
            source_image = transform(source_image).unsqueeze(0).to(self.device)
            source_image = source_image * 2 - 1  # Scale to [-1, 1]

        # Initialize with random noise
        x = torch.randn((1, 3, self.diffusion.img_size, self.diffusion.img_size)).to(self.device)

        step_size = self.diffusion.noise_steps // num_frames

        with torch.no_grad():
            for i in reversed(range(0, self.diffusion.noise_steps, step_size)):
                # Get current timestep
                t = torch.tensor([i]).to(self.device)

                # Predict noise and denoise
                predicted_noise = self.model(x, t, source_image)

                alpha = self.diffusion.alphas[i]
                alpha_hat = self.diffusion.alpha_hat[i]
                beta = self.diffusion.betas[i]

                if i > 0:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)

                x = 1 / torch.sqrt(alpha) * (
                    x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise
                ) + torch.sqrt(beta) * noise

                # Create comparison frame
                if i % step_size == 0:
                    current_img = (x.clamp(-1, 1) + 1) / 2
                    grid = self.create_comparison_grid(
                        source_image[0],
                        current_img[0]
                    )
                    grid_img = (grid.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
                    frames.append(Image.fromarray(grid_img))

        # Save as GIF
        frames[0].save(
            save_path,
            save_all=True,
            append_images=frames[1:],
            duration=1000/fps,
            loop=0
        )
        return frames

    def display_interactive(self, source_image, figsize=(12, 6)):
        """Display interactive visualization in notebook"""
        frames = self.save_generation_process(source_image)

        fig, ax = plt.subplots(figsize=figsize)
        plt.axis('off')

        def animate(frame):
            ax.clear()
            ax.imshow(frame)
            ax.axis('off')
            ax.set_title('Left: Source Image | Right: Generated Image')

        anim = FuncAnimation(
            fig,
            animate,
            frames=frames,
            interval=100,
            repeat=True
        )
        plt.close()
        return anim

    def display_static_comparison(self, source_image, generated_image, figsize=(12, 6)):
        """Display static comparison of source and generated images"""
        plt.figure(figsize=figsize)

        grid = self.create_comparison_grid(source_image, generated_image)
        grid_img = grid.permute(1, 2, 0).cpu().numpy()

        plt.imshow(grid_img)
        plt.axis('off')
        plt.title('Left: Source Image | Right: Generated Image')
        plt.show()

# Modify the training loop to include visualization
def train_with_visualization(source_dir, target_dir, num_epochs=100, batch_size=8,
                           device="cuda", save_dir="results"):
    # Initialize models and dataset as before
    model = ConditionalUNet().to(device)
    diffusion = FaceTranslationDiffusion()
    visualizer = FaceTranslationVisualizer(model, diffusion, device)

    # Create save directory
    os.makedirs(save_dir, exist_ok=True)

    # Training loop with visualization
    for epoch in range(num_epochs):
        # ... (previous training code) ...

        # Generate and save visualization every n epochs
        if (epoch + 1) % 10 == 0:
            model.eval()
            with torch.no_grad():
                # Get a sample image
                sample_source = next(iter(dataloader))['source'][:1].to(device)

                # Generate and save the transformation process
                save_path = os.path.join(save_dir, f"epoch_{epoch+1}_process.gif")
                visualizer.save_generation_process(
                    sample_source,
                    save_path=save_path
                )

                # Generate final image and save comparison
                final_generated = diffusion.sample(model, sample_source)
                plt.figure(figsize=(12, 6))
                visualizer.display_static_comparison(
                    sample_source[0],
                    final_generated[0]
                )
                plt.savefig(os.path.join(save_dir, f"epoch_{epoch+1}_comparison.png"))
                plt.close()

# Example usage
def demo_visualization(source_image_path):
    # Load and preprocess source image
    source_image = Image.open(source_image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(256),
        transforms.ToTensor(),
    ])
    source_tensor = transform(source_image).unsqueeze(0)

    # Initialize models
    model = ConditionalUNet().to("cuda")
    diffusion = FaceTranslationDiffusion()
    visualizer = FaceTranslationVisualizer(model, diffusion)

    # Load model weights (assuming you have trained the model)
    model.load_state_dict(torch.load("model_weights.pth"))

    # Generate and display transformation process
    print("Generating transformation process...")
    anim = visualizer.display_interactive(source_tensor)

    # Generate final image and display comparison
    print("Generating final comparison...")
    with torch.no_grad():
        generated_image = diffusion.sample(model, source_tensor.to("cuda"))
        visualizer.display_static_comparison(source_tensor[0], generated_image[0])

    return anim


import torch
import math
import torch.nn.functional as F

class FaceTranslationDiffusion:
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, device="cuda"):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.img_size = img_size
        self.device = device

        # Define beta schedule
        self.betas = torch.linspace(beta_start, beta_end, noise_steps, device=self.device)  # betas on device
        self.alphas = 1. - self.betas
        self.alpha_hat = torch.cumprod(self.alphas, dim=0)  # alpha_hat on device

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        ε = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * ε, ε

    def sample_timesteps(self, n):
        return torch.randint(0, self.noise_steps, (n,), device=self.device)

    def sample(self, model, source_image, n=1):
        model.eval()
        with torch.no_grad():
            x = torch.randn((n, 3, self.img_size, self.img_size), device=self.device)  # x on device
            for i in tqdm(reversed(range(0, self.noise_steps)), desc='sampling loop time step', total=self.noise_steps):
                t = torch.full((n,), i, device=self.device, dtype=torch.long)  # t on device
                predicted_noise = model(x, t, source_image)
                alpha = self.alphas[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.betas[t][:, None, None, None]
                if i > 0:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        # Return the generated image in the range [-1, 1] for consistency
        return x.clamp(-1, 1)

if __name__ == "__main__":
    # Example usage
    source_dir = "/root/.cache/kagglehub/datasets/elibooklover/victorian400/versions/5"
    target_dir = "/root/.cache/kagglehub/datasets/ashwingupta3012/human-faces/versions/1"

    train(
        source_dir=source_dir,
        target_dir=target_dir,
        num_epochs=100,
        batch_size=2
    )

    source_image_path = "/content/drive/MyDrive/victorian_potrait.jpg.png"
    animation = demo_visualization(source_image_path)

    # Save animation
    animation.save('transformation.gif', writer='pillow')


Found 8 source images
Found 6973 target images


Epoch 1/100:   0%|          | 0/4 [00:01<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 16.00 GiB. GPU 0 has a total capacity of 14.74 GiB of which 14.39 GiB is free. Process 17453 has 354.00 MiB memory in use. Of the allocated memory 202.12 MiB is allocated by PyTorch, and 23.88 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torch.optim import Adam
from tqdm import tqdm
import os
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
import math

class DiffusionModel:
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, device='cuda' if torch.cuda.is_available() else 'cpu'):
        """
        Initialize diffusion model parameters

        Args:
            noise_steps (int): Number of diffusion steps
            beta_start (float): Starting noise scale
            beta_end (float): Ending noise scale
            device (str): Device to store tensors on ('cuda' or 'cpu')
        """
        self.noise_steps = noise_steps
        self.device = device

        # Generate noise schedule and move to device
        self.betas = torch.linspace(beta_start, beta_end, noise_steps).to(device)
        self.alphas = 1 - self.betas
        self.alpha_prod = torch.cumprod(self.alphas, dim=0).to(device)
        self.sqrt_alpha_prod = torch.sqrt(self.alpha_prod).to(device)
        self.sqrt_one_minus_alpha_prod = torch.sqrt(1 - self.alpha_prod).to(device)

    def forward_diffusion(self, x0, t):
        """
        Add noise to input image at timestep t

        Args:
            x0 (torch.Tensor): Original image
            t (torch.Tensor): Timestep

        Returns:
            Noisy image and noise
        """
        noise = torch.randn_like(x0)
        sqrt_alpha_prod_t = self.sqrt_alpha_prod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha_prod_t = self.sqrt_one_minus_alpha_prod[t].view(-1, 1, 1, 1)

        noisy_x = sqrt_alpha_prod_t * x0 + sqrt_one_minus_alpha_prod_t * noise
        return noisy_x, noise

    class UNet(nn.Module):
        def __init__(self, in_channels=3, out_channels=3, base_channels=64):
            """
            U-Net architecture for noise prediction

            Args:
                in_channels (int): Input image channels
                out_channels (int): Output noise channels
                base_channels (int): Base convolution channels
            """
            super().__init__()

            # Encoder
            self.enc1 = self._block(in_channels, base_channels)
            self.enc2 = self._block(base_channels, base_channels * 2)
            self.pool = nn.MaxPool2d(2)

            # Middle
            self.middle = self._block(base_channels * 2, base_channels * 4)

            # Decoder
            self.upconv2 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, kernel_size=2, stride=2)
            self.dec2 = self._block(base_channels * 4, base_channels * 2)

            self.upconv1 = nn.ConvTranspose2d(base_channels * 2, base_channels, kernel_size=2, stride=2)
            self.dec1 = self._block(base_channels * 2, base_channels)

            self.final_conv = nn.Conv2d(base_channels, out_channels, kernel_size=1)

        def _block(self, in_channels, out_channels):
            """Standard convolutional block with normalization and activation"""
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, 3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )

        def forward(self, x, t):
            """
            Forward pass through U-Net

            Args:
                x (torch.Tensor): Noisy input image
                t (torch.Tensor): Timestep

            Returns:
                Predicted noise
            """
            # Time embedding (simple positional encoding)
            t_emb = torch.sin(t.float())

            # Encoder
            enc1 = self.enc1(x)
            enc2 = self.enc2(self.pool(enc1))

            # Middle
            middle = self.middle(self.pool(enc2))

            # Decoder
            dec2 = self.upconv2(middle)
            dec2 = torch.cat([dec2, enc2], dim=1)
            dec2 = self.dec2(dec2)

            dec1 = self.upconv1(dec2)
            dec1 = torch.cat([dec1, enc1], dim=1)
            dec1 = self.dec1(dec1)

            return self.final_conv(dec1)

    def reverse_diffusion(self, x, model):
        """
        Reverse the diffusion process to generate an image

        Args:
            x (torch.Tensor): Starting noisy image
            model (nn.Module): Noise prediction model

        Returns:
            Denoised image
        """
        model.eval()
        with torch.no_grad():
            for t in reversed(range(self.noise_steps)):
                t_tensor = torch.full((x.shape[0],), t, dtype=torch.long)
                predicted_noise = model(x, t_tensor)

                # Compute coefficients
                beta_t = self.betas[t]
                one_minus_alpha_t = 1 - self.alphas[t]
                sqrt_one_minus_alpha_prod_t = self.sqrt_one_minus_alpha_prod[t]

                # Compute noise scaling
                noise_scaling = beta_t / sqrt_one_minus_alpha_prod_t

                # Denoise
                x = (1 / torch.sqrt(self.alphas[t])) * (
                    x - noise_scaling * predicted_noise
                )

                # Add noise if not at the final step
                if t > 0:
                    noise = torch.randn_like(x)
                    x += torch.sqrt(self.betas[t]) * noise

        return x.clamp(0, 1)

class DiffusionTrainer:
    def __init__(self, data_dir, image_size=64, batch_size=32):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.image_size = image_size
        self.batch_size = batch_size

        # Initialize dataset and dataloader
        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        self.dataset = datasets.ImageFolder(
            root=data_dir,
            transform=self.transform
        )

        self.dataloader = DataLoader(
            self.dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )

        # Initialize model and diffusion process
        self.diffusion = DiffusionModel()
        self.model = self.diffusion.UNet(in_channels=3, out_channels=3).to(self.device)

        # Initialize optimizer
        self.optimizer = Adam(self.model.parameters(), lr=2e-4)

    def train(self, num_epochs=100, save_interval=10):
        """Train the diffusion model"""
        for epoch in range(num_epochs):
            self.model.train()
            total_loss = 0

            progress_bar = tqdm(self.dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
            for batch_idx, (images, _) in enumerate(progress_bar):
                images = images.to(self.device)
                batch_size = images.shape[0]

                # Sample random timesteps
                t = torch.randint(0, self.diffusion.noise_steps, (batch_size,),
                                device=self.device)

                # Forward diffusion
                noisy_images, noise = self.diffusion.forward_diffusion(images, t)

                # Predict noise
                predicted_noise = self.model(noisy_images, t)

                # Calculate loss
                loss = nn.MSELoss()(predicted_noise, noise)

                # Backpropagation
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                total_loss += loss.item()

                # Update progress bar
                progress_bar.set_postfix({"Loss": total_loss / (batch_idx + 1)})

            # Save checkpoint
            if (epoch + 1) % save_interval == 0:
                self.save_checkpoint(epoch + 1)
                self.generate_samples(epoch + 1)

    def save_checkpoint(self, epoch):
        """Save model checkpoint"""
        checkpoint_dir = "checkpoints"
        os.makedirs(checkpoint_dir, exist_ok=True)

        checkpoint_path = os.path.join(checkpoint_dir, f"diffusion_epoch_{epoch}.pt")
        torch.save({
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, checkpoint_path)

    def generate_samples(self, epoch, num_samples=16):
        """Generate and save sample images"""
        self.model.eval()
        samples_dir = "samples"
        os.makedirs(samples_dir, exist_ok=True)

        # Start from random noise
        x = torch.randn(num_samples, 3, self.image_size, self.image_size).to(self.device)

        # Generate images
        with torch.no_grad():
            generated_images = self.diffusion.reverse_diffusion(x, self.model)

        # Save generated images
        fig, axs = plt.subplots(4, 4, figsize=(12, 12))
        for i in range(num_samples):
            img = generated_images[i].cpu()
            img = (img + 1) / 2  # Denormalize
            img = img.permute(1, 2, 0).numpy()
            axs[i//4, i%4].imshow(img)
            axs[i//4, i%4].axis('off')

        plt.tight_layout()
        plt.savefig(os.path.join(samples_dir, f'samples_epoch_{epoch}.png'))
        plt.close()

def main():
    # Configuration
    data_dir = "/root/.cache/kagglehub/datasets/pavansanagapati/images-dataset/versions/1"
    image_size = 64
    batch_size = 32
    num_epochs = 100

    # Initialize trainer
    trainer = DiffusionTrainer(
        data_dir=data_dir,
        image_size=image_size,
        batch_size=batch_size
    )

    # Start training
    trainer.train(num_epochs=num_epochs)

if __name__ == "__main__":
    main()

FileNotFoundError: [Errno 2] No such file or directory: '/root/.cache/kagglehub/datasets/pavansanagapati/images-dataset/versions/1'