In [6]:
import random
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib
# from tqdm import tqdm
import medpy.metric as metric

import torch.nn as nn
import torch
import wandb

import monai
from monai.data import DataLoader, Dataset
from monai.transforms.utils import allow_missing_keys_mode
from monai.transforms import BatchInverseTransform
from monai.networks.nets import DynUNet

In [8]:
print(
    DynUNet(
        spatial_dims = 3,
        in_channels = 1, 
        out_channels = 2,
        kernel_size = [3, 3, 3, 3, 3, 3],
        strides = [1, 2, 2, 2, 2, 2], 
        upsample_kernel_size = [2, 2, 2, 2, 2]
    )(
        torch.rand((4, 1, 128, 128, 32))
    ).shape
)

torch.Size([4, 2, 128, 128, 32])


In [None]:
# [1, 2, 256, 256, 32]) torch.Size([1, 2, 256, 256, 32]
metric.hd95(torch.rand((1, 2, 256, 256, 32)).argmax(dim=1).squeeze(), torch.rand((1, 2, 256, 256, 32)).argmax(dim=1).squeeze(), voxelspacing=[1.464845, 1.464845, 10.0])

In [108]:
model2d = DynUNet(
    spatial_dims = 2,
    in_channels = 1, 
    out_channels = 2,
    kernel_size = [3, 3, 3, 3, 3, 3],
    strides = [1, 2, 2, 2, 2, 2], 
    upsample_kernel_size = [2, 2, 2, 2, 2]
)

model3d = DynUNet(
    spatial_dims = 3,
    in_channels = 1, 
    out_channels = 2,
    kernel_size = [3, 3, 3, 3, 3, 3],
    strides = [1, 2, 2, 2, 2, 2], 
    upsample_kernel_size = [2, 2, 2, 2, 2]
)

In [5]:
# They have the same number of modules
print(len(list(model2d.modules())), len(list(model3d.modules())))

114 114


In [83]:
print("2d conv layer shapes\t3d conv layers shapes")
for mod2d, mod3d in zip(model2d.modules(), model3d.modules()):
    if isinstance(mod2d, nn.Conv2d) and isinstance(mod3d, nn.Conv3d):
        weights2d = mod2d.weight.data.numpy().shape
        weights3d = mod3d.weight.data.numpy().shape
        state = "easy" if weights2d[:2] == weights3d[:2] else "????"
        print(state, weights2d, "\t", weights3d)

2d conv layer shapes	3d conv layers shapes
easy (32, 1, 3, 3) 	 (32, 1, 3, 3, 3)
easy (32, 32, 3, 3) 	 (32, 32, 3, 3, 3)
easy (64, 32, 3, 3) 	 (64, 32, 3, 3, 3)
easy (64, 64, 3, 3) 	 (64, 64, 3, 3, 3)
easy (128, 64, 3, 3) 	 (128, 64, 3, 3, 3)
easy (128, 128, 3, 3) 	 (128, 128, 3, 3, 3)
easy (256, 128, 3, 3) 	 (256, 128, 3, 3, 3)
easy (256, 256, 3, 3) 	 (256, 256, 3, 3, 3)
???? (512, 256, 3, 3) 	 (320, 256, 3, 3, 3)
???? (512, 512, 3, 3) 	 (320, 320, 3, 3, 3)
???? (512, 512, 3, 3) 	 (320, 320, 3, 3, 3)
???? (512, 512, 3, 3) 	 (320, 320, 3, 3, 3)
???? (512, 1024, 3, 3) 	 (320, 640, 3, 3, 3)
???? (512, 512, 3, 3) 	 (320, 320, 3, 3, 3)
easy (256, 512, 3, 3) 	 (256, 512, 3, 3, 3)
easy (256, 256, 3, 3) 	 (256, 256, 3, 3, 3)
easy (128, 256, 3, 3) 	 (128, 256, 3, 3, 3)
easy (128, 128, 3, 3) 	 (128, 128, 3, 3, 3)
easy (64, 128, 3, 3) 	 (64, 128, 3, 3, 3)
easy (64, 64, 3, 3) 	 (64, 64, 3, 3, 3)
easy (32, 64, 3, 3) 	 (32, 64, 3, 3, 3)
easy (32, 32, 3, 3) 	 (32, 32, 3, 3, 3)
easy (2, 32, 1, 1) 	 (

In [32]:
torch.randn(32, 1, 3, 3)[0]

tensor([[[ 0.6326, -0.8669, -0.9719],
         [ 2.2894,  0.6494, -0.7566],
         [-0.7648,  0.1312,  0.4432]]])

In [None]:
# def init_weights_from_distribution_of_mode():
#     pass

In [84]:
print("2d all layer shapes\t\t\t 3d all layers shapes")
for mod2d, mod3d in zip(model2d.modules(), model3d.modules()):
    try:
        weights2d = mod2d.weight.data.numpy().shape
        weights3d = mod3d.weight.data.numpy().shape
        state = "easy" if weights2d[:2] == weights3d[:2] else "????"
        print(state, mod2d.__class__.__name__, weights2d, "\t\t", mod3d.__class__.__name__, weights3d)
    except:
        pass

2d all layer shapes			 3d all layers shapes
easy Conv2d (32, 1, 3, 3) 		 Conv3d (32, 1, 3, 3, 3)
easy Conv2d (32, 32, 3, 3) 		 Conv3d (32, 32, 3, 3, 3)
easy InstanceNorm2d (32,) 		 InstanceNorm3d (32,)
easy InstanceNorm2d (32,) 		 InstanceNorm3d (32,)
easy Conv2d (64, 32, 3, 3) 		 Conv3d (64, 32, 3, 3, 3)
easy Conv2d (64, 64, 3, 3) 		 Conv3d (64, 64, 3, 3, 3)
easy InstanceNorm2d (64,) 		 InstanceNorm3d (64,)
easy InstanceNorm2d (64,) 		 InstanceNorm3d (64,)
easy Conv2d (128, 64, 3, 3) 		 Conv3d (128, 64, 3, 3, 3)
easy Conv2d (128, 128, 3, 3) 		 Conv3d (128, 128, 3, 3, 3)
easy InstanceNorm2d (128,) 		 InstanceNorm3d (128,)
easy InstanceNorm2d (128,) 		 InstanceNorm3d (128,)
easy Conv2d (256, 128, 3, 3) 		 Conv3d (256, 128, 3, 3, 3)
easy Conv2d (256, 256, 3, 3) 		 Conv3d (256, 256, 3, 3, 3)
easy InstanceNorm2d (256,) 		 InstanceNorm3d (256,)
easy InstanceNorm2d (256,) 		 InstanceNorm3d (256,)
???? Conv2d (512, 256, 3, 3) 		 Conv3d (320, 256, 3, 3, 3)
???? Conv2d (512, 512, 3, 3) 		 Conv3

In [100]:
print(model2d)

DynUNet(
  (input_block): UnetBasicBlock(
    (conv1): Convolution(
      (conv): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (conv2): Convolution(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
    (norm1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (norm2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  )
  (downsamples): ModuleList(
    (0): UnetBasicBlock(
      (conv1): Convolution(
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      )
      (conv2): Convolution(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
      (norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_run

In [117]:
def init_normal_per_module(from_module: nn.Module, to_module: nn.Module) -> None:
    """
    Initialize the weights of a model (to_module) using the normal distributions of the layers in another module (from_module).
    The function will take the weights in an entire module i.e. shape [128, 64, 3, 3] and compute a single number mean and std to 
    init the entire layer in another module.
    
    Args:
        from_module (nn.Module): Source module providing weights for initialization.
        to_module (nn.Module): Target module whose weights will be initialized.
    
    Note:
        The modules should be -very- similar as this functions makes the assumption the models have the same layers (either 2D or 3D) in the same order.
    """
    assert len(list(from_module.modules())) == len(list(to_module.modules())), "Error: Models should contain the 'same' layers"

    for from_mod, to_mod in zip(from_module.modules(), to_module.modules()):

        if isinstance(from_mod, (nn.Conv2d, nn.InstanceNorm2d, nn.ConvTranspose2d)):
            # Get distributions across the entire module, i.e. a single number for mean and std
            mean = torch.mean(from_mod.weight.data) 
            std  = torch.std(from_mod.weight.data)
            
            to_mod.weight.data.normal_(mean, std)

        elif isinstance(from_mod, nn.LeakyReLU):
            to_mod.negative_slope = from_mod.negative_slope

init_normal_per_module(model2d, model3d)

In [None]:
def init_normal_per_layer(from_module: nn.Module, to_module: nn.Module) -> None:
    pass

init_normal_per_layer(model2d, model3d)

In [None]:
to_mod.weight.data = torch.ones(to_mod.weight.data.shape)#torch.normal(mean.expand(weight_shape), std.expand(weight_shape))

In [118]:
# Init from "layer-wise"
a = torch.randn([32, 32, 3, 3])
torch.mean(a, axis=[0,1]).shape

torch.Size([3, 3])

In [115]:
for mod in model3d.modules():
    if isinstance(mod, (nn.Conv3d, nn.InstanceNorm3d, nn.ConvTranspose3d)):
        print(mod.weight.data)
        break

tensor([[[[[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],

          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],

          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]]]],



        [[[[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],

          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],

          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]]]],



        [[[[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],

          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],

          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]]]],



        [[[[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],

          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],

          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]]]],



        [[[[1., 1., 1.],
           [1., 1., 1.],
          

In [65]:
a = torch.randn(10,1)
torch.normal(torch.zeros(1), std = torch.ones(1))

tensor([1.5979])

In [69]:
torch.zeros(1).expand(10)

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])