In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.transforms as T
import matplotlib.pyplot as plt
from PIL import Image

class PositionalNoiseEmbedding(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.embedding_dim = embedding_dim

    def forward(self, sigma):
        device = sigma.device # sigma is noise level

        half_dim = self.embedding_dim // 2
        emb = torch.log(torch.tensor(10000.0, device=device)) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = sigma[:, None] * emb[None, :]  # (B, half_dim)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)  # (B, embedding_dim)
        return emb

class UNetDDPMpp(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, time_emb_dim=128):
        super().__init__()
        # Time embedding
        self.time_embed = PositionalNoiseEmbedding(time_emb_dim)
        self.time_emb_proj1 = nn.Linear(time_emb_dim, out_channels)
        self.time_emb_proj2 = nn.Linear(time_emb_dim, out_channels)
        self.time_emb_proj3 = nn.Linear(time_emb_dim, out_channels * 2)
    
        #encoder
        self.encoder1 = self.convolution_block(in_channels, out_channels)
        self.encoder2 = self.convolution_block(out_channels, out_channels*2)

        # Bottleneck
        self.bottleneck = self.convolution_block(out_channels * 2, out_channels * 4)

        #decoder
        self.decoder2 = self.up_block(out_channels * 4, out_channels * 2)
        self.decoder1 = self.up_block(out_channels * 2, out_channels)

        # Final output
        self.last_conv = nn.Conv2d(out_channels, 1, kernel_size=1)

    def convolution_block(self, in_c, out_c):
        return nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    def up_block(self, in_c, out_c):
        return nn.Sequential(
            nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2),
            self.convolution_block(out_c * 2, out_c)
        )  
    
    def forward(self, x, sigma):
        # return x_pred  
        # # Predicted noise 
        # # maybe need time embedding
        t_emb = self.time_embed(sigma)  # [batch, time_emb_dim]

        # Inject time embedding into each stage
        emb1 = self.time_emb_proj1(t_emb).unsqueeze(-1).unsqueeze(-1)  # [B, C, 1, 1]
        emb2 = self.time_emb_proj2(t_emb).unsqueeze(-1).unsqueeze(-1)
        emb3 = self.time_emb_proj3(t_emb).unsqueeze(-1).unsqueeze(-1)

        enc1 = self.encoder1(x + emb1)

        enc2 = self.encoder2(F.max_pool2d(enc1, 2) + emb2)
        
        b = self.bottleneck(F.max_pool2d(enc2, 2) + emb3)

        d2 = self.decoder2[0](b)
        d2 = self.decoder2[1](torch.cat([d2, enc2], dim=1))

        d1 = self.decoder1[0](d2)
        d1 = self.decoder1[1](torch.cat([d1, enc1], dim=1))

        return self.last_conv(d1)
    


#### test
model = UNetDDPMpp()
x = torch.randn(2, 3, 128, 128)  # batch_size=1, 1 channel, 128x128 image
sigma = torch.randn(x.shape[0])   # Random noise schedule for batch size 1

y = model(x, sigma)
print(y.shape)  # Output: torch.Size([1, 1, 128, 128])





torch.Size([2, 1, 128, 128])


In [14]:
class Denoiser(nn.Module):
    def __init__(self, unet, M=1000):
        super().__init__()
        self.unet = unet
        self.M = M

    def get_precondition(self, sigma):
        sigma = sigma.view(-1, 1, 1, 1)  # Shape: [B, 1, 1, 1]
        c_skip = torch.ones_like(sigma) 
        c_out = -sigma
        c_in = 1 / (torch.sqrt(sigma**2 + 1))

        c_noise = (self.M - 1) * sigma**(-1)  # Inverse sigma
        return c_skip, c_out, c_in, c_noise

    def denoise(self, x, sigma):
        c_skip, c_out, c_in, c_noise = self.get_precondition(sigma)

        c_in = c_in.view(-1, 1, 1, 1)
        c_out = c_out.view(-1, 1, 1, 1)
        c_skip = c_skip.view(-1, 1, 1, 1)
        c_noise = c_noise.view(-1, 1, 1, 1)

        # t_emb = self.unet.time_embed(sigma)  # [B, time_emb_dim]

        x_in = c_in * x
        predicted_noise = self.unet(x_in, sigma)

        x_denoised = c_skip * x + c_out * predicted_noise
        return x_denoised

        


In [15]:
unet_model = UNetDDPMpp()
denoiser = Denoiser(unet_model)

#test
x = torch.randn(2, 3, 128, 128)  # Noisy input
sigma = torch.tensor([0.1, 0.5])  # batch of noise levels

# Denoise the image
denoised = denoiser.denoise(x, sigma)
print(denoised.shape)  # Should be [2, 1, 128, 128]


torch.Size([2, 3, 128, 128])


In [16]:
# Forward process
def beta_t(t, beta_min=0.1, beta_max=20.0):
    return beta_min + t * (beta_max - beta_min)

def get_zigma_vp(t, beta_d=19.9, beta_min=0.1):
    t = torch.tensor(t)
    zigma = torch.sqrt(torch.exp(0.5*beta_d * t**2 + beta_min*t) - 1)
    return zigma

def get_time_step(i, steps):
    epsilon_s = 0.001
    time_step =  epsilon_s +  (i * (1 - epsilon_s))/(steps - 1)
    return time_step

def simulate_vp_sde(x0, steps=100, beta_min=0.1, beta_max=20.0):
    x = x0.clone()
    B, C, H, W = x.shape
    dt = 1.0 / steps
    images = [x0.cpu()] 

    for i in range(steps):
        # t = torch.tensor(i / steps).to(x.device)
        # beta = beta_t(t, beta_min, beta_max)

        t = get_time_step(i, steps)
        print('t', t)
        beta = get_zigma_vp(t)
        
        noise = torch.randn_like(x)
        
        drift = -0.5 * beta * x * dt
        diffusion = (beta * dt).sqrt() * noise
        
        x = x + drift + diffusion
        images.append(x.cpu())

    return images
def show_images(images, every=10):
    n = len(images)
    indices = list(range(0, n, every))
    fig, axs = plt.subplots(1, len(indices), figsize=(15, 3))
    for ax, i in zip(axs, indices):
        img = images[i][0].permute(1, 2, 0).numpy()
        img = (img - img.min()) / (img.max() - img.min())  # normalize
        ax.imshow(img)
        ax.set_title(f"Step {i}")
        ax.axis('off')
    plt.tight_layout()
    plt.show()


In [17]:

def print_img_fwd_process(img):
    # img = Image.open('../dog.jpg').convert('RGB')
    transform = T.Compose([
        T.Resize((128, 128)),         # or 128x128 etc., match your model
        T.ToTensor(),               # Converts to [0, 1]
        T.Normalize(0.5, 0.5)       # Maps to [-1, 1]
    ])
    x0 = transform(img).unsqueeze(0).to('cpu')

    noisy_mg = simulate_vp_sde(x0, steps=100)
    show_images(noisy_mg, every=10)

In [None]:
img = Image.open('../dog.jpg')
print_img_fwd_process(img)

FileNotFoundError: [Errno 2] No such file or directory: 'dog.jpg'