# Insert Pre-trained baqckbone onto mask2former

In [2]:
import torch
from transformers import FocalNetForImageClassification, FocalNetConfig, Mask2FormerForUniversalSegmentation


model = Mask2FormerForUniversalSegmentation.from_pretrained('facebook/mask2former-swin-tiny-coco-instance')
config = model.config


# Load the pre-trained FocalNet model
focalnet_model = FocalNetForImageClassification.from_pretrained('microsoft/focalnet-tiny')
focalnet_config = focalnet_model.config

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model

Mask2FormerForUniversalSegmentation(
  (model): Mask2FormerModel(
    (pixel_level_module): Mask2FormerPixelLevelModule(
      (encoder): SwinBackbone(
        (embeddings): SwinEmbeddings(
          (patch_embeddings): SwinPatchEmbeddings(
            (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
          )
          (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (encoder): SwinEncoder(
          (layers): ModuleList(
            (0): SwinStage(
              (blocks): ModuleList(
                (0-1): 2 x SwinLayer(
                  (layernorm_before): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
                  (attention): SwinAttention(
                    (self): SwinSelfAttention(
                      (query): Linear(in_features=96, out_features=96, bias=True)
                      (key): Linear(in_features=96, out_features=96, bias=True)
                      (value

In [4]:
focalnet_model

FocalNetForImageClassification(
  (focalnet): FocalNetModel(
    (embeddings): FocalNetEmbeddings(
      (patch_embeddings): FocalNetPatchEmbeddings(
        (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): FocalNetEncoder(
      (stages): ModuleList(
        (0): FocalNetStage(
          (layers): ModuleList(
            (0): FocalNetLayer(
              (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
              (modulation): FocalNetModulation(
                (projection_in): Linear(in_features=96, out_features=195, bias=True)
                (projection_context): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1))
                (activation): GELU(approximate='none')
                (projection_out): Linear(in_features=96, out_features=96, bias=True)
                (projection_dropout): Dropout(p=0.0, inplace=F

In [5]:
print(model.model.pixel_level_module.encoder)
print(sum(p.numel() for p in model.model.pixel_level_module.encoder.parameters()))

SwinBackbone(
  (embeddings): SwinEmbeddings(
    (patch_embeddings): SwinPatchEmbeddings(
      (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
    )
    (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): SwinEncoder(
    (layers): ModuleList(
      (0): SwinStage(
        (blocks): ModuleList(
          (0-1): 2 x SwinLayer(
            (layernorm_before): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
            (attention): SwinAttention(
              (self): SwinSelfAttention(
                (query): Linear(in_features=96, out_features=96, bias=True)
                (key): Linear(in_features=96, out_features=96, bias=True)
                (value): Linear(in_features=96, out_features=96, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
              )
              (output): SwinSelfOutput(
                (dense): Linear(in_features=96, out_features=96, bias=True)
   

In [6]:
print(focalnet_model.focalnet)
print(sum(p.numel() for p in focalnet_model.focalnet.parameters()))

FocalNetModel(
  (embeddings): FocalNetEmbeddings(
    (patch_embeddings): FocalNetPatchEmbeddings(
      (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
    )
    (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): FocalNetEncoder(
    (stages): ModuleList(
      (0): FocalNetStage(
        (layers): ModuleList(
          (0): FocalNetLayer(
            (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
            (modulation): FocalNetModulation(
              (projection_in): Linear(in_features=96, out_features=195, bias=True)
              (projection_context): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1))
              (activation): GELU(approximate='none')
              (projection_out): Linear(in_features=96, out_features=96, bias=True)
              (projection_dropout): Dropout(p=0.0, inplace=False)
              (focal_layers): ModuleList(
                (0): Sequential(
   

# Replace the swin -T backbone with the FocalNet model

In [7]:
# Replace the backbone with the FocalNet model
model.model.pixel_level_module.encoder = focalnet_model.focalnet
# model.backbone.config = focalnet_config
# model.backbone.model_name = 'microsoft/focalnet-tiny'


# # freeze

# for param in model.backbone.parameters():
#     param.requires_grad = False



# other training hyperparameters

# Modify other training hyperparameters as needed
model.config.learning_rate = 0.001
model.config.num_train_epochs = 5

model

Mask2FormerForUniversalSegmentation(
  (model): Mask2FormerModel(
    (pixel_level_module): Mask2FormerPixelLevelModule(
      (encoder): FocalNetModel(
        (embeddings): FocalNetEmbeddings(
          (patch_embeddings): FocalNetPatchEmbeddings(
            (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
          )
          (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (encoder): FocalNetEncoder(
          (stages): ModuleList(
            (0): FocalNetStage(
              (layers): ModuleList(
                (0): FocalNetLayer(
                  (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
                  (modulation): FocalNetModulation(
                    (projection_in): Linear(in_features=96, out_features=195, bias=True)
                    (projection_context): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1))
                    (activation): GELU(approxi

In [None]:
import json

# Path to coco-panoptic-id2label.json
id2label_file = 'datasets_config/coco-panoptic-id2label.json'

# Load the id2label file
with open(id2label_file, 'r') as f:
    id2label_data = json.load(f)

# Create the label2id dictionary
label2id_data = {}
for category_id, category_data in id2label_data.items():
    label = category_data['name']
    label2id_data[label] = int(category_id)

# Path to coco-panoptic-label2id.json
label2id_file = 'datasets_config/coco-panoptic-label2id.json'

# Save the label2id dictionary as a JSON file
with open(label2id_file, 'w') as f:
    json.dump(label2id_data, f, indent=4)

print(f'Successfully created {label2id_file}.')


In [1]:
import datasets

COCO_DIR = "coco_datasets"
ds = datasets.load_dataset("coco_dataset_script.py", "2017", data_dir=COCO_DIR)

  from .autonotebook import tqdm as notebook_tqdm


Downloading and preparing dataset coco_dataset_script/2017 to /home/tanzila/.cache/huggingface/datasets/coco_dataset_script/2017-data_dir=coco_datasets/0.0.0/90661a729949a1f3fee4957f866f3dad26b41867b75efc9ea0decbeb4d6599cb...


                                                                        

Dataset coco_dataset_script downloaded and prepared to /home/tanzila/.cache/huggingface/datasets/coco_dataset_script/2017-data_dir=coco_datasets/0.0.0/90661a729949a1f3fee4957f866f3dad26b41867b75efc9ea0decbeb4d6599cb. Subsequent calls will reuse this data.


100%|██████████| 3/3 [00:02<00:00,  1.40it/s]


In [4]:
ds["train"][0]

{'image_id': 203564,
 'caption_id': 37,
 'caption': 'A bicycle replica with a clock as the front wheel.',
 'height': 400,
 'width': 400,
 'file_name': '000000203564.jpg',
 'coco_url': 'http://images.cocodataset.org/train2017/000000203564.jpg',
 'image_path': 'coco_datasets/train2017/000000203564.jpg'}