In [1]:
import random
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib
# from tqdm import tqdm
import medpy.metric as metric

import torch.nn as nn
import torch.nn.functional as F
import torch
import wandb

import monai
from monai.data import DataLoader, Dataset
from monai.transforms.utils import allow_missing_keys_mode
from monai.transforms import BatchInverseTransform
from monai.networks.nets import DynUNet

In [2]:
print(
    DynUNet(
        spatial_dims = 3,
        in_channels = 1, 
        out_channels = 2,
        kernel_size = [3, 3, 3, 3, 3, 3],
        strides = [1, 2, 2, 2, 2, 2], 
        upsample_kernel_size = [2, 2, 2, 2, 2]
    )(
        torch.rand((4, 1, 128, 128, 32))
    ).shape
)

torch.Size([4, 2, 128, 128, 32])


In [None]:
# [1, 2, 256, 256, 32]) torch.Size([1, 2, 256, 256, 32]
metric.hd95(torch.rand((1, 2, 256, 256, 32)).argmax(dim=1).squeeze(), torch.rand((1, 2, 256, 256, 32)).argmax(dim=1).squeeze(), voxelspacing=[1.464845, 1.464845, 10.0])

In [2]:
model2d = DynUNet(
    spatial_dims = 2,
    in_channels = 1, 
    out_channels = 2,
    kernel_size = [3, 3, 3, 3, 3, 3],
    strides = [1, 2, 2, 2, 2, 2], 
    upsample_kernel_size = [2, 2, 2, 2, 2]
)

model3d = DynUNet(
    spatial_dims = 3,
    in_channels = 1, 
    out_channels = 2,
    kernel_size = [3, 3, 3, 3, 3, 3],
    strides = [1, 2, 2, 2, 2, 2], 
    upsample_kernel_size = [2, 2, 2, 2, 2]
)

In [5]:
# They have the same number of modules
print(len(list(model2d.modules())), len(list(model3d.modules())))

114 114


In [210]:
print("2d conv layer shapes\t3d conv layers shapes")
for mod2d, mod3d in zip(model2d.modules(), model3d.modules()):
    if (isinstance(mod2d, nn.Conv2d) and isinstance(mod3d, nn.Conv3d)):
        weights2d = mod2d.weight.data.numpy().shape
        weights3d = mod3d.weight.data.numpy().shape
        state = "easy" if weights2d[:2] == weights3d[:2] else "????"
        print(state, weights2d, "\t", weights3d)

2d conv layer shapes	3d conv layers shapes
easy (32, 1, 3, 3) 	 (32, 1, 3, 3, 3)
easy (32, 32, 3, 3) 	 (32, 32, 3, 3, 3)
easy (64, 32, 3, 3) 	 (64, 32, 3, 3, 3)
easy (64, 64, 3, 3) 	 (64, 64, 3, 3, 3)
easy (128, 64, 3, 3) 	 (128, 64, 3, 3, 3)
easy (128, 128, 3, 3) 	 (128, 128, 3, 3, 3)
easy (256, 128, 3, 3) 	 (256, 128, 3, 3, 3)
easy (256, 256, 3, 3) 	 (256, 256, 3, 3, 3)
easy (320, 256, 3, 3) 	 (320, 256, 3, 3, 3)
easy (320, 320, 3, 3) 	 (320, 320, 3, 3, 3)
easy (320, 320, 3, 3) 	 (320, 320, 3, 3, 3)
easy (320, 320, 3, 3) 	 (320, 320, 3, 3, 3)
easy (320, 640, 3, 3) 	 (320, 640, 3, 3, 3)
easy (320, 320, 3, 3) 	 (320, 320, 3, 3, 3)
easy (256, 512, 3, 3) 	 (256, 512, 3, 3, 3)
easy (256, 256, 3, 3) 	 (256, 256, 3, 3, 3)
easy (128, 256, 3, 3) 	 (128, 256, 3, 3, 3)
easy (128, 128, 3, 3) 	 (128, 128, 3, 3, 3)
easy (64, 128, 3, 3) 	 (64, 128, 3, 3, 3)
easy (64, 64, 3, 3) 	 (64, 64, 3, 3, 3)
easy (32, 64, 3, 3) 	 (32, 64, 3, 3, 3)
easy (32, 32, 3, 3) 	 (32, 32, 3, 3, 3)
easy (2, 32, 1, 1) 	 (2

In [209]:
print("2d conv.T layer shapes\t3d conv.T layer shapes")
for mod2d, mod3d in zip(model2d.modules(), model3d.modules()):
    if (isinstance(mod2d, nn.ConvTranspose2d) and isinstance(mod3d, nn.ConvTranspose3d)):
        weights2d = mod2d.weight.data.numpy().shape
        weights3d = mod3d.weight.data.numpy().shape
        state = "easy" if weights2d[:2] == weights3d[:2] else "????"
        print(state, weights2d, "\t", weights3d)

2d conv.T layer shapes	3d conv.T layer shapes
easy (320, 320, 2, 2) 	 (320, 320, 2, 2, 2)
easy (320, 256, 2, 2) 	 (320, 256, 2, 2, 2)
easy (256, 128, 2, 2) 	 (256, 128, 2, 2, 2)
easy (128, 64, 2, 2) 	 (128, 64, 2, 2, 2)
easy (64, 32, 2, 2) 	 (64, 32, 2, 2, 2)


In [9]:
def init_normal_per_channel(from_module: nn.Module, to_module: nn.Module) -> None:
    """
    Initialize the weights of a model (to_module) using the normal distributions per conv channel in another module (from_module).
    The function will take the weights in an entire module i.e. shape [128, 64, 3, 3] and compute a single number mean and std with 
    shape [128, 64, 1], and tile it to shape [128, 64, 3, 3, 3]
    
    Args:
        from_module (nn.Module): Source module providing weights for initialization.
        to_module (nn.Module): Target module whose weights will be initialized.
    
    Note:
        The modules should be -very- similar as this functions makes the assumption the models have the same layers (either 2D or 3D) in the same order.
    """
    assert len(list(from_module.modules())) == len(list(to_module.modules())), "Error: Models should contain the 'same' layers"

    for from_mod, to_mod in zip(from_module.modules(), to_module.modules()):

        if isinstance(from_mod, (nn.Conv2d, nn.ConvTranspose2d)):

            # Handle the 'complicated' case where shapes dont match, for instance (512, 256, 3, 3) and (320, 256, 3, 3, 3).
            # TODO: We could try something more complicated, e.g. interpolation, random crop, tiling etc. for now we just take a simple crop
            if from_mod.weight.shape != to_mod.weight.shape[:4]:
                from_mod.weight.data = from_mod.weight.data[:to_mod.weight.shape[0], :to_mod.weight.shape[1]]

            # This ungodly line takes the mean of the 2D conv (3,3) and then expands the single number into a 3D convolution, i.e. [32, 64] -> [32, 64, 3, 3, 3]
            # TODO: I cant find a nicer way than the triple unsqueeze, though i am sure there is a nicer / more general way of doing it
            mean = torch.mean(from_mod.weight.data, axis = (-1, -2)).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(to_mod.weight.data.shape)
            std  = torch.std(from_mod.weight.data, axis = (-1, -2), unbiased=False).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(to_mod.weight.data.shape) # Unbiased, see here: https://github.com/pytorch/pytorch/issues/29372
            
            # Handle Conv2D layers on the form [2, 32, 1, 1]. They have only 1 number so std is 'nan', replace with 0
            std = torch.nan_to_num(std, nan = 0.0)

            to_mod.weight.data = torch.normal(mean, std = std)
                        
        elif isinstance(from_mod, nn.LeakyReLU):
            to_mod.negative_slope = from_mod.negative_slope

        elif isinstance(from_mod, nn.InstanceNorm2d):
            to_mod.weight.data = from_mod.weight.data


saved_model = torch.load("/work3/s204163/3dimaging_finalproject/weights/baseline2d_101/baseline2d_final.pt", map_location=torch.device('cpu'))
init_normal_per_channel(from_module=saved_model, to_module=model3d)

# init_normal_per_channel(from_module=model2d, to_module=model3d)

In [205]:
torch.randn([32, 32, 3, 3])[:10, :10].shape

torch.Size([10, 10, 3, 3])

In [117]:
a = torch.mean(torch.randn([2, 32, 3, 3]), axis=(-1,-2))

In [118]:
b = a.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(2, 32, 3, 3, 3)
b.shape

torch.Size([2, 32, 3, 3, 3])

In [119]:
a[0,3]

tensor(0.5448)

In [120]:
b[0,3]

tensor([[[0.5448, 0.5448, 0.5448],
         [0.5448, 0.5448, 0.5448],
         [0.5448, 0.5448, 0.5448]],

        [[0.5448, 0.5448, 0.5448],
         [0.5448, 0.5448, 0.5448],
         [0.5448, 0.5448, 0.5448]],

        [[0.5448, 0.5448, 0.5448],
         [0.5448, 0.5448, 0.5448],
         [0.5448, 0.5448, 0.5448]]])

In [30]:
torch.mean(torch.arange(16,dtype=float).reshape(4, 1, 2, 2), axis=(-1, -2))

tensor([[ 1.5000],
        [ 5.5000],
        [ 9.5000],
        [13.5000]], dtype=torch.float64)

In [90]:
torch.normal(torch.zeros(4), std = torch.ones(4)).shape

torch.Size([4])

In [183]:
# Assuming you have your image tensor
image_original = torch.randn(512, 256, 3)

# Resize the image tensor using bilinear interpolation
image_resized = F.interpolate(image_original.unsqueeze(0).permute(0, 3, 1, 2), size=(320, 256), mode='bilinear', align_corners=False)
print(image_resized.shape)
# Remove the batch dimension and adjust the shape
image_resized = image_resized.squeeze(0).permute(1, 2, 0)

print(image_resized.shape)  # Output: torch.Size([320, 256, 3])


torch.Size([1, 3, 320, 256])
torch.Size([320, 256, 3])


In [173]:
# x = torch.ones(512, 256, 3, 3)

# x = F.interpolate(x.unsqueeze(0), size=(3,4,4), mode="trilinear").squeeze(0)
# x.shape

x = torch.ones(3,4,64,64)
x = x.permute(1,0,2,3)
x = F.interpolate(x.unsqueeze(0), size=(3, 2, 10), mode="bilinear").squeeze(0)
x = x.permute(1,0,2,3)
x.shape

NotImplementedError: Got 5D input, but bilinear mode needs 4D input

In [134]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Assuming you have your weights tensor
weights_original = torch.randn(512, 256, 3, 3)

# Reshape the original weights tensor to add a singleton dimension
# This is required to match the spatial dimensions for trilinear interpolation
weights_reshaped = weights_original.unsqueeze(0).unsqueeze(0)

# Resize the weights tensor using trilinear interpolation
weights_resized = F.interpolate(weights_reshaped, size=(320, 256, 3, 3), mode='trilinear', align_corners=False)

# Remove the singleton dimensions
weights_resized = weights_resized.squeeze(0).squeeze(0)

# Plot the original weights
fig = plt.figure(figsize=(10, 5))
ax1 = fig.add_subplot(1, 2, 1, projection='3d')
ax1.set_title('Original Weights')

x, y, z = torch.meshgrid(torch.arange(weights_original.shape[0]),
                          torch.arange(weights_original.shape[1]),
                          torch.arange(weights_original.shape[2] * 3),
                          indexing='ij')

ax1.scatter(x, y, z, c=weights_original.view(-1).numpy(), cmap='viridis')
ax1.set_xlabel('Dimension 1')
ax1.set_ylabel('Dimension 2')
ax1.set_zlabel('Dimension 3')

# Plot the resized weights
ax2 = fig.add_subplot(1, 2, 2, projection='3d')
ax2.set_title('Resized Weights')

x, y, z = torch.meshgrid(torch.arange(weights_resized.shape[0]),
                          torch.arange(weights_resized.shape[1]),
                          torch.arange(weights_resized.shape[2] * 3),
                          indexing='ij')

ax2.scatter(x, y, z, c=weights_resized.view(-1).numpy(), cmap='viridis')
ax2.set_xlabel('Dimension 1')
ax2.set_ylabel('Dimension 2')
ax2.set_zlabel('Dimension 3')

plt.tight_layout()
plt.show()


NotImplementedError: Input Error: Only 3D, 4D and 5D input Tensors supported (got 6D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact (got trilinear)

In [None]:
# def init_weights_from_distribution_of_mode():
#     pass

In [84]:
print("2d all layer shapes\t\t\t 3d all layers shapes")
for mod2d, mod3d in zip(model2d.modules(), model3d.modules()):
    try:
        weights2d = mod2d.weight.data.numpy().shape
        weights3d = mod3d.weight.data.numpy().shape
        state = "easy" if weights2d[:2] == weights3d[:2] else "????"
        print(state, mod2d.__class__.__name__, weights2d, "\t\t", mod3d.__class__.__name__, weights3d)
    except:
        pass

2d all layer shapes			 3d all layers shapes
easy Conv2d (32, 1, 3, 3) 		 Conv3d (32, 1, 3, 3, 3)
easy Conv2d (32, 32, 3, 3) 		 Conv3d (32, 32, 3, 3, 3)
easy InstanceNorm2d (32,) 		 InstanceNorm3d (32,)
easy InstanceNorm2d (32,) 		 InstanceNorm3d (32,)
easy Conv2d (64, 32, 3, 3) 		 Conv3d (64, 32, 3, 3, 3)
easy Conv2d (64, 64, 3, 3) 		 Conv3d (64, 64, 3, 3, 3)
easy InstanceNorm2d (64,) 		 InstanceNorm3d (64,)
easy InstanceNorm2d (64,) 		 InstanceNorm3d (64,)
easy Conv2d (128, 64, 3, 3) 		 Conv3d (128, 64, 3, 3, 3)
easy Conv2d (128, 128, 3, 3) 		 Conv3d (128, 128, 3, 3, 3)
easy InstanceNorm2d (128,) 		 InstanceNorm3d (128,)
easy InstanceNorm2d (128,) 		 InstanceNorm3d (128,)
easy Conv2d (256, 128, 3, 3) 		 Conv3d (256, 128, 3, 3, 3)
easy Conv2d (256, 256, 3, 3) 		 Conv3d (256, 256, 3, 3, 3)
easy InstanceNorm2d (256,) 		 InstanceNorm3d (256,)
easy InstanceNorm2d (256,) 		 InstanceNorm3d (256,)
???? Conv2d (512, 256, 3, 3) 		 Conv3d (320, 256, 3, 3, 3)
???? Conv2d (512, 512, 3, 3) 		 Conv3

In [100]:
print(model2d)

DynUNet(
  (input_block): UnetBasicBlock(
    (conv1): Convolution(
      (conv): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (conv2): Convolution(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
    (norm1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (norm2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  )
  (downsamples): ModuleList(
    (0): UnetBasicBlock(
      (conv1): Convolution(
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      )
      (conv2): Convolution(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
      (norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_run

In [117]:
def init_normal_per_module(from_module: nn.Module, to_module: nn.Module) -> None:
    """
    Initialize the weights of a model (to_module) using the normal distributions of the layers in another module (from_module).
    The function will take the weights in an entire module i.e. shape [128, 64, 3, 3] and compute a single number mean and std to 
    init the entire layer in another module.
    
    Args:
        from_module (nn.Module): Source module providing weights for initialization.
        to_module (nn.Module): Target module whose weights will be initialized.
    
    Note:
        The modules should be -very- similar as this functions makes the assumption the models have the same layers (either 2D or 3D) in the same order.
    """
    assert len(list(from_module.modules())) == len(list(to_module.modules())), "Error: Models should contain the 'same' layers"

    for from_mod, to_mod in zip(from_module.modules(), to_module.modules()):

        if isinstance(from_mod, (nn.Conv2d, nn.InstanceNorm2d, nn.ConvTranspose2d)):
            # Get distributions across the entire module, i.e. a single number for mean and std
            mean = torch.mean(from_mod.weight.data) 
            std  = torch.std(from_mod.weight.data)
            
            to_mod.weight.data.normal_(mean, std)

        elif isinstance(from_mod, nn.LeakyReLU):
            to_mod.negative_slope = from_mod.negative_slope

init_normal_per_module(model2d, model3d)

In [None]:
def init_normal_per_layer(from_module: nn.Module, to_module: nn.Module) -> None:
    pass

init_normal_per_layer(model2d, model3d)

In [None]:
to_mod.weight.data = torch.ones(to_mod.weight.data.shape)#torch.normal(mean.expand(weight_shape), std.expand(weight_shape))

In [118]:
# Init from "layer-wise"
a = torch.randn([32, 32, 3, 3])
torch.mean(a, axis=[0,1]).shape

torch.Size([3, 3])

In [115]:
for mod in model3d.modules():
    if isinstance(mod, (nn.Conv3d, nn.InstanceNorm3d, nn.ConvTranspose3d)):
        print(mod.weight.data)
        break

tensor([[[[[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],

          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],

          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]]]],



        [[[[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],

          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],

          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]]]],



        [[[[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],

          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],

          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]]]],



        [[[[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],

          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]],

          [[1., 1., 1.],
           [1., 1., 1.],
           [1., 1., 1.]]]],



        [[[[1., 1., 1.],
           [1., 1., 1.],
          

In [65]:
a = torch.randn(10,1)
torch.normal(torch.zeros(1), std = torch.ones(1))

tensor([1.5979])

In [69]:
torch.zeros(1).expand(10)

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])