# Problem: Train a 3D CNN network for CT images
### Problem Statement
You are tasked with employing and evaluating a 3D CNN model in Pytorch for semantic segmentation on synthetically generated CT images. 
Your goal is to review the input and label data shapes. Next, define a MedCNN model class with a `forward` method that emulates a encode-decoder architecture with appropriate input and output channels based on the input shapes.   

### Requirements
1. **Implement** a MedCNN model class with Conv3D and ConvTranspose3d for downsampling and upsampling respectively
2. **Perform** transfer learning from a ResNet18 - a common strategy for custom architectures
3. **Use** dice loss and train the model for 5 epochs.
### Constraints
- Use `Pytorch` in-built convolution layers
- Ensure, there is a segmentation head at the end of the network


<details>
  <summary>💡 Hint</summary>
  - Strip off the `Avgpooling` and linear layers from ResNet18 using `list(resnet_model.children())[:-2]`
  <br>
  - [Conv3D](https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html)
  <br>
  - [ConvTranspose3D](https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html)
  <br>
  - [Forum discussion on model.children](https://discuss.pytorch.org/t/module-children-vs-module-modules/4551)
</details>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

In [2]:
# Generate synthetic CT-scan data (batches, slices, RGB) and associated segmentation masks
torch.manual_seed(42)
batch = 100
num_slices = 10
channels = 3
width = 256
height = 256

ct_images = torch.randn(size=(batch, num_slices, channels, width, height))
segmentation_masks = (torch.randn(size=(batch, num_slices, 1, width, height))>0).float()

print(f"CT images (train examples) shape: {ct_images.shape}")
print(f"Segmentation binary masks (labels) shape: {segmentation_masks.shape}")

CT images (train examples) shape: torch.Size([100, 10, 3, 256, 256])
Segmentation binary masks (labels) shape: torch.Size([100, 10, 1, 256, 256])


In [3]:
# Define the MedCNN class and its forward method
class MedCNN(nn.Module):
    def __init__(self, backbone, out_channel=1):
        super(MedCNN, self).__init__()
        self.backbone = backbone
        
        #Downsample
        self.conv1 = nn.Conv3d(512, 64, kernel_size=(3, 3, 3), padding=1)
        self.conv2 = nn.Conv3d(64, 64, kernel_size=(3, 3, 3), padding=1)
        
        #Upsample
        self.conv_transpose1 = nn.ConvTranspose3d(64, 32, kernel_size=(1, 4, 4), stride=(1, 4, 4))
        self.conv_transpose2 = nn.ConvTranspose3d(32, 16, kernel_size=(1, 8, 8), stride=(1, 8, 8))
        
        #Final convolution layer from 16 to 1 channel
        self.final_conv = nn.Conv3d(16, out_channel, kernel_size=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        b, d, c, w, h = x.size() #Input size: [B, D, C, W, H]
        print(f"Input shape [B, D, C, W, H]: {b, d, c, w, h}")
        
        x = x.view(b*d, c, w, h) #Input to Resent 2DConv layers [B*D, C, W, H]
        features = self.backbone(x)
        print(f"ResNet output shape[B*D, C, W, H]: {features.shape}")
        
        _, new_c, new_w, new_h = features.size()
        x = features.view(b, d, new_c, new_w, new_h) #[B, D, C, W, H]
        x = torch.permute(x, (0, 2, 1, 3, 4)) #rearrange for 3DConv layers [B, C, D, W, H]
        print(f"Reshape Resnet output for 3DConv #1 [B, C, D, W, H]: {x.shape}")
        
        #Downsampling
        x = self.relu(self.conv1(x))
        print(f"Output shape 3D Conv #1: {x.shape}")
        x = self.relu(self.conv2(x))
        print(f"Output shape 3D Conv #2: {x.shape}")
        
        #Upsampling
        x = self.relu(self.conv_transpose1(x))
        print(f"Output shape 3D Transposed Conv #1: {x.shape}")
        x = self.relu(self.conv_transpose2(x))
        print(f"Output shape 3D Transposed Conv #2: {x.shape}")

        #final segmentation
        x = torch.sigmoid(self.final_conv(x))
        print(f"Final shape: {x.shape}")
        
        return x

In [4]:
resnet_model = torchvision.models.resnet18(pretrained=True)
resnet_model = nn.Sequential(*list(resnet_model.children())[:-2])
model = MedCNN(backbone=resnet_model)



In [5]:
x = model(ct_images)

Input shape [B, D, C, W, H]: (100, 10, 3, 256, 256)
ResNet output shape[B*D, C, W, H]: torch.Size([1000, 512, 8, 8])
Reshape Resnet output for 3DConv #1 [B, C, D, W, H]: torch.Size([100, 512, 10, 8, 8])
Output shape 3D Conv #1: torch.Size([100, 64, 10, 8, 8])
Output shape 3D Conv #2: torch.Size([100, 64, 10, 8, 8])
Output shape 3D Transposed Conv #1: torch.Size([100, 32, 10, 32, 32])
Output shape 3D Transposed Conv #2: torch.Size([100, 16, 10, 256, 256])
Final shape: torch.Size([100, 1, 10, 256, 256])
