In [10]:
import os
from typing import *

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from IPython.display import Image, display, clear_output
from sklearn.manifold import TSNE
from torch import Tensor
from torch.distributions import Normal
from torchvision.utils import make_grid
import torch.optim as optim



In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F


def double_conv_unpadded(in_channels, out_channels):
    """
    Two consecutive 3x3 unpadded convolutions (padding=0).
    Total size reduction per block is 4 pixels (2 from first conv + 2 from second conv).
    """
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )


class EncoderBlock(nn.Module):
    """
    U-Net Encoder step: Conv-Conv block followed by Max Pooling.
    Returns both the feature map (for skip) and the pooled map (for next block).
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()
        # Conv block
        self.conv = double_conv_unpadded(in_channels, out_channels)
        # Pooling layer 
        self.pool = nn.MaxPool2d(2)

    def forward(self, x):
        conv_output = self.conv(x)
        pool_output = self.pool(conv_output)
        return conv_output, pool_output


class DecoderBlock(nn.Module):
    """
    U-Net Decoder step: Up-convolution, Cropping, Concatenation, and Refinement.
    """
    def __init__(self, in_channels, skip_channels, out_channels):
        super().__init__()
        
        self.up = nn.ConvTranspose2d(
            in_channels, 
            out_channels, 
            kernel_size=2, 
            stride=2
        )
        self.conv = double_conv_unpadded(out_channels + skip_channels, out_channels)

    def forward(self, x_bottom, x_skip):
        
        # Double the size of the feature map from the layer below
        x_up = self.up(x_bottom) 
        
        # The skip connection (x_skip) is LARGER than the upsampled map (x_up) in this architecture.
        # We must calculate the difference and crop the larger x_skip to match the smaller x_up.
        diff_h = x_skip.size(2) - x_up.size(2)
        diff_w = x_skip.size(3) - x_up.size(3)
        
        # ensure the sizes match
        if diff_h < 0 or diff_w < 0:
             raise RuntimeError(f"Cropping error: x_up ({x_up.shape}) is unexpectedly larger than x_skip ({x_skip.shape}). Check encoder padding.")

        # apply croping to the larger x_skip tensor to match x_up size
        x_skip_cropped = x_skip[:, :, 
            diff_h // 2 : x_skip.size(2) - diff_h + diff_h // 2, 
            diff_w // 2 : x_skip.size(3) - diff_w + diff_w // 2
        ]
        
        # concatenate the x_up with skipped conection
        x_combined = torch.cat([x_up, x_skip_cropped], dim=1)
        
        # Pass combined map through the convolutional block
        return self.conv(x_combined)



class Encoder(nn.Module):
    """
    Block-based U-Net Encoder: 5 levels of downsampling.
    """
    def __init__(self):
        super().__init__()
        
        # blocks 
        self.block1 = EncoderBlock(1, 64)        # 1 -> 64 channels
        self.block2 = EncoderBlock(64, 128)      # 64 -> 128 channels
        self.block3 = EncoderBlock(128, 256)     # 128 -> 256 channels
        self.block4 = EncoderBlock(256, 512)     # 256 -> 512 channels
        
        # bttleneck 
        self.bottleneck = double_conv_unpadded(512, 1024) 

    def forward(self, x):
        x1, p1 = self.block1(x)
        x2, p2 = self.block2(p1)
        x3, p3 = self.block3(p2)
        x4, p4 = self.block4(p3)
        x5 = self.bottleneck(p4)
        
        return x1, x2, x3, x4, x5
    
class Bottleneck(nn.Module):
    def __init__(self, input_dim, z_dim):
        C, H, W = input_dim
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(C * H * W, z_dim)
        self.af1 = nn.ReLU()
        self.fc2 = nn.Linear(z_dim, C*H*W)
        self.C, self.H, self.W = C, H, W
    def forward(self, x):
        x_flat = self.flatten(x)
        x = self.fc1(x_flat)
        x = self.af1(x)
        x = self.fc2(x)
        return x

class Decoder(nn.Module):
    def __init__(self, num_classes=1): 
        super().__init__()
        
        # 1024 -> 512 (combines with Encoder x4, 512 channels)
        self.upconv4 = DecoderBlock(in_channels=1024, skip_channels=512, out_channels=512) 
        
        # 512 -> 256 (Combines with Encoder x3, 256 channels)
        self.upconv3 = DecoderBlock(in_channels=512, skip_channels=256, out_channels=256) 
        
        # 256 -> 128 (Combines with Encoder x2, 128 channels)
        self.upconv2 = DecoderBlock(in_channels=256, skip_channels=128, out_channels=128) 
        
        # 128 -> 64 (Combines with Encoder x1, 64 channels)
        self.upconv1 = DecoderBlock(in_channels=128, skip_channels=64, out_channels=64) 
        
        self.out_conv = nn.Conv2d(64, num_classes, kernel_size=1) 

    def forward(self, x5_bottleneck, x4_skip, x3_skip, x2_skip, x1_skip):
        d4 = self.upconv4(x5_bottleneck, x4_skip)
        d3 = self.upconv3(d4, x3_skip)
        d2 = self.upconv2(d3, x2_skip)
        d1 = self.upconv1(d2, x1_skip)
        
        return self.out_conv(d1)



class Unet(nn.Module):
    def __init__(self, num_classes=1, in_hid=1, hid_hid=1):
        super().__init__()
        self.encoder = Encoder()
        self.bottleneck = Bottleneck(in_hid, hid_hid)
        self.decoder = Decoder(num_classes)
        
    def forward(self, x):
        # Encoder
        x1, x2, x3, x4, x5 = self.encoder(x)
        print(x5.shape, x4.shape)
        x5 = self.bottleneck(x5)
        
        # Decoder
        output = self.decoder(x5, x4, x3, x2, x1)
        
        return output



if __name__ == '__main__':
    # Test with the input size from your diagram (572x572)
    model = Unet(num_classes=2, in_hid = 28*28, hid_hid=1024) 
    
    # Input is 1 sample, 1 channel, 572x572
    test_input = torch.randn(1, 1, 572, 572) 
    
    # Run the model
    output = model(test_input)
    
    # Expected output size (original size - (4 pixels * 4 downsampling steps) = 572 - 16 = 556? No, cropping handles this)
    # The output size should match the size after the last unpadded double conv:
    # 572 -> 568 (x1)
    # The decoder will align the final output to this size.
    print(f"Unet Output shape: {list(output.shape)}") # Expected: [1, 2, 568, 568]

TypeError: cannot unpack non-iterable int object

In [None]:
# build train_loader, test_laoder, and valid_loader.



In [None]:
EPOCH = 10
lr = 1e-3
optimizer = optim.Adam(lr=lr)
criterion_l = nn.MSELoss()
criterion_u = nn.CrossEntropyLoss
valid_loss = []
train_loss = []
    
Unet.train()

for epoch in range(EPOCH):
    batch_loss = []
    for x_l, y_l in train_loader:
        optimizer.zero_grad()


        x = input # some kind of data splittin here 
        output = Unet(x)



        loss_l.backward()
        optimizer.step()

        batch_loss.append(loss)
        
        loss_


