In [None]:
import os
import json
import random
import numpy as np
from collections import Counter, OrderedDict
from tqdm.auto import tqdm
import wandb

import cv2
from PIL import Image
from matplotlib import pyplot as plt

import albumentations as A
import albumentations.augmentations.functional as F
from albumentations.pytorch import ToTensorV2

import torch
import torch.nn as nn
import torch.nn.functional as F
torch.set_printoptions(sci_mode=False)
from torch.utils.data import Dataset, DataLoader, default_collate
from timm.models.layers import LayerNorm2d
import torchshow

from utils import load_model_and_may_interpolate
from modeling_utils import _get_base_config, _get_large_config

torch.backends.cudnn.benchmark = True

In [2]:
logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

In [None]:
logit_scale.exp()

In [4]:
%load_ext autoreload
%autoreload 2

In [5]:
from transformers import XLMRobertaTokenizer, AutoConfig
from transformers import AutoImageProcessor, XLMRobertaTokenizer
from torchscale.architecture.config import EncoderConfig
from lion_pytorch import Lion
from accelerate import Accelerator
from accelerate.utils import set_seed
from accelerate import notebook_launcher, DistributedDataParallelKwargs

In [6]:
from BEiT3_adapter.panoptic_dataset import COCOPanopticDataset

In [7]:
class HpConfig:
    img_size = 640
    drop_path = 0.1
    batch_size = 2
    val_batch_size = 1
    grad_acc_steps = 1
    lr = 1e-4
    weight_decay = 0.05
    grad_ckpt = False
    num_gpu = 2
    mixed_precision='bf16'
    wls_token = '<wls>'
    sep_token = '▁;'

In [8]:
def get_dataloaders(accelerator):
    tokenizer = XLMRobertaTokenizer("../beit3_weights/beit3.spm")
    tokenizer.add_tokens([HpConfig.wls_token, HpConfig.sep_token])
    
    coco_mask2former_processor = AutoImageProcessor.from_pretrained(
        "facebook/mask2former-swin-base-coco-panoptic",
        do_resize=False, do_rescale=True, do_normalize=True, ignore_index=0,
    )
    ade_mask2former_processor = AutoImageProcessor.from_pretrained(
        "facebook/mask2former-swin-large-ade-panoptic",
        do_resize=False, do_rescale=True, do_normalize=True, ignore_index=0,
    )

    train_transform = A.Compose(
        [
            A.HorizontalFlip(p=0.5),
            A.SmallestMaxSize([HpConfig.img_size*i//10 for i in range(5, 21)], p=1.0),
            A.PadIfNeeded(
                HpConfig.img_size, HpConfig.img_size,
                position=A.PadIfNeeded.PositionType.TOP_LEFT,
                border_mode=cv2.BORDER_CONSTANT,
                ),
            A.RandomCrop(HpConfig.img_size, HpConfig.img_size),
        ]
    )

    val_trainform = A.Compose(
        [
            A.LongestMaxSize(HpConfig.img_size, p=1.0),
            A.PadIfNeeded(
                HpConfig.img_size, HpConfig.img_size,
                position=A.PadIfNeeded.PositionType.TOP_LEFT,
                border_mode=cv2.BORDER_CONSTANT,
            ),
        ]
    )

    with open('../../datasets/COCO/annotations/panoptic_train2017.json') as file:
        coco_train_ann = json.load(file)
    with open('../../datasets/COCO/annotations/panoptic_val2017.json') as file:
        coco_val_ann = json.load(file)
    with open('../../datasets/ADE20K/from_mmdet/ADEChallengeData2016/ade20k_panoptic_val.json') as file:
        ade_val_ann = json.load(file)

    coco_train_dataset = COCOPanopticDataset(
        coco_train_ann,
        '../../datasets/COCO/train2017',
        '../../datasets/COCO/annotations/panoptic_train2017',
        transform=train_transform,
        processor=coco_mask2former_processor,
        use_text=True,
        tokenizer=tokenizer,
        sep_token=HpConfig.sep_token,
        use_sep=True,
        num_sampled_label=133,
        wls_token=HpConfig.wls_token,
        max_sep_num=3,
        # use_sep=False,
    )

    coco_val_dataset = COCOPanopticDataset(
        coco_val_ann,
        '../../datasets/COCO/val2017',
        '../../datasets/COCO/annotations/panoptic_val2017',
        transform=val_trainform,
        processor=coco_mask2former_processor,
        use_text=True,
        tokenizer=tokenizer,
        sep_token=HpConfig.sep_token,
        use_sep=True,
        num_sampled_label=133,
        wls_token=HpConfig.wls_token,
        max_sep_num=1,
        # use_sep=False,
    )

    ade_val_dataset = COCOPanopticDataset(
        ade_val_ann,
        '../../datasets/ADE20K/from_mmdet/ADEChallengeData2016/images/validation',
        '../../datasets/ADE20K/from_mmdet/ADEChallengeData2016/ade20k_panoptic_val',
        transform=val_trainform,
        processor=ade_mask2former_processor,
        use_text=True,
        tokenizer=tokenizer,
        sep_token=HpConfig.sep_token,
        use_sep=True,
        num_sampled_label=150,
        wls_token=HpConfig.wls_token,
        max_sep_num=1,
        # use_sep=False,
    )

    def custom_collate(batch):
        collated_batch = {}
        
        first_elem = batch[0]
        if 'mask_labels' in first_elem:
            collated_batch['mask_labels'] = [b.pop('mask_labels') for b in batch]
        if 'class_labels' in first_elem:
            collated_batch['class_labels'] = [b.pop('class_labels') for b in batch]
        if 'origin_class_labels' in first_elem:
            collated_batch['origin_class_labels'] = [b.pop('origin_class_labels') for b in batch]
        if 'input_ids' in first_elem:
            collated_batch.update(tokenizer.pad(
                [{'input_ids': b.pop('input_ids')} for b in batch],
                max_length=640, padding=True,
            ))
        
        collated_batch.update(default_collate(batch))
        
        return collated_batch

    coco_train_loader = DataLoader(
        coco_train_dataset,
        batch_size=HpConfig.batch_size,
        shuffle=True,
        num_workers=8,
        pin_memory=True,
        drop_last=True,
        collate_fn=custom_collate,
    )

    coco_val_loader = DataLoader(
        coco_val_dataset,
        batch_size=HpConfig.val_batch_size,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
        drop_last=False,
        collate_fn=custom_collate,
    )

    ade_val_loader = DataLoader(
        ade_val_dataset,
        batch_size=HpConfig.val_batch_size,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
        drop_last=False,
        collate_fn=custom_collate,
    )

    return coco_train_loader, coco_val_loader, ade_val_loader

In [20]:
class CustomEmbedding(nn.Module):
    def __init__(self, old_embedding, new_embedding, split_idx):
        super().__init__()
        self.old_embedding = old_embedding
        self.new_embedding = new_embedding
        self.split_idx = split_idx

    def forward(self, input_ids):
        old_embeds = self.old_embedding(
            input_ids.clamp(max=self.old_embedding.num_embeddings - 1))
        new_embeds = self.new_embedding(
            (input_ids - self.split_idx).clamp(min=0))

        return torch.where(
            input_ids.unsqueeze(-1) < self.split_idx, old_embeds, new_embeds)

def create_model(accelerator, load_weight=True, freeze_backbone=True, interpolate_pos=False, add_new_embedding=False):
    from BEiT3_adapter.beit3_seg_ov_v2 import BEiT3SegForUniversalSegmentation
    
    mask2former_config = AutoConfig.from_pretrained("facebook/mask2former-swin-base-coco-panoptic", )
    mask2former_config.backbone_config = dict(
        beit3_args=_get_large_config(
            img_size=HpConfig.img_size,
            drop_path_rate=HpConfig.drop_path,
            checkpoint_activations=False,
        ),
        deform_num_heads=16,
        deform_ratio=0.5,
        interaction_indexes=[[0, 5], [6, 11], [12, 17], [18, 23]],

        init_values=1e-6,
        conv_inplane=64,
        n_points=4,
        cffn_ratio=0.25,
        with_cp=HpConfig.grad_ckpt,
    )
    mask2former_config.backbone_dim = 768
    mask2former_config.num_labels = 133
    mask2former_config.use_text_cross_attn = True
    mask2former_config.use_text_features = True
    mask2former_config.use_text_contrastive_loss = True
    mask2former_config.use_objectness_loss = True
    mask2former_config.match_once_only = False
    mask2former_config.drop_first_ce_loss = True
    mask2former_config.encoder_layers=6
    mask2former_config.decoder_layers=10
    mask2former_config.objectness_weight = 2

    beit3_seg = BEiT3SegForUniversalSegmentation(mask2former_config)
    beit3_seg = beit3_seg.apply(beit3_seg._init_weights)
    beit3_seg.model.pixel_level_module.encoder.init_weights()

    if load_weight:
        if accelerator.is_main_process:
            print('Loading BEiT3 pretraind weight...')
            load_model_and_may_interpolate(
                '../beit3_weights/beit3_base_patch16_224.pth',
                beit3_seg.model.pixel_level_module.encoder,
                'model|module',
                'beit3.',
            )
            print()
            mask2former_pretrained_weigths = torch.load('./training_checkpoints/vit_adapter_mask2former_coco_768.pth')
            beit3_seg_param_shapes = {n:v.shape for n, v in beit3_seg.state_dict().items()}
            for name, v_shape in [(n, v.shape) for n, v in mask2former_pretrained_weigths.items()]:
                if name in beit3_seg_param_shapes and v_shape != beit3_seg_param_shapes[name]:
                    print('mismatch:', name, v_shape, beit3_seg_param_shapes[name])
                    del mask2former_pretrained_weigths[name]
            r = beit3_seg.load_state_dict(mask2former_pretrained_weigths, strict=False)
            print(r)

    if interpolate_pos:
        if accelerator.is_main_process:
            with torch.no_grad():
                origin_pos = beit3_seg.model.pixel_level_module.encoder.encoder.embed_positions.B.weight[2:130].clone()
                new_pos = F.interpolate(origin_pos.unsqueeze(0).permute(0, 2, 1), 640, mode='linear').permute(0, 2, 1)[0]
                beit3_seg.model.pixel_level_module.encoder.encoder.embed_positions.B.weight[2:640+2] = new_pos
                # beit3_seg.model.transformer_module.psuedo_class_embedder[:512] = new_pos

    if add_new_embedding:
        print('Creating new embedding...')
        old_embedding = beit3_seg.model.pixel_level_module.encoder.text_embed
        new_embedding_init_weight = old_embedding.weight[[0]].detach().clone()
        new_embedding = nn.Embedding(
            1,
            768,
            _weight=new_embedding_init_weight,
        )
        beit3_seg.model.pixel_level_module.encoder.text_embed = CustomEmbedding(
            old_embedding,
            new_embedding,
            64002,
        )

    if freeze_backbone:
        freeze_keywords = [
            'model.pixel_level_module.encoder.text_embed', 
            'model.pixel_level_module.encoder.vision_embed', 
            'model.pixel_level_module.encoder.encoder', 
        ]
        for name, param in beit3_seg.named_parameters():
            if any([kw in name for kw in freeze_keywords]):
                param.requires_grad_(False)
            else:
                param.requires_grad_(True)
        beit3_seg.model.pixel_level_module.encoder.encoder.embed_positions.A.requires_grad_(True)
        beit3_seg.model.pixel_level_module.encoder.encoder.embed_positions.B.requires_grad_(True)
        if add_new_embedding:
            beit3_seg.model.pixel_level_module.encoder.text_embed.new_embedding.weight.requires_grad_(True)
                
        train_names = []
        freeze_names = []
        for name, param in beit3_seg.named_parameters():
            if param.requires_grad:
                train_names.append(name)
            else:
                freeze_names.append(name)

        if accelerator.is_main_process:
            for name in train_names:
                print('o', name)
            for name in freeze_names:
                print('x', name)

    return beit3_seg

In [10]:
def configure_optimizer(accelerator, model):
    def get_parameter_names(model, forbidden_layer_types):
        """
        Returns the names of the model parameters that are not inside a forbidden layer.
        """
        result = []
        for name, child in model.named_children():
            result += [
                f"{name}.{n}"
                for n in get_parameter_names(child, forbidden_layer_types)
                if not isinstance(child, tuple(forbidden_layer_types))
            ]
        # Add model specific parameters (defined with nn.Parameter) since they are not in any child.
        result += list(model._parameters.keys())
        return result
    
    lion_lr = HpConfig.lr / 5
    lion_weight_decay = HpConfig.weight_decay * 5

    no_decay_names = ["bias", "embed_positions", "queries_embedder", "psuedo_class_embedder", "position_embedding"]
    decay_parameters = get_parameter_names(model, [nn.LayerNorm, LayerNorm2d])
    decay_parameters = [name for name in decay_parameters if all([not ndn in name for ndn in no_decay_names])]

    param_groups = {
        "backbone_decay": [],
        "backbone_no_decay": [],
        "head_decay": [],
        "head_no_decay": [],
    }
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if n in decay_parameters and 'beit3' in n:
            param_groups['backbone_decay'].append((n, p))
        elif not n in decay_parameters and 'beit3' in n:
            param_groups['backbone_no_decay'].append((n, p))
        elif n in decay_parameters and not 'beit3' in n:
            param_groups['head_decay'].append((n, p))
        elif not n in decay_parameters and not 'beit3' in n:
            param_groups['head_no_decay'].append((n, p))
        else:
            print(f'Strange param: {n}')

    # for group_name, group in param_groups.items():
    #     print(group_name, len(group))
    #     for n, _ in group:
    #         print(f'    - {n}')

    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in param_groups['head_decay']],
            "weight_decay": lion_weight_decay,
            "lr": lion_lr,
        },
        {
            "params": [p for n, p in param_groups['head_no_decay']],
            "weight_decay": 0.0,
            "lr": lion_lr,
        },
        {
            "params": [p for n, p in param_groups['backbone_decay']],
            "weight_decay": lion_weight_decay,
            "lr": lion_lr*0.2,
        },
        {
            "params": [p for n, p in param_groups['backbone_no_decay']],
            "weight_decay": 0.0,
            "lr": lion_lr*0.2,
        },
    ]

    optimizer = Lion(
        optimizer_grouped_parameters,
        # lr=lion_lr,
        # weight_decay=lion_weight_decay,
    )

    def lr_lambda(step):
        if step < 2000*HpConfig.num_gpu:
            return step/(2000*HpConfig.num_gpu)
        elif step > 40000*HpConfig.num_gpu:
            return 0.01
        elif step > 30000*HpConfig.num_gpu:
            return 0.1
        else:
            return 1
        
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    return optimizer, lr_scheduler

In [11]:
def training_loop(seed: int = 42):
    set_seed(seed)
    # Initialize accelerator
    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    accelerator = Accelerator(
        mixed_precision=HpConfig.mixed_precision,
        gradient_accumulation_steps=HpConfig.grad_acc_steps,
        log_with='wandb',
        kwargs_handlers=[kwargs],
    )

    # Build dataloaders
    coco_train_loader, coco_val_loader, ade_val_loader = get_dataloaders(accelerator)
    # model = create_model(accelerator, load_weight=True, freeze_backbone=True, interpolate_pos=True)
    model = create_model(
        accelerator, load_weight=False, freeze_backbone=True,
        interpolate_pos=False, add_new_embedding=True,
        )

    # if accelerator.is_local_main_process:
    #     model.model.transformer_module = torch.compile(
    #         model.model.transformer_module,
    #         # mode='max-autotune',
    #     )

    optimizer, lr_scheduler = configure_optimizer(accelerator, model)

    coco_train_loader, coco_val_loader, ade_val_loader = accelerator.prepare(
        coco_train_loader, coco_val_loader, ade_val_loader
    )
    model, optimizer, lr_scheduler = accelerator.prepare(
        model, optimizer, lr_scheduler
    )
    # lr_scheduler.step_with_optimizer = False

    model.module.model.pixel_level_module.encoder.encoder = torch.compile(
        model.module.model.pixel_level_module.encoder.encoder,
        mode='max-autotune',
        # mode="reduce-overhead",
    )
    # model.module.model.transformer_module = torch.compile(
    #     model.module.model.transformer_module,
    #     mode='max-autotune',
    #     # mode="reduce-overhead",
    # )

    # if accelerator.is_local_main_process:
    #     accelerator.init_trackers(
    #         "BEiT3_Seg_Acc",
    #         config={
    #             'img_size': 640,
    #         },
    #     )
    #     wandb.run.log_code(
    #         ".",
    #         include_fn=lambda path: path.endswith(".py") or path.endswith(".ipynb"),
    #     )

    global_step = 0
    while global_step < 50000:
        model.train()
        for loader_step, batch in enumerate(tqdm(coco_train_loader, disable=not accelerator.is_local_main_process)):
            optimizer.zero_grad()

            outputs = model(
                pixel_values=batch['pixel_values'],
                input_ids=batch['input_ids'],
                cat_input_ids=batch['cat_token_idxs'],
                text_padding_position=1-batch['attention_mask'],
                class_labels=batch['origin_class_labels'],
                mask_labels=batch['mask_labels'],
                return_loss_dict=True,
            )

            accelerator.backward(outputs.loss)
            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            if not accelerator.optimizer_step_was_skipped:
                lr_scheduler.step()

            step_log = {
                "train": {'loss': outputs.loss},
                "losses": outputs.loss_dict,
                "learning rates": {f"group_{i}":lr for i, lr in enumerate(lr_scheduler.get_last_lr())},
            }
            accelerator.log(step_log, step=global_step)

            if global_step != 0 and global_step % 4000 == 0 or global_step == 2000:
                accelerator.print(f'Saving model on step: {global_step}..')
                accelerator.wait_for_everyone()
                accelerator.save_state(f'training_checkpoints/adapter-v11-{global_step}')
                # accelerator.save_model(model, f'training_checkpoints/adapter-v2-{global_step}')
                accelerator.print('Model Saved!')

            # accelerator.print(f'Saveing training state on step: {global_step}..')
            # accelerator.save_state(output_dir="latest-training-state")
            # accelerator.print('State Saved!')

            if global_step >= 90000:
                break

            global_step += 1

    accelerator.end_training()

In [12]:
# args = (96, )
# notebook_launcher(training_loop, args, num_processes=HpConfig.num_gpu)