<a href="https://colab.research.google.com/github/Lcocks/DS6050-DeepLearning/blob/main/6HW_U_NET.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>



## Module 6 HW (to be submitted with Module 7 HW): From Building Blocks to U-Net - Understanding Architectural Innovation
(even if you did the homework with your teammates you should submit individually and include the names of your teammates)

### Part 1: Understanding U-Net Architecture (15 points)
U-Net was introduced by Ronneberger et al. (2015) for biomedical image segmentation. The architecture won the ISBI cell tracking challenge 2015 by a large margin.

![U-Net Architecture](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png)
*Figure 1: Original U-Net architecture (Ronneberger et al., 2015)*

### Starter Code: Basic U-Net Implementation

```python
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # Input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)
```

#### Questions to Answer:

**Q1.1 (5 points):** Identify and explain the three major architectural components/concepts from our previous lectures that U-Net combines. For each component:
- Name the concept
- Explain how it's implemented in U-Net
- Describe what problem it solves

**Q1.2 (5 points):** Skip connections in U-Net use **concatenation** rather than **addition** (as in ResNet).
- What are the implications of this design choice?
- How does this affect the number of parameters?
- When might addition be preferred over concatenation?

**Q1.3 (5 points):** Critical Analysis:
- What are the potential limitations of the original U-Net architecture?
- How might U-Net struggle with very high-resolution images?
- Propose at least two modifications that could address these limitations.

---

### Part 2: Architectural Variants (15 points)

#### Variant 1: U-Net++ (Nested U-Net)

U-Net++ introduces nested skip connections to reduce the semantic gap between encoder and decoder features.

```python
class UNetPlusPlus(nn.Module):
    """
    U-Net++: A Nested U-Net Architecture
    Zhou et al., 2018
    """
    def __init__(self, in_channels=3, out_channels=1):
        super().__init__()
        
        # Define the nested structure
        # X^{0,0} -> X^{1,0} -> X^{2,0} -> X^{3,0} -> X^{4,0}
        #    |         |         |         |
        # X^{0,1} -> X^{1,1} -> X^{2,1} -> X^{3,1}
        #    |         |         |
        # X^{0,2} -> X^{1,2} -> X^{2,2}
        #    |         |
        # X^{0,3} -> X^{1,3}
        #    |
        # X^{0,4}
        
        nb_filter = [32, 64, 128, 256, 512]
        
        # Nested skip connections would be implemented here
        # This is a simplified structure for illustration
        pass
```

#### Variant 2: V-Net (Volumetric Convolutional Neural Networks)

V-Net extends U-Net to 3D volumetric data and adds residual connections within each stage.

```python
class VNetBlock(nn.Module):
    """
    V-Net block with residual connection
    Milletari et al., 2016
    """
    def __init__(self, in_channels, out_channels, num_convs):
        super().__init__()
        self.convs = nn.ModuleList([
            nn.Conv3d(in_channels if i == 0 else out_channels,
                     out_channels, kernel_size=3, padding=1)
            for i in range(num_convs)
        ])
        self.activation = nn.PReLU(out_channels)
        
        # Residual connection
        self.residual = nn.Conv3d(in_channels, out_channels, kernel_size=1) \
                        if in_channels != out_channels else nn.Identity()
    
    def forward(self, x):
        residual = self.residual(x)
        
        out = x
        for conv in self.convs:
            out = self.activation(conv(out))
        
        return out + residual
```

#### Questions to Answer:

**Q3.1 (5 points):** U-Net++ Analysis:
- What problem does the nested skip connection architecture solve?
- How does it reduce the "semantic gap"?
- What is the computational cost compared to standard U-Net?

**Q3.2 (5 points):** V-Net Contributions:
- Besides 3D extension, what are the key innovations in V-Net?
- Why are residual connections particularly important for volumetric data?
- How does the Dice loss implementation differ for 3D data?

**Q3.3 (5 points):** Comparative Analysis:
Create a comparison table with at least 5 criteria comparing:
- Standard U-Net
- U-Net++
- V-Net

Consider: parameter count, memory usage, suitable applications, training difficulty, and inference speed.


---

## Resources

### Papers
- [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597)
- [UNet++: A Nested U-Net Architecture](https://arxiv.org/abs/1807.10165)
- [V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation](https://arxiv.org/abs/1606.04797)

### Implementations
- [PyTorch U-Net](https://github.com/milesial/Pytorch-UNet)
- [Segmentation Models PyTorch](https://github.com/qubvel/segmentation_models.pytorch)
- [MONAI (Medical Imaging)](https://monai.io/)

---