In [1]:
import torch
import torchvision as tv
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import proplot as pplt
import seaborn as sns
import os
from torchinfo import summary
import monai

In [2]:
model = tv.models.convnext_large(weights=None)

In [3]:
summary(model, (1, 3, 224, 224))

Layer (type:depth-idx)                        Output Shape              Param #
ConvNeXt                                      [1, 1000]                 --
├─Sequential: 1-1                             [1, 1536, 7, 7]           --
│    └─Conv2dNormActivation: 2-1              [1, 192, 56, 56]          --
│    │    └─Conv2d: 3-1                       [1, 192, 56, 56]          9,408
│    │    └─LayerNorm2d: 3-2                  [1, 192, 56, 56]          384
│    └─Sequential: 2-2                        [1, 192, 56, 56]          --
│    │    └─CNBlock: 3-3                      [1, 192, 56, 56]          306,048
│    │    └─CNBlock: 3-4                      [1, 192, 56, 56]          306,048
│    │    └─CNBlock: 3-5                      [1, 192, 56, 56]          306,048
│    └─Sequential: 2-3                        [1, 384, 28, 28]          --
│    │    └─LayerNorm2d: 3-6                  [1, 192, 56, 56]          384
│    │    └─Conv2d: 3-7                       [1, 384, 28, 28]          295

In [17]:
vit = tv.models.vit_l_16(weights=None)

In [20]:
with torch.no_grad():
    o = vit.cpu()(torch.randn(1,3, 224, 224))

In [21]:
summary(vit, (1, 3, 224, 224))

Layer (type:depth-idx)                             Output Shape              Param #
VisionTransformer                                  [1, 1000]                 1,024
├─Conv2d: 1-1                                      [1, 1024, 14, 14]         787,456
├─Encoder: 1-2                                     [1, 197, 1024]            201,728
│    └─Dropout: 2-1                                [1, 197, 1024]            --
│    └─Sequential: 2-2                             [1, 197, 1024]            --
│    │    └─EncoderBlock: 3-1                      [1, 197, 1024]            12,596,224
│    │    └─EncoderBlock: 3-2                      [1, 197, 1024]            12,596,224
│    │    └─EncoderBlock: 3-3                      [1, 197, 1024]            12,596,224
│    │    └─EncoderBlock: 3-4                      [1, 197, 1024]            12,596,224
│    │    └─EncoderBlock: 3-5                      [1, 197, 1024]            12,596,224
│    │    └─EncoderBlock: 3-6                      [1, 197, 10

### Notes

okay, so the plan is, to build an encoder that has two branches, one VIT and one ConvNext, both in 3D. The first layer of the ConvNext does not do much downsampling
where as the VIT could do a lot since it has the attention mechanism for context.

Then the outputs every three layers are compressed through a bottle neck and merged, so the transformer outputs are included into the unet and vice versa maybe then the decoder is a standard unet decoder, but with my stochastic layers added in. The model also has a classifiction/regression head from the middle of the unet that is used for non generative tasks, the decoder used for generative stuff.

In [4]:
class ConvNextBlock3D(nn.Module):
    def __init__(self, in_planes, layer_scale, stochastic_depth_prob):
        super().__init__()
        
        self.block = nn.Sequential(
            nn.Conv3d(in_planes, in_planes, kernel_size=7, padding=3, groups=in_planes, bias=True),
            tv.ops.Permute([0, 2, 3, 4, 1]),
            nn.LayerNorm(in_planes),
            tv.ops.Permute([0, 4, 1, 2, 3]),
            nn.Conv3d(in_planes, in_planes * 4, kernel_size=1, bias=True),
            nn.GELU(),
            nn.Conv3d(in_planes * 4, in_planes, kernel_size=1, bias=True)
        )
        
        self.layer_scale = nn.Parameter(torch.ones(in_planes, 1, 1, 1) * layer_scale)
        self.stochastic_depth_prob = stochastic_depth_prob
        self.sdp_mode = "row"
        
    def forward(self, x):
        o = x
        o = self.layer_scale * self.block(o)
        o = tv.ops.stochastic_depth(o, self.stochastic_depth_prob, self.sdp_mode)
        o += x 
        return o

In [12]:
cnb = nn.Sequential(
    ConvNextBlock3D(12, 1, 0.5),
    ConvNextBlock3D(12, 1, 0.5),
    ConvNextBlock3D(12, 1, 0.5),
    ConvNextBlock3D(12, 1, 0.5),
    ConvNextBlock3D(12, 1, 0.5),
    ConvNextBlock3D(12, 1, 0.5),
    ConvNextBlock3D(12, 1, 0.5),
    ConvNextBlock3D(12, 1, 0.5),
    ConvNextBlock3D(12, 1, 0.5),
    ConvNextBlock3D(12, 1, 0.5)
)

In [13]:
with torch.no_grad():
    o = cnb.cpu()(torch.randn(1, 12, 60, 192, 240))

In [15]:
summary(cnb, (1, 12, 60, 192, 240))

Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [1, 12, 60, 192, 240]     --
├─ConvNextBlock3D: 1-1                   [1, 12, 60, 192, 240]     12
│    └─Sequential: 2-1                   [1, 12, 60, 192, 240]     --
│    │    └─Conv3d: 3-1                  [1, 12, 60, 192, 240]     4,128
│    │    └─Permute: 3-2                 [1, 60, 192, 240, 12]     --
│    │    └─LayerNorm: 3-3               [1, 60, 192, 240, 12]     24
│    │    └─Permute: 3-4                 [1, 12, 60, 192, 240]     --
│    │    └─Conv3d: 3-5                  [1, 48, 60, 192, 240]     624
│    │    └─GELU: 3-6                    [1, 48, 60, 192, 240]     --
│    │    └─Conv3d: 3-7                  [1, 12, 60, 192, 240]     588
├─ConvNextBlock3D: 1-2                   [1, 12, 60, 192, 240]     12
│    └─Sequential: 2-2                   [1, 12, 60, 192, 240]     --
│    │    └─Conv3d: 3-8                  [1, 12, 60, 192, 240]     4,128
│    │ 