In [34]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

import onnx
import torch
# load a model pre-trained on COCO

In [35]:
num_classes = 2
# load an instance segmentation model pre-trained on COCO
model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")

# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

# now get the number of input features for the mask classifier
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
# and replace the mask predictor with a new one
model.roi_heads.mask_predictor = MaskRCNNPredictor(
    in_features_mask,
    hidden_layer,
    num_classes
)
# model.backbone.body.conv1 = torch.nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)


In [36]:
from torchvision.models.detection.transform import GeneralizedRCNNTransform

# Custom transform to handle 6 channels
class CustomTransform(GeneralizedRCNNTransform):
    def __init__(self, min_size, max_size, image_mean, image_std):
        super(CustomTransform, self).__init__(min_size, max_size, image_mean, image_std)

    def normalize(self, image):
        dtype, device = image.dtype, image.device
        mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
        std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
        return (image - mean[:, None, None]) / std[:, None, None]

# Define mean and std for 6 channels
image_mean = [0.485, 0.456, 0.406, 0.485, 0.456, 0.406]  # Example: same as RGB for extra channels
image_std = [0.229, 0.224, 0.225, 0.229, 0.224, 0.225]    # Example: same as RGB for extra channels

# Create a custom transform with modified mean and std for 6 channels
transform = CustomTransform(min_size=(800,), max_size=1333, image_mean=image_mean, image_std=image_std)



In [37]:
# want to get the first layer in the model
for name, layer in model.named_modules():
    print(name)
    print(layer)
    break
# model.backbone.body.conv1 = torch.nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# remove the transform layer
# model.transform = torch.nn.Identity()
model.transform = transform
modified_conv = torch.nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
with torch.no_grad():
    modified_conv.weight[:, :3, :, :] = model.backbone.body.conv1.weight
    modified_conv.weight[:, 3:, :, :] = model.backbone.body.conv1.weight[:, :3, :, :]

model.backbone.body.conv1 = torch.nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# want to get the first layer in the model
for name, layer in model.named_modules():
    print(name)
    print(layer)
    break




MaskRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(i

In [38]:
dummy_input = torch.randn(1, 6, 224, 224)
model.eval()
model.forward(dummy_input)

[{'boxes': tensor([[2.2052e+02, 2.1860e+02, 2.2400e+02, 2.2188e+02],
          [2.7984e+00, 8.9238e+01, 1.1584e+02, 1.8152e+02],
          [7.8833e-01, 7.8777e+01, 1.7768e+01, 2.0238e+02],
          [2.1769e+02, 2.1212e+02, 2.2400e+02, 2.2321e+02],
          [3.7625e+01, 1.4715e+02, 1.7344e+02, 2.0805e+02],
          [1.8248e+01, 6.6828e+01, 7.9317e+01, 1.9133e+02],
          [7.7333e-01, 1.5424e+02, 1.7287e+01, 2.1217e+02],
          [1.3892e+02, 2.5632e+01, 1.4242e+02, 2.9806e+01],
          [2.4837e+01, 1.1795e+02, 1.6938e+02, 1.9164e+02],
          [5.5595e+01, 1.1253e+02, 5.8595e+01, 1.1542e+02],
          [9.3248e+01, 1.1303e+02, 1.5798e+02, 2.0640e+02],
          [2.0903e+02, 1.3088e+02, 2.2400e+02, 2.1063e+02],
          [5.8954e+01, 1.1645e+02, 1.2110e+02, 2.0539e+02],
          [7.4229e+00, 6.1869e+01, 5.4219e+01, 1.9317e+02],
          [9.1605e+01, 1.1920e+02, 2.2104e+02, 1.9162e+02],
          [5.4710e+01, 1.4824e+02, 5.7451e+01, 1.5087e+02],
          [2.2129e+02, 2.1391e+