In [None]:
import os
import json
import random
import numpy as np
from collections import Counter, OrderedDict
from tqdm.auto import tqdm

import cv2
from PIL import Image
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, default_collate
from timm.models.layers import LayerNorm2d
import torchshow

In [2]:
from transformers import XLMRobertaTokenizer, AutoConfig
from transformers import AutoImageProcessor, XLMRobertaTokenizer
from torchscale.architecture.config import EncoderConfig

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
class HpConfig:
    img_size = 640
    drop_path = 0.1
    val_batch_size = 1
    lr = 1e-4
    weight_decay = 0.05
    grad_ckpt = False

    batch_size = 2
    grad_acc_steps = 4
    num_gpu = 2
    mixed_precision='bf16'

In [5]:
from utils import load_model_and_may_interpolate
from modeling_utils import _get_large_config

In [6]:
from beit3_seg import BEiT3SegForUniversalSegmentation
    
mask2former_config = AutoConfig.from_pretrained("facebook/mask2former-swin-base-coco-panoptic", )
mask2former_config.backbone_config = dict(
    beit3_args=_get_large_config(
        img_size=HpConfig.img_size,
        drop_path_rate=HpConfig.drop_path,
        checkpoint_activations=False,
    ),
    deform_num_heads=16,
    deform_ratio=0.5,
    interaction_indexes=[[0, 5], [6, 11], [12, 17], [18, 23]],

    init_values=1e-6,
    conv_inplane=64,
    n_points=4,
    cffn_ratio=0.25,
    with_cp=HpConfig.grad_ckpt,
    num_segments = 1000,
)
mask2former_config.backbone_dim = 1024
mask2former_config.num_labels = 3

mask2former_config.use_text_cross_attn = True
mask2former_config.use_text_features = True
mask2former_config.use_text_contrastive_loss = True
mask2former_config.use_objectness_loss = False

mask2former_config.match_once_only = False
mask2former_config.drop_first_ce_loss = False
mask2former_config.encoder_layers=6
mask2former_config.decoder_layers=10

beit3_seg = BEiT3SegForUniversalSegmentation(mask2former_config)
beit3_seg = beit3_seg.apply(beit3_seg._init_weights)
beit3_seg.model.pixel_level_module.encoder.init_weights()

In [None]:
beit3_seg

In [None]:
tokenizer = XLMRobertaTokenizer("./beit3.spm")
tokenizer.add_tokens(["<WLS>"])

In [None]:
tokenizer.tokenize("<WLS>dog;<WLS>cat;<WLS>rabbit;")

In [10]:
bs = 4
pixel_values = torch.randn(bs, 3, HpConfig.img_size, HpConfig.img_size)
input_ids = tokenizer(["<WLS>dog;<WLS>cat;<WLS>rabbit;"]*bs, return_tensors="pt")["input_ids"]
cat_input_ids = torch.tensor([[0, 3, 6] for _ in range(bs)])
mask_labels =[torch.randint(0, 2, (2, HpConfig.img_size, HpConfig.img_size)).float().to("cuda") for _ in range(bs)]
class_labels = torch.tensor([[1,2] for _ in range(bs)])

In [11]:
beit3_seg = beit3_seg.to("cuda").eval()

In [None]:
with torch.no_grad():
    outputs = beit3_seg(
        pixel_values=pixel_values.to("cuda"),
        input_ids=input_ids.to("cuda"),
        cat_input_ids=cat_input_ids.to("cuda"),
        mask_labels=mask_labels,
        class_labels=class_labels.to("cuda"),
    )

In [None]:
outputs.loss