# Problem: Implement parameter initialization strategies for a CNN model in Pytorch

### Problem Statement
You are tasked with employing and evaluating a CNN model\'s parameter initialization strategies in Pytorch. 
Your goal is to initialize the weights and biases of a vanilla CNN model provided in the problem statement and comment on the implications of each strategy.

### Requirements
1. **Initialize** weights and biases in the following ways:
   - **Zero Initialization**: set the parameters to zero
   - **Random Initialization**: sets model parameters to random values drawn from a normal distribution 
   - **Xavier Initialization** sets them to random values from a normal distribution with **mean=0 and variance=1\/n**
   - **Kaiming He Initialization** initializes to random values from a normal distribution with **mean=0 and variance=2\/n**
2. Train and compute accuracy for each strategy
### Constraints
- Use the given CNN model and the training and testing helper functions for accuracy computations.
- Ensure the model is compatible with the CIFAR-10 dataset, which contains 10 classes.


<details>
  <summary>💡 Hint</summary>
  - Use `torch.nn.init` for weight initialization
  <br>
  - Resources to read: [All you need is a good init](https://arxiv.org/pdf/1511.06422)
</details>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

In [2]:
# Generate synthetic CT-scan data (batches, slices, RGB)
torch.manual_seed(42)
batch = 100
num_slices = 10
channels = 3
width = 256
height = 256

ct_images = torch.randn(size=(batch, num_slices, channels, width, height))
# y = torch.sin(X)
ct_images_reshape = ct_images.view(batch*num_slices, channels, width, height)
print(f"Input image dataset shape: {ct_images.shape} and {ct_images_reshape.shape}")

Input image dataset shape: torch.Size([100, 10, 3, 256, 256]) and torch.Size([1000, 3, 256, 256])


In [3]:
# Define the RNN Model
class MedCNN(nn.Module):
    def __init__(self, resnet_model):
        super(MedCNN, self).__init__()
        self.resnet_model = resnet_model
        self.conv1 = nn.Conv3d(512, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1))
        self.conv2 = nn.Conv3d(64, 32, kernel_size=(3, 1, 1), stride=(1, 1, 1))
        self.transconv1 = nn.ConvTranspose3d(32, 3, kernel_size=(4, 16, 16), stride=(3, 16, 16))
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool3d(2, 2)
        self.fc1 = nn.Linear(32*3*4*4, 128*128)
        self.fc2 = nn.Linear(128*128, 128*128)
        self.dropout = nn.Dropout(0.4)

    def forward(self, x):
        b, d, c, w, h = x.size() #Input size: [B, D, C, W, H]
        print(f"Input shape [B, D, C, W, H]: {b, d, c, w, h}")
        x = x.view(b*d, c, w, h) #Input to Resent 2DConv layers [B*D, C, W, H]
        features = self.resnet_model(x)
        print(f"ResNet output shape[B*D, C, W, H]: {features.shape}")
        _, new_c, new_w, new_h = features.size()
        x = features.view(b, d, new_c, new_w, new_h) #[B, D, C, W, H]
        x = torch.permute(x, (0, 2, 1, 3, 4)) #rearrange for 3DConv layers [B, C, D, W, H]
        print(f"Reshape Resnet output for 3DConv #1 [B, C, D, W, H]: {x.shape}")
        x = self.relu(self.conv1(x))
        print(f"Output shape 3D Conv #1 layer: {x.shape}")
        x = self.relu(self.pool(self.conv2(x)))
        print(f"Output shape 3D Conv #2 layer: {x.shape}")
        x = self.relu(self.transconv1(x))
        print(f"Output shape 3D Transposed Conv #1 layer: {x.shape}")
        # x = x.view(x.size()[0], -1)
        # x = self.relu(self.fc1(x))
        # print(f"Output shape Linear #1 layer: {x.shape}")
        # x = self.dropout(x)
        # x = self.fc2(x)
        # print(f"Output shape Linear #2 layer: {x.shape}")
        
        return features

In [4]:
resnet_model = torchvision.models.resnet18(pretrained=True)
resnet_model = nn.Sequential(*list(resnet_model.children())[:-2])
model = MedCNN(resnet_model)



In [5]:
x = model(ct_images)

Input shape [B, D, C, W, H]: (100, 10, 3, 256, 256)
ResNet output shape[B*D, C, W, H]: torch.Size([1000, 512, 8, 8])
Reshape Resnet output for 3DConv #1 [B, C, D, W, H]: torch.Size([100, 512, 10, 8, 8])
Output shape 3D Conv #1 layer: torch.Size([100, 64, 8, 8, 8])
Output shape 3D Conv #2 layer: torch.Size([100, 32, 3, 4, 4])
Output shape 3D Transposed Conv #1 layer: torch.Size([100, 3, 10, 64, 64])
