In [1]:
import warnings
warnings.filterwarnings("ignore")


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from copy import deepcopy
import yaml
import sys
sys.path.append('./')
from basics.models.simMIM.build_simmim import build_model
from basics.models.model import parse_model

In [2]:
# Load the pretrained SimMIM model from a different source
with open('basics/models/config/simmim/pr.yaml', 'r') as f:
    cnfg = yaml.safe_load(f)
simmim = build_model(cnfg, False)  # Replace with actual function to load SimMIM model

# Remove the classification head from the SimMIM model
simmim = torch.nn.Sequential(*list(simmim.children())[:-1])

(576, 576) (4, 4)
input_resolution (144, 144)
input_resolution (72, 72)
input_resolution (36, 36)
input_resolution (18, 18)


In [3]:
simmim

Sequential(
  (0): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
  )
  (1): Dropout(p=0.0, inplace=False)
  (2): ModuleList(
    (0): BasicLayer(
      dim=192, input_resolution=(144, 144), depth=2
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          dim=192, input_resolution=(144, 144), num_heads=4, window_size=6, shift_size=0, mlp_ratio=4.0
          (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            dim=192, window_size=(6, 6), num_heads=4
            (qkv): Linear(in_features=192, out_features=576, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=192, out_features=192, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((192,), eps

In [4]:
with open('models/model.yaml', 'r') as f:
    cfg = yaml.safe_load(f)
# Load the YOLOv5 model
yolov5 = parse_model(deepcopy(cfg),'head', ch=[3], config=None)  # Replace with actual function to load YOLOv5

8 [[10, 13, 16, 30, 33, 23]] [128]


In [5]:
type(yolov5)

tuple

In [6]:
yolov5

(Sequential(
   (0): Conv(
     (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
     (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (act): SiLU()
   )
   (1): Upsample(scale_factor=2.0, mode='nearest')
   (2): Concat()
   (3): C3(
     (cv1): Conv(
       (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
       (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (act): SiLU()
     )
     (cv2): Conv(
       (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
       (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (act): SiLU()
     )
     (cv3): Conv(
       (conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
       (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (act): SiLU()
     )
     (m): Sequential(
       (0): Bottlenec

In [7]:
# Convert the tuple to a list
yolov5_list = list(yolov5)

# Replace the backbone in the yolov5 model
yolov5_list[0] = simmim

# Convert the list back to a tuple
yolov5 = tuple(yolov5_list)

# Print the modified yolov5 model
print(yolov5)


(Sequential(
  (0): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
  )
  (1): Dropout(p=0.0, inplace=False)
  (2): ModuleList(
    (0): BasicLayer(
      dim=192, input_resolution=(144, 144), depth=2
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          dim=192, input_resolution=(144, 144), num_heads=4, window_size=6, shift_size=0, mlp_ratio=4.0
          (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            dim=192, window_size=(6, 6), num_heads=4
            (qkv): Linear(in_features=192, out_features=576, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=192, out_features=192, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((192,), ep