In [2]:
import torch
import torch.nn as nn
import torchvision
from torchsummary import summary
from torchvision.models import resnet50
from torchvision.models.resnet import Bottleneck

In [3]:
class ResNet4Channel(nn.Module):
    def __init__(self, out_features=1000):
        super(ResNet4Channel, self).__init__()
        backbone = resnet50(weights=torchvision.models.ResNet50_Weights)
        weights = backbone.conv1.weight.clone()

        # h x w x 4
        self.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.conv1.weight = nn.Parameter(
            torch.cat((weights, weights[:, 1:2, :, :]), dim=1)
        )
        self.bn1 = backbone.bn1
        self.relu = backbone.relu
        self.maxpool = backbone.maxpool
        # h x 4
        self.layer1 = backbone.layer1
        self.layer2 = backbone.layer2
        self.layer3 = backbone.layer3
        self.layer4 = backbone.layer4
        # self.avgpool = backbone.avgpool
        # self.fc = nn.Linear(512 * 4, out_features)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)
        return x1, x2, x3, x4

    def print_summary(self, input_size):
        summary(self, input_size)

In [5]:
model = ResNet4Channel()  # Assuming UNet and num_classes are defined
input_size = (
    4,
    128,
    128,
)  # Example input size: 4-channel RGB-D image with 224x224 resolution
# model.print_summary(input_size)
x1, x2, x3, x4 = model.forward(torch.randn(1, 4, 128, 128))
print(x1.shape, x2.shape, x3.shape, x4.shape, sep='\n')



torch.Size([1, 256, 32, 32])
torch.Size([1, 512, 16, 16])
torch.Size([1, 1024, 8, 8])
torch.Size([1, 2048, 4, 4])
