In [None]:
import layoutparser as lp
import cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from layoutparser.models import Detectron2LayoutModel
from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers
from ultralytics import YOLO
from detectron2.config import get_cfg
from detectron2.engine import DefaultTrainer
import os
from detectron2.data import MetadataCatalog

In [2]:
from detectron2.data.datasets import register_coco_instances

register_coco_instances("newspaper_train", {}, "coco_annotations/train.json", "train_photos/")
register_coco_instances("newspaper_val", {}, "coco_annotations/val.json", "val_photos/")


In [None]:
model_newspaper = Detectron2LayoutModel(
    config_path='config_newspaper.yml',
    label_map={
        0: "Photograph",
        1: "Illustration",
        2: "Map",
        3: "Comics/Cartoon",
        4: "Editorial Cartoon",
        5: "Headline",
        6: "Advertisement"
    },
    extra_config=["MODEL.ROI_HEADS.SCORE_THRESH_TEST", 0.5],
    device="cuda"
)


In [4]:
model = model_newspaper.model.model
old_head = model.roi_heads.box_predictor
old_cls_weight = old_head.cls_score.weight.data.clone()
old_cls_bias = old_head.cls_score.bias.data.clone()
old_bbox_weight = old_head.bbox_pred.weight.data.clone()
old_bbox_bias = old_head.bbox_pred.bias.data.clone()

old_state_dict = model.state_dict()
filtered_state_dict = {
    k: v for k, v in old_state_dict.items()
    if not k.startswith("roi_heads.box_predictor.")
}

In [None]:
model_newspaper_2 = Detectron2LayoutModel(
    config_path='config_newspaper_2class.yml',
    label_map={
        0: "Illustration",
        1: "Headline",
    },
    extra_config=["MODEL.ROI_HEADS.SCORE_THRESH_TEST", 0.5],
    device="cuda"
)


In [6]:
model_2 = model_newspaper_2.model.model
new_state_dict = model_2.state_dict()

In [7]:
num_new_classes = 2
in_features = model_2.roi_heads.box_predictor.cls_score.in_features

new_head = FastRCNNOutputLayers(
    input_shape=model_2.roi_heads.box_head.output_shape,
    box2box_transform=model_2.roi_heads.box_predictor.box2box_transform,
    num_classes=num_new_classes,
)
new_head = new_head.to("cuda")


In [8]:
# CLASS 0: "Illustration" ← old "Illustration" (1)
new_head.cls_score.weight.data[0] = old_cls_weight[0]
new_head.cls_score.bias.data[0] = old_cls_bias[0]

# CLASS 1: "Title" ← old "Headline" (5)
new_head.cls_score.weight.data[1] = old_cls_weight[5]
new_head.cls_score.bias.data[1] = old_cls_bias[5]

# Background CLASS
new_head.cls_score.weight.data[-1] = old_cls_weight[-1]
new_head.cls_score.bias.data[-1] = old_cls_bias[-1]

# Now for bounding box regression:
# For class-specific box regression: 4 values per class

# CLASS 0: Illustration (1)
new_head.bbox_pred.weight.data[0*4:1*4] = old_bbox_weight[1*4:2*4]
new_head.bbox_pred.bias.data[0*4:1*4] = old_bbox_bias[1*4:2*4]

# CLASS 1: Headline (5)
new_head.bbox_pred.weight.data[1*4:2*4] = old_bbox_weight[5*4:6*4]
new_head.bbox_pred.bias.data[1*4:2*4] = old_bbox_bias[5*4:6*4]


In [9]:
model_2.load_state_dict(filtered_state_dict, strict=False)
model_2.roi_heads.box_predictor = new_head


In [10]:
# Config setup
cfg = get_cfg()
cfg.INPUT.MIN_SIZE_TRAIN = (800,)
cfg.INPUT.MAX_SIZE_TRAIN = 1333
cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING = "choice"
cfg.DATASETS.TRAIN = ("newspaper_train",)
cfg.DATASETS.TEST = ("newspaper_val",)
cfg.DATALOADER.NUM_WORKERS = 2
cfg.OUTPUT_DIR = "./output"
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025
cfg.SOLVER.MAX_ITER = 3000
cfg.SOLVER.STEPS = [1000, 2000]    
cfg.SOLVER.GAMMA = 0.2              
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2  # Illustration, Title
cfg.MODEL.BACKBONE.FREEZE_AT = 5
MetadataCatalog.get("newspaper_train").thing_classes = ["Illustration", "Title"]
MetadataCatalog.get("newspaper_val").thing_classes = ["Illustration", "Title"]


os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

In [None]:
from detectron2.engine import DefaultTrainer

custom_model = model_2  

class CustomTrainer(DefaultTrainer):
    @classmethod
    def build_model(cls, cfg):
        return custom_model

trainer = CustomTrainer(cfg)
trainer.resume_or_load(resume=False)


import torch
print(torch.cuda.is_available())

trainer.train()
