In [1]:
import torch
from torch import nn
import numpy as np
from modules.nn import *

In [2]:
class ResBlock3D(nn.Module):
    """ Residual block with 3D convolutions  """

    def __init__(
        self,
        channels,
        emb_channels,
        dropout,
        out_channels=None,
        use_conv=False,
        use_scale_shift_norm=False,
        use_checkpoint=False,
        residual = True,
    ):
        super().__init__()
        self.channels = channels
        self.emb_channels = emb_channels
        self.dropout = dropout
        self.out_channels = out_channels
        self.use_conv = use_conv
        self.residual = residual
        self.use_checkpoint = use_checkpoint
        self.use_scale_shift_norm = use_scale_shift_norm
        dims = 3

        self.in_layers = nn.Sequential(
            nn.GroupNorm(4,channels),
            SiLU(),
            conv_nd(dims, channels, self.out_channels, 3, padding=1),
        )
        self.emb_layers = nn.Sequential(
            SiLU(),
            linear(
                emb_channels,
                2 * self.out_channels if use_scale_shift_norm else self.out_channels,
            ),
        )
        self.out_layers = nn.Sequential(
            nn.GroupNorm(4,self.out_channels),
            SiLU(),
            nn.Dropout(p=dropout),
            zero_module(
                conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
            ),
        )

        if residual:
            if self.out_channels == channels:
                self.skip_connection = nn.Identity()
            elif use_conv:
                self.skip_connection = conv_nd(
                    dims, channels, self.out_channels, 3, padding=1
                )
            else:
                self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
        else:
            self.skip_connection = None

    def forward(self, x, emb):
        """
        Apply the block to a Tensor, conditioned on a timestep embedding.

        :param x: an [N x C x ...] Tensor of features.
        :param emb: an [N x emb_channels] Tensor of timestep embeddings.
        :return: an [N x C x ...] Tensor of outputs.
        """
        return checkpoint(
            self._forward, (x, emb), self.parameters(), self.use_checkpoint
        )

    def _forward(self, x, emb):
        h = self.in_layers(x)
        emb_out = self.emb_layers(emb).type(h.dtype)
        while len(emb_out.shape) < len(h.shape):
            emb_out = emb_out[..., None]
        if self.use_scale_shift_norm:
            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
            scale, shift = th.chunk(emb_out, 2, dim=1)
            h = out_norm(h) * (1 + scale) + shift
            h = out_rest(h)
        else:
            h = h + emb_out
            h = self.out_layers(h)
        if self.residual:
            h = self.skip_connection(x) + h
        return h

class Initial3DConvBlock(nn.Module):
    def __init__(
        self,
        emb_channels,
        dropout,
        channels_3d,
        channels_freq,
        out_channels,
        use_conv=False,
        use_scale_shift_norm=False,
        use_checkpoint=False,
        residual = True,
    ):
        super().__init__()
        self.emb_channels = emb_channels
        self.dropout = dropout
        self.out_channels = out_channels
        self.use_conv = use_conv
        self.residual = residual
        self.use_checkpoint = use_checkpoint
        self.use_scale_shift_norm = use_scale_shift_norm

        self.initialConv = conv_nd(3, 1, channels_3d, kernel_size=3, padding=1)
        self.resBlock1 = ResBlock3D(
            channels = channels_3d,
            emb_channels = emb_channels,
            dropout = 0,
            out_channels = channels_3d,
            use_conv = False,
            use_scale_shift_norm = True,
            use_checkpoint = False,
            residual=residual,
        )
        # self.resBlock2 = ResBlock3D(
        #     channels = channels_3d,
        #     emb_channels = emb_channels,
        #     dropout = 0,
        #     out_channels = channels_3d,
        #     use_conv = False,
        #     use_scale_shift_norm = True,
        #     use_checkpoint = False,
        #     residual=residual,
        # )
        self.outConv = conv_nd(2, channels_3d*channels_freq, out_channels, kernel_size=3, padding=1)
        


    def forward(self,x,emb):
        x = x.unsqueeze(1)
        x = self.initialConv(x)
        x = self.resBlock1(x,emb)
        # x = self.resBlock2(x,emb)
        b,d,c,h,w = x.shape
        x = torch.reshape(x,(b,d*c,h,w))
        x =self.outConv(x)
        return x

In [3]:
testBlock = Initial3DConvBlock(
    emb_channels = 64*4,
    dropout = 0,
    channels_3d = 8,
    channels_freq = 80,
    out_channels = 64,
    use_conv=False,
    use_scale_shift_norm=True,
    use_checkpoint=False,
    residual = False,
)
B = 8
C = 80
H = 96
W = 128
x = torch.randn(B,C,H,W)
emb = torch.randn(B,64*4)

with torch.inference_mode():
    y = testBlock(x,emb)
    print(y.shape)

torch.Size([8, 64, 96, 128])


In [4]:
from torchinfo import summary
summary(testBlock, input_size=[(B,C,H,W), (B,64*4)], dtypes=[torch.float, torch.int])

Layer (type:depth-idx)                   Output Shape              Param #
Initial3DConvBlock                       [8, 64, 96, 128]          --
├─Conv3d: 1-1                            [8, 8, 80, 96, 128]       224
├─ResBlock3D: 1-2                        [8, 8, 80, 96, 128]       --
│    └─Sequential: 2-1                   [8, 8, 80, 96, 128]       --
│    │    └─GroupNorm: 3-1               [8, 8, 80, 96, 128]       16
│    │    └─SiLU: 3-2                    [8, 8, 80, 96, 128]       --
│    │    └─Conv3d: 3-3                  [8, 8, 80, 96, 128]       1,736
│    └─Sequential: 2-2                   [8, 16]                   --
│    │    └─SiLU: 3-4                    [8, 256]                  --
│    │    └─Linear: 3-5                  [8, 16]                   4,112
│    └─Sequential: 2-3                   --                        --
│    │    └─GroupNorm: 3-6               [8, 8, 80, 96, 128]       16
│    │    └─SiLU: 3-7                    [8, 8, 80, 96, 128]       --
│    │  

In [5]:
from modules.model import UNETv2
nn_model = UNETv2(in_channels=80, residual=False, attention_res=[])
summary(nn_model, input_size=[(B,C,H,W), (B,1,H,W),(B,)], dtypes=[torch.float, torch.float, torch.int])

Layer (type:depth-idx)                   Output Shape              Param #
UNETv2                                   [8, 1, 96, 128]           --
├─Sequential: 1-1                        [8, 256]                  --
│    └─PositionalEncoding: 2-1           [8, 64]                   --
│    └─Linear: 2-2                       [8, 256]                  16,640
│    └─SiLU: 2-3                         [8, 256]                  --
│    └─Linear: 2-4                       [8, 256]                  65,792
├─Conv2d: 1-2                            [8, 64, 96, 128]          5,184
├─ResBlock: 1-3                          [8, 64, 96, 128]          --
│    └─Sequential: 2-5                   [8, 64, 96, 128]          --
│    │    └─GroupNorm32: 3-1             [8, 64, 96, 128]          128
│    │    └─SiLU: 3-2                    [8, 64, 96, 128]          --
│    │    └─Conv2d: 3-3                  [8, 64, 96, 128]          36,928
│    └─Sequential: 2-6                   [8, 128]                  --

In [6]:
from modules.model import UNETv3
nn_model = UNETv3(in_channels=80, residual=False, attention_res=[])
summary(nn_model, input_size=[(B,C,H,W), (B,1,H,W),(B,)], dtypes=[torch.float, torch.float, torch.int])

Layer (type:depth-idx)                   Output Shape              Param #
UNETv3                                   [8, 1, 96, 128]           --
├─Sequential: 1-1                        [8, 256]                  --
│    └─PositionalEncoding: 2-1           [8, 64]                   --
│    └─Linear: 2-2                       [8, 256]                  16,640
│    └─SiLU: 2-3                         [8, 256]                  --
│    └─Linear: 2-4                       [8, 256]                  65,792
├─Initial3DConvBlock: 1-2                [8, 64, 96, 128]          --
│    └─Conv3d: 2-5                       [8, 8, 80, 96, 128]       224
│    └─ResBlock3D: 2-6                   [8, 8, 80, 96, 128]       --
│    │    └─Sequential: 3-1              [8, 8, 80, 96, 128]       1,752
│    │    └─Sequential: 3-2              [8, 16]                   4,112
│    │    └─Sequential: 3-3              --                        1,752
│    └─Conv2d: 2-7                       [8, 64, 96, 128]          