In [1]:
import torch
import torchvision
from torchvision.models.feature_extraction import create_feature_extractor

In [2]:
model = torchvision.models.resnet50()
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [6]:
state_dict = torch.load("./satlas_weights.pth")

In [25]:
def adjust_state_dict_prefix(state_dict, needed, prefix=None, prefix_allowed_count=None):
    """
    Adjusts the keys in the state dictionary by replacing 'backbone.backbone' prefix with 'backbone'.

    Args:
        state_dict (dict): Original state dictionary with 'backbone.backbone' prefixes.

    Returns:
        dict: Modified state dictionary with corrected prefixes.
    """
    new_state_dict = {}
    for key, value in state_dict.items():
        # Assure we're only keeping keys that we need for the current model component. 
        if not needed in key:
            continue

        # Update the key prefixes to match what the model expects.
        if prefix is not None:
            while key.count(prefix) > prefix_allowed_count:
                key = key.replace(prefix, '', 1)

        new_state_dict[key] = value
    return new_state_dict

In [34]:
state_dict = adjust_state_dict_prefix(state_dict, "resnet", "backbone.resnet.",  prefix_allowed_count=0)

In [35]:
model.load_state_dict(state_dict)

RuntimeError: Error(s) in loading state_dict for ResNet:
	size mismatch for conv1.weight: copying a param with shape torch.Size([64, 9, 7, 7]) from checkpoint, the shape in current model is torch.Size([64, 3, 7, 7]).

In [37]:
model.conv1 = torch.nn.Conv2d(9, 64, kernel_size=7, stride=2, padding=3, bias=False)

In [38]:
model.load_state_dict(state_dict)

<All keys matched successfully>

In [41]:
swin_model = torchvision.models.swin_v2_b()

In [55]:
swin_state_dict = torch.load("./satlas_swin.pth")

In [59]:
swin_state_dict = adjust_state_dict_prefix(swin_state_dict, "backbone", "backbone.", 0)

In [71]:
swin_model.features[0][0] = torch.nn.Conv2d(9, 128, kernel_size=4, stride=2, padding=3, bias=True)

In [72]:
swin_model.load_state_dict(swin_state_dict)

<All keys matched successfully>

In [83]:
model

ResNet(
  (conv1): Conv2d(9, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [None]:
mode

In [78]:
cf = create_feature_extractor(model, {"relu": "x1", "layer1": "x2", "layer2": "x3", "layer3": "x4", "layer4": "out"})

In [81]:
output = cf(torch.rand(10, 9, 128, 128))

In [94]:
output["x3"].shape

torch.Size([10, 512, 16, 16])

In [134]:
def swin_feature_extractor(model):
    def forward(x):
        outputs = []
        for layer in model.features:
            x = layer(x)
            outputs.append(x.permute(0, 3, 1, 2))
        return [outputs[-7], outputs[-5], outputs[-3], outputs[-1]]
    return lambda x: forward(x)

In [135]:
swin_cf = swin_feature_extractor(swin_model)

In [137]:
feats = swin_cf(torch.rand(10, 9, 128, 128))

In [139]:
feats[0].shape

torch.Size([10, 128, 66, 66])

In [140]:
feats[1].shape

torch.Size([10, 256, 33, 33])

In [141]:
feats[2].shape

torch.Size([10, 512, 17, 17])

In [142]:
len(feats)

4