In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PositionalNoiseEmbedding(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.embedding_dim = embedding_dim

    def forward(self, sigma):
        device = sigma.device # sigma is noise level
        half_dim = self.embedding_dim // 2
        emb = torch.log(torch.tensor(10000.0, device=device)) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = sigma[:, None] * emb[None, :]  # (B, half_dim)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)  # (B, embedding_dim)
        return emb

class UNetDDPMpp(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, time_emb_dim=128):
        super().__init__()
# Time embedding
        self.time_embed = PositionalNoiseEmbedding(time_emb_dim)
        self.time_emb_proj1 = nn.Linear(time_emb_dim, out_channels)
        self.time_emb_proj2 = nn.Linear(time_emb_dim, out_channels)
        self.time_emb_proj3 = nn.Linear(time_emb_dim, out_channels * 2)

    
        #encoder
        self.encoder1 = self.convolution_block(in_channels, out_channels)
        self.encoder2 = self.convolution_block(out_channels, out_channels*2)

        # Bottleneck
        self.bottleneck = self.convolution_block(out_channels * 2, out_channels * 4)

        #decoder
        self.decoder2 = self.up_block(out_channels * 4, out_channels * 2)
        self.decoder1 = self.up_block(out_channels * 2, out_channels)

        # Final output
        self.last_conv = nn.Conv2d(out_channels, 1, kernel_size=1)


    
    def convolution_block(self, in_c, out_c):
        return nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    def up_block(self, in_c, out_c):
        return nn.Sequential(
            nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2),
            self.convolution_block(out_c * 2, out_c)
        )  
    
    def forward(self, x, sigma):
        print("s", sigma)
        # return x_pred  
        # # Predicted noise 
        # # maybe need time embedding
        t_emb = self.time_embed(sigma)  # [batch, time_emb_dim]

        # Inject time embedding into each stage
        emb1 = self.time_emb_proj1(t_emb).unsqueeze(-1).unsqueeze(-1)  # [B, C, 1, 1]
        emb2 = self.time_emb_proj2(t_emb).unsqueeze(-1).unsqueeze(-1)
        emb3 = self.time_emb_proj3(t_emb).unsqueeze(-1).unsqueeze(-1)

        print('enc2',self.encoder2)
        enc1 = self.encoder1(x + emb1)
        print('ss',(F.max_pool2d(enc1, 2)).shape)
        print('sdf', emb2.shape)

        enc2 = self.encoder2(F.max_pool2d(enc1, 2) + emb2)
        
        b = self.bottleneck(F.max_pool2d(enc2, 2) + emb3)

        d2 = self.decoder2[0](b)
        d2 = self.decoder2[1](torch.cat([d2, enc2], dim=1))

        d1 = self.decoder1[0](d2)
        d1 = self.decoder1[1](torch.cat([d1, enc1], dim=1))

        return self.last_conv(d1)

    
    def beta_schedule(self, x):
        pass

model = UNetDDPMpp()
x = torch.randn(1, 3, 128, 128)  # batch_size=1, 1 channel, 128x128 image
sigma = torch.randn(1)  # Random noise schedule for batch size 1

y = model(x, sigma)
print(y.shape)  # Output: torch.Size([1, 1, 128, 128])



s tensor([-0.1273])
enc2 Sequential(
  (0): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(6, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
)
ss torch.Size([1, 3, 64, 64])
sdf torch.Size([1, 3, 1, 1])
torch.Size([1, 1, 128, 128])


In [None]:
from models.unet import UNet

class DDPM(nn.Module):
    def __init__(self, model, beta_schedule):
        super().__init__()
        self.model = model
        self.beta_schedule = beta_schedule  # schedule based on schedule

    def forward(self, x_t, t):
        return self.model(x_t, t)  # Predict noise

    def sample(self, num_samples):
        # can be deterministic or statistic
        pass
