## Draft 1: Run SMA in MMF

In [None]:
!pip install --upgrade pip
!pip install --root-user-action=ignore requests
!pip install --root-user-action=ignore
!pip install -q transformers


In [None]:
## Download data

!pip install kaggle

!sudo rm -rf /root/.kaggle


!mkdir ~/.kaggle

%cd .

!cp kaggle.json ~/.kaggle/


!chmod 600 ~/.kaggle/kaggle.json


!sudo chown `whoami` ~/.kaggle/kaggle.json



!export KAGGLE_CONFIG_DIR='/.kaggle/'

%cd .

!kaggle datasets download -d anhnguyen14/vitextcap-combined-data

!ls '/workspace'

%cd /workspace/M4C
!ls


!pip install -q evaluate
!pip install -q rouge_score
!pip install --upgrade nltk

import nltk
nltk.__version__

from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Dataloader

# !unzip -q '/content/drive/MyDrive/ViTextCap/Data/Object Features/vinvl_extraction.zip' -d Object_Features
# !unzip -q /content/drive/MyDrive/ViTextCap/Data/FastText.zip -d FastText
# !unzip -q '/content/drive/MyDrive/ViTextCap/Data/SwimTextSpotter/ocr_1.zip.zip' -d ocr_features
# !unzip -q '/content/drive/MyDrive/ViTextCap/Data/SwimTextSpotter/ocr_2.zip.zip' -d ocr_features

!sudo apt-get update
!sudo apt-get install unzip


%cd .


!ls

!unzip -q '/content/drive/MyDrive/ViTextCap/Data/combined_data.zip' -d combined_data


## .

!pip install pycocoevalcap


!pip -q install evaluate

from pycocoevalcap.cider.cider import Cider

!pip install scikit-learn


from glob import glob
import nltk
import math
import torch
import numpy as np
from torch import nn
from random import choice
from tqdm.auto import tqdm
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from nltk.tokenize import word_tokenize
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, AutoModel
from sklearn.model_selection import train_test_split
nltk.download('punkt')

class ViTextCapsDataset(Dataset):
  def __init__(self, tokenizer, data_folder_path=None, mask_value=64000.0):

    self.data_paths=glob(data_folder_path+'/*')
    self.dummy_tensor = torch.ones((1, 300))
    self.mask_value = mask_value
    self.tokenizer = tokenizer

  def __getitem__(self, idx):
    sample = np.load(self.data_paths[idx], allow_pickle=True).item()
    return {
            'id': sample['image_id'],
            'captions': sample['captions'],
            'obj_boxes': torch.tensor(sample['obj']['boxes']),
            'obj_features': torch.tensor(sample['obj']['features']),
            'ocr_texts': sample['ocr']['texts'],
            'ocr_boxes': torch.tensor(sample['ocr']['boxes']),
            'ocr_token_embeddings': torch.tensor(sample['ocr']['fasttext_token']) if len(sample['ocr']['fasttext_token']) > 0 else self.dummy_tensor,
            'ocr_rec_features': torch.tensor(sample['ocr']['rec_features']),
            'ocr_det_features': torch.tensor(sample['ocr']['det_features'])
        }


  def __len__(self):
    return len(self.data_paths)

  def collate_fn(self, batch):

    raw_captions = []
    captions_ = []
    obj_boxes_tensor = []
    obj_features_tensor = []
    ocr_boxes_tensor = []
    ocr_token_embeddings_tensor = []
    ocr_rec_features_tensor = []
    ocr_det_features_tensor = []
    texts_ = []


    for each in batch:
      captions = each['captions']
      for i in range(len(each['captions'])):
        raw_captions.append(captions)
        #print('raw_1',raw_captions)
        captions_.append(each['captions'][i])
        #print('captions_1',captions_)

      obj_boxes_tensor.extend([each['obj_boxes']]*len(captions))
      obj_features_tensor.extend([each['obj_features']]*len(captions))

      ocr_boxes_tensor.extend([each['ocr_boxes']]*len(captions))
      ocr_token_embeddings_tensor.extend([each['ocr_token_embeddings']]*len(captions))
      ocr_rec_features_tensor.extend([each['ocr_rec_features']]*len(captions))
      ocr_det_features_tensor.extend([each['ocr_det_features']]*len(captions))

      texts_.extend([each['ocr_texts']]*len(captions))

    # Convert obj list to tensor
    #print(torch.tensor(obj_boxes_tensor).shape)
    obj_boxes_tensor = torch.stack(obj_boxes_tensor)
    #print(torch.tensor(obj_features_tensor).shape)
    obj_features_tensor = torch.stack(obj_features_tensor)

    #print('****\n',obj_boxes_tensor.shape)
    #print(obj_features_tensor.shape)

    #print(obj_boxes_tensor.shape)
    #print(obj_features_tensor.shape)
    # Convert ocr list to tensor
    ocr_boxes_tensor = pad_sequence(ocr_boxes_tensor, batch_first=True, padding_value=self.mask_value)
    ocr_token_embeddings_tensor = pad_sequence(ocr_token_embeddings_tensor, batch_first=True, padding_value=1)
    ocr_rec_features_tensor = pad_sequence(ocr_rec_features_tensor, batch_first=True, padding_value=1)
    ocr_det_features_tensor = pad_sequence(ocr_det_features_tensor, batch_first=True, padding_value=1)



    vs = self.tokenizer.vocab_size + 1
    labels_= []

    # Captions to token
    for i, caption in enumerate(captions_):
      label_ = []

      for token in word_tokenize(caption):

          if token in texts_[i] and token not in self.tokenizer.get_vocab():
            label_.append(texts_[i].index(token) + vs)
          else:
            label_ += self.tokenizer(token)['input_ids'][1: -1]

      label_.append(2) # 2 is <eos> in tokenizer
      labels_.append(torch.tensor(label_))

    # Convert labels_ 2 tensor
    labels_ = pad_sequence(labels_, batch_first=True, padding_value=1)

    dec_mask = torch.ones_like(labels_)
    dec_mask = dec_mask.masked_fill(labels_ == 1, 0) # batch_size, seq_length

    # Get the ocr_attention_mask
    ocr_attn_mask = torch.ones_like(ocr_boxes_tensor)
    ocr_attn_mask = ocr_attn_mask.masked_fill(ocr_boxes_tensor == self.mask_value, 0)[:, :, 0] # batch_size, seq_length
    ocr_boxes_tensor = ocr_boxes_tensor.masked_fill(ocr_boxes_tensor == self.mask_value, 1)

    # Join attention_mask
    obj_attn_mask = torch.ones(size=(obj_boxes_tensor.size(0), obj_boxes_tensor.size(1))) # batch_size, seq_length
    join_attn_mask = torch.cat([obj_attn_mask, ocr_attn_mask, dec_mask], dim=-1)

    return {
          'obj_boxes': obj_boxes_tensor,
          'obj_features': obj_features_tensor,
          'ocr_boxes': ocr_boxes_tensor,
          'ocr_token_embeddings': ocr_token_embeddings_tensor,
          'ocr_rec_features': ocr_rec_features_tensor,
          'ocr_det_features': ocr_det_features_tensor,
          'join_attn_mask': join_attn_mask,
          'labels': labels_,
          'texts': texts_,
          'raw_captions': raw_captions
    }

#tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base-v2")
#phobert_model = AutoModel.from_pretrained("vinai/phobert-base-v2")

!pip -q install rouge_score

# Create Dataset
train = ViTextCapsDataset(tokenizer,
                        'combined_data/train')
dev = ViTextCapsDataset(tokenizer,
                        'combined_data/dev')
test = ViTextCapsDataset(tokenizer,
                        'combined_data/test')

from torch.utils.data import DataLoader
train_dataloader = DataLoader(train, batch_size=4, shuffle=True, collate_fn=train.collate_fn)
test_dataloader = DataLoader(test, batch_size=4, shuffle=True, collate_fn=test.collate_fn)
dev_dataloader = DataLoader(dev, batch_size=4, shuffle=True, collate_fn=dev.collate_fn)

next(iter(train_dataloader))

import evaluate
bleu = evaluate.load('bleu')
rouge = evaluate.load('rouge')
meteor = evaluate.load('meteor')
cider_scorer = Cider()



class SMA(Pythia):
    def __init__(self, config):
        super().__init__(config)

    def build(self):
        self.mmt_config = BertConfig(**self.config.mmt)
        self.mmt = MMT(self.mmt_config)
        self.so_to_mmt_in = nn.Linear(3*1536, self.mmt_config.hidden_size)
        self.st_to_mmt_in = nn.Linear(3*1536, self.mmt_config.hidden_size)
        self.so_layer_norm = BertLayerNorm(self.mmt_config.hidden_size)
        self.st_layer_norm = BertLayerNorm(self.mmt_config.hidden_size)
        self.so_drop = nn.Dropout(0.1)
        self.st_drop = nn.Dropout(0.1)
        self.linear_go_to_mmt_in = nn.Linear(2048, self.mmt_config.hidden_size)
        self.linear_gt_to_mmt_in = nn.Linear(300, self.mmt_config.hidden_size)
        self.go_layer_norm = BertLayerNorm(self.mmt_config.hidden_size)
        self.gt_layer_norm = BertLayerNorm(self.mmt_config.hidden_size)
        self.go_drop = nn.Dropout(0.1)
        self.gt_drop = nn.Dropout(0.1)
        self.linear_updated_ocr_to_mmt_in = nn.Linear(300, self.mmt_config.hidden_size)
        self.updated_ocr_layer_norm = BertLayerNorm(self.mmt_config.hidden_size)
        self.updated_ocr_drop = nn.Dropout(self.config.ocr.dropout_prob)
        self.linear_joint = nn.Linear(1536,768)
        self.answer_processor =  registry.get(self._datasets[0] + "_answer_processor")
        self.ocr_ptr_net = OcrPtrNet(**self.config.classifier.ocr_ptr_net)
        # modules requiring custom learning rates (usually for finetuning)
        self.finetune_modules = []
        self._build_obj_encoding()
        self._build_ocr_encoding()
        # init feature embedding for "image"
        setattr(self, "image_feature_dim", self.config["image_feature_dim"])
        self.feature_embeddings_out_dim = 0
        feature_attn_model_params = self.config["image_feature_embeddings"][0]
        feature_embedding = ImageEmbedding(
            getattr(self, "image_feature_dim"),
            self.text_embeddings_out_dim,
            **feature_attn_model_params
        )
        self.feature_embeddings_out_dim += feature_embedding.out_dim
        self.feature_embeddings_out_dim *= getattr(self, "image_feature_dim")
        setattr(
            self, "image_feature_embeddings_out_dim", self.feature_embeddings_out_dim
        )
        del self.feature_embeddings_out_dim
        setattr(
            self,
            "image_feature_embedding",
            feature_embedding
        )
        # init feature embedding for "context"
        setattr(self, "context_feature_dim", self.config["context_feature_dim"])
        self.feature_embeddings_out_dim = 0
        feature_attn_model_params = self.config["context_feature_embeddings"][0]
        feature_embedding = ImageEmbedding(
            getattr(self, "context_feature_dim"),
            self.text_embeddings_out_dim,
            **feature_attn_model_params
        )
        self.feature_embeddings_out_dim += feature_embedding.out_dim
        self.feature_embeddings_out_dim *= getattr(self, "context_feature_dim")
        setattr(
            self, "context_feature_embeddings_out_dim", self.feature_embeddings_out_dim
        )
        del self.feature_embeddings_out_dim
        setattr(
            self,
            "context_feature_embedding",
            feature_embedding
        )
        num_choices = registry.get(self._datasets[0] + "_num_final_outputs")
        self.classifier = ClassifierLayer(
            self.config["classifier"]["type"],
            in_dim=768,
            out_dim=num_choices-50,
            **self.config["classifier"]["params"]
        )
        # Modify the output layer to output a sequence of words
        self.classifier = nn.LSTM(input_size=self.mmt_config.hidden_size, hidden_size=self.mmt_config.hidden_size, num_layers=1)
        # Add a linear layer to map the LSTM output to the vocabulary size
        self.linear = nn.Linear(self.mmt_config.hidden_size, vocab_size)


    def _build_obj_encoding(self):
        self.obj_dim = 2048
        # object appearance feature: Faster R-CNN
        self.obj_faster_rcnn_fc7 = ImageEncoder(
            encoder_type='finetune_faster_rcnn_fpn_fc7',
            in_dim=2048,
            weights_file='detectron/fc6/fc7_w.pkl',
            bias_file='detectron/fc6/fc7_b.pkl',
            model_data_dir=self.config["model_data_dir"]
        )
        # apply smaller lr to pretrained Faster R-CNN fc7
        self.finetune_modules.append({
            'module': self.obj_faster_rcnn_fc7,
            'lr_scale': 0.1,
        })
        # OBJ location feature: relative bounding box coordinates (4-dim)
        self.linear_obj_bbox_to_mmt_in = nn.Linear(
            4, self.obj_dim
        )
        self.obj_feat_layer_norm = nn.LayerNorm(self.obj_dim)
        self.obj_bbox_layer_norm = nn.LayerNorm(self.obj_dim)
        self.obj_drop = nn.Dropout(0.1)

    def _build_ocr_encoding(self):
        self.ocr_fastext_dim = 300
        self.ocr_phoc_dim = 604
        self.ocr_RCNN_dim = 2048
        self.transformer_cnn_dim = 512
        # OCR appearance feature: Faster R-CNN
        self.ocr_faster_rcnn_fc7 = ImageEncoder(
            encoder_type='finetune_faster_rcnn_fpn_fc7',
            in_dim=2048,
            weights_file='detectron/fc6/fc7_w.pkl',
            bias_file='detectron/fc6/fc7_b.pkl',
            model_data_dir=self.config["model_data_dir"]
        )
        self.finetune_modules.append({
            'module': self.ocr_faster_rcnn_fc7,
            'lr_scale': 0.1,
        })

        # OCR appearance feature: relative Fasttext + PHOC + FasterRCNN
        self.linear_ocr_appear_to_mmt_in = nn.Linear(
            self.ocr_fastext_dim+self.ocr_RCNN_dim+self.ocr_phoc_dim+self.transformer_cnn_dim, self.ocr_fastext_dim
            # self.ocr_fastext_dim+self.ocr_RCNN_dim+self.ocr_phoc_dim, self.ocr_fastext_dim
        )
        # OCR location feature: relative bounding box coordinates (4-dim)
        self.linear_ocr_bbox_to_mmt_in = nn.Linear(
            4, self.ocr_fastext_dim
        )
        self.ocr_feat_layer_norm = nn.LayerNorm(self.ocr_fastext_dim)
        self.ocr_bbox_layer_norm = nn.LayerNorm(self.ocr_fastext_dim)
        self.ocr_drop = nn.Dropout(0.1)

    def _forward_obj_encoding(self, sample_list):
        # object appearance feature: Faster R-CNN fc7
        obj_fc6 = sample_list.image_feature_0[:,:36,:]
        obj_fc7 = self.obj_faster_rcnn_fc7(obj_fc6)
        obj_fc7 = F.normalize(obj_fc7, dim=-1)

        obj_feat = obj_fc7
        obj_bbox = sample_list.obj_bbox[:,:36]
        obj_mmt_in = (
            self.obj_feat_layer_norm(
                obj_feat
            ) + self.obj_bbox_layer_norm(
                self.linear_obj_bbox_to_mmt_in(obj_bbox)
            )
        )
        obj_mmt_in = self.obj_drop(obj_mmt_in)
        return obj_mmt_in

    def _forward_ocr_encoding(self, sample_list):
        # OCR FastText feature (300-dim)
        ocr_fasttext = sample_list.context_feature_0
        ocr_fasttext = F.normalize(ocr_fasttext, dim=-1)
        assert ocr_fasttext.size(-1) == 300

        # OCR PHOC feature (604-dim)
        ocr_phoc = sample_list.context_phoc
        ocr_phoc = F.normalize(ocr_phoc, dim=-1)
        assert ocr_phoc.size(-1) == 604

        # OCR appearance feature: Faster R-CNN fc7
        ocr_fc6 = sample_list.image_feature_1[:,:ocr_fasttext.size(1),:]
        ocr_fc7 = self.ocr_faster_rcnn_fc7(ocr_fc6)
        ocr_fc7 = F.normalize(ocr_fc7, dim=-1)
        assert ocr_fc7.size(-1) == 2048

        # OCR appearance feature: Transformer global representation feature
        ocr_trans = sample_list.image_feature_2[:,:ocr_fasttext.size(1),:]
        ocr_trans = F.normalize(ocr_trans, dim=-1)
        assert ocr_trans.size(-1) == 512

        ocr_feat = torch.cat(
            [ocr_fasttext, ocr_fc7, ocr_phoc, ocr_trans],
            # [ocr_fasttext, ocr_fc7, ocr_phoc],
            dim=-1
        )
        
        ocr_bbox = sample_list.ocr_bbox.coordinates
        ocr_mmt_in = (
                    self.ocr_feat_layer_norm(
                        self.linear_ocr_appear_to_mmt_in(ocr_feat)
                    ) + self.ocr_bbox_layer_norm(
                        self.linear_ocr_bbox_to_mmt_in(ocr_bbox)
                    )
        )
        ocr_mmt_in = self.ocr_drop(ocr_mmt_in)
        return ocr_mmt_in

    def get_optimizer_parameters(self, config):
        optimizer_param_groups = []

        base_lr = config.optimizer_attributes.params.lr
        # collect all the parameters that need different/scaled lr
        finetune_params_set = set()
        for m in self.finetune_modules:
            optimizer_param_groups.append({
                "params": list(m['module'].parameters()),
                "lr": base_lr * m['lr_scale']
            })
            finetune_params_set.update(list(m['module'].parameters()))
        # remaining_params are those parameters w/ default lr
        remaining_params = [
            p for p in self.parameters() if p not in finetune_params_set
        ]
        # put the default lr parameters at the beginning
        # so that the printed lr (of group 0) matches the default lr
        optimizer_param_groups.insert(0, {"params": remaining_params})

        return optimizer_param_groups

    def forward(self, sample_list):
        txt_inds = sample_list.text
        txt_mask = _get_mask(sample_list.text_len, sample_list.text.size(1))
        text_bert_out = self.text_bert(txt_inds=txt_inds,txt_mask=txt_mask)
        sample_list.text = text_bert_out

        _, s_o, s_oo, s_ot, s_t, s_tt, s_to = self.process_text_embedding(sample_list)

        obj_encoded_feats = self._forward_obj_encoding(sample_list)
        ocr_encoded_feats = self._forward_ocr_encoding(sample_list)
        g_o = self.process_feature_embedding(
            "image", sample_list, s_o, s_homo=s_oo, s_hetero=s_ot, 
            pre_ques_embed=sample_list.text, obj_feats=obj_encoded_feats, ocr_feats=ocr_encoded_feats
        )
        g_t, updated_ocr = self.process_feature_embedding(
            "context", sample_list, s_t, s_homo=s_tt, s_hetero=s_to, 
            pre_ques_embed=sample_list.text, obj_feats=obj_encoded_feats, ocr_feats=ocr_encoded_feats
        )  # torch.Size([128, 350])

        s_o = torch.cat((s_o,s_oo,s_ot),dim=-1)
        s_t = torch.cat((s_t,s_tt,s_to),dim=-1)
        s_o = self.so_drop(self.so_layer_norm(self.so_to_mmt_in(s_o.unsqueeze(1)) ))
        s_t = self.st_drop(self.st_layer_norm(self.st_to_mmt_in(s_t.unsqueeze(1)) ))
        so_mask = torch.ones(s_o.size(0),s_o.size(1),dtype=torch.float32,device=s_o.device)
        st_mask = torch.ones(s_t.size(0),s_t.size(1),dtype=torch.float32,device=s_t.device)        
        g_o = self.go_drop(self.go_layer_norm(self.linear_go_to_mmt_in(g_o)))
        g_t = self.gt_drop(self.gt_layer_norm(self.linear_gt_to_mmt_in(g_t)))
        go_mask = torch.ones(g_o.size(0),g_o.size(1),dtype=torch.float32,device=g_o.device)
        gt_mask = torch.ones(g_t.size(0),g_t.size(1),dtype=torch.float32,device=g_t.device)

        ocr_emb = self.updated_ocr_drop(self.updated_ocr_layer_norm(self.linear_updated_ocr_to_mmt_in(updated_ocr)))
        ocr_tokens = sample_list.context
        # binary mask of valid OCR vs padding
        ocr_nums = sample_list.context_info_0.max_features
        ocr_mask = _get_mask(ocr_nums, ocr_tokens.size(1))
       
        if self.training:
            prev_inds = sample_list.train_prev_inds.clone()
            mmt_results = self.mmt(
                s_o, so_mask, s_t, st_mask, 
                g_o, go_mask, g_t, gt_mask, 
                ocr_emb=ocr_emb,
                ocr_mask=ocr_mask,
                fixed_ans_emb=self.classifier.module.weight,
                prev_inds=prev_inds,
            )
            g_O = mmt_results["mmt_so_output"]*mmt_results["mmt_go_output"]
            g_T = mmt_results["mmt_st_output"]*mmt_results["mmt_gt_output"]
            update_joint_embedding = torch.cat((g_O, g_T),dim=-1) # torch.Size([32, 1, 1536])
            update_joint_embedding = self.linear_joint(update_joint_embedding)
            mmt_dec_output = mmt_results["mmt_dec_output"] # torch.Size([32, 12, 768])
            score_feature = torch.cat([update_joint_embedding, mmt_dec_output[:,1:,:]], dim=-2)
            mmt_ocr_output = mmt_results["mmt_ocr_output"] 
            fixed_scores = self.classifier(score_feature)
            dynamic_ocr_scores = self.ocr_ptr_net(
                score_feature, mmt_ocr_output, ocr_mask
            )
            scores = torch.cat([fixed_scores, dynamic_ocr_scores], dim=-1)
        else:
            dec_step_num = sample_list.train_prev_inds.size(1)
            # fill prev_inds with BOS_IDX at index 0, and zeros elsewhere
            prev_inds = torch.zeros_like(
                sample_list.train_prev_inds
            )
            prev_inds[:, 0] = self.answer_processor.BOS_IDX
            
            # greedy decoding at test time
            for t in range(dec_step_num):
                mmt_results = self.mmt(
                    s_o, so_mask, s_t, st_mask, 
                    g_o, go_mask, g_t, gt_mask, 
                    ocr_emb=ocr_emb,
                    ocr_mask=ocr_mask,
                    fixed_ans_emb=self.classifier.module.weight,
                    prev_inds=prev_inds,
                )
                if t==0:
                    g_O = mmt_results["mmt_so_output"]*mmt_results["mmt_go_output"]
                    g_T = mmt_results["mmt_st_output"]*mmt_results["mmt_gt_output"]
                    update_joint_embedding = torch.cat((g_O, g_T),dim=-1) # torch.Size([32, 1, 1536])
                    update_joint_embedding = self.linear_joint(update_joint_embedding)
                mmt_dec_output = mmt_results["mmt_dec_output"]
                score_feature = torch.cat([update_joint_embedding, mmt_dec_output[:,1:,:]], dim=-2)
                mmt_ocr_output = mmt_results["mmt_ocr_output"]
                fixed_scores = self.classifier(score_feature)
                dynamic_ocr_scores = self.ocr_ptr_net(
                    score_feature, mmt_ocr_output, ocr_mask
                )
                scores = torch.cat([fixed_scores, dynamic_ocr_scores], dim=-1)
                # find the highest scoring output (either a fixed vocab
                # or an OCR), and add it to prev_inds for auto-regressive
                # decoding
                argmax_inds = scores.argmax(dim=-1)
                prev_inds[:, 1:] = argmax_inds[:, :-1]

        return {"scores": scores} 

class TextBert(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        # self.apply(self.init_weights)  # old versions of pytorch_transformers
        self.init_weights()

    def forward(self, txt_inds, txt_mask):
        encoder_inputs = self.embeddings(txt_inds)
        attention_mask = txt_mask

        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        assert not extended_attention_mask.requires_grad
        head_mask = [None] * self.config.num_hidden_layers

        encoder_outputs = self.encoder(
            encoder_inputs,
            extended_attention_mask,
            head_mask=head_mask
        )
        seq_output = encoder_outputs[0]

        return seq_output

class MMT(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.prev_pred_embeddings = PrevPredEmbeddings(config)
        self.encoder = BertEncoder(config)
        # self.apply(self.init_weights)  # old versions of pytorch_transformers
        self.init_weights()

    def forward(self,
                so,so_mask,st,st_mask,
                go,go_mask,gt,gt_mask,
                ocr_emb,
                ocr_mask,
                fixed_ans_emb,
                prev_inds):
        # build embeddings for predictions in previous decoding steps
        # fixed_ans_emb is an embedding lookup table for each fixed vocabulary
        dec_emb = self.prev_pred_embeddings(fixed_ans_emb, ocr_emb, prev_inds)
        # a zero mask for decoding steps, so the encoding steps elements can't
        # attend to decoding steps.
        # A triangular causal mask will be filled for the decoding steps
        # later in extended_attention_mask
        dec_mask = torch.zeros(dec_emb.size(0),dec_emb.size(1),dtype=torch.float32,device=dec_emb.device)
        encoder_inputs = torch.cat(
            [so,st,go,gt,ocr_emb,dec_emb],
            dim=1
        )
        attention_mask = torch.cat(
            [so_mask,st_mask,go_mask,gt_mask,ocr_mask,dec_mask],
            dim=1
        )

        # offsets of each modality in the joint embedding space
        so_max_num = so_mask.size(-1)
        st_max_num = st_mask.size(-1)
        go_max_num = go_mask.size(-1)
        gt_max_num = gt_mask.size(-1)
        ocr_max_num = ocr_mask.size(-1)
        dec_max_num = dec_mask.size(-1)
        so_begin = 0
        so_end = so_max_num
        st_begin = so_max_num
        st_end = st_begin + st_max_num
        go_begin = so_max_num + st_max_num 
        go_end = go_begin + go_max_num
        gt_begin = so_max_num  + st_max_num  + go_max_num 
        gt_end = gt_begin + gt_max_num
        ocr_begin = so_max_num + st_max_num + go_max_num + gt_max_num
        ocr_end = ocr_begin + ocr_max_num

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, from_seq_length, to_seq_length]
        # So we can broadcast to
        # [batch_size, num_heads, from_seq_length, to_seq_length]
        to_seq_length = attention_mask.size(1)
        from_seq_length = to_seq_length

        # generate the attention mask similar to prefix LM
        # all elements can attend to the elements in encoding steps
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        extended_attention_mask = extended_attention_mask.repeat(
            1, 1, from_seq_length, 1
        )
        # decoding step elements can attend to themselves in a causal manner
        extended_attention_mask[:, :, -dec_max_num:, -dec_max_num:] = \
            _get_causal_mask(dec_max_num, encoder_inputs.device)

        # flip the mask, so that invalid attention pairs have -10000.
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        assert not extended_attention_mask.requires_grad
        head_mask = [None] * self.config.num_hidden_layers

        encoder_outputs = self.encoder(
            encoder_inputs,
            extended_attention_mask,
            head_mask=head_mask
        )

        mmt_seq_output = encoder_outputs[0]
        mmt_so_output = mmt_seq_output[:, so_begin:so_end]
        mmt_st_output = mmt_seq_output[:, st_begin:st_end]
        mmt_go_output = mmt_seq_output[:, go_begin:go_end]
        mmt_gt_output = mmt_seq_output[:, gt_begin:gt_end]
        mmt_ocr_output = mmt_seq_output[:, ocr_begin:ocr_end]
        mmt_dec_output = mmt_seq_output[:, -dec_max_num:]

        results = {
            'mmt_seq_output': mmt_seq_output,
            'mmt_so_output': mmt_so_output,
            'mmt_st_output': mmt_st_output,
            'mmt_go_output': mmt_go_output,
            'mmt_gt_output': mmt_gt_output,
            'mmt_ocr_output': mmt_ocr_output,
            'mmt_dec_output': mmt_dec_output,
        }
        return results

class OcrPtrNet(nn.Module):
    def __init__(self, hidden_size, query_key_size=None):
        super().__init__()

        if query_key_size is None:
            query_key_size = hidden_size
        self.hidden_size = hidden_size
        self.query_key_size = query_key_size

        self.query = nn.Linear(hidden_size, query_key_size)
        self.key = nn.Linear(hidden_size, query_key_size)

    def forward(self, query_inputs, key_inputs, attention_mask):
        extended_attention_mask = (1.0 - attention_mask) * -10000.0
        assert extended_attention_mask.dim() == 2
        extended_attention_mask = extended_attention_mask.unsqueeze(1)

        query_layer = self.query(query_inputs)
        if query_layer.dim() == 2:
            query_layer = query_layer.unsqueeze(1)
            squeeze_result = True
        else:
            squeeze_result = False
        key_layer = self.key(key_inputs)

        scores = torch.matmul(
            query_layer,
            key_layer.transpose(-1, -2)
        )
        scores = scores / math.sqrt(self.query_key_size)
        scores = scores + extended_attention_mask
        if squeeze_result:
            scores = scores.squeeze(1)

        return scores

class PrevPredEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()

        MAX_DEC_LENGTH = 100
        MAX_TYPE_NUM = 5
        hidden_size = config.hidden_size
        ln_eps = config.layer_norm_eps

        self.position_embeddings = nn.Embedding(MAX_DEC_LENGTH, hidden_size)
        self.token_type_embeddings = nn.Embedding(MAX_TYPE_NUM, hidden_size)

        self.ans_layer_norm = BertLayerNorm(hidden_size, eps=ln_eps)
        self.ocr_layer_norm = BertLayerNorm(hidden_size, eps=ln_eps)
        self.emb_layer_norm = BertLayerNorm(hidden_size, eps=ln_eps)
        self.emb_dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, ans_emb, ocr_emb, prev_inds):
        assert prev_inds.dim() == 2 and prev_inds.dtype == torch.long
        assert ans_emb.dim() == 2

        batch_size = prev_inds.size(0)
        seq_length = prev_inds.size(1)
        ans_num = ans_emb.size(0)

        # apply layer normalization to both answer embedding and OCR embedding
        # before concatenation, so that they have the same scale
        ans_emb = self.ans_layer_norm(ans_emb)
        ocr_emb = self.ocr_layer_norm(ocr_emb)
        assert ans_emb.size(-1) == ocr_emb.size(-1)
        ans_emb = ans_emb.unsqueeze(0).expand(batch_size, -1, -1)
        ans_ocr_emb_cat = torch.cat([ans_emb, ocr_emb], dim=1)
        raw_dec_emb = _batch_gather(ans_ocr_emb_cat, prev_inds)

        # Add position and type embedding for previous predictions
        position_ids = torch.arange(
            seq_length,
            dtype=torch.long,
            device=ocr_emb.device
        )
        position_ids = position_ids.unsqueeze(0).expand(batch_size, seq_length)
        position_embeddings = self.position_embeddings(position_ids)
        # Token type ids: 0 -- vocab; 1 -- OCR
        token_type_ids = prev_inds.ge(ans_num).long()
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        embeddings = position_embeddings + token_type_embeddings
        embeddings = self.emb_layer_norm(embeddings)
        embeddings = self.emb_dropout(embeddings)
        dec_emb = raw_dec_emb + embeddings

        return dec_emb

def _get_mask(nums, max_num):
    # non_pad_mask: b x lq, torch.float32, 0. on PAD
    batch_size = nums.size(0)
    arange = torch.arange(0, max_num).unsqueeze(0).expand(batch_size, -1)
    non_pad_mask = arange.to(nums.device).lt(nums.unsqueeze(-1))
    non_pad_mask = non_pad_mask.type(torch.float32)
    return non_pad_mask

@functools.lru_cache(maxsize=32)
def _get_causal_mask(seq_length, device):
    # generate a lower triangular mask
    mask = torch.zeros(seq_length, seq_length, device=device)
    for i in range(seq_length):
        for j in range(i+1):
            mask[i, j] = 1.
    return mask

def _batch_gather(x, inds):
    assert x.dim() == 3
    batch_size = x.size(0)
    length = x.size(1)
    dim = x.size(2)
    x_flat = x.view(batch_size*length, dim)

    batch_offsets = torch.arange(batch_size, device=inds.device) * length
    batch_offsets = batch_offsets.unsqueeze(-1)
    assert batch_offsets.dim() == inds.dim()
    inds_flat = batch_offsets + inds
    results = F.embedding(inds_flat, x_flat)
    return results

def convert_prediction_to_ans(prediction_list, sample_texts, tokenizer):
    raw_predicts = []

    for l, text in zip(prediction_list, sample_texts):
        raw_predict = []
        vocab_predict = []

        for w in l:
            if w > 64000: # Kiểm tra nếu w là ocr token
                if len(vocab_predict) > 0: # Kiểm tra nếu trước đó đã có các vocab token
                    raw_predict.append(tokenizer.decode(vocab_predict))
                    vocab_predict = []

                raw_predict.append(text[w - 64001])
            else:
                vocab_predict += [w]

        if len(vocab_predict) > 0:
                raw_predict.append(tokenizer.decode(vocab_predict))

        caption = ' '.join(raw_predict)
        raw_predicts.append(caption)

    return raw_predicts

def ignore_padding(labels, outputs, padding_values):

  mask = labels != padding_values

  new_outputs = []

  for i, each in enumerate(mask) :
    ignore_outputs = outputs[i][each].tolist()
    new_outputs.append(ignore_outputs)

  return new_outputs

def compute_metrics(outputs, labels, padding_values, raw_captions, ocr_tokens, tokenizer):
    batch_size, seq_length, vocab_size = outputs.size()

    outputs_c = outputs.argmax(dim=-1)
    outputs_c = ignore_padding(labels, outputs_c, padding_values)

    pred_ans = convert_prediction_to_ans(outputs_c, ocr_tokens, tokenizer)

    check = np.random.randn()
    if check > 0.7:
        print('-'*30)
        print(f'prediction: {pred_ans[0]}')
        print(f'ground-tru: {raw_captions[0][0]}')

    bleu1_score = bleu.compute(predictions=pred_ans, references=raw_captions, max_order=1)['bleu']
    bleu2_score = bleu.compute(predictions=pred_ans, references=raw_captions, max_order=2)['bleu']
    bleu3_score = bleu.compute(predictions=pred_ans, references=raw_captions, max_order=3)['bleu']
    bleu4_score = bleu.compute(predictions=pred_ans, references=raw_captions, max_order=4)['bleu']
    rouge_score = rouge.compute(predictions=pred_ans, references=raw_captions)['rougeL']
    meteor_score = meteor.compute(predictions=pred_ans, references=raw_captions)['meteor']

    hypotheses_dict = {i: [h] for i, h in enumerate(pred_ans)}
    references_dict = {i: r for i, r in enumerate(raw_captions)}


    # Compute the CIDEr score
    cider_score, _ = cider_scorer.compute_score(references_dict, hypotheses_dict)

    outputs_1 = outputs.view(-1, vocab_size)
    labels_1 = labels.view(-1)
    mask = labels_1 != padding_values

    outputs_2 = outputs_1[mask]
    labels_2 = labels_1[mask]

    acc = (outputs_2.argmax(dim=-1) == labels_2).sum().item() / len(labels_2)

    return [acc, bleu1_score, bleu2_score, bleu3_score, bleu4_score, rouge_score, meteor_score, cider_score]

def train(model, train_dataloader, test_dataloader, criterion, optimizer, tokenizer, epochs, device='cpu'):

  # Send model to device
  model = model.to(device)

  train_losses, test_losses = [], []
  train_evals, test_evals = {'accuracy': [], 'bleu1': [], 'bleu2': [], 'bleu3': [], 'bleu4': [], 'rouge': [], 'meteor': [], 'cider': []}, {'accuracy': [], 'bleu1': [], 'bleu2': [], 'bleu3': [], 'bleu4': [], 'rouge': [], 'meteor': [], 'cider': []}

  for epoch in range(epochs):

    train_loss, test_loss = 0, 0
    train_eval, test_eval = [], []
    i = 0
    k = len(train_dataloader)

    ### Train ###
    model.train()

    for sample in tqdm(train_dataloader):

      # Forwad pass
      outputs = model(sample, device) # batch_size, T, vocab + ocr_tokens
      loss = criterion(outputs.mT, sample['labels'].to(device))
      optimizer.zero_grad() # Xóa cái optimizer ở vòng lặp trước

      # Calculate loss per batch
      train_loss += loss.item()
      train_eval.append(compute_metrics(outputs, sample['labels'].to(device), 1, sample['raw_captions'], sample['texts'], tokenizer))

      # Optimizer & Backward
      loss.backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)

      optimizer.step() # update weight
      # scheduler.step()
      i += 1
      if i == round(k * 0.8):
        for g in optimizer.param_groups:
          g['lr'] *= 0.95


    ### Evaluate ###
    model.eval()
    with torch.inference_mode():

      for sample in tqdm(test_dataloader):

        # Forward pass
        outputs = model(sample, device) # batch_size, T, vocab + ocr_tokens
        loss = criterion(outputs.mT, sample['labels'].to(device))

        # Calculate loss per batch
        test_loss += loss.item()
        test_eval.append(compute_metrics(outputs, sample['labels'].to(device), 1, sample['raw_captions'], sample['texts'], tokenizer))

    # Preprocess stuff
    train_eval, test_eval = np.array(train_eval), np.array(test_eval)


    # Save checkpoints
    if epoch % 5 == 0:
        PATH = f"model_epoch_{epoch}.pt"  # Save with epoch number
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss.item(),  # Save only the last loss
        }, PATH)

    # Save stuff
    test_loss, train_loss = test_loss/len(test_dataloader), train_loss/len(train_dataloader)
    train_acc, test_acc = train_eval[:, 0].mean(), test_eval[:, 0].mean()
    train_bleu1, test_bleu1 = train_eval[:, 1].mean(), test_eval[:, 1].mean()
    train_bleu2, test_bleu2 = train_eval[:, 2].mean(), test_eval[:, 2].mean()
    train_bleu3, test_bleu3 = train_eval[:, 3].mean(), test_eval[:, 3].mean()
    train_bleu4, test_bleu4 = train_eval[:, 4].mean(), test_eval[:, 4].mean()
    train_rouge, test_rouge = train_eval[:, 5].mean(), test_eval[:, 5].mean()
    train_meteor, test_meteor = train_eval[:, 6].mean(), test_eval[:, 6].mean()
    train_cider, test_cider = train_eval[:, 7].mean(), test_eval[:, 7].mean()


    train_losses.append(train_loss), test_losses.append(test_loss)
    train_evals['accuracy'].append(train_acc), test_evals['accuracy'].append(test_acc)
    train_evals['bleu1'].append(train_bleu1), test_evals['bleu1'].append(test_bleu1)
    train_evals['bleu2'].append(train_bleu2), test_evals['bleu2'].append(test_bleu2)
    train_evals['bleu3'].append(train_bleu3), test_evals['bleu3'].append(test_bleu3)
    train_evals['bleu4'].append(train_bleu4), test_evals['bleu4'].append(test_bleu4)
    train_evals['rouge'].append(train_rouge), test_evals['rouge'].append(test_rouge)
    train_evals['meteor'].append(train_meteor), test_evals['meteor'].append(test_meteor)
    train_evals['cider'].append(train_cider), test_evals['cider'].append(test_cider)

    # Tracking the model
    print(f'Epoch: {epoch}')
    print(f'Train loss    : {train_loss:.4f}  | Test loss   : {test_loss:.4f}')
    print(f'Train acc     : {train_acc:.4f}   | Test acc    : {test_acc:.4f}')
    print(f'Train bleu1   : {train_bleu1:.4f} | Test bleu1  : {test_bleu1:.4f}')
    print(f'Train bleu2   : {train_bleu2:.4f} | Test bleu2  : {test_bleu2:.4f}')
    print(f'Train bleu3   : {train_bleu3:.4f} | Test bleu3  : {test_bleu3:.4f}')
    print(f'Train bleu4   : {train_bleu4:.4f} | Test bleu4  : {test_bleu4:.4f}')
    print(f'Train rouge   : {train_rouge:.4f} | Test rouge  : {test_rouge:.4f}')
    print(f'Train meteor  : {train_meteor:.4f}| Test meteor : {test_meteor:.4f}')
    print(f'Train cider   : {train_cider:.4f} | Test cider  : {test_cider:.4f}')

  return train_losses, train_evals, test_losses, test_evals

def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight.data)



def loss_fn(outputs, labels, padding_value):

  vocab_ocr_size = outputs.size(-1)
  outputs_ = outputs.reshape(-1, vocab_ocr_size) # batch_size * seq_l, vocab + ocr size
  labels_ = labels.reshape(-1) # batch_size * seq_l

  mask = labels_ != padding_value

  outputs_non_pad = outputs_[mask] # batch_size * seq_l - pad, vocab + ocr size
  labels_non_pad = labels_[mask] # batch_size * seq_l - pad

  # Use the original tensor as indices to select rows from the identity matrix
  converted_labels = torch.zeros_like(outputs_non_pad) # batch_size * seq_l - pad, vocab + ocr size
  converted_labels[torch.arange(len(labels_non_pad)), labels_non_pad] = 1

  pos_weight = (converted_labels==0.).sum()/converted_labels.sum()
  criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='sum')

  return criterion(outputs_non_pad, converted_labels)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

phobert_model.embeddings.word_embeddings.requires_grad = False
fixed_ans_emb = phobert_model.embeddings.word_embeddings.weight
model = M4C(obj_in_dim=1024,
            ocr_in_dim=812,
            hidden_size=768,
            n_heads=12,
            d_k=64,
            n_layers=4,
            vocab_size=tokenizer.vocab_size + 1,
            fixed_ans_emb=fixed_ans_emb)
model.apply(initialize_weights);
loss_function = lambda outputs, labels: loss_fn(outputs, labels, padding_value=1)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


# Learning rate warm-up function
warmup = 400
model_size = 768
factor = 0.1
def warmup_lr(current_epoch):
    current_epoch += 1
    return factor * (model_size ** (-0.5) * min(current_epoch ** (-0.5), current_epoch * warmup ** (-1.5)))

# Create a LambdaLR scheduler for learning rate warm-up
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_lr)

import matplotlib.pyplot as plt

plt.plot(np.arange(1, 20000), [warmup_lr(i) for i in range(1, 20000)])

for epoch in range(1, 25):
    scheduler.step()
    print('Epoch {}, lr {}'.format(epoch, optimizer.param_groups[0]['lr']))

base_lr = 1e-4
criterion = nn.CrossEntropyLoss(ignore_index=1)
optimizer = torch.optim.Adam(model.get_optimizer_parameters(base_lr), lr=base_lr)

train_losses, train_accies, test_losses, test_accies = train(model,
                                                             train_dataloader,
                                                             dev_dataloader,
                                                             criterion,
                                                             optimizer,
                                                             tokenizer,
                                                             epochs=20,
                                                             device=device)

checkpoint = {'model': M4C(obj_in_dim=1024,
            ocr_in_dim=812,
            hidden_size=768,
            n_heads=12,
            d_k=64,
            n_layers=4,
            vocab_size=tokenizer.vocab_size + 1,
            fixed_ans_emb=fixed_ans_emb),
              'state_dict': model.state_dict(),
              'optimizer' : optimizer.state_dict()}

torch.save(checkpoint, 'checkpoint.pth')

# Generate x-axis values (epochs)
epochs = range(1, len(train_losses) + 1)
def plot_(epochs, train_losses):
    # Plot the training losses
    plt.plot(epochs, train_losses, marker='o', linestyle='-')
    plt.xlabel('Epoch')
    plt.ylabel('Training Loss')
    plt.title('Training Loss Over Epochs')
    plt.grid(True)
    plt.show()

plot_(epochs, train_losses)
plot_(epochs, train_accies)


import matplotlib.pyplot as plt


epochs = range(1, len(train_accies['accuracy']) + 1)

plt.figure(figsize=(10, 6))

plt.plot(epochs, train_accies['accuracy'], 'b', label='Accuracy')

plt.plot(epochs, train_accies['bleu1'], 'g', label='BLEU-1')
plt.plot(epochs, train_accies['bleu2'], 'r', label='BLEU-2')
plt.plot(epochs, train_accies['bleu3'], 'c', label='BLEU-3')
plt.plot(epochs, train_accies['bleu4'], 'm', label='BLEU-4')
plt.plot(epochs, train_accies['rouge'], 'y', label='ROUGE')
plt.plot(epochs, train_accies['meteor'], 'k', label='METEOR')
plt.plot(epochs, train_accies['cider'], 'orange', label='CIDEr')

plt.title('TRAINING - Metrics Over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Score')
plt.legend()

plt.grid(True)
plt.show()


import matplotlib.pyplot as plt


epochs = range(1, len(test_accies['accuracy']) + 1)

plt.figure(figsize=(10, 6))

plt.plot(epochs, test_accies['accuracy'], 'b', label='Accuracy')

plt.plot(epochs, test_accies['bleu1'], 'g', label='BLEU-1')
plt.plot(epochs, test_accies['bleu2'], 'r', label='BLEU-2')
plt.plot(epochs, test_accies['bleu3'], 'c', label='BLEU-3')
plt.plot(epochs, test_accies['bleu4'], 'm', label='BLEU-4')
plt.plot(epochs, test_accies['rouge'], 'y', label='ROUGE')
plt.plot(epochs, test_accies['meteor'], 'k', label='METEOR')
plt.plot(epochs, test_accies['cider'], 'orange', label='CIDEr')

plt.title('VALIDATION - Metrics Over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Score')
plt.legend()

plt.grid(True)
plt.show()


def test_fn(model, test_dataloader, criterion, optimizer, tokenizer, device='cpu'):
    model.to(device)
    test_eval = []
    test_evals = {'accuracy': [], 'bleu1': [], 'bleu2': [], 'bleu3': [], 'bleu4': [], 'rouge': [], 'meteor': [], 'cider': []}
    model.to(device)
    model.eval()
    test_loss=0
    with torch.inference_mode():

      for sample in tqdm(test_dataloader):

        # Forward pass
        outputs = model(sample, device, test=True) # batch_size, T, vocab + ocr_tokens
        loss = criterion(outputs.mT, sample['labels'].to(device))

        # Calculate loss per batch
        test_loss += loss.item()
        test_eval.append(compute_metrics(outputs, sample['labels'].to(device), 1, sample['raw_captions'], sample['texts'], tokenizer))

    test_eval = np.array(test_eval)

    # Save stuff
    test_loss = test_loss/len(test_dataloader)
    test_acc =  test_eval[:, 0].mean()
    test_bleu1 = test_eval[:, 1].mean()
    test_bleu2 = test_eval[:, 2].mean()
    test_bleu3 = test_eval[:, 3].mean()
    test_bleu4 = test_eval[:, 4].mean()
    test_rouge = test_eval[:, 5].mean()
    test_meteor = test_eval[:, 6].mean()
    test_cider = test_eval[:, 7].mean()


    test_evals['accuracy'].append(test_acc)
    test_evals['bleu1'].append(test_bleu1)
    test_evals['bleu2'].append(test_bleu2)
    test_evals['bleu3'].append(test_bleu3)
    test_evals['bleu4'].append(test_bleu4)
    test_evals['rouge'].append(test_rouge)
    test_evals['meteor'].append(test_meteor)
    test_evals['cider'].append(test_cider)

    # Tracking the model

    print(f'Test loss   : {test_loss:.4f}')
    print(f'Test acc    : {test_acc:.4f}')
    print(f'Test bleu1  : {test_bleu1:.4f}')
    print(f'Test bleu2  : {test_bleu2:.4f}')
    print(f'Test bleu3  : {test_bleu3:.4f}')
    print(f'Test bleu4  : {test_bleu4:.4f}')
    print(f'Test rouge  : {test_rouge:.4f}')
    print(f'Test meteor : {test_meteor:.4f}')
    print(f'Test cider  : {test_cider:.4f}')

    return test_loss, test_evals

model = model.to(device)
test_fn(model, test_dataloader, criterion, optimizer, tokenizer, device=device)



In [None]:
class Pythia(BaseModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self._global_config = registry.get("config")
        self._datasets = self._global_config.datasets.split(",")

    def build(self):
        self._build_word_embedding()
        self._init_text_embeddings("text")
        self._init_feature_encoders("image")
        self._init_feature_embeddings("image")
        self._init_combine_layer("image", "text")
        self._init_classifier(self._get_classifier_input_dim())
        self._init_extras()

    def _build_word_embedding(self):
        assert len(self._datasets) > 0
        text_processor = registry.get(self._datasets[0] + "_text_processor")
        vocab = text_processor.vocab
        self.word_embedding = vocab.get_embedding(torch.nn.Embedding, embedding_dim=300)

    def _init_text_embeddings(self, attr="text"):
        if "embeddings" not in attr:
            attr += "_embeddings"

        text_embeddings = []
        text_embeddings_list_config = self.config[attr]

        embeddings_out_dim = 0

        for text_embedding in text_embeddings_list_config:
            embedding_type = text_embedding.type
            embedding_kwargs = ConfigNode(text_embedding.params)

            self._update_text_embedding_args(embedding_kwargs)

            embedding = TextEmbedding(embedding_type, **embedding_kwargs)

            text_embeddings.append(embedding)
            embeddings_out_dim += embedding.text_out_dim

        setattr(self, attr + "_out_dim", embeddings_out_dim)
        setattr(self, attr, nn.ModuleList(text_embeddings))

    def _update_text_embedding_args(self, args):
        # Add model_data_dir to kwargs
        args["model_data_dir"] = self.config["model_data_dir"]

    def _init_feature_encoders(self, attr):
        feat_encoders = []
        feat_encoders_list_config = self.config[attr + "_feature_encodings"]
        feature_dim = self.config[attr + "_feature_dim"]
        setattr(self, attr + "_feature_dim", feature_dim)

        for feat_encoder in feat_encoders_list_config:
            encoder_type = feat_encoder["type"]
            encoder_kwargs = feat_encoder["params"]
            encoder_kwargs["model_data_dir"] = self.config["model_data_dir"]

            feat_model = ImageEncoder(encoder_type, feature_dim, **encoder_kwargs)

            feat_encoders.append(feat_model)
            setattr(self, attr + "_feature_dim", feat_model.out_dim)

        setattr(self, attr + "_feature_encoders", nn.ModuleList(feat_encoders))

    def _init_feature_embeddings(self, attr):
        feature_embeddings_list = []
        num_feature_feat = len(
            getattr(self.config, "{}_feature_encodings".format(attr))
        )

        self.feature_embeddings_out_dim = 0

        for _ in range(num_feature_feat):
            feature_embeddings = []
            feature_attn_model_list = self.config[attr + "_feature_embeddings"]

            for feature_attn_model_params in feature_attn_model_list:
                feature_embedding = ImageEmbedding(
                    getattr(self, attr + "_feature_dim"),
                    self.text_embeddings_out_dim,
                    **feature_attn_model_params
                )
                feature_embeddings.append(feature_embedding)
                self.feature_embeddings_out_dim += feature_embedding.out_dim

            feature_embeddings = nn.ModuleList(feature_embeddings)
            feature_embeddings_list.append(feature_embeddings)

        self.feature_embeddings_out_dim *= getattr(self, attr + "_feature_dim")

        setattr(
            self, attr + "_feature_embeddings_out_dim", self.feature_embeddings_out_dim
        )
        del self.feature_embeddings_out_dim
        setattr(
            self,
            attr + "_feature_embeddings_list",
            nn.ModuleList(feature_embeddings_list),
        )

    def _get_embeddings_attr(self, attr):
        embedding_attr1 = attr
        if hasattr(self, attr + "_embeddings_out_dim"):
            embedding_attr1 = attr + "_embeddings_out_dim"
        else:
            embedding_attr1 = attr + "_feature_embeddings_out_dim"

        return embedding_attr1

    def _init_combine_layer(self, attr1, attr2):
        config_attr = attr1 + "_" + attr2 + "_modal_combine"

        multi_modal_combine_layer = ModalCombineLayer(
            self.config[config_attr]["type"],
            getattr(self, self._get_embeddings_attr(attr1)),
            getattr(self, self._get_embeddings_attr(attr2)),
            **self.config[config_attr]["params"]
        )

        setattr(
            self,
            attr1 + "_" + attr2 + "_multi_modal_combine_layer",
            multi_modal_combine_layer,
        )

    def _init_classifier(self, combined_embedding_dim):
        # TODO: Later support multihead
        num_choices = registry.get(self._datasets[0] + "_num_final_outputs")

        self.classifier = ClassifierLayer(
            self.config["classifier"]["type"],
            in_dim=combined_embedding_dim,
            out_dim=num_choices,
            **self.config["classifier"]["params"]
        )

    def _init_extras(self):
        self.inter_model = None

    def get_optimizer_parameters(self, config):
        combine_layer = self.image_text_multi_modal_combine_layer
        params = [
            {"params": self.word_embedding.parameters()},
            {"params": self.image_feature_embeddings_list.parameters()},
            {"params": self.text_embeddings.parameters()},
            {"params": combine_layer.parameters()},
            {"params": self.classifier.parameters()},
            {
                "params": self.image_feature_encoders.parameters(),
                "lr": (config["optimizer_attributes"]["params"]["lr"] * 0.1),
            },
        ]

        return params

    def _get_classifier_input_dim(self):
        return self.image_text_multi_modal_combine_layer.out_dim

    def process_text_embedding(
        self, sample_list, embedding_attr="text_embeddings", info=None
    ):
        text_embeddings = []

        # Get "text" attribute in case of "text_embeddings" case
        # and "context" attribute in case of "context_embeddings"
        texts = getattr(sample_list, embedding_attr.split("_")[0])

        # Get embedding models
        text_embedding_models = getattr(self, embedding_attr)

        for text_embedding_model in text_embedding_models:
            # TODO: Move this logic inside
            if isinstance(text_embedding_model, PreExtractedEmbedding):
                embedding = text_embedding_model(sample_list.question_id)
            else:
                embedding = text_embedding_model(texts)
            text_embeddings.append(embedding)

        # # visualize decomposed question attention
        # image_id = getattr(sample_list, "image_id")
        # question_id = getattr(sample_list, "question_id").cpu()
        # question_id = question_id.numpy()
        # batch_size_t, _, _ = text_embeddings[0][7].shape
        # for cnt in range(0, batch_size_t):
        #     # image_path_org = './save/temp_check/'+question_id[cnt]+'image_id.pdh'
        #     # torch.save(image_id[cnt], image_path_org)
        #     attn_path_org = './save/temp_check/'+str(question_id[cnt])+'_a_o.pdh'
        #     torch.save(text_embeddings[0][7][cnt], attn_path_org)
        #     attn_path_org = './save/temp_check/'+str(question_id[cnt])+'_a_oo.pdh'
        #     torch.save(text_embeddings[0][8][cnt], attn_path_org)
        #     attn_path_org = './save/temp_check/'+str(question_id[cnt])+'_a_ot.pdh'
        #     torch.save(text_embeddings[0][9][cnt], attn_path_org)
        #     attn_path_org = './save/temp_check/'+str(question_id[cnt])+'_a_t.pdh'
        #     torch.save(text_embeddings[0][10][cnt], attn_path_org)
        #     attn_path_org = './save/temp_check/'+str(question_id[cnt])+'_a_tt.pdh'
        #     torch.save(text_embeddings[0][11][cnt], attn_path_org)
        #     attn_path_org = './save/temp_check/'+str(question_id[cnt])+'_a_to.pdh'
        #     torch.save(text_embeddings[0][12][cnt], attn_path_org)
        return text_embeddings[0][0], text_embeddings[0][1], text_embeddings[0][2], text_embeddings[0][3], text_embeddings[0][4], text_embeddings[0][5], text_embeddings[0][6]

    def process_feature_embedding(
        self, attr, sample_list, s_central, 
        s_homo=None, s_hetero=None, pre_ques_embed=None,
        obj_feats=None, ocr_feats=None
    ):
        """
        parameters:

        input: 
        attr: "image" or "context"
        sample_list: just sample_list
        s_central: question features for guiding purpose, torch.Size([128, 2048])
                   s_o/s_t
        s_homo: s_oo/s_tt
        s_hetero: s_ot/s_to

        output:
        """
        # add obj bbox feats and image size        
        batch, bbox_num, obj_feat_dim = obj_feats.shape
        _, _, ocr_feat_dim = ocr_feats.shape
        knn_k = 5
        loc_dim = 5
        # expand obj_feats
        temp_expand_obj_feat = obj_feats[0][0]
        temp_expand_obj_feat = temp_expand_obj_feat.expand(batch,1,obj_feat_dim)*0
        temp_expand_obj_feat = torch.cat((obj_feats,temp_expand_obj_feat),1)
                    
        # expand ocr_feats
        temp_expand_ocr_feat = ocr_feats[0][0]
        temp_expand_ocr_feat = temp_expand_ocr_feat.expand(batch,1,ocr_feat_dim)*0
        temp_expand_ocr_feat = torch.cat((ocr_feats,temp_expand_ocr_feat),1)
       
        if attr == 'image':
            batch_size_t = ( sample_list.get_batch_size() )
            # Get "image_feature_0"
            feature = getattr(
                sample_list, "{}_feature_{:d}".format(attr, 0), None
            )
            feature = feature[:batch_size_t]
            # Get info related to the current feature. info is generally
            # in key of format "image_info_0" for 0th feature
            feature_info = getattr(sample_list, "{}_info_{:d}".format(attr, 0), {})
            # For Pythia, we need max_features to mask attention
            feature_dim = getattr(feature_info, "max_features", None)
            if feature_dim is not None:
                feature_dim = feature_dim[:batch_size_t]
            # Get feature embedding
            feature_embedding_model = getattr(self, attr + "_feature_embedding")
            encoded_feature = obj_feats
            batch, bbox_num, obj_feat_dim = encoded_feature.shape

            # obj_obj_edge_feature = None
            # oo edge generation
            obj_obj_edge_feature = torch.zeros((batch, bbox_num, knn_k, obj_feat_dim+loc_dim)).float()
            obj_obj_edge_feature = obj_obj_edge_feature.cuda()
            oo_edge = getattr(getattr(sample_list, "ocr_bbox"), "edge_oo")
            oo_edgefeats = getattr(getattr(sample_list, "ocr_bbox"), "edge_oofeats")
            for i in range (batch):
                obj_obj_edge_feature[i] = torch.cat((oo_edgefeats[i], temp_expand_obj_feat[i][oo_edge[i]]),2)
            
            # obj_text_edge_feature = None
            # ot edge generation
            obj_text_edge_feature = torch.zeros((batch, bbox_num, knn_k, ocr_feat_dim+loc_dim)).float()
            obj_text_edge_feature = obj_text_edge_feature.cuda()
            ot_edge = getattr(getattr(sample_list, "ocr_bbox"), "edge_ot")
            ot_edgefeats = getattr(getattr(sample_list, "ocr_bbox"), "edge_otfeats")
            for i in range (batch):
                obj_text_edge_feature[i] = torch.cat((ot_edgefeats[i], temp_expand_ocr_feat[i][ot_edge[i]]),2)

            oo_edge_feature = obj_obj_edge_feature
            ot_edge_feature = obj_text_edge_feature
            
            s_o, s_oo, s_ot = s_central, s_homo, s_hetero
            # for ablation study purpose, 
            # o feature + oo relation + ot relation
            if (s_oo is not None) and (oo_edge_feature is not None) and (s_ot is not None) and (ot_edge_feature is not None) and (pre_ques_embed is not None):
                inp = (attr, encoded_feature, s_o, feature_dim, s_oo, oo_edge_feature, s_ot, ot_edge_feature,pre_ques_embed)
            # o feature + oo relation
            elif (s_oo is not None) and (oo_edge_feature is not None) and (pre_ques_embed is not None):
                inp = (attr, encoded_feature, s_o, feature_dim, s_oo, oo_edge_feature, pre_ques_embed)
            # o feature + ot relation
            elif (s_ot is not None) and (ot_edge_feature is not None) and (pre_ques_embed is not None):
                inp = (attr, encoded_feature, s_o, feature_dim, s_ot, ot_edge_feature,pre_ques_embed)
            # o feature only
            else: inp = (attr, encoded_feature, s_o, feature_dim)
            
            g_o = feature_embedding_model(*inp)
            return g_o

        elif attr == 'context':
            batch_size_t = ( sample_list.get_batch_size() )
            # Get "context_feature_0"
            feature = getattr(
                sample_list, "{}_feature_{:d}".format(attr, 0), None
            )
            feature = feature[:batch_size_t]
            # Get info related to the current feature. info is generally
            # in key of format "image_info_0" for 0th feature
            feature_info = getattr(sample_list, "{}_info_{:d}".format(attr, 0), {})
            # For Pythia, we need max_features to mask attention
            feature_dim = getattr(feature_info, "max_features", None)
            if feature_dim is not None:
                feature_dim = feature_dim[:batch_size_t]
            # Get feature embedding
            feature_embedding_model = getattr(self, "context_feature_embedding")
            encoded_feature = ocr_feats
            batch, bbox_num, _ = encoded_feature.shape
            
            # text_text_edge_feature = None
            # tt edge generation
            text_text_edge_feature = torch.zeros((batch, bbox_num, knn_k, ocr_feat_dim+loc_dim)).float()
            text_text_edge_feature = text_text_edge_feature.cuda()
            tt_edge = getattr(getattr(sample_list, "ocr_bbox"), "edge_tt")
            tt_edgefeats = getattr(getattr(sample_list, "ocr_bbox"), "edge_ttfeats")
            for i in range (batch):
                text_text_edge_feature[i] = torch.cat((tt_edgefeats[i], temp_expand_ocr_feat[i][tt_edge[i]]),2)

            # text_obj_edge_feature = None
            # to edge generation
            text_obj_edge_feature = torch.zeros((batch, bbox_num, knn_k, obj_feat_dim+loc_dim)).float()
            text_obj_edge_feature = text_obj_edge_feature.cuda()
            to_edge = getattr(getattr(sample_list, "ocr_bbox"), "edge_to")
            to_edgefeats = getattr(getattr(sample_list, "ocr_bbox"), "edge_tofeats")
            for i in range (batch):
                text_obj_edge_feature[i] = torch.cat((to_edgefeats[i], temp_expand_obj_feat[i][to_edge[i]]),2)

            tt_edge_feature = text_text_edge_feature
            to_edge_feature = text_obj_edge_feature
            
            s_t, s_tt, s_to = s_central, s_homo, s_hetero
            # for ablation study purpose
            # t feature + tt relation + to relation
            if (s_tt is not None) and (tt_edge_feature is not None) and (s_to is not None) and (to_edge_feature is not None) and (pre_ques_embed is not None):
                inp = (attr, encoded_feature, s_t, feature_dim, s_tt, tt_edge_feature, s_to, to_edge_feature,pre_ques_embed)
            # t feature + tt relation
            elif (s_tt is not None) and (tt_edge_feature is not None) and (pre_ques_embed is not None):
                inp = (attr, encoded_feature, s_t, feature_dim, s_tt, tt_edge_feature, pre_ques_embed)
            # t feature + to relation
            elif (s_to is not None) and (to_edge_feature is not None) and (pre_ques_embed is not None):
                inp = (attr, encoded_feature, s_t, feature_dim, s_to, to_edge_feature,pre_ques_embed)
            # t feature only
            else:
                inp = (attr, encoded_feature, s_t, feature_dim)

            g_t, updated_ocr = feature_embedding_model(*inp)
            return g_t, updated_ocr

    def combine_embeddings(self, *args):
        feature_names = args[0]
        feature_embeddings = args[1]

        layer = "_".join(feature_names) + "_multi_modal_combine_layer"
        return getattr(self, layer)(*feature_embeddings)

    def calculate_logits(self, joint_embedding, **kwargs):
        return self.classifier(joint_embedding)

    def forward(self, sample_list):
        sample_list.text = self.word_embedding(sample_list.text)
        text_embedding_total = self.process_text_embedding(sample_list)

        image_embedding_total, _ = self.process_feature_embedding(
            "image", sample_list, text_embedding_total
        )

        if self.inter_model is not None:
            image_embedding_total = self.inter_model(image_embedding_total)

        joint_embedding = self.combine_embeddings(
            ["image", "text"], [image_embedding_total, text_embedding_total]
        )

        model_output = {"scores": self.calculate_logits(joint_embedding)}

        return model_output