<a href="https://colab.research.google.com/github/abrham17/Diffusion_model-UNet-implementation/blob/main/Diffusion_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

class UNet(nn.Module):
    """
    UNet architecture for conditional Denoising Diffusion Probabilistic Models (DDPMs).
    args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        time_dim (int): Dimension of the time embedding.
        num_classes (int): Number of conditional generation.
    returns:
        out (Tensor): Predicted noise tensor for the reverse diffusion process.
    process:
        1. Down sampling: from in_channels to 64, 128, 256.
        2. Bottleneck: from 256 to 256.
        3. Up sampling: from 256 to 128, 64, out_channels.
    """
    def __init__(self, in_channels=1, out_channels=1, time_dim=128, num_classes=None):
        super(UNet, self).__init__()
        self.down1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.bottleneck = nn.Conv2d(128, 256, 3, padding=1)
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.up2 = nn.Sequential(
            nn.Conv2d(128 + 64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.out = nn.Conv2d(64, out_channels, 3, padding=1)
        self.time_emb = nn.Linear(1, time_dim)
        self.class_emb = nn.Embedding(num_classes, time_dim) if num_classes else None

    def forward(self, x, t, c=None):
        t_emb = self.time_emb(t)
        c_emb = self.class_emb(c) if c is not None else 0
        d1 = self.down1(x)
        d2 = self.down2(d1)
        b = self.bottleneck(d2)
        u1 = self.up1(b)
        u2 = self.up2(torch.cat([u1, d1], dim=1))
        out = self.out(u2)
        return out