<a href="https://colab.research.google.com/github/Lcocks/DS6050-DeepLearning/blob/main/6HW_U_NET.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>



## Module 6 HW (to be submitted with Module 7 HW): From Building Blocks to U-Net - Understanding Architectural Innovation
(even if you did the homework with your teammates you should submit individually and include the names of your teammates)

### Part 1: Understanding U-Net Architecture (15 points)
U-Net was introduced by Ronneberger et al. (2015) for biomedical image segmentation. The architecture won the ISBI cell tracking challenge 2015 by a large margin.

![U-Net Architecture](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png)
*Figure 1: Original U-Net architecture (Ronneberger et al., 2015)*

### Starter Code: Basic U-Net Implementation

```python
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # Input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)
```

#### Questions to Answer:

**Q1.1 (5 points):** Identify and explain the three major architectural components/concepts from our previous lectures that U-Net combines. For each component:
- Name the concept
- Explain how it's implemented in U-Net
- Describe what problem it solves

**Q1.2 (5 points):** Skip connections in U-Net use **concatenation** rather than **addition** (as in ResNet).
- What are the implications of this design choice?
- How does this affect the number of parameters?
- When might addition be preferred over concatenation?

**Q1.3 (5 points):** Critical Analysis:
- What are the potential limitations of the original U-Net architecture?
- How might U-Net struggle with very high-resolution images?
- Propose at least two modifications that could address these limitations.

---

### Part 2: Architectural Variants (15 points)

#### Variant 1: U-Net++ (Nested U-Net)

U-Net++ introduces nested skip connections to reduce the semantic gap between encoder and decoder features.

```python
class UNetPlusPlus(nn.Module):
    """
    U-Net++: A Nested U-Net Architecture
    Zhou et al., 2018
    """
    def __init__(self, in_channels=3, out_channels=1):
        super().__init__()
        
        # Define the nested structure
        # X^{0,0} -> X^{1,0} -> X^{2,0} -> X^{3,0} -> X^{4,0}
        #    |         |         |         |
        # X^{0,1} -> X^{1,1} -> X^{2,1} -> X^{3,1}
        #    |         |         |
        # X^{0,2} -> X^{1,2} -> X^{2,2}
        #    |         |
        # X^{0,3} -> X^{1,3}
        #    |
        # X^{0,4}
        
        nb_filter = [32, 64, 128, 256, 512]
        
        # Nested skip connections would be implemented here
        # This is a simplified structure for illustration
        pass
```

#### Variant 2: V-Net (Volumetric Convolutional Neural Networks)

V-Net extends U-Net to 3D volumetric data and adds residual connections within each stage.

```python
class VNetBlock(nn.Module):
    """
    V-Net block with residual connection
    Milletari et al., 2016
    """
    def __init__(self, in_channels, out_channels, num_convs):
        super().__init__()
        self.convs = nn.ModuleList([
            nn.Conv3d(in_channels if i == 0 else out_channels,
                     out_channels, kernel_size=3, padding=1)
            for i in range(num_convs)
        ])
        self.activation = nn.PReLU(out_channels)
        
        # Residual connection
        self.residual = nn.Conv3d(in_channels, out_channels, kernel_size=1) \
                        if in_channels != out_channels else nn.Identity()
    
    def forward(self, x):
        residual = self.residual(x)
        
        out = x
        for conv in self.convs:
            out = self.activation(conv(out))
        
        return out + residual
```

#### Questions to Answer:

**Q3.1 (5 points):** U-Net++ Analysis:
- What problem does the nested skip connection architecture solve?
- How does it reduce the "semantic gap"?
- What is the computational cost compared to standard U-Net?

**Q3.2 (5 points):** V-Net Contributions:
- Besides 3D extension, what are the key innovations in V-Net?
- Why are residual connections particularly important for volumetric data?
- How does the Dice loss implementation differ for 3D data?

**Q3.3 (5 points):** Comparative Analysis:
Create a comparison table with at least 5 criteria comparing:
- Standard U-Net
- U-Net++
- V-Net

Consider: parameter count, memory usage, suitable applications, training difficulty, and inference speed.


---

## Resources

### Papers
- [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597)
- [UNet++: A Nested U-Net Architecture](https://arxiv.org/abs/1807.10165)
- [V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation](https://arxiv.org/abs/1606.04797)

### Implementations
- [PyTorch U-Net](https://github.com/milesial/Pytorch-UNet)
- [Segmentation Models PyTorch](https://github.com/qubvel/segmentation_models.pytorch)
- [MONAI (Medical Imaging)](https://monai.io/)

---

Q1.1 (5 points): Identify and explain the three major architectural components/concepts from our previous lectures that U-Net combines. For each component:

Name the concept - 

    The encoder-decoder architecture. The Convolutional blocks following convolutional, batchnorm, activation (ReLU). The copy and crop (skip connection).

Explain how it's implemented in U-Net - 

    The encoder-decoder architecture is a downscaliung and upscaling which can create a bottleneck to reduce noise and help focus on the important features. This downsampling reduces dimensionality and increases feature channels available while the upscaling does the opposite to bring back spatial resolution.
    The convolutional block is implemented with a double convbatchrelu to further reduce overfitting (adding regularization), these are repeadetly 3x3 convolutions with padding=1 so we are not losing spatial dimensions but are learning more features.
    The copy and crop is checked at the Up() forward pass where we concatenate features channel-wise. These include unsampled and encoded features.  
  
Describe what problem it solves

    The architecture is extracting features while maintaining spatial understanding, which was previously not done. Where the encoder is giving context to the features and the decoder is giving localiztion via spatial dimensions of the pixels.
    The ConvBatNormReLU double block is adding regularization and accelarting convergence in fewer runs. Here our block and doubling further increases the allowance of increased learning rate with efficiency.
    The copy and crop allows the model to retain information loss regarding spatial structure. This will help the localization of features. 

Q1.2 (5 points): Skip connections in U-Net use concatenation rather than addition (as in ResNet).

What are the implications of this design choice?

    U-net can relate the unsampled features with the encoded features rather than adding them together. This implies an information retention whereas the addition would immediately combine the features and may lose information that could otherwise be gained in further passes.

How does this affect the number of parameters?

    This effectively doubles the number of parameters since we are doubling the channel dimension.

When might addition be preferred over concatenation?

    Concatentation can lead to overfitting on small data sources so this can be reduced and going hand in hand if you are computational restricted then being half the number of parameters will allows smoother running for the model (more stable with less data).

Q1.3 (5 points): Critical Analysis:

What are the potential limitations of the original U-Net architecture?

    The original architecture still has a semantic gap in high-level semantic features (what/where objects are) while maintaining spatial precision. The functionality also collapses with higher resolution images (extremely inefficient), in effect also quadrupling the memory usage when you double the resolution. 

How might U-Net struggle with very high-resolution images?

    High resolution images would be extremely inefficient and resource intensive to run the basic U-net on. As mentioned doubling the resolution (say from 512 to 1024) would quadruple the memory usage so a ~2GB usage would turn to ~8GB. Global relationships are also not modeled (only local for the most part), so when you have a very large or detailed image there is a heavy loss of relevant information patterns. 

Propose at least two modifications that could address these limitations.

    1 modification would be implementning a depth-wise separable convolutions alongside the concatenation since that would reduce the number of parameters but in doing so you lose unrealized accuracy. 
    Another would be using attention-gates like in module 8 where we have more specialized feature extraction. 



Q3.1 (5 points): U-Net++ Analysis:

What problem does the nested skip connection architecture solve?

    It helps reduce the semantic gap between the encoder and decoder meaning that information learned in the encoder is at first simpler patterns like edges and corners but at the same time the decoder can be recognizing high-level patterns like class specific patterns (what a curve of a number is for example).

How does it reduce the "semantic gap"?

    The defined architecture reduced the semantic gap by providing smaller jumps from each semantic layer as to create a more refined understaing. The nested architecture helps bridge this gap by relating more relavent features together.

What is the computational cost compared to standard U-Net?

    Compared to U-net U-net++ is more resource intensive but not as much as say adding the features at each skip instead of concatenating them. It addes about 1.4x the number of parameters as U-Net (as seen in the U-Net++ paper). Overall though this is not a large leap.

Q3.2 (5 points): V-Net Contributions:

Besides 3D extension, what are the key innovations in V-Net?

    V-net only uses the first layer for a channels transformation then uses the subsequent layers while mainting the same number of channels. Each stage then has a residual connection that is added back to the output. These can help the gradient flow from collapsing. So the each channel is learning its own activation vector.

Why are residual connections particularly important for volumetric data?

    Residual connections will help preserve the gradient flow and prevent vanishing/exploding (moreso concerned with vanishing here). With volumetric data we are going deeper rather then wider with more channels, this allows us to handle the heavily increased volume of data given our residuals are capturing the necessary information.

How does the Dice loss implementation differ for 3D data?

    The dice coefficient for loss is used to better handle rare occurences, which is the case with medical imaging where most of the image information is irrelevant to the wanted predictor. It is setup as a comparison between the ground truth and the predicted segmentation one a scale from 0-1. For 3D data this means the it compares the spatial volume of the predicted portion and the ground truth portion of volume.

Q3.3 (5 points): Comparative Analysis: Create a comparison table with at least 5 criteria comparing:

Standard U-Net
U-Net++
V-Net
Consider: parameter count, memory usage, suitable applications, training difficulty, and inference speed.

| Criterion | Standard U-Net | U-Net++ | V-Net |
|-----------|---------------|---------|-------|
| **Architecture & Skip Connections** | Direct encoder/decoder concatenation | Nested connections with multiple decoder paths | Direct concatenation & residual connections added at each step |
| **Dimensionality & Memory** | 2D (~2.3 GB for 512×512) | 2D (~3.8 GB, 1.4-6× more) | 3D (~4.8 GB for 128³, scales cubically as larger base data) |
| **Inference Speed** | Fastest: 23 ms (baseline) | Moderate: 35 ms (1.5× slower) | Slowest: 180 ms (7× slower) |
| **Application** | General-purpose 2D segmentation, real-time applications, obvious separation of patterns | Complex 2D boundaries, high-accuracy requirements, ambiguous patterns | 3D medical imaging, volumetric data |
| **Training Difficulty/Data Needs** | Easier (~1k 2D images, straightforward) | Medium (~2k+ images, prone to overfitting on small data, needs deep supervision) | Hardest (50-200 3D volumes, memory intensive, much longer training) |
