In [16]:
import torch
import torch.nn as nn

In [17]:
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = torch.log(torch.tensor(10000, device=device)) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

In [18]:
class TimeAwareBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim):
        super().__init__()
        # Time projection: turns the time vector into the same size as our channels
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
        
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x, t_emb):
        # 1. First convolution
        h = self.relu(self.conv1(x))
        
        # 2. Process time: Linear -> Add dimension to match (B, C, 1, 1)
        time_feat = self.relu(self.time_mlp(t_emb))
        time_feat = time_feat[(...,) + (None,) * 2] # This makes it (Batch, Out_Ch, 1, 1)
        
        # 3. Inject time! We add the time info to every pixel in the feature map
        h = h + time_feat
        
        # 4. Second convolution
        return self.relu(self.conv2(h))

In [19]:
class TinyUNet(nn.Module):
    def __init__(self):
        super().__init__()

        time_dim=256

        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_dim),
            nn.Linear(time_dim, time_dim),
            nn.GELU()
        )
        
        # 1. ENCODER (The Downward path)
        self.enc1 = TimeAwareBlock(1, 16,time_dim)   # 28x28 -> 28x28
        self.pool1 = nn.MaxPool2d(2)         # 28x28 -> 14x14
        
        self.enc2 = TimeAwareBlock(16, 32,time_dim)  # 14x14 -> 14x14
        self.pool2 = nn.MaxPool2d(2)         # 14x14 -> 7x7
        
        # 2. BOTTLENECK
        self.mid = TimeAwareBlock(32, 64,time_dim)   # 7x7
        
        # 3. DECODER (The Upward path)
        self.up2 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2) # 7x7 -> 14x14
        self.dec2 = TimeAwareBlock(64, 32,time_dim) # Why 64? Because we concatenate!
        
        self.up1 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2) # 14x14 -> 28x28
        self.dec1 = TimeAwareBlock(32, 16,time_dim)
        
        self.final_conv = nn.Conv2d(16, 1, kernel_size=1) # Back to 1 channel (MNIST)

    # def Conv_block(self, in_ch, out_ch):
    #     return nn.Sequential(
    #         nn.Conv2d(in_ch, out_ch, 3, padding=1),
    #         nn.ReLU(),
    #         nn.Conv2d(out_ch, out_ch, 3, padding=1),
    #         nn.ReLU()
    #     )

    def forward(self, x, time):
        time_dim=self.time_mlp(time)
        print(f"time_dim Shape{time_dim.shape}")
        # Step 1: Encoder
        e1 = self.enc1(x, time_dim)
        print(f"e1 Shape{e1.shape}")
        p1 = self.pool1(e1)
        print(f"p1 Shape{p1.shape}")
        
        e2 = self.enc2(p1,time_dim)
        print(f"e2 Shape{e2.shape}")
        p2 = self.pool2(e2)
        print(f"p2 Shape{p2.shape}")
        
        # Step 2: Bottleneck
        m = self.mid(p2,time_dim)
        print(f"m Shape{m.shape}")
        
        # Step 3: Decoder with Skip Connections
        # Hint: use torch.cat([upsampled_tensor, encoder_tensor], dim=1)
        
        u2= self.up2(m)
        print(f"u1 Shape{u2.shape}")
        x = torch.cat([u2, e2], dim=1)
        print(f"x Shape{x.shape}") 
        x = self.dec2(x, time_dim)
        print(f"x Shape{x.shape}")
        
        u1= self.up1(x)
        print(f"u2 Shape{u1.shape}")
        x= torch.cat([u1,e1], dim=1)
        print(f"x Shape{x.shape}")
        x= self.dec1(x,time_dim)
        print(f"x Shape{x.shape}")
        x = self.final_conv(x)
        print(f"final Shape{x.shape}") 
        
        
        return x

In [20]:
# 1. Initialize your model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = TinyUNet().to(device)

# 2. Create "Mock" data
# Shape: (Batch Size=8, Channels=1, Height=28, Width=28)
batch_size = 8
mock_images = torch.randn(batch_size, 1, 28, 28).to(device)

# Shape: (Batch Size=8,) - representing random timesteps t
mock_time = torch.randint(0, 1000, (batch_size,)).to(device)

# 3. Pass it through the model
try:
    output = model(mock_images, mock_time)
    print("✅ Success!")
    print(f"Input shape:  {mock_images.shape}")
    print(f"Output shape: {output.shape}")
    
    # Check if shapes match
    assert output.shape == mock_images.shape, "Shape Mismatch!"
    
except Exception as e:
    print("❌ Model Crashed!")
    print(e)

time_dim Shapetorch.Size([8, 256])
e1 Shapetorch.Size([8, 16, 28, 28])
p1 Shapetorch.Size([8, 16, 14, 14])
e2 Shapetorch.Size([8, 32, 14, 14])
p2 Shapetorch.Size([8, 32, 7, 7])
m Shapetorch.Size([8, 64, 7, 7])
u1 Shapetorch.Size([8, 32, 14, 14])
x Shapetorch.Size([8, 64, 14, 14])
x Shapetorch.Size([8, 32, 14, 14])
u2 Shapetorch.Size([8, 16, 28, 28])
x Shapetorch.Size([8, 32, 28, 28])
x Shapetorch.Size([8, 16, 28, 28])
final Shapetorch.Size([8, 1, 28, 28])
✅ Success!
Input shape:  torch.Size([8, 1, 28, 28])
Output shape: torch.Size([8, 1, 28, 28])


In [21]:
total_params = sum(p.numel() for p in model.parameters())
print(f"Total Parameters: {total_params:,}")

if total_params > 1500000:
    print("⚠️ Warning: Model is a bit large for the 'tiny' requirement.")

Total Parameters: 223,665


In [22]:
# 1. Run forward pass
output = model(mock_images, mock_time)

# 2. Calculate a fake loss (mean of the output)
loss = output.mean()

# 3. Calculate gradients
loss.backward()

# 4. Check if the very first layer has gradients
# Replace 'enc1' with whatever you named your first layer
first_layer_grad = model.enc1.conv1.weight.grad
if first_layer_grad is not None:
    print("✅ Gradients are flowing correctly to the start of the model!")
else:
    print("❌ Gradients are blocked. Check your skip connections!")

time_dim Shapetorch.Size([8, 256])
e1 Shapetorch.Size([8, 16, 28, 28])
p1 Shapetorch.Size([8, 16, 14, 14])
e2 Shapetorch.Size([8, 32, 14, 14])
p2 Shapetorch.Size([8, 32, 7, 7])
m Shapetorch.Size([8, 64, 7, 7])
u1 Shapetorch.Size([8, 32, 14, 14])
x Shapetorch.Size([8, 64, 14, 14])
x Shapetorch.Size([8, 32, 14, 14])
u2 Shapetorch.Size([8, 16, 28, 28])
x Shapetorch.Size([8, 32, 28, 28])
x Shapetorch.Size([8, 16, 28, 28])
final Shapetorch.Size([8, 1, 28, 28])
✅ Gradients are flowing correctly to the start of the model!
