In [1]:
import torch 
import torch.nn as nn

In [2]:

import torch 
import torch.nn as nn
class LinearScehduler:
    def __init__(self, num_timestamps, beta_start, beta_end):
        self.num_timestamps = num_timestamps
        self.beta_start = beta_start
        self.beta_end = beta_end

        self.betas = torch.linspace(beta_start, beta_end, num_timestamps) # To linearly increase BETA from start to end, we will have BETA from 0 to T
        self.alphas = 1. - self.betas 
        self.alpha_cumilative_product = torch.cumprod(self.alphas, dim = 0)
        self.alpha_sqroot_cumilative_prod = torch.sqrt(self.alpha_cumilative_product)
        self.one_minus_alpha_squareroot = torch.sqrt( 1. - self.alpha_cumilative_product)


    def add_noise(self, original_image, noise,t ):
        """
        add noise to the image in the forward process
        the images and noise will be of shape BxCxHxW and a 1D tensor for time stamp 't' of size 'B'
        """
        """
        Forward method for diffusion
        :param original: Image on which noise is to be applied
        :param noise: Random Noise Tensor (from normal dist)
        :param t: timestep of the forward process of shape -> (B,)
        :return:
        """
        shape = original_image.shape 
        batch_size = shape[0]

        alpha_sqrt_cum_prod = self.alpha_sqroot_cumilative_prod[t].reshape(batch_size)
        one_minus_alphs_sqrt = self.one_minus_alpha_squareroot[t].reshape(batch_size)

        for _ in range(len(shape)-1):
            """Reshape aplha sqrt and alpha-1 sqrt to Bx1x1x1"""
            alpha_sqrt_cum_prod = alpha_sqrt_cum_prod.unsqueeze(-1)
            one_minus_alphs_sqrt = one_minus_alphs_sqrt.unsqueeze(-1)
        return alpha_sqrt_cum_prod*original_image + one_minus_alphs_sqrt*noise

    def reverse_process(self, xt, noise_predicted, t):
        """
        Forward method for diffusion
        :param original: Image on which noise is to be applied
        :param noise: Random Noise Tensor (from normal dist)
        :param t: timestep of the forward process of shape -> (B,)
        :return: tuple of (mean, image), it returns the predicted mean of the distribution and the predicted denoised image
        """
        x0 = (xt - (self.one_minus_alpha_squareroot[t]*noise_predicted)) / self.alpha_sqroot_cumilative_prod[t]

        x0 = torch.clamp(x0, -1., 1.)

        mean = xt - ((self.betas[t]*noise_predicted) / self.alpha_sqroot_cumilative_prod[t])
        mean = mean / torch.sqrt(self.alphas[t])

        if t==0:
            return mean, x0
        else:
            variance = (1. - self.alphas[t]) * (1.- self.alpha_cumilative_product[t])
            variance = variance / (1. - self.alphas[t])
            sigma = variance ** 0.5 
            z = torch.randn(xt.shape).to(xt.device)
            #return the sample from the distribution using Reparameterization trick
            return mean + sigma*z, x0

    
        
        

In [3]:

import torch
import torch.nn as nn

class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class TimeEmbed(nn.Module):

    """Takes time stam't' and the required embeddings dimnestion.
    Then 't' is passed to Embedding followed by Linear layer, acitvation function and a final Linear layer:
    This is done to project the timestamp values as vectors.
    return: BxD embedding representation of B time steps.
    """    
    
    def __init__(self, t_embed_dim):
        super().__init__()
        self.t_embed_dim = t_embed_dim
        self.fc = nn.Linear(t_embed_dim, t_embed_dim)
        self.swish = Swish()
    
    def forward(self, t):

        # Factor: 10000^(2i/d_model)
        factor = 10000 ** (torch.arange(
            start=0, end=self.t_embed_dim // 2, dtype=torch.float32, device=t.device
        ) / (self.t_embed_dim // 2))

        # Compute embeddings
        t_emb = t[:, None] / factor  # Shape: (B, t_embed_dim // 2)
        t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)  # Shape: (B, t_embed_dim)

        # Pass through fully connected layer and Swish activation
        t_emb = self.swish(self.fc(t_emb))  # Final projection with non-linearity
        return t_emb


In [4]:
class Down_block(nn.Module):
    def __init__(self,)

SyntaxError: expected ':' (4064722152.py, line 1)