# Convert 2D Pretrained Model to 3D

In [9]:
import torch
import torchvision.models as models
import torch.nn as nn
import copy

In [None]:
model_2d = models.densenet121(pretrained=True)
model_2d.eval()

DenseNet(
  (features): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): _DenseLayer(
        (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu

In [3]:
model_2d = models.densenet121(pretrained=True)
model_2d.eval()
for name, module in model_2d.features.named_children():
    print("name: ",name)
    print("module: ",module)
    print('------------')



name:  conv0
module:  Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
------------
name:  norm0
module:  BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
------------
name:  relu0
module:  ReLU(inplace=True)
------------
name:  pool0
module:  MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
------------
name:  denseblock1
module:  _DenseBlock(
  (denselayer1): _DenseLayer(
    (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU(inplace=True)
    (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu2): ReLU(inplace=True)
    (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  )
  (denselayer2): _DenseLayer(
    (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stat

In [16]:
def inflate_densenet121_to_3d():

    # Loading pre-trained model
    model_2d = models.densenet121(pretrained=True)
    model_2d.eval()

    #Defining the 3D model class
    class DenseNet1213D(nn.Module):
        def __init__(self, model_2d):
            super(DenseNet1213D, self).__init__()


            def inflate_module(module, is_conv0=False):

                # Is current layer nn.Conv2d? 
                if isinstance(module, nn.Conv2d): 

                    # Extracting parameters from 2d
                    out_channels = module.out_channels
                    in_channels = 1 if is_conv0 else module.in_channels
                    kernel_size = module.kernel_size[0] # For square kernel
                    stride = module.stride[0]
                    padding = module.padding[0]
                    
                    weight_2d = module.weight.data # Shape: (out_channel, in_channel, h, w)
                    if is_conv0:
                        # Use first channel of RGB weights for grayscale
                        weight_2d = weight_2d[:, :1, :, :]

                    depth = kernel_size # Depth dimension

                    weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth # adding new dim at 2nd index -> (out_channels, in_channels, depth, h, w) -> Normalization

                    # Defining the conv3d block
                    conv3d = nn.Conv3d(
                        in_channels=in_channels,
                        out_channels=out_channels,
                        kernel_size=(kernel_size, kernel_size, kernel_size),
                        stride=(stride, stride, stride),
                        padding=(padding,padding,padding)
                    )

                    conv3d.weight.data = weight_3d
                    
                    # For bias
                    if module.bias is not None:
                        conv3d.bias.data = module.bias.data

                    return conv3d
                

                # Handling non-convolutional layers

                # MaxPool layer
                elif isinstance(module, nn.MaxPool2d):
                    return nn.MaxPool3d(
                        kernel_size=module.kernel_size,
                        stride=module.stride,
                        padding=module.padding
                    )

                # AvgPool layer
                elif isinstance(module, nn.AvgPool2d):
                    return nn.AvgPool3d(
                        kernel_size=module.kernel_size,
                        stride=module.stride,
                        padding=module.padding
                    )    
                    
                # BatchNorm Layer    
                elif isinstance(module, nn.BatchNorm2d):
                    # Convert BatchNorm2d to BatchNorm3d
                    batchnorm3d = nn.BatchNorm3d(module.num_features)
                    if hasattr(module, 'weight') and module.weight is not None:
                        batchnorm3d.weight.data = module.weight.data
                    if hasattr(module, 'bias') and module.bias is not None:
                        batchnorm3d.bias.data = module.bias.data
                    if hasattr(module, 'running_mean'):
                        batchnorm3d.running_mean = module.running_mean
                    if hasattr(module, 'running_var'):
                        batchnorm3d.running_var = module.running_var
                    return batchnorm3d

                elif isinstance(module, (nn.Sequential, nn.Module)):
                    # Recursively inflate submodules
                    new_module = copy.deepcopy(module)
                    for name, child in new_module.named_children():
                        # Pass is_conv0=True for conv0
                        setattr(new_module, name, inflate_module(child, is_conv0=(name == 'conv0')))
                    return new_module
                else:
                    # Return unchanged layers (e.g., ReLU)
                    return module
            
            # Create features by inflating the 2D model's features
            self.features = inflate_module(model_2d.features)

        def forward(self, x):
            return self.features(x)
    
    return DenseNet1213D(model_2d)

In [18]:
model_3d = inflate_densenet121_to_3d()

In [19]:
model_3d

DenseNet1213D(
  (features): Sequential(
    (conv0): Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(2, 2, 2), padding=(3, 3, 3))
    (norm0): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv3d(64, 128, kernel_size=(1, 1, 1), stride=(1, 1, 1))
        (norm2): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv3d(128, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      )
      (denselayer2): _DenseLayer(
        (norm1): BatchNorm3d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReL