In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models.video as video_models

In [2]:
checkpoint = torch.load("harmonization_public.pt", map_location="cpu")

In [3]:
class UNet(nn.Module):
    def __init__(self, in_ch, out_ch, conditional_ch=0, num_lvs=4, base_ch=16, final_act='noact'):
        super().__init__()
        self.final_act = final_act
        self.in_conv = nn.Conv2d(in_ch, base_ch, 3, 1, 1)

        self.down_convs = nn.ModuleList()
        self.down_samples = nn.ModuleList()
        self.up_samples = nn.ModuleList()
        self.up_convs = nn.ModuleList()
        for lv in range(num_lvs):
            ch = base_ch * (2 ** lv)
            self.down_convs.append(ConvBlock2d(ch + conditional_ch, ch * 2, ch * 2))
            self.down_samples.append(nn.MaxPool2d(kernel_size=2, stride=2))
            self.up_samples.append(Upsample(ch * 4))
            self.up_convs.append(ConvBlock2d(ch * 4, ch * 2, ch * 2))
        bottleneck_ch = base_ch * (2 ** num_lvs)
        self.bottleneck_conv = ConvBlock2d(bottleneck_ch, bottleneck_ch * 2, bottleneck_ch * 2)
        self.out_conv = nn.Sequential(nn.Conv2d(base_ch * 2, base_ch, 3, 1, 1),
                                      nn.LeakyReLU(0.1),
                                      nn.Conv2d(base_ch, out_ch, 3, 1, 1))

    def forward(self, in_tensor, condition=None):
        encoded_features = []
        x = self.in_conv(in_tensor)
        for down_conv, down_sample in zip(self.down_convs, self.down_samples):
            if condition is not None:
                feature_dim = x.shape[-1]
                down_conv_out = down_conv(torch.cat([x, condition.repeat(1, 1, feature_dim, feature_dim)], dim=1))
            else:
                down_conv_out = down_conv(x)
            x = down_sample(down_conv_out)
            encoded_features.append(down_conv_out)
        x = self.bottleneck_conv(x)
        for encoded_feature, up_conv, up_sample in zip(reversed(encoded_features),
                                                       reversed(self.up_convs),
                                                       reversed(self.up_samples)):
            x = up_sample(x, encoded_feature)
            x = up_conv(x)
        x = self.out_conv(x)
        if self.final_act == 'sigmoid':
            x = torch.sigmoid(x)
        elif self.final_act == "relu":
            x = torch.relu(x)
        elif self.final_act == 'tanh':
            x = torch.tanh(x)
        else:
            x = x
        return x


class ConvBlock2d(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, mid_ch, 3, 1, 1),
            nn.InstanceNorm2d(mid_ch),
            nn.LeakyReLU(0.1),
            nn.Conv2d(mid_ch, out_ch, 3, 1, 1),
            nn.InstanceNorm2d(out_ch),
            nn.LeakyReLU(0.1)
        )

    def forward(self, in_tensor):
        return self.conv(in_tensor)


class Upsample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        out_ch = in_ch // 2
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, 1, 1),
            nn.InstanceNorm2d(out_ch),
            nn.LeakyReLU(0.1)
        )

    def forward(self, in_tensor, encoded_feature):
        up_sampled_tensor = F.interpolate(in_tensor, size=None, scale_factor=2, mode='bilinear', align_corners=False)
        up_sampled_tensor = self.conv(up_sampled_tensor)
        return torch.cat([encoded_feature, up_sampled_tensor], dim=1)



In [4]:
def modify_r3d_for_5_channels(model_name='r3d_18', num_classes=2):
    """Loads a 3D ResNet model and modifies its first layer for 5 input channels."""
    # Load the specified ResNet3D model, without pre-trained weights initially
    if model_name == 'r3d_18':
        model = video_models.r3d_18(weights=None)
    elif model_name == 'mc3_18': # Another option
         model = video_models.mc3_18(weights=None)
    # Add more models here if needed (r3d_34, r3d_50...)
    else:
        raise ValueError(f"Unsupported model name: {model_name}")

    # Get parameters of the original first convolutional layer (stem[0])
    original_conv = model.stem[0]
    original_out_channels = original_conv.out_channels
    original_kernel_size = original_conv.kernel_size
    original_stride = original_conv.stride
    original_padding = original_conv.padding
    original_bias = original_conv.bias is not None # Check if bias exists

    # Create the new first convolutional layer with 5 input channels
    new_first_conv = nn.Conv3d(in_channels=5,
                               out_channels=original_out_channels,
                               kernel_size=original_kernel_size,
                               stride=original_stride,
                               padding=original_padding,
                               bias=original_bias) # Keep bias setting consistent

    # Replace the original first layer with the new one
    model.stem[0] = new_first_conv

    # --- Modify the final fully connected layer for the desired number of output classes ---
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes) # Example: 2 classes for Autism vs Control

    return model


In [5]:
input_ch = 1
beta_dim = 5
input_height = 192
input_width = 224
batch_size = 1
dummy_input = torch.randn(batch_size, input_ch, input_height, input_width)
print("Input Shape: ", dummy_input.shape)

Input Shape:  torch.Size([1, 1, 192, 224])


In [6]:
beta_encoder = UNet(in_ch=1, out_ch=5, base_ch=8, final_act='none')
beta_encoder.load_state_dict(checkpoint['beta_encoder']) 

beta_encoder.eval()
with torch.no_grad():
    beta = beta_encoder(dummy_input)
    print("Beta Shape: ", beta.shape)

Beta Shape:  torch.Size([1, 5, 192, 224])


In [7]:
depth = 192
num_classes = 2

r3d_model = modify_r3d_for_5_channels(model_name='r3d_18', num_classes=num_classes)
r3d_model.eval()

VideoResNet(
  (stem): BasicStem(
    (0): Conv3d(5, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False)
    (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Sequential(
        (0): Conv3DSimple(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (conv2): Sequential(
        (0): Conv3DSimple(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Sequential(
        (0): Conv3DSimple(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1):

In [8]:
dummy_resnet_input = torch.randn(batch_size, beta_dim, depth, input_height, input_width)
print("R3D Input Shape: ", dummy_resnet_input.shape)

R3D Input Shape:  torch.Size([1, 5, 192, 192, 224])


In [9]:
with torch.no_grad():
    r3d_output = r3d_model(dummy_resnet_input)
    print("R3D Output Shape: ", r3d_output.shape)

R3D Output Shape:  torch.Size([1, 2])
