In [None]:
from typing import Literal, Sequence
import torch
import torch.nn as nn
from monai.networks.nets.resnet import resnet10, resnet50

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_mri_vectorizer(model_type: Literal[10, 50]) -> nn.Module:
    """Get MRI vectorizer model."""
    if model_type == 10:
        return resnet10(n_input_channels=1, feed_forward=False)
    elif model_type == 50:
        return resnet50(n_input_channels=1, feed_forward=False)
    else:
        raise ValueError("Invalid model type.")


def normalize_to_01(tensors: Sequence[torch.Tensor]) -> list[torch.Tensor]:
    """Normalizes a sequence of tensors to the range [0, 1]."""
    max_val = torch.max(torch.stack([t.max() for t in tensors]))
    min_val = torch.min(torch.stack([t.min() for t in tensors]))
    # assert max_val != min_val, "Cannot normalize tensor with min == max"
    return [(t - min_val) / (max_val - min_val) for t in tensors]


def gaussian_pdf(x: int, window_size: int, sigma: float) -> float:
    return -((x - window_size // 2) ** 2) / float(2 * sigma**2)


def gaussian_3d(window_size: int, sigma: float) -> torch.Tensor:
    """Create 1D Gaussian kernel."""
    gauss = torch.Tensor(
        [gaussian_pdf(x, window_size, sigma) for x in range(window_size)]
    )
    gauss = torch.exp(gauss)
    return gauss / gauss.sum()


def create_3d_window(window_size: int, sigma: float = 1.5) -> torch.Tensor:
    """Create 3D Gaussian window."""
    _1D_window = gaussian_3d(window_size, sigma).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t())
    _3D_window = (
        _1D_window.mm(_2D_window.reshape(1, -1))
        .reshape(window_size, window_size, window_size)
        .float()
        .unsqueeze(0)
        .unsqueeze(0)
    )
    return _3D_window

In [3]:
rn10 = torch.compile(get_mri_vectorizer(10).to(device))
rn50 = torch.compile(get_mri_vectorizer(50).to(device))

In [4]:
test_input_small = torch.randn(1, 1, 64, 64, 64).to(device)
test_input_large = torch.randn(1, 1, 256, 256, 256).to(device)

In [16]:
rn10.layer4(
    rn10.layer3(
        rn10.layer2(
            rn10.layer1(
                rn10.maxpool(rn10.act(rn10.bn1(rn10.conv1(test_input_small))))
            )
        )
    )
).size()

torch.Size([1, 512, 4, 4, 4])

In [6]:
rn50(test_input_small).size()

torch.Size([1, 2048])

In [17]:
rn10.layer4(
    rn10.layer3(
        rn10.layer2(
            rn10.layer1(
                rn10.maxpool(rn10.act(rn10.bn1(rn10.conv1(test_input_large))))
            )
        )
    )
).size()

torch.Size([1, 512, 16, 16, 16])

In [9]:
rn50(test_input_large).size()

torch.Size([1, 2048])

In [10]:
rn10

ResNet(
  (conv1): Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3), bias=False)
  (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act): ReLU(inplace=True)
  (maxpool): MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): ResNetBlock(
      (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU(inplace=True)
      (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): ResNetBlock(
      (conv1): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
      (bn1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=T