In [7]:
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets.folder import default_loader
from torchvision import transforms

from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.data.transforms import RandomResizedCropAndInterpolation
from timm.data import create_transform
from PIL import Image

from tasks.randaugrandaug import RandomAugment

from models.modeling_mplug import BertLMHeadModel, BertModel, BertConfig, FusionModel

from transformers import AutoTokenizer, XLMRobertaTokenizer
from models.predictor import TextGenerator

import json
import time

from accelerate import Accelerator, notebook_launcher
from accelerate.utils import set_seed

import numpy as np
import math

### Visual transformer

In [None]:
class VisionEmbedding(nn.Module):
    """Image to Patch Embedding"""

    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        contain_mask_token=False,
        prepend_cls_token=False,
    ):
        super().__init__()
        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
        )

        if contain_mask_token:
            self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        else:
            self.mask_token = None

        if prepend_cls_token:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        else:
            self.cls_token = None

    def num_position_embeddings(self):
        if self.cls_token is None:
            return self.num_patches
        else:
            return self.num_patches + 1

    def forward(self, x, masked_position=None, **kwargs):
        B, C, H, W = x.shape
        assert (
            H == self.img_size[0] and W == self.img_size[1]
        ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)

        batch_size, seq_len, _ = x.size()

        if masked_position is not None:
            assert self.mask_token is not None
            mask_token = self.mask_token.expand(batch_size, seq_len, -1)
            w = masked_position.unsqueeze(-1).type_as(mask_token)
            x = x * (1 - w) + mask_token * w

        if self.cls_token is not None:
            cls_tokens = self.cls_token.expand(
                batch_size, -1, -1
            )  # stole cls_tokens impl from Phil Wang, thanks
            x = torch.cat((cls_tokens, x), dim=1)

        return x

In [None]:
def build_transform(is_train, args):
    if is_train:
        t = [
            RandomResizedCropAndInterpolation(args["input_size"], scale=(0.5, 1.0), interpolation=args["train_interpolation"]), 
            transforms.RandomHorizontalFlip(),
        ]
        if args["randaug"]:
            t.append(
                RandomAugment(
                    2, 7, isPIL=True, 
                    augs=[
                        'Identity','AutoContrast','Equalize','Brightness','Sharpness', 
                        'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate', 
                    ]))
        t += [
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), 
        ]
        t = transforms.Compose(t)
    else:
        t = transforms.Compose([
            transforms.Resize((args["input_size"], args["input_size"]), interpolation=3), 
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD)
        ])

    return t

### Config

In [2]:
def get_config():
    return {
        'task': None,
        'input_size': 224,
        'train_interpolation': 'bicubic',
        'randaug': True,
        'text_encoder': 'uitnlp/CafeBERT',
        'text_decoder': 'uitnlp/CafeBERT',
        'vision_width': 768,
        'roberta_config': '.\\configs\\config_bert.json',
        'beam_size': 5,
        'min_length': 1,
        'max_length': 10,
        'start_epoch': 0,
        'max_epoch': 10, # default 20
        'batch_size': 2, # default 128
        'seed': 42,
        'lr': 5e-4,
        'min_lr': 1e-6,
        'warmup_epochs': 5,
        'warmup_steps': -1,
        'update_freq': 1,
        'checkpoint_dir': '..\\save_states',
        'eos': '[SEP]'
             
    }

config = get_config()

### Module model

In [None]:
class BEIPLUG(nn.Module):
    def __init__(self, tokenizer: None, config: None) -> None:
        super().__init__()
        
        self.tokenizer = tokenizer
        self.module_setting(config)
        self.visual_encoder = VisionEmbedding(contain_mask_token=True, prepend_cls_token=True)
        self.text_encoder = BertModel.from_pretrained(config['text_encoder'], config=self.config_encoder, add_pooling_layer=False)
        self.fusion_encoder = FusionModel.from_pretrained(config['text_encoder'], config=self.config_fusion, add_pooling_layer=False)
        self.text_decoder = BertLMHeadModel.from_pretrained(config['text_decoder'], config=self.config_decoder)
        self.beam_generator = TextGenerator(args=config, model=self.text_decoder)

    def forward(self, image: None, question: None, answer: None, train: True):
        image = image.to(dtype=next(self.parameters()).dtype)
        image_embeds = self.visual_encoder(image)
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.divice)

        if train:
            answer_target = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100)
            text_output = self.text_encoder(input_ids=question.input_ids, attention_mask=question.attention_mask, return_dict=True)
            text_embeds = text_output.last_hidden_state
            fusion_output = self.fusion_encoder.forward(
                encoder_embeds=text_embeds,
                attention_mask=question.attention_mask,
                encoder_hidden_states=image_embeds,
                encoder_attention_mask=image_atts,
                return_dict=False
            )

            image_output, question_output = fusion_output

            question_output = torch.cat([image_output, question_output], dim=1)
            merge_text_attention = torch.cat([image_atts, question.attention_mask], dím=1)

            answer_output = self.text_decoder.forward(input_ids=answer.input_ids,
                                                    attention_mask=answer.attention_mask,
                                                    encoder_hidden_states=question_output,
                                                    encoder_attention_mask=merge_text_attention,
                                                    labels=answer_target,
                                                    return_dict=True,
                                                    reduction='none'
            )

            loss = answer_output.loss
            loss = loss.sum()/image.size(0)

            return loss

        else:
            text_output = self.text_decoder(question.input_ids, attetion_mask=question.attention_mask, return_dict=True)
            text_embeds = text_output.last_hidden_state
            fusion_output = self.fusion_encoder.forward(
                encoder_embeds=text_embeds,
                attention_mask=question.attention_mask,
                encoder_hidden_states=image_embeds,
                encoder_attention_mask=image_atts,
                return_dict=False
            )

            image_output, question_output = fusion_output
            question_output = torch.cat([image_output, question_output], dim=1)
            merge_text_attention = torch.cat([image_atts, question.attention_mask], dim=1)
            topk_ids, topk_probs = self.generation(question_states=question_output, question_atts=merge_text_attention)

            return topk_ids, topk_probs



    def module_setting(self, config):
        self.config_encoder = BertConfig.from_json_file(config['roberta_config'])   
        self.config_encoder.num_hidden_layers = self.config_encoder.text_encoder_layers
        self.config_fusion = BertConfig.from_json_file(config['roberta_config'])   
        self.config_decoder = BertConfig.from_json_file(config['roberta_config'])
        self.config_decoder.add_cross_attention = True
        self.config_decoder.num_hidden_layers = self.config_decoder.text_decode_layers

    def generation(self, question_states, question_atts):
        encoder_inputs = [question_states, question_atts]
        topk_ids, topk_scores = self.beam_generator.translate_batch(encoder_inputs)  
        return topk_ids, topk_scores

### Dataset

In [None]:
class Dataset(Dataset):
    def __init__(self, data_path: None, image_path: None, is_train: bool, config: None) -> None:
        super().__init__()
        self.is_train = is_train
        self.image_transform = build_transform(is_train=is_train, args=config)
        self.dataset = json.load(open(data_path, 'r'))
        self.question_list = list(self.dataset['annotations'].keys())
        self.image_path = image_path

        
    def __len__(self):
        return len(self.question_list)
    
    def __getitem__(self, idx):
        if self.is_train:
            question_id = self.question_list[idx]
            set_vqa_dict = self.dataset['annotations'][question_id]            
            question = set_vqa_dict['question']
            answer = set_vqa_dict['answer']
            
            # Mapping to get image
            image_id = set_vqa_dict['image_id']
            image_id_path = self.dataset['images'][str(image_id)]
            image_path = self.image_path + "\\" + image_id_path
            image = default_loader(image_path)
            image = self.image_transform(image)

            return image, question, answer
    
        else:
            question_id = self.question_list[idx]
            set_vqa_dict = self.dataset['annotations'][question_id]            
            question = set_vqa_dict['question']
                    
            # Mapping to get image
            image_id = set_vqa_dict['image_id']
            image_id_path = self.dataset['images'][str(image_id)]
            image_path = self.image_path + "\\" + image_id_path
            image = default_loader(image_path)
            image = self.image_transform(image)

            return image, question, question_id

### Dataloader

In [None]:
def create_loader(dataset: Dataset, is_train: bool, batch_size: int, num_workers: int = 0):
    if is_train:
        shuffle = True
        drop_last = True
    else:        
        shuffle = False
        drop_last = False

    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=drop_last
    )

    return loader

### Training with Accelerator notebook

In [None]:
data_path = ".\\data\\"
train_dataset = Dataset(data_path + "training-data\\vlsp2023_train_data.json", 
                        data_path + "training-data\\training-images",
                        is_train=True, config=config
                        )
dev_dataset = Dataset(data_path + "public-test-data\\vlsp2023_dev_data.json", 
                        data_path + "public-test-data\\dev-images",
                        is_train=False, config=config
                        )
test_dataset = Dataset(data_path + "private-test-data\\vlsp2023_test_data.json", 
                        data_path + "private-test-data\\test-images",
                        is_train=False, config=config
                        )

train_loader = create_loader(train_dataset, is_train=True, batch_size=config['batch_size'])
dev_loader = create_loader(dev_dataset, is_train=False, batch_size=config['batch_size'])
test_loader = create_loader(test_dataset, is_train=False, batch_size=config['batch_size'])

Tokenizer

In [9]:
tokenizer = AutoTokenizer.from_pretrained(config['text_encoder'])

ValueError: Couldn't instantiate the backend tokenizer from one of: 
(1) a `tokenizers` library serialization file, 
(2) a slow tokenizer instance to convert or 
(3) an equivalent slow tokenizer class to instantiate and convert. 
You need to have sentencepiece installed to convert a slow tokenizer to a fast one.

Scheduler lr

In [None]:
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
                     start_warmup_value=0, warmup_steps=-1, sched_type="cos"):
    warmup_schedule = np.array([])
    warmup_iters = warmup_epochs * niter_per_ep
    if warmup_steps > 0:
        warmup_iters = warmup_steps
    #print("Set warmup steps = %d" % warmup_iters)
    if warmup_epochs > 0:
        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)

    if sched_type == "cos":
        iters = np.arange(epochs * niter_per_ep - warmup_iters)
        schedule = np.array([
            final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
    elif sched_type == "linear":
        schedule = np.linspace(base_value, final_value, epochs * niter_per_ep - warmup_iters)
    else:
        raise NotImplementedError()

    schedule = np.concatenate((warmup_schedule, schedule))

    assert len(schedule) == epochs * niter_per_ep
    return schedule

In [None]:
def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()

In [None]:
def main():
    set_seed(config['seed'])
    # Initialize accelerator
    accelerator = Accelerator(log_with='wandb')
    device = accelerator.device
    
    # Initialize BEIPLUG model
    model = BEIPLUG(tokenizer=tokenizer, config=config)

    # Compile optimizer
    optimizer = AdamW(model.parameters(), lr=config['lr'])

    # Prepare objects for accelerator 
    model, optimizer, train_loader, dev_loader, test_loader = accelerator.prepare(
        model, optimizer, train_loader, dev_loader, test_loader
    )

    # Initialize experiments tracker
    accelerator.init_trackers()

    total_batch_size = config['batch_size'] * config['update_freq'] * get_world_size()
    num_traning_steps_per_epoch = len(train_loader) // total_batch_size

    # Initialize lr_scheduler
    lr_scheduler = cosine_scheduler(
        base_value=config['lr'], final_value=config['min_lr'], epochs=config['max_epoch'],
        niter_per_ep=num_traning_steps_per_epoch, warmup_epochs=config['warmup_epochs'], warmup_steps=config['warmup_steps']
    )

    start_epoch = config['start_epoch']
    max_epoch = config['max_epoch']
        
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    accelerator.print(f'Number of training parameters {n_parameters}')
    accelerator.print(f"Batch size {config['batch_size']}")
    accelerator.print('Start training')
    start_time = time.time()

    start_steps = config['max_epoch'] * num_traning_steps_per_epoch
 
    # training
    for epoch in range(start_epoch, max_epoch):
        # training
        model.train(True)        
        for data_iter_step, (image, question, answer) in enumerate(train_loader):
            model.zero_grad()
            model.micro_steps = 0   
            step = data_iter_step // config['update_freq']
            global_step = start_steps + step

            if lr_scheduler is not None and data_iter_step % config['update_freq'] == 0:
                for i, param_group in enumerate(optimizer.param_groups):
                    if lr_scheduler is not None:
                        param_group['lr'] = lr_scheduler[global_step] * param_group['lr_scale']

            question_input = tokenizer(question, padding='longest', truncation=True, return_tensors='pt').to(device)
            answer_input = tokenizer(answer, padding='longest', return_tensors='pt').to(device)
            training_loss = model(image, question_input, answer_input, train=True)
            accelerator.log({"training_loss": training_loss})
            accelerator.backward(training_loss)
            optimizer.step()
            optimizer.zero_grad()

        # evaluating
        model.eval()
        for image, question, answer in dev_loader:
            with torch.no_grad():
                question_input = tokenizer(question, padding='longest', truncation=True, return_tensors='pt').to(device)
                answer_input = tokenizer(answer, padding='longest', return_tensors='pt').to(device)
                development_loss = model(image, question_input, answer_input, train=True)
            accelerator.log({"development_loss": development_loss})

        accelerator.print(f"Epoch {epoch+1}: Training loss {training_loss} Evaluation loss {development_loss} .")
        
        # testing
        if epoch >=10 and epoch % 10 == 0:
            model.eval()
            all_predict_answer = {}
            for image, question, question_ids in test_loader:
                with torch.no_grad():
                    question_input = tokenizer(question, padding='longest', truncation=True, return_tensors='pt').to(device)
                    topk_ids, topk_probs = model(image=image, question=question_input, is_train=False)
                topk_ids = accelerator.gather(topk_ids)
                question_ids = accelerator.gather(question_ids)

                for question_id, topk_id in zip(question_ids, topk_ids):
                    predict_answer = tokenizer.decode(topk_id).replace("[SEP]", "").replace("[CLS]", "").replace("[PAD]", "").strip()
                    all_predict_answer[question_id] = predict_answer

                
        # Save model checkpoint after epoch
        accelerator.save_state(config['checkpoint_dir'])


    accelerator.print("End training")
    end_time = time.time()
    accelerator.print(f"Total time {start_time - end_time}")
    accelerator.end_training()

In [None]:
notebook_launcher(main, num_processes=1)