<a href="https://colab.research.google.com/github/Avniiii2606/Image-To-Image-Diffusion/blob/main/image_to_image_diffusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import kagglehub


path = kagglehub.dataset_download("ashwingupta3012/human-faces")

print("Path to dataset files:", path)

path1 = kagglehub.dataset_download("elibooklover/victorian400")

print("Path to dataset files:", path1)

from google.colab import drive
drive.mount('/content/drive')

import os
os.environ["WANDB_DISABLED"] = "true"

import torch
torch.cuda.empty_cache()

Path to dataset files: /root/.cache/kagglehub/datasets/ashwingupta3012/human-faces/versions/1
Path to dataset files: /root/.cache/kagglehub/datasets/elibooklover/victorian400/versions/5
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
 CUDA_LAUNCH_BLOCKING=1

In [None]:
import torch.distributed as dist
import torch.nn.parallel
from torch.utils.data.distributed import DistributedSampler

def setup_ddp(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


In [None]:
!pip uninstall torch torchvision torchaudio
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

Found existing installation: torch 2.6.0+cu118
Uninstalling torch-2.6.0+cu118:
  Would remove:
    /usr/local/bin/torchfrtrace
    /usr/local/bin/torchrun
    /usr/local/lib/python3.11/dist-packages/functorch/*
    /usr/local/lib/python3.11/dist-packages/torch-2.6.0+cu118.dist-info/*
    /usr/local/lib/python3.11/dist-packages/torch/*
    /usr/local/lib/python3.11/dist-packages/torchgen/*
Proceed (Y/n)? Y
  Successfully uninstalled torch-2.6.0+cu118
Found existing installation: torchvision 0.21.0+cu118
Uninstalling torchvision-0.21.0+cu118:
  Would remove:
    /usr/local/lib/python3.11/dist-packages/torchvision-0.21.0+cu118.dist-info/*
    /usr/local/lib/python3.11/dist-packages/torchvision.libs/libcudart.60cfec8e.so.11.0
    /usr/local/lib/python3.11/dist-packages/torchvision.libs/libjpeg.1c1c4b09.so.8
    /usr/local/lib/python3.11/dist-packages/torchvision.libs/libnvjpeg.70530407.so.11
    /usr/local/lib/python3.11/dist-packages/torchvision.libs/libpng16.0364a1db.so.16
    /usr/local

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
import numpy as np
from PIL import Image
import os
from tqdm import tqdm
from pathlib import Path
import math

class FaceTranslationDataset(Dataset):
    """Custom dataset for paired face images (painting/statue -> real)"""
    def __init__(self, source_dir, target_dir, image_size=128):
        # Use Path.rglob to search for images recursively in subfolders
        self.source_paths = sorted(Path(source_dir).rglob('*.jpg'))
        self.target_paths = sorted(Path(target_dir).rglob('*.jpg'))

        # Print the number of images found
        print(f"Found {len(self.source_paths)} source images")
        print(f"Found {len(self.target_paths)} target images")

        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def __len__(self):
        # Ensure both source and target have the same number of images
        return min(len(self.source_paths), len(self.target_paths))

    def __getitem__(self, idx):
        source_img = Image.open(self.source_paths[idx]).convert('RGB')
        target_img = Image.open(self.target_paths[idx]).convert('RGB')

        return {
            'source': self.transform(source_img),
            'target': self.transform(target_img)
        }

class CrossAttention(nn.Module):
    """Fixed cross attention module for conditioning"""
    def __init__(self, channels, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        self.channels = channels
        self.scale = (channels // num_heads) ** -0.5

        # Fix: Initialize GroupNorm with the correct number of channels
        self.norm_x = nn.GroupNorm(1, channels)  # Use separate GroupNorms
        self.norm_context = nn.GroupNorm(1, channels)  # For the context with 'channels' channels

        self.to_q = nn.Conv2d(channels, channels, 1)
        self.to_k = nn.Conv2d(channels, channels, 1)  # Change input channels to 'channels'
        self.to_v = nn.Conv2d(channels, channels, 1)  # Change input channels to 'channels'
        self.to_out = nn.Conv2d(channels, channels, 1)  # Initialize the context norm with the number of channels in the input tensor x

        # Define context_proj here
        self.context_proj = nn.Conv2d(in_channels=3, out_channels=channels, kernel_size=1) # Assuming context has 3 channels

    def forward(self, x, context):
        batch, c, h, w = x.shape

        # Project the context to the desired number of channels before normalization
        context = self.context_proj(context)


        # Normalize inputs using separate GroupNorms
        x = self.norm_x(x)
        context = self.norm_context(context)

        # Create Q, K, V
        q = self.to_q(x)
        k = self.to_k(context)
        v = self.to_v(context)

        # Reshape for attention
        q = q.view(batch, self.num_heads, c // self.num_heads, h * w)
        k = k.view(batch, self.num_heads, c // self.num_heads, h * w)
        v = v.view(batch, self.num_heads, c // self.num_heads, h * w)

        # Transpose for attention computation
        q, k, v = map(lambda t: t.transpose(-2, -1), (q, k, v))

        # Attention
        attention = torch.softmax(torch.matmul(q, k.transpose(-2, -1)) * self.scale, dim=-1)
        out = torch.matmul(attention, v)

        # Reshape back
        out = out.transpose(-2, -1).reshape(batch, c, h, w)
        return self.to_out(out) + x

class ConditionalUNet(nn.Module):
    """Updated U-Net with fixed cross-attention"""
    def __init__(self, in_channels=3, out_channels=3, time_emb_dim=128):
        super().__init__()

        # Time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim * 2),
            nn.GELU(),
            nn.Linear(time_emb_dim * 2, time_emb_dim)
        )

        # Initial convolution
        self.conv_in = nn.Conv2d(in_channels, 64, 3, padding=1)

        # Encoder
        self.down1 = nn.ModuleList([
            nn.Conv2d(64, 128, 3, padding=1),
            CrossAttention(128),
            nn.Conv2d(128, 128, 3, padding=1)
        ])

        self.down2 = nn.ModuleList([
            nn.Conv2d(128, 128, 3, padding=1),
            CrossAttention(128),
            nn.Conv2d(128, 128, 3, padding=1)
        ])

        self.down3 = nn.ModuleList([
            nn.Conv2d(128, 512, 3, padding=1),
            CrossAttention(512),
            nn.Conv2d(512, 512, 3, padding=1)
        ])

        # Middle
        self.mid = nn.ModuleList([
            nn.Conv2d(512, 512, 3, padding=1),
            CrossAttention(512),
            nn.Conv2d(512, 512, 3, padding=1)
        ])

        # Decoder
        self.up1 = nn.ModuleList([
            nn.ConvTranspose2d(1024, 128, 2, stride=2),
            CrossAttention(128),
            nn.Conv2d(128, 128, 3, padding=1)
        ])

        self.up2 = nn.ModuleList([
            nn.ConvTranspose2d(512, 128, 2, stride=2),
            CrossAttention(128),
            nn.Conv2d(128, 128, 3, padding=1)
        ])

        self.up3 = nn.ModuleList([
            nn.ConvTranspose2d(128, 64, 2, stride=2),
            CrossAttention(64),
            nn.Conv2d(64, 64, 3, padding=1)
        ])

        # Final convolution
        self.conv_out = nn.Conv2d(64, out_channels, 1)

        # Downsample and upsample operations
        self.downs = nn.ModuleList([
            nn.MaxPool2d(2),
            nn.MaxPool2d(2),
            nn.MaxPool2d(2)
        ])

    def forward(self, x, t, context):
        # Time embedding
        t = self.time_mlp(t)[:, :, None, None]  # Add spatial dimensions

        # Initial conv
        x = self.conv_in(x)

        # Cache residuals
        residuals = []

        # Encoder
        for i, down in enumerate([self.down1, self.down2, self.down3]):
            residuals.append(x)
            x = F.gelu(down[0](x))
            x = down[1](x, context)  # Cross attention
            x = F.gelu(down[2](x))
            x = self.downs[i](x)

        # Middle
        x = F.gelu(self.mid[0](x))
        x = self.mid[1](x, context)  # Cross attention
        x = F.gelu(self.mid[2](x))

        # Decoder
        for i, up in enumerate([self.up1, self.up2, self.up3]):
            x = up[0](torch.cat([x, residuals[-i-1]], dim=1))
            x = up[1](x, context)  # Cross attention
            x = F.gelu(up[2](x))

        return self.conv_out(x)

class SinusoidalPositionEmbeddings(nn.Module):
    """Time embedding module"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

def train(source_dir, target_dir, num_epochs=100, batch_size=1, device="cuda", gradient_accumulation_steps=4): # add gradient_accumulation_steps
    # Initialize dataset and dataloader
    dataset = FaceTranslationDataset(source_dir, target_dir, image_size=128) # Reduce image size to 128x128
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    # Initialize model and diffusion
    model = ConditionalUNet().to(device)
    diffusion = FaceTranslationDiffusion()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        for batch_idx, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")):
            source_images = batch['source'].to(device)
            target_images = batch['target'].to(device)

            # Sample random timesteps
            t = torch.randint(0, diffusion.noise_steps, (source_images.shape[0],)).to(device)

            # Add noise to target images
            noisy_target, noise = diffusion.noise_images(target_images, t)

            # Predict noise
            predicted_noise = model(noisy_target, t, source_images)

             # Calculate loss
            loss = F.mse_loss(noise, predicted_noise)

            # Gradient Accumulation
            loss = loss / gradient_accumulation_steps # scale loss
            loss.backward()

            if (batch_idx + 1) % gradient_accumulation_steps == 0: # update every n steps
                optimizer.step()
                optimizer.zero_grad()

            # Optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        # Print progress
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")

        # Save sample generations periodically
        if (epoch + 1) % 10 == 0:
            save_sample_generations(model, diffusion, dataloader, epoch, device)

def save_sample_generations(model, diffusion, dataloader, epoch, device):
    model.eval()
    with torch.no_grad():
        sample_source = next(iter(dataloader))['source'][:1].to(device)
        sample_result = diffusion.sample(model, sample_source)

        # Save the result
        save_image(
            torch.cat([sample_source, sample_result], dim=0),
            f"samples/epoch_{epoch+1}.png",
            normalize=True,
            nrow=2
        )

import torch
import torchvision
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import imageio
from IPython.display import display, HTML
import numpy as np
from torchvision.utils import make_grid
import io

class FaceTranslationVisualizer:
    def __init__(self, model, diffusion, device="cuda"):
        self.model = model
        self.diffusion = diffusion
        self.device = device

    def tensor_to_image(self, tensor):
        """Convert tensor to numpy image"""
        img = tensor.cpu().detach()
        img = (img.clamp(-1, 1) + 1) / 2
        img = img.permute(1, 2, 0).numpy()
        return (img * 255).astype(np.uint8)

    def create_comparison_grid(self, source_img, generated_img):
        """Create a grid with source and generated images"""
        # Ensure inputs are in the right format
        if isinstance(source_img, np.ndarray):
            source_img = torch.from_numpy(source_img).permute(2, 0, 1)
        if isinstance(generated_img, np.ndarray):
            generated_img = torch.from_numpy(generated_img).permute(2, 0, 1)

        # Create grid
        grid = make_grid(
            [source_img, generated_img],
            nrow=2,
            normalize=True,
            value_range=(-1, 1)
        )
        return grid

    def save_generation_process(self, source_image, save_path="generation_process.gif",
                              num_frames=50, fps=10):
        """Generate and save the transformation process as a GIF"""
        self.model.eval()
        frames = []

        # Prepare source image
        if not isinstance(source_image, torch.Tensor):
            transform = torchvision.transforms.ToTensor()
            source_image = transform(source_image).unsqueeze(0).to(self.device)
            source_image = source_image * 2 - 1  # Scale to [-1, 1]

        # Initialize with random noise
        x = torch.randn((1, 3, self.diffusion.img_size, self.diffusion.img_size)).to(self.device)

        step_size = self.diffusion.noise_steps // num_frames

        with torch.no_grad():
            for i in reversed(range(0, self.diffusion.noise_steps, step_size)):
                # Get current timestep
                t = torch.tensor([i]).to(self.device)

                # Predict noise and denoise
                predicted_noise = self.model(x, t, source_image)

                alpha = self.diffusion.alphas[i]
                alpha_hat = self.diffusion.alpha_hat[i]
                beta = self.diffusion.betas[i]

                if i > 0:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)

                x = 1 / torch.sqrt(alpha) * (
                    x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise
                ) + torch.sqrt(beta) * noise

                # Create comparison frame
                if i % step_size == 0:
                    current_img = (x.clamp(-1, 1) + 1) / 2
                    grid = self.create_comparison_grid(
                        source_image[0],
                        current_img[0]
                    )
                    grid_img = (grid.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
                    frames.append(Image.fromarray(grid_img))

        # Save as GIF
        frames[0].save(
            save_path,
            save_all=True,
            append_images=frames[1:],
            duration=1000/fps,
            loop=0
        )
        return frames

    def display_interactive(self, source_image, figsize=(12, 6)):
        """Display interactive visualization in notebook"""
        frames = self.save_generation_process(source_image)

        fig, ax = plt.subplots(figsize=figsize)
        plt.axis('off')

        def animate(frame):
            ax.clear()
            ax.imshow(frame)
            ax.axis('off')
            ax.set_title('Left: Source Image | Right: Generated Image')

        anim = FuncAnimation(
            fig,
            animate,
            frames=frames,
            interval=100,
            repeat=True
        )
        plt.close()
        return anim

    def display_static_comparison(self, source_image, generated_image, figsize=(12, 6)):
        """Display static comparison of source and generated images"""
        plt.figure(figsize=figsize)

        grid = self.create_comparison_grid(source_image, generated_image)
        grid_img = grid.permute(1, 2, 0).cpu().numpy()

        plt.imshow(grid_img)
        plt.axis('off')
        plt.title('Left: Source Image | Right: Generated Image')
        plt.show()

# Modify the training loop to include visualization
def train_with_visualization_ddp(rank, world_size, source_dir, target_dir, num_epochs=100, batch_size=1, save_dir="results"):
    setup_ddp(rank, world_size)

    device = torch.device(f"cuda:{rank}")
    model = ConditionalUNet().to(device)
    diffusion = FaceTranslationDiffusion()

    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])

    # Use DistributedSampler to distribute data across GPUs
    dataset = FaceTranslationDataset(source_dir, target_dir, image_size=128)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=4)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

    for epoch in range(num_epochs):
        sampler.set_epoch(epoch)  # Ensure proper shuffling across epochs
        model.train()

        for batch in dataloader:
            source_images = batch['source'].to(device)
            target_images = batch['target'].to(device)

            t = torch.randint(0, diffusion.noise_steps, (source_images.shape[0],)).to(device)
            noisy_target, noise = diffusion.noise_images(target_images, t)
            predicted_noise = model(noisy_target, t, source_images)

            loss = F.mse_loss(noise, predicted_noise)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Save model only from rank 0 to avoid multiple saves
        if rank == 0 and (epoch + 1) % 10 == 0:
            torch.save(model.state_dict(), f"{save_dir}/model_epoch_{epoch+1}.pth")

    dist.destroy_process_group()  # Clean up distributed training


def demo_visualization(source_image_path):
    # Load and preprocess source image
    source_image = Image.open(source_image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize(128),
        transforms.CenterCrop(128),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Add this line for normalization
    ])
    source_tensor = transform(source_image).unsqueeze(0)

    # Initialize models
    model = ConditionalUNet().to("cuda")
    diffusion = FaceTranslationDiffusion()
    visualizer = FaceTranslationVisualizer(model, diffusion)

    # Load model weights (assuming you have trained the model)
    model.load_state_dict(torch.load("model_weights.pth"))

    # Generate and display transformation process
    print("Generating transformation process...")
    anim = visualizer.display_interactive(source_tensor)

    # Generate final image and display comparison
    print("Generating final comparison...")
    with torch.no_grad():
        generated_image = diffusion.sample(model, source_tensor.to("cuda"))
        visualizer.display_static_comparison(source_tensor[0], generated_image[0])

    return anim


import torch
import math
import torch.nn.functional as F

class FaceTranslationDiffusion:
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=128, device="cuda"):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.img_size = img_size
        self.device = device

        # Define beta schedule
        self.betas = torch.linspace(beta_start, beta_end, noise_steps, device=self.device)  # betas on device
        self.alphas = 1. - self.betas
        self.alpha_hat = torch.cumprod(self.alphas, dim=0)  # alpha_hat on device

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        ε = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * ε, ε

    def sample_timesteps(self, n):
        return torch.randint(0, self.noise_steps, (n,), device=self.device)

    def sample(self, model, source_image, n=1):
        model.eval()
        with torch.no_grad():
            x = torch.randn((n, 3, self.img_size, self.img_size), device=self.device)  # x on device
            for i in tqdm(reversed(range(0, self.noise_steps)), desc='sampling loop time step', total=self.noise_steps):
                t = torch.full((n,), i, device=self.device, dtype=torch.long)  # t on device
                predicted_noise = model(x, t, source_image)
                alpha = self.alphas[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.betas[t][:, None, None, None]
                if i > 0:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        # Return the generated image in the range [-1, 1] for consistency
        return x.clamp(-1, 1)



if __name__ == "__main__":
    # Example usage
    source_dir = "/root/.cache/kagglehub/datasets/elibooklover/victorian400/versions/5"
    target_dir = "/root/.cache/kagglehub/datasets/ashwingupta3012/human-faces/versions/1"

    # Enable memory fragmentation avoidance
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

    train(
        source_dir=source_dir,
        target_dir=target_dir,
        num_epochs=100,
        batch_size=1,
        gradient_accumulation_steps=4
    )

    source_image_path = "/content/drive/MyDrive/victorian_potrait.jpg.png"
    animation = demo_visualization(source_image_path)

    # Save animation
    animation.save('transformation.gif', writer='pillow')


Found 8 source images
Found 6973 target images


Epoch 1/100:   0%|          | 0/8 [00:02<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 8.00 GiB. GPU 0 has a total capacity of 14.74 GiB of which 6.49 GiB is free. Process 33284 has 8.25 GiB memory in use. Of the allocated memory 8.12 GiB is allocated by PyTorch, and 2.09 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)