In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, time_emb_dim=128):
        super().__init__()
        #encoder
        self.encoder1 = self.convolution_block(in_channels, out_channels)
        self.encoder2 = self.convolution_block(out_channels, out_channels*2)

        # Bottleneck
        self.bottleneck = self.convolution_block(out_channels * 2, out_channels * 4)

        #decoder
        self.decoder2 = self.up_block(out_channels * 4, out_channels * 2)
        self.decoder1 = self.up_block(out_channels * 2, out_channels)

        # Final output
        self.last_conv = nn.Conv2d(out_channels, 1, kernel_size=1)
    
    def convolution_block(self, in_c, out_c):
        return nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    def up_block(self, in_c, out_c):
        return nn.Sequential(
            nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2),
            self.convolution_block(out_c * 2, out_c)
        )
    
    def forward(self, x):
        # return x_pred  
        # # Predicted noise 
        # # maybe need time embedding
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(F.max_pool2d(enc1, 2))
        
        b = self.bottleneck(F.max_pool2d(enc2, 2))

        d2 = self.decoder2[0](b)
        d2 = self.decoder2[1](torch.cat([d2, enc2], dim=1))

        d1 = self.decoder1[0](d2)
        d1 = self.decoder1[1](torch.cat([d1, enc1], dim=1))

        return self.last_conv(d1)
    

model = UNet()
x = torch.randn(1, 3, 128, 128)  # batch_size=1, 1 channel, 128x128 image
y = model(x)
print(y.shape)  # Output: torch.Size([1, 1, 128, 128])

torch.Size([1, 1, 128, 128])


In [None]:

    # def beta_schedule(self, x):
    #     pass

    # def get_sigma_schedule(T, sigma_min=0.01, sigma_max=1.0):
    #    return torch.linspace(sigma_max, sigma_min, T)
    
    # def training_step(model, x0, sigmas, criterion):
    #     # Sample a timestep
    #     T = sigmas.shape[0]
    #     batch_size = x0.size(0)
    #     t = torch.randint(0, T, (batch_size,), device=x0.device)
    #     sigma_t = sigmas[t].to(x0.device)  # [B]

    #     # Sample noise and create noisy input
    #     noise = torch.randn_like(x0)
    #     x_noisy = x0 + sigma_t[:, None, None, None] * noise  # [B, C, H, W]

    #     # Predict noise
    #     pred_noise = model(x_noisy, sigma_t)

    #     # VP Loss: MSE between true noise and predicted noise
    #     loss = criterion(pred_noise, noise)
    #     return loss
    
    # def get_precondition(self, sigma, M=1000): #based on table1 EDM
    #     sigma = sigma.view(-1, 1, 1, 1)  # Shape: [B, 1, 1, 1]
    #     c_skip = torch.ones_like(sigma) 
    #     c_out = -sigma
    #     c_in = 1/(torch.sqrt(sigma**2 + 1))

    #     c_noise = (M - 1) * sigma**(-1) # inverse sigma

    #     return c_skip, c_out, c_in, c_noise
    
    # def denoise(self, x, sigma, M=1000):

    #     c_skip, c_out, c_in, c_noise = self.get_precondition(sigma, M)

    #     # Reshape for broadcasting over image
    #     c_in = c_in.view(-1, 1, 1, 1)
    #     c_out = c_out.view(-1, 1, 1, 1)
    #     c_skip = c_skip.view(-1, 1, 1, 1)

    #     c_noise = c_noise.view(-1, 1, 1, 1)

    #     # Predict noise from preconditioned input
    #     x_in = c_in * x
    #     noise_cond = c_noise  # Pass to U-Net
    #     predicted_noise = self(x_in, noise_cond)

    #     # Recover denoised image
    #     x_denoised = c_skip * x + c_out * predicted_noise
    #     return x_denoised