#Downloads

In [None]:
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
# Authenticate and create the PyDrive client.
# This only needs to be done once per notebook.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

In [None]:
file_id = '1-F07uiQtPtlISfml0Y_0xFxgirEmiXuq' # URL id.
downloaded = drive.CreateFile({'id': file_id})
downloaded.GetContentFile('mask-rcnn-predict_pkl.zip')

In [None]:
file_id = '13Yq7zielAiyBZJu6OjIrj8orlPTQ0OQd' # URL id.
downloaded = drive.CreateFile({'id': file_id})
downloaded.GetContentFile('text_feature_bert_ltfeat.zip')

In [None]:
file_id = '1wSLMZ-Qjoe6EoxYvW_4GBwwaXCE9xvf6' # URL id.
downloaded = drive.CreateFile({'id': file_id})
downloaded.GetContentFile('OpenCQA_Graph.zip')

In [None]:
file_id = '1Zo7t0j2jZ2jzDa5Y0h7cCq5D6YVGKrKx' # URL id.
downloaded = drive.CreateFile({'id': file_id})
downloaded.GetContentFile('OpenCQA_Graph_6_rels_pie.zip')

In [None]:
file_id = '1xYcRy1EMQF1Iyj0ZbIyEc-2b377n0KP4' # URL id.
downloaded = drive.CreateFile({'id': file_id})
downloaded.GetContentFile('unichart_patch_object_pred_ae.zip')

In [None]:
file_id = '1_15A3I1he-SH_yS0vkZInDDEZakz-VLb' # URL id.
downloaded = drive.CreateFile({'id': file_id})
downloaded.GetContentFile('unichart_patch_len_pred_ae.zip')

In [None]:
!unzip mask-rcnn-predict_pkl.zip

In [None]:
!unzip text_feature_bert_ltfeat.zip

In [None]:
!unzip OpenCQA_Graph.zip

In [None]:
!unzip OpenCQA_Graph_6_rels_pie.zip

In [None]:
!unzip unichart_patch_object_pred_ae.zip

In [None]:
!unzip unichart_patch_len_pred_ae.zip

In [None]:
!git clone https://github.com/vis-nlp/OpenCQA.git

In [None]:
!pip install sentencepiece

In [None]:
!pip install wandb

In [None]:
!pip install sacrebleu
!pip install sacremoses

# Utils

In [None]:
import torch
torch.manual_seed(42)

In [None]:
import os

def create_folder_if_not_exists(folder_path):
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
        print(f"Folder '{folder_path}' created successfully.")
    else:
        print(f"Folder '{folder_path}' already exists.")

In [None]:
import json
import pandas as pd
import pickle

In [None]:
import torch

In [None]:
import torch.nn as nn

In [None]:
def evaluate(model, dataloader, evaluator, criteria='bleu'):
    with torch.no_grad():
        quesid2ans = {}
        from tqdm.autonotebook import tqdm
        with tqdm(range(len(dataloader))) as pbar:
          for i, batch in enumerate(dataloader):
              ques_ids = batch.pop('question_ids')
              batch = {k: v.to(device) for k, v in batch.items()}
              results = model.test_step(batch)

              pred_ans = results['pred_ans']

              for qid, ans in zip(ques_ids, pred_ans):
                  quesid2ans[qid] = ans

              pbar.update(1)

    qid2ans_list = [quesid2ans]
    quesid2ans = {}
    for qid2ans in qid2ans_list:
      for k, v in qid2ans.items():
        quesid2ans[k] = v
    return evaluator.evaluate_raw(quesid2ans, criteria=criteria)

In [None]:
def pad_zeros(matrix, desired_shape):
  if matrix.shape[0] == 0:
      return torch.zeros((desired_shape))
  # Calculate the amount of padding needed for each dimension
  rows_padding = desired_shape[0] - matrix.shape[0]
  cols_padding = desired_shape[1] - matrix.shape[1]

  # Pad the original tensor with zeros using torch
  return torch.nn.functional.pad(matrix, (0, cols_padding, 0, rows_padding), value=0)

In [None]:
def csv_to_textdf(src):
  df = pd.read_csv(src)
  # Initialize an empty list to store the concatenated rows
  concatenated_rows = []
  headers = list(df.columns)

  concatenated_rows.append(' | '.join(headers))
  # Iterate through each row of the DataFrame
  for index, row in df.iterrows():
      # Concatenate the cell values in the row with the '|' separator
      concatenated_row = ' | '.join(str(cell) for cell in row)
      # Append the concatenated row to the list
      concatenated_rows.append(concatenated_row)

  # Join all the concatenated rows into a single long string
  return ' & '.join(concatenated_rows)

In [None]:
import re
import numpy as np
import torch
import torch.distributed as dist
import collections
import logging

def get_area(pos):
    """
    Args
        pos: [B, N, 4]
            (x1, x2, y1, y2)

    Return
        area : [B, N]
    """
    # [B, N]
    height = pos[:, :, 3] - pos[:, :, 2]
    width = pos[:, :, 1] - pos[:, :, 0]
    area = height * width
    return area

def get_relative_distance(pos):
    """
    Args
        pos: [B, N, 4]
            (x1, x2, y1, y2)

    Return
        out : [B, N, N, 4]
    """
    # B, N = pos.size()[:-1]

    # [B, N, N, 4]
    relative_distance = pos.unsqueeze(1) - pos.unsqueeze(2)

    return relative_distance


class LossMeter(object):
    def __init__(self, maxlen=100):
        """Computes and stores the running average"""
        self.vals = collections.deque([], maxlen=maxlen)

    def __len__(self):
        return len(self.vals)

    def update(self, new_val):
        self.vals.append(new_val)

    @property
    def val(self):
        return sum(self.vals) / len(self.vals)

    def __repr__(self):
        return str(self.val)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def set_global_logging_level(level=logging.ERROR, prefices=[""]):
    """
    Override logging levels of different modules based on their name as a prefix.
    It needs to be invoked after the modules have been loaded so that their loggers have been initialized.

    Args:
        - level: desired level. e.g. logging.INFO. Optional. Default is logging.ERROR
        - prefices: list of one or more str prefices to match (e.g. ["transformers", "torch"]). Optional.
          Default is `[""]` to match all active loggers.
          The match is a case-sensitive `module_name.startswith(prefix)`
    """
    prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })')
    for name in logging.root.manager.loggerDict:
        if re.match(prefix_re, name):
            logging.getLogger(name).setLevel(level)


def get_iou(anchors, gt_boxes):
    """
    anchors: (N, 4) torch floattensor
    gt_boxes: (K, 4) torch floattensor
    overlaps: (N, K) ndarray of overlap between boxes and query_boxes
    """
    N = anchors.size(0)

    if gt_boxes.size() == (4,):
        gt_boxes = gt_boxes.view(1, 4)
    K = gt_boxes.size(0)

    gt_boxes_area = (
        (gt_boxes[:, 2] - gt_boxes[:, 0] + 1) *
        (gt_boxes[:, 3] - gt_boxes[:, 1] + 1)
    ).view(1, K)

    anchors_area = (
        (anchors[:, 2] - anchors[:, 0] + 1) *
        (anchors[:, 3] - anchors[:, 1] + 1)
    ).view(N, 1)

    boxes = anchors.view(N, 1, 4).expand(N, K, 4)
    query_boxes = gt_boxes.view(1, K, 4).expand(N, K, 4)

    iw = (
        torch.min(boxes[:, :, 2], query_boxes[:, :, 2])
        - torch.max(boxes[:, :, 0], query_boxes[:, :, 0])
        + 1
    )
    iw[iw < 0] = 0

    ih = (
        torch.min(boxes[:, :, 3], query_boxes[:, :, 3])
        - torch.max(boxes[:, :, 1], query_boxes[:, :, 1])
        + 1
    )
    ih[ih < 0] = 0

    ua = anchors_area + gt_boxes_area - (iw * ih)
    overlaps = iw * ih / ua

    return overlaps


def xywh_to_xyxy(boxes):
    """Convert [x y w h] box format to [x1 y1 x2 y2] format."""
    return np.hstack((boxes[:, 0:2], boxes[:, 0:2] + boxes[:, 2:4] - 1))

## Graph Layers

In [None]:
import math

import torch

from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module


class GraphConvolution(Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj, dis):
        support = torch.matmul(input, self.weight)
        if dis is None:
          output = torch.matmul(adj, support)
        else:
          output = torch.matmul(adj*dis, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

In [None]:
import torch.nn as nn
import torch.nn.functional as F


class GCN(nn.Module):
    def __init__(self, nfeat, nhid, ofeat, dropout):
        super(GCN, self).__init__()

        self.ofeat = ofeat
        self.dropout = dropout
        self.gc1 = GraphConvolution(nfeat, nhid)
        self.gc2 = GraphConvolution(nhid, ofeat)

    def forward(self, x, adj, dis):
        x = F.relu(self.gc1(x, adj, dis))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj, dis)
        return x

In [None]:
class GraphFuse(nn.Module):
    def __init__(self):
        super(GraphFuse, self).__init__()
        self.fg_gcn = GCN(768,768,768, 0.2)
        self.sem_gcn = GCN(768,768,768, 0.2)
        self.fc1 = nn.Linear(768*2, 768)
        self.fc2 = nn.Linear(768, 768)

    def forward(self, ful_adj, sem_adj, ful_weights, sem_weights, vis_feat, text_feat, batch_size, graph_mask):
        full_feature = self.fg_gcn(vis_feat, ful_adj, ful_weights).clone()
        sem_feature = self.sem_gcn(text_feat, sem_adj, None).clone()

        sem_feature = sem_feature * graph_mask
        out = torch.cat((full_feature, sem_feature[:, :full_feature.shape[1]]), dim=2)
        out = F.relu(self.fc1(out))
        out = F.dropout(out, 0.1, training=self.training)
        out = F.relu(self.fc2(out))
        return out

## VL-T5

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from transformers import T5Tokenizer, T5TokenizerFast, PreTrainedTokenizer, PreTrainedTokenizerFast, PreTrainedTokenizerBase
import re
import sentencepiece as spm

class VLT5Tokenizer(T5Tokenizer):

    # vocab_files_names = VOCAB_FILES_NAMES
    # pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    # max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    # model_input_names = ["attention_mask"]

    def __init__(
        self,
        vocab_file,
        eos_token="</s>",
        unk_token="<unk>",
        pad_token="<pad>",
        extra_ids=100,
        vis_extra_ids=100,
        additional_special_tokens=None,
        **kwargs
    ):
        # Add extra_ids to the special token list
        if extra_ids > 0 and additional_special_tokens is None:
            additional_special_tokens = ["<extra_id_{}>".format(i) for i in range(extra_ids)]
        elif extra_ids > 0 and additional_special_tokens is not None:
            # Check that we have the right number of extra_id special tokens
            extra_tokens = len(set(filter(lambda x: bool("extra_id" in x), additional_special_tokens)))
            if extra_tokens != extra_ids:
                raise ValueError(
                    f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are provided to T5Tokenizer. "
                    "In this case the additional_special_tokens must include the extra_ids tokens"
                )
        self.sp_model = spm.SentencePieceProcessor()
        self.sp_model.Load(vocab_file)
        if vis_extra_ids > 0:
            additional_special_tokens.extend(["<vis_extra_id_{}>".format(i) for i in range(vis_extra_ids)])

        self.vocab_file = vocab_file
        self._extra_ids = extra_ids
        self._vis_extra_ids = vis_extra_ids

        PreTrainedTokenizer.__init__(
            self,
            eos_token=eos_token,
            unk_token=unk_token,
            pad_token=pad_token,
            extra_ids=extra_ids,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )

    @property
    def vocab_size(self):
        return self.sp_model.get_piece_size() + self._extra_ids + self._vis_extra_ids

    def get_vocab(self):
        vocab = {self.convert_ids_to_tokens(
            i): i for i in range(self.vocab_size)}
        vocab.update(self.added_tokens_encoder)
        return vocab

    def _convert_token_to_id(self, token):
        """ Converts a token (str) in an id using the vocab. """
        if token.startswith("<extra_id_"):
            match = re.match(r"<extra_id_(\d+)>", token)
            num = int(match.group(1))
            return self.vocab_size - num - 1 - self._vis_extra_ids
        elif token.startswith("<vis_extra_id_"):
            match = re.match(r"<vis_extra_id_(\d+)>", token)
            num = int(match.group(1))
            return self.vocab_size - num - 1
        return self.sp_model.piece_to_id(token)

    def _convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        if index < self.sp_model.get_piece_size():
            token = self.sp_model.IdToPiece(index)
        else:
            if index > self.sp_model.get_piece_size() + self._extra_ids - 1:
                token = "<vis_extra_id_{}>".format(self.vocab_size - 1 - index)
            else:
                token = "<extra_id_{}>".format(self.vocab_size - self._vis_extra_ids - 1 - index)
        return token


# Below are for Rust-based Fast Tokenizer

from transformers.convert_slow_tokenizer import SpmConverter
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
from typing import Any, Dict, List, Optional, Tuple, Union


class VLT5Converter(SpmConverter):
    def vocab(self, proto):
        vocab = [(piece.piece, piece.score) for piece in proto.pieces]
        num_extra_ids = self.original_tokenizer._extra_ids
        vocab += [("<extra_id_{}>".format(i), 0.0)
                  for i in range(num_extra_ids - 1, -1, -1)]

        num_vis_extra_ids = self.original_tokenizer._vis_extra_ids
        vocab += [("<vis_extra_id_{}>".format(i), 0.0)
                  for i in range(num_vis_extra_ids - 1, -1, -1)]

        return vocab

    def post_processor(self):
        return processors.TemplateProcessing(
            single=["$A", "</s>"],
            pair=["$A", "</s>", "$B", "</s>"],
            special_tokens=[
                ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
            ],
        )


def convert_slow_vlt5tokenizer(vlt5tokenizer):
    return VLT5Converter(vlt5tokenizer).converted()


class VLT5TokenizerFast(T5TokenizerFast):

    # vocab_files_names = VOCAB_FILES_NAMES
    # pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    # max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    # model_input_names = ["attention_mask"]
    slow_tokenizer_class = VLT5Tokenizer

    prefix_tokens: List[int] = []

    def __init__(
        self,
        vocab_file,
        tokenizer_file=None,
        eos_token="</s>",
        unk_token="<unk>",
        pad_token="<pad>",
        extra_ids=100,
        vis_extra_ids=100,
        additional_special_tokens=None,
        **kwargs
    ):
        # Add extra_ids to the special token list
        if extra_ids > 0 and additional_special_tokens is None:
            additional_special_tokens = ["<extra_id_{}>".format(i) for i in range(extra_ids)]
        elif extra_ids > 0 and additional_special_tokens is not None:
            # Check that we have the right number of extra_id special tokens
            extra_tokens = len(set(filter(lambda x: bool("extra_id" in x), additional_special_tokens)))
            if extra_tokens != extra_ids:
                raise ValueError(
                    f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are provided to T5Tokenizer. "
                    "In this case the additional_special_tokens must include the extra_ids tokens"
                )

        if vis_extra_ids > 0:
            additional_special_tokens.extend(["<vis_extra_id_{}>".format(i) for i in range(vis_extra_ids)])

        slow_tokenizer = self.slow_tokenizer_class(
            vocab_file,
            tokenizer_file=tokenizer_file,
            eos_token=eos_token,
            unk_token=unk_token,
            pad_token=pad_token,
            extra_ids=extra_ids,
            vis_extra_ids=vis_extra_ids,
            # additional_special_tokens=additional_special_tokens,
            **kwargs
        )
        fast_tokenizer = convert_slow_vlt5tokenizer(slow_tokenizer)
        self._tokenizer = fast_tokenizer

        PreTrainedTokenizerBase.__init__(
            self,
            tokenizer_file=tokenizer_file,
            eos_token=eos_token,
            unk_token=unk_token,
            pad_token=pad_token,
            extra_ids=extra_ids,
            vis_extra_ids=vis_extra_ids,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )

        self.vocab_file = vocab_file
        self._extra_ids = extra_ids
        self._vis_extra_ids = vis_extra_ids

In [None]:
from dataclasses import dataclass

from transformers.models.t5.modeling_t5 import (
    T5Stack, T5Block, T5LayerNorm, T5LayerSelfAttention, T5LayerFF, T5LayerCrossAttention,
    T5PreTrainedModel, T5ForConditionalGeneration
)

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss

from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
import copy

from transformers.modeling_outputs import ModelOutput, BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput
from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from transformers.utils import logging
from transformers import BeamScorer, BeamSearchScorer

# from utils import *

logger = logging.get_logger(__name__)


class VisualEmbedding(nn.Module):
    def __init__(self, config, obj_order_embedding):
        super().__init__()
        self.config = config
        feat_dim = config.feat_dim
        pos_dim = config.pos_dim
        # n_objs = config.n_objs
        n_images = config.n_images

        if self.config.individual_vis_layer_norm:

            # Object feature encoding
            feat_embedding = [nn.Linear(feat_dim, config.d_model)]
            if self.config.use_vis_layer_norm:
                feat_embedding.append(T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon))
            self.feat_embedding = nn.Sequential(*feat_embedding)

            # self.relative_vis_pos_embedding = nn.Linear(pos_dim + 1, config.num_heads)
            absolute_vis_pos_embedding = [nn.Linear(pos_dim + 1, config.d_model)]
            if self.config.use_vis_layer_norm:
                absolute_vis_pos_embedding.append(T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon))
            self.absolute_vis_pos_embedding = nn.Sequential(*absolute_vis_pos_embedding)
            # self.absolute_vis_pos_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)

            if self.config.use_vis_order_embedding:
                # self.obj_order_embedding = nn.Embedding(n_objs, config.d_model)
                self.obj_order_embedding = obj_order_embedding
                self.img_order_embedding = nn.Embedding(n_images, config.d_model)

        else:
            # Object feature encoding
            feat_embedding = [nn.Linear(feat_dim, config.d_model)]
            # if self.config.use_vis_layer_norm:
            #     feat_embedding.append(T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon))
            self.feat_embedding = nn.Sequential(*feat_embedding)

            # self.relative_vis_pos_embedding = nn.Linear(pos_dim + 1, config.num_heads)
            absolute_vis_pos_embedding = [nn.Linear(pos_dim + 1, config.d_model)]
            # if self.config.use_vis_layer_norm:
            #     absolute_vis_pos_embedding.append(T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon))
            self.absolute_vis_pos_embedding = nn.Sequential(*absolute_vis_pos_embedding)
            # self.absolute_vis_pos_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)

            if self.config.use_vis_order_embedding:
                # self.obj_order_embedding = nn.Embedding(n_objs, config.d_model)
                self.obj_order_embedding = obj_order_embedding
                self.img_order_embedding = nn.Embedding(n_images, config.d_model)

            if self.config.use_vis_layer_norm:
                self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)

    def get_area(self, pos):
        """
        Args
            pos: [B, N, 4]
                (x1, x2, y1, y2)
        Return
            area : [B, N]
        """
        # [B, N]
        height = pos[:, :, 3] - pos[:, :, 2]
        width = pos[:, :, 1] - pos[:, :, 0]
        area = height * width
        return area


    def forward(self, feats, pos, img_order_ids=None, obj_order_ids=None):
        """
        Args
            feats: [B, N, feat_dim]
            pos: [B, N, 4]
                (x1, x2, y1, y2)
        Return
            relative_vis_pos_embedding: [B, N, N, n_heads]
            absolute_vis_pos_embedding: # [B, N, d_model]
        """

        B, N, _ = feats.size()
        assert pos.size() == (B, N, 4)

        feat_embedding = self.feat_embedding(feats)

        device = feats.device
        dtype = feats.dtype

        area = self.get_area(pos).unsqueeze(2) # [B, N, 1]
        pos = torch.cat([pos, area], dim=2) # [B, N, 5]

        # [B, N, d_model]
        absolute_vis_pos_embedding = self.absolute_vis_pos_embedding(pos)
        # absolute_vis_pos_embedding = self.absolute_vis_pos_layer_norm(absolute_vis_pos_embedding)


        if self.config.use_vis_order_embedding:
            if img_order_ids is None:
                img_order_ids = torch.zeros(N, dtype=torch.long, device=device)
                img_order_ids = img_order_ids.unsqueeze(0) #.expand(B, -1)
            img_order_embedding = self.img_order_embedding(img_order_ids)

            if obj_order_ids is None:
                obj_order_ids = torch.arange(N, dtype=torch.long, device=device)
                obj_order_ids = obj_order_ids.unsqueeze(0) #.expand(B,-1)
            # assert obj_order_ids.max().item() < 32200, obj_order_ids
            obj_order_ids = self.obj_order_embedding.num_embeddings - obj_order_ids - 1
            obj_order_embedding = self.obj_order_embedding(obj_order_ids)

            vis_embedding = feat_embedding + absolute_vis_pos_embedding + \
                img_order_embedding + obj_order_embedding

        else:
            vis_embedding = feat_embedding + absolute_vis_pos_embedding

        if not self.config.individual_vis_layer_norm:
            if self.config.use_vis_layer_norm:
                vis_embedding = self.layer_norm(vis_embedding)

        return vis_embedding


class JointEncoder(T5Stack):
    def __init__(self, config, embed_tokens=None):
        super(T5Stack, self).__init__(config)
        self.config = config

        self.embed_tokens = embed_tokens
        self.is_decoder = self.config.is_decoder
        assert self.config.is_decoder is False

        self.visual_embedding = VisualEmbedding(self.config, embed_tokens)

        self.block = nn.ModuleList(
            [T5Block(config, has_relative_attention_bias=(i == 0))
                for i in range(config.num_layers)]
        )
        self.final_layer_norm = T5LayerNorm(
            config.d_model, eps=config.layer_norm_epsilon)
        self.dropout = nn.Dropout(config.dropout_rate)

        self.init_weights()
        self.model_parallel = False
        self.device_map = None

    def set_input_embeddings(self, new_embeddings):
        self.embed_tokens = new_embeddings
        self.visual_embedding.obj_order_embedding = new_embeddings

    def forward(
        self,
        input_ids=None,
        attention_mask=None,

        vis_inputs=None,
        vis_attention_mask=None,

        inputs_embeds=None,
        head_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):

        if inputs_embeds is None:
            assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
            inputs_embeds = self.embed_tokens(input_ids)

        B, L = inputs_embeds.size()[:-1]

        vis_feats = vis_inputs[0]
        boxes = vis_inputs[1]
        img_order_ids = None
        obj_order_ids = None
        if len(vis_inputs) >= 3:
            img_order_ids = vis_inputs[2]
        if len(vis_inputs) == 4:
            obj_order_ids = vis_inputs[3]

        vis_embeds = self.visual_embedding(
            vis_feats, boxes, img_order_ids, obj_order_ids)

        V_L = vis_embeds.size(1)
        inputs_embeds = torch.cat([inputs_embeds, vis_embeds], dim=1)
        if attention_mask is None:
            attention_mask = input_ids.ne(self.config.pad_token_id).to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)

        if vis_attention_mask is None:
            vis_attention_mask = attention_mask.new_ones(B, V_L)

        attention_mask = torch.cat([attention_mask, vis_attention_mask], dim=1)

        # ourselves in which case we just need to make it broadcastable to all heads.
        extended_attention_mask = self.get_extended_attention_mask(
            attention_mask,
            (B, L+V_L),
            inputs_embeds.device)

        # initialize past_key_values with `None` if past does not exist
        if past_key_values is None:
            past_key_values = [None] * len(self.block)

        # Prepare head mask if needed
        head_mask = self.get_head_mask(head_mask, self.config.num_layers)
        present_key_value_states = () if use_cache else None
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
        all_cross_attentions = () if (output_attentions and self.is_decoder) else None
        # position_bias = None
        # encoder_decoder_position_bias = None

        hidden_states = self.dropout(inputs_embeds)

        if self.config.num_layers > 0:

            assert self.block[0].layer[0].SelfAttention.has_relative_attention_bias

            seq_length = L + V_L
            q_len = seq_length
            k_len = seq_length

            # [1, n_heads, Q_len, K_len]
            text_position_bias = self.block[0].layer[0].SelfAttention.compute_bias(
                L, L)
            num_heads = text_position_bias.size(1)
            position_bias = text_position_bias.new_zeros(
                1, num_heads, seq_length, seq_length)
            position_bias[:, :, :L, :L] = text_position_bias

            # print('position_bias size', position_bias.size())
            # print('attention_mask size', attention_mask.size())
            # print('extended_attention_mask size', extended_attention_mask.size())
            # relative position bias only between Text <-> Text
            # no relative position bias Text -> Vision
            # no relative position bias Vision -> Text
            # no relative position bias Vision <-> Vision
            # position_bias[:, :, L:, :] = 0
            # position_bias[:, :, :, L:] = 0
            position_bias = position_bias + extended_attention_mask

            for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):

                # if output_hidden_states:
                #     all_hidden_states = all_hidden_states + (hidden_states,)

                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask=extended_attention_mask,
                    position_bias=position_bias,
                    encoder_hidden_states=None,
                    encoder_attention_mask=None,
                    encoder_decoder_position_bias=None,
                    layer_head_mask=head_mask[i],
                    past_key_value=past_key_value,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                )
                if use_cache is None:
                  layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
                # layer_outputs is a tuple with:
                # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
                hidden_states, present_key_value_state = layer_outputs[:2]

                # We share the position biases between the layers - the first layer store them
                # layer_outputs = hidden-states, key-value-states (self-attention weights),
                # (self-attention position bias), (cross-attention weights), (cross-attention position bias)
                position_bias = layer_outputs[2]

                # append next layer key value states
                if use_cache:
                    present_key_value_states = present_key_value_states + \
                        (present_key_value_state,)

                # if output_attentions:
                #     all_attentions = all_attentions + (layer_outputs[3],)
                #     if self.is_decoder:
                #         all_cross_attentions = all_cross_attentions + \
                #             (layer_outputs[5],)

        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.dropout(hidden_states)

        # Add last layer
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    present_key_value_states,
                    all_hidden_states,
                    all_attentions,
                    all_cross_attentions,
                ]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=present_key_value_states,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
            cross_attentions=all_cross_attentions,
        )


class VLT5(T5ForConditionalGeneration):
    _keys_to_ignore_on_load_missing = [
        r"encoder\.embed_tokens\.weight",
        r"decoder\.embed_tokens\.weight",
        r"lm_head\.weight",
    ]
    _keys_to_ignore_on_load_unexpected = [
        r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
    ]

    def __init__(self, config):
        super(T5ForConditionalGeneration, self).__init__(config)

        self.config = config

        self.model_dim = config.d_model

        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
        encoder_config.is_decoder = False
        encoder_config.use_cache = False
        encoder_config.is_encoder_decoder = False

        #---- Modified ----#
        # self.encoder = T5Stack(encoder_config, self.shared)
        self.encoder = JointEncoder(encoder_config, self.shared)
        #------------------#

        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
        decoder_config.is_encoder_decoder = False

        self.decoder = T5Stack(decoder_config, self.shared)

        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)

        self.init_weights()
        self.gf = GraphFuse()
        self.proj = nn.Linear(768, 768)
        self.norm = nn.LayerNorm(768)
        # Model parallel
        self.model_parallel = False
        self.device_map = None

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
        self.encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)

    def extend_vocab(self, vocab_size):

        new_shared = nn.Embedding(vocab_size, self.config.d_model)
        old_weight = self.shared.weight.data.detach().clone()
        old_vocab_size = old_weight.size(0)
        new_shared.weight.data[:old_vocab_size, :] = old_weight
        self.shared = new_shared

        new_lm_head = nn.Linear(self.config.d_model, vocab_size, bias=False)
        old_weight = self.lm_head.weight.data.detach().clone()
        old_vocab_size = old_weight.size(0)
        new_lm_head.weight.data[:old_vocab_size, :] = old_weight
        self.lm_head = new_lm_head

        self.vis_encoder.visual_embedding.obj_order_embedding = self.shared

        self.encoder.embed_tokens = self.shared
        self.decoder.embed_tokens = self.shared

        self.lm_head.weight = self.shared.weight

        self.config.vocab_size = vocab_size
        self.encoder.config.vocab_size = vocab_size
        self.vis_encoder.config.vocab_size = vocab_size
        self.decoder.config.vocab_size = vocab_size


    # @add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING)
    # @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        encoder_outputs=None,

        vis_inputs=None,
        vis_attention_mask=None,

        decoder_input_ids=None,
        decoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        labels=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        head_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        reduce_loss=False,
        return_hidden_state=False,
        full_adj = None,
        sem_adj = None,
        full_weights = None,
        rtexts_feats = None,
        graph_mask = None,
        **kwargs,
    ):

        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if encoder_outputs is None:

            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,

                vis_inputs=vis_inputs,
                vis_attention_mask=vis_attention_mask,

                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )

        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(
                    encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(
                    encoder_outputs) > 2 else None,
            )

        hidden_states = encoder_outputs[0]
        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

        # If decoding with past key value states, only the last tokens
        # should be given as an input
        if past_key_values is not None:
            assert labels is None, "Decoder should not use cached key value states when training."
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids[:, -1:]
            if decoder_inputs_embeds is not None:
                decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]

        if attention_mask is None:
            attention_mask = input_ids.ne(self.config.pad_token_id).to(dtype=hidden_states.dtype, device=hidden_states.device)
        if vis_attention_mask is None:
            B, L = attention_mask.size()
            V_L = encoder_outputs[0].size(1) - L
            vis_attention_mask = attention_mask.new_ones(B, V_L)
        encoder_attention_mask = torch.cat([attention_mask, vis_attention_mask], dim=1)

        obj_features = hidden_states[:, -36:]
        gf_feats = self.gf(full_adj, sem_adj, full_weights, None, obj_features, self.proj(rtexts_feats), full_adj.shape[0],graph_mask)
        gf_feats_input = torch.zeros_like(hidden_states)
        gf_feats_input[:, -36:] = gf_feats
        hidden_states += gf_feats_input

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,

            encoder_hidden_states=hidden_states,
            encoder_attention_mask=encoder_attention_mask,

            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # print('decoder_outputs')
        # print(decoder_outputs)

        sequence_output = decoder_outputs[0]

        assert self.config.tie_word_embeddings is True

        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim ** -0.5)

        if return_hidden_state:
            return sequence_output

        lm_logits = self.lm_head(sequence_output)

        loss = None
        if labels is not None:
            # loss_fct = CrossEntropyLoss(ignore_index=-100)
            # loss = loss_fct(
            #     lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666

            if reduce_loss:
                loss_fct = CrossEntropyLoss(ignore_index=-100)
            else:
                loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')
            loss = loss_fct(
                lm_logits.view(-1, lm_logits.size(-1)),
                labels.view(-1))

            # print('loss')
            # print(loss)

        # if not return_dict:
        #     output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
        #     return ((loss,) + output) if loss is not None else output

        return VLSeq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_last_hidden_state=decoder_outputs.last_hidden_state,
            decoder_hidden_states=decoder_outputs.hidden_states,
            # decoder_attentions=decoder_outputs.attentions,
            # encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            # encoder_hidden_states=encoder_outputs.hidden_states,
            # encoder_attentions=encoder_outputs.attentions,
            # vis_encoder_last_hidden_state=vis_encoder_outputs.last_hidden_state,
            # vis_encoder_hidden_states=vis_encoder_outputs.hidden_states,
            # vis_encoder_attentions=vis_encoder_outputs.attentions,
            # cross_encoder_outputs=cross_encoder_outputs
        )

    def prepare_inputs_for_generation(
        self, input_ids, past=None, attention_mask=None, use_cache=None,
        encoder_outputs=None,
        full_adj = None,
        sem_adj = None,
        full_weights = None,
        rtexts_feats = None,
        graph_mask = None,
        **kwargs):

        # cut decoder_input_ids if past is used
        if past is not None:
            input_ids = input_ids[:, -1:]

        output = {
            "decoder_input_ids": input_ids,
            "past_key_values": past,
            "encoder_outputs": encoder_outputs,
            "attention_mask": attention_mask,
            "use_cache": use_cache,
            "full_adj": full_adj,
            "sem_adj": sem_adj,
            "full_weights": full_weights,
            "rtexts_feats": rtexts_feats,
            "graph_mask": graph_mask,
        }

        if 'vis_attention_mask' in kwargs:
            output['vis_attention_mask'] = kwargs['vis_attention_mask']

        return output

    @staticmethod
    def _expand_inputs_for_generation(
        input_ids: torch.LongTensor,
        expand_size: int = 1,
        is_encoder_decoder: bool = False,
        attention_mask: torch.LongTensor = None,
        encoder_outputs: ModelOutput = None,
        **model_kwargs
    ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
        expanded_return_idx = (
            torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1,
                                                                expand_size).view(-1).to(input_ids.device)
        )
        input_ids = input_ids.index_select(0, expanded_return_idx)

        if "token_type_ids" in model_kwargs:
            token_type_ids = model_kwargs["token_type_ids"]
            model_kwargs["token_type_ids"] = token_type_ids.index_select(
                0, expanded_return_idx)

        if attention_mask is not None:
            model_kwargs["attention_mask"] = attention_mask.index_select(
                0, expanded_return_idx)

        if model_kwargs.get("vis_attention_mask", None) is not None:
            model_kwargs['vis_attention_mask'] = model_kwargs['vis_attention_mask'].index_select(
                0, expanded_return_idx)

        if is_encoder_decoder:
            assert encoder_outputs is not None
            encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
                0, expanded_return_idx
            )
            model_kwargs["encoder_outputs"] = encoder_outputs

        full_adj = model_kwargs['full_adj']
        model_kwargs['full_adj'] = full_adj.index_select(0, expanded_return_idx)

        sem_adj = model_kwargs['sem_adj']
        model_kwargs['sem_adj'] = sem_adj.index_select(0, expanded_return_idx)

        full_weights = model_kwargs['full_weights']
        model_kwargs['full_weights'] = full_weights.index_select(0, expanded_return_idx)

        rtexts_feats = model_kwargs['rtexts_feats']
        model_kwargs['rtexts_feats'] = rtexts_feats.index_select(0, expanded_return_idx)

        graph_mask = model_kwargs['graph_mask']
        model_kwargs['graph_mask'] = graph_mask.index_select(0, expanded_return_idx)

        return input_ids, model_kwargs


@dataclass
class VLSeq2SeqLMOutput(ModelOutput):
    """
    Base class for sequence-to-sequence language models outputs.

    Args:
        loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
            Languaged modeling loss.
        logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
            List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`,  with each tensor of shape
            :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).

            Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
            used (see ``past_key_values`` input) to speed up sequential decoding.
        decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
        decoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
            self-attention heads.
        encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
            Sequence of hidden-states at the output of the last layer of the encoder of the model.
        encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
        encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
            self-attention heads.
    """

    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    past_key_values: Optional[List[torch.FloatTensor]] = None
    decoder_last_hidden_state: Optional[Tuple[torch.FloatTensor]] = None
    decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
    encoder_last_hidden_state: Optional[torch.FloatTensor] = None
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None

    vis_encoder_last_hidden_state: Optional[torch.FloatTensor] = None
    vis_encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    vis_encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None

    # cross_encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None

### VL-T5

In [None]:
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class VLT5VQA(VLT5):
    def __init__(self, config, tokenizer, num_answers=None, label2ans=None):
        super().__init__(config)

        if config.classifier:
            self.answer_head = nn.Sequential(
                nn.Linear(config.d_model, config.d_model * 2),
                nn.GELU(),
                nn.LayerNorm(config.d_model * 2),
                nn.Linear(config.d_model * 2, num_answers)
            )

        self.num_answers = num_answers
        self.label2ans = label2ans
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.tokenizer = tokenizer

    def train_step(self, batch):
        device = next(self.parameters()).device
        vis_feats = batch['vis_feats']
        input_ids = batch['input_ids']
        vis_pos = batch['boxes']
        full_adj = batch['full_adj']
        sem_adj = batch['sem_adj']
        full_weights = batch['full_weights']
        rtexts_feats = batch['rtexts_feats']
        graph_mask = batch['graph_mask']

        if self.config.classifier:
            B = len(input_ids)

            decoder_input_ids = torch.ones(
                B, 1, dtype=torch.long, device=device) * self.config.decoder_start_token_id

            output = self(
                input_ids=input_ids,
                vis_inputs=(vis_feats, vis_pos),
                decoder_input_ids=decoder_input_ids,
                output_hidden_states=True,
                return_dict=True
            )
            target = batch['targets'].to(device)

            last_layer_hidden_state = output.decoder_hidden_states[-1]
            last_hidden_state = last_layer_hidden_state.view(B, -1, self.config.d_model)[:, -1]

            # [B, num_answers]
            logit = self.answer_head(last_hidden_state)

            loss = self.bce_loss(logit, target)

        else:
            lm_labels = batch["target_ids"].to(device)

            output = self(
                input_ids=input_ids,
                vis_inputs=(vis_feats, vis_pos),
                labels=lm_labels,
                return_dict=True,
                full_adj = full_adj,
                sem_adj = sem_adj,
                full_weights = full_weights,
                rtexts_feats = rtexts_feats,
                graph_mask = graph_mask,
            )
            assert 'loss' in output

            lm_mask = (lm_labels != -100).float()
            B, L = lm_labels.size()

            loss = output['loss']

            loss = loss.view(B, L) * lm_mask

            loss = loss.sum(dim=1) / lm_mask.sum(dim=1).clamp(min=1)  # B

            loss = loss * batch['scores'].to(device=device)

            loss = loss.mean()

        result = {
            'loss': loss
        }

        return result

    @torch.no_grad()
    def test_step(self, batch, **kwargs):
        self.eval()
        device = next(self.parameters()).device
        vis_feats = batch['vis_feats']
        input_ids = batch['input_ids']
        vis_pos = batch['boxes']
        full_adj = batch['full_adj']
        sem_adj = batch['sem_adj']
        full_weights = batch['full_weights']
        rtexts_feats = batch['rtexts_feats']
        graph_mask = batch['graph_mask']
        result = {}
        if self.config.classifier:
            B = len(input_ids)

            decoder_input_ids = torch.ones(
                B, 1, dtype=torch.long, device=device) * self.config.decoder_start_token_id

            output = self(
                input_ids=input_ids,
                vis_inputs=(vis_feats, vis_pos),
                decoder_input_ids=decoder_input_ids,
                output_hidden_states=True,
                return_dict=True
            )

            last_layer_hidden_state = output.decoder_hidden_states[-1]
            last_hidden_state = last_layer_hidden_state.view(B, -1, self.config.d_model)[:, -1]

            # [B, num_answers]
            logit = self.answer_head(last_hidden_state)

            score, pred_ans_id = logit.max(1)
            pred_ans_id = pred_ans_id.cpu().numpy()
            pred_ans = [self.label2ans[ans_id] for ans_id in pred_ans_id]

            result['pred_ans'] = pred_ans

        else:
            output = self.generate(
                input_ids=input_ids,
                vis_inputs=(vis_feats, vis_pos),
                num_beams = 4,
                max_length = 128,
                full_adj = full_adj,
                sem_adj = sem_adj,
                full_weights = full_weights,
                rtexts_feats = rtexts_feats,
                graph_mask = graph_mask,
                **kwargs
            )
            generated_sents = self.tokenizer.batch_decode(output, skip_special_tokens=True)
            result['token_ids'] = output
            result['pred_ans'] = generated_sents

        return result

In [None]:
def create_config():
    from transformers import T5Config, BartConfig
    config_class = T5Config
    config = config_class.from_pretrained('t5-base')

    config.feat_dim = 2048
    config.pos_dim = 4
    config.n_images = 2

    config.use_vis_order_embedding = True

    config.dropout_rate = 0.1
    config.dropout = 0.1
    config.attention_dropout = 0.1
    config.activation_dropout = 0.1

    config.use_vis_layer_norm = True
    config.individual_vis_layer_norm = True
    config.losses = 'lm,obj,attr,feat'

    config.share_vis_lang_layer_norm = True
    config.classifier = False

    return config

In [None]:
def dump_result(quesid2ans: dict, path):
        """
        Dump results to a json file, which could be submitted to the VQA online evaluation.
        VQA json file submission requirement:
            results = [result]
            result = {
                "question_id": int,
                "answer": str
            }
        :param quesid2ans: dict of quesid --> ans
        :param path: The desired path of saved file.
        """
        with open(path, 'w') as f:
            result = []
            for ques_id, ans in quesid2ans.items():
                result.append({
                    'question_id': ques_id,
                    'answer': ans
                })
            json.dump(result, f, indent=4, sort_keys=True)

In [None]:
config = create_config()

In [None]:
from transformers import T5Tokenizer, BartTokenizer, T5TokenizerFast, BartTokenizerFast

In [None]:
vlt5_tokenizer = VLT5TokenizerFast.from_pretrained(
    't5-base',
    do_lower_case=True,
    return_tensors = 'pt'
    )

In [None]:
model = VLT5VQA.from_pretrained('t5-base', config=config, tokenizer=vlt5_tokenizer)

In [None]:
model.resize_token_embeddings(vlt5_tokenizer.vocab_size)

In [None]:
from pprint import pprint

In [None]:
def load_checkpoint(model, ckpt_path):
    state_dict = load_state_dict(ckpt_path, 'cpu')

    original_keys = list(state_dict.keys())
    for key in original_keys:
        if key.startswith("vis_encoder."):
            new_key = 'encoder.' + key[len("vis_encoder."):]
            state_dict[new_key] = state_dict.pop(key)

        if key.startswith("model.vis_encoder."):
            new_key = 'model.encoder.' + key[len("model.vis_encoder."):]
            state_dict[new_key] = state_dict.pop(key)

    results = model.load_state_dict(state_dict, strict=False)
    print('Model loaded from ', ckpt_path)
    pprint(results)

## Dataset

In [None]:
import json

In [None]:
from torch.utils.data import Dataset, DataLoader
import random
class GraphDataset(Dataset):
    def __init__(self, df, pad_len, tokenizer):
        self.df = df
        self.pad_len = pad_len
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        item = self.df.loc[idx]
        sem_data = pickle.load(open(item['graph_sem'], 'rb'))
        base_data = pickle.load(open(item['graph_base'], 'rb'))
        feats = pickle.load(open(item['feature'], 'rb'))
        text_data = pickle.load(open(item['text'], 'rb'))

        full_weights = torch.tensor(base_data['full_weights'])[:self.pad_len, :self.pad_len]
        full_adj = torch.tensor(base_data['normalized_adj'])[:self.pad_len, :self.pad_len]
        sem_adj = torch.tensor(sem_data['normalized_adj'])
        vis_feats = torch.tensor(feats['visual_feats'], dtype=torch.float32)[:self.pad_len]

        bboxes = torch.tensor(feats['bboxes'], dtype=torch.float32)[:self.pad_len]
        bboxes = pad_zeros(bboxes, (self.pad_len, 4))

        label_feats = torch.tensor(text_data['label_feature'], dtype=torch.float32)
        rtexts_feats = torch.tensor(text_data['text_feature'], dtype=torch.float32)
        text_feats = torch.cat((label_feats, rtexts_feats), dim=0)

        graph_mask = torch.zeros(sem_adj.shape[0], 768)
        graph_mask[:full_adj.shape[0]] = 1

        input_ids = self.tokenizer(item['input_texts'], truncation=True, max_length=512)['input_ids']
        target_ids = self.tokenizer(item['answers'])['input_ids']

        # Extract corresponding rows from the two matrices using the selected row indices
        full_weights = pad_zeros(full_weights, (self.pad_len,self.pad_len))
        full_adj = pad_zeros(full_adj, (self.pad_len,self.pad_len))
        vis_feats = pad_zeros(vis_feats, (self.pad_len, 2048))
        question_ids = item['question_ids']

        input = {'full_weights': full_weights, 'full_adj': full_adj, 'sem_adj': sem_adj, 'scores': [1], 'rtexts_feats': text_feats, 'vis_feats': vis_feats,
                 'boxes': bboxes, 'input_ids': input_ids, 'target_ids': target_ids, 'question_ids': question_ids,
                 'graph_mask': graph_mask}

        return input

In [None]:
def collator(batch):
  input_batch = {'full_weights': None, 'full_adj': None, 'sem_adj': None, 'input_ids': [], 'target_ids': [], 'scores': [], 'graph_mask': None,
                 'rtexts_feats': None, 'vis_feats': None, 'boxes': None, 'question_ids': []}

  max_input = max([len(b['input_ids']) for b in batch])
  max_target = max([len(b['target_ids']) for b in batch])
  max_texts = max(max([b['rtexts_feats'].shape[0] for b in batch]), 36)
  max_mask = max(max([b['graph_mask'].shape[0] for b in batch]), 36)
  max_sem = max(max([b['sem_adj'].shape[0] for b in batch]), 36)

  for b in batch:
    if input_batch['full_weights'] is None:
      input_batch['full_weights'] = b['full_weights'].unsqueeze(0)
    else:
      input_batch['full_weights'] = torch.cat((input_batch['full_weights'], b['full_weights'].unsqueeze(0)))

    if input_batch['full_adj'] is None:
      input_batch['full_adj'] = b['full_adj'].unsqueeze(0)
    else:
      input_batch['full_adj'] = torch.cat((input_batch['full_adj'], b['full_adj'].unsqueeze(0)))

    padded_graph_mask = pad_zeros(b['graph_mask'], (max_mask, 768)).unsqueeze(0)
    if input_batch['graph_mask'] is None:
      input_batch['graph_mask'] = padded_graph_mask
    else:
      input_batch['graph_mask'] = torch.cat((input_batch['graph_mask'], padded_graph_mask))

    padded_sem_adj = pad_zeros(b['sem_adj'], (max_sem, max_sem)).unsqueeze(0)
    if input_batch['sem_adj'] is None:
      input_batch['sem_adj'] = padded_sem_adj
    else:
      input_batch['sem_adj'] = torch.cat((input_batch['sem_adj'], padded_sem_adj))

    padded_rtexts_feats = pad_zeros(b['rtexts_feats'], (max_texts, 768)).unsqueeze(0)
    if input_batch['rtexts_feats'] is None:
      input_batch['rtexts_feats'] = padded_rtexts_feats
    else:
      input_batch['rtexts_feats'] = torch.cat((input_batch['rtexts_feats'], padded_rtexts_feats))

    if input_batch['vis_feats'] is None:
      input_batch['vis_feats'] = b['vis_feats'].unsqueeze(0)
    else:
      input_batch['vis_feats'] = torch.cat((input_batch['vis_feats'], b['vis_feats'].unsqueeze(0)))

    if input_batch['boxes'] is None:
      input_batch['boxes'] = b['boxes'].unsqueeze(0)
    else:
      input_batch['boxes'] = torch.cat((input_batch['boxes'], b['boxes'].unsqueeze(0)))

    input_batch['input_ids'].append(b['input_ids'] + [0]*(max_input-len(b['input_ids'])))
    input_batch['target_ids'].append(b['target_ids'] + [0]*(max_target-len(b['target_ids'])))
    input_batch['scores'].append(b['scores'])
    input_batch['question_ids'].append(b['question_ids'])

  input_batch['vis_feats'][input_batch['vis_feats'] == float("Inf")] = 0
  input_batch['boxes'][input_batch['boxes'] == float("Inf")] = 0
  input_batch['boxes'][input_batch['boxes'] == float("-Inf")] = 0
  input_batch['input_ids'] = torch.tensor(input_batch['input_ids'])
  input_batch['target_ids'] = torch.tensor(input_batch['target_ids'])
  input_batch['scores'] = torch.tensor(input_batch['scores'])
  word_mask = input_batch['target_ids'] != 0
  input_batch['target_ids'][~word_mask] = -100
  return input_batch

In [None]:
def create_df(split):
  pairs = json.load(open(f'/content/OpenCQA/etc/data/{split}.json'))
  question_ids = []
  queries = []
  answers = []
  graph_bases = []
  graph_sem = []
  text_paths = []
  feature_paths = []
  input_texts = []
  for uid, pair in pairs.items():
    question_ids.append(uid)
    queries.append(pair[3])
    answers.append(str(pair[-2]))
    graph_bases.append(f'/content/OpenCQA_Graph/{uid}.pkl')
    graph_sem.append(f'/content/OpenCQA_Graph_6_rels/{uid}.pkl')
    text_paths.append(f'/content/text_feature_bert/{uid}.pkl')
    feature_paths.append(f'/content/mask-rcnn-predict_pkl/{uid}.pkl')
    ocr_data = json.load(open(f'/content/OpenCQA/bboxes/{uid}.json'))
    ocr_text = '|'.join([s['sentence'] for s in ocr_data])
    input_texts.append(pair[3] + ' <SEP> ' + pair[1] + ' <SEP> ' + ocr_text + '<SEP>')
  return pd.DataFrame({'queries': queries, 'answers': answers, 'question_ids': question_ids, 'graph_base': graph_bases, 'graph_sem': graph_sem, 'feature': feature_paths, 'input_texts': input_texts, 'text': text_paths})

##Evaluator

In [None]:
from nltk.translate.bleu_score import sentence_bleu
from nltk.tokenize import word_tokenize
from sacrebleu.metrics import BLEU, CHRF, TER
from sacremoses import MosesPunctNormalizer, MosesTokenizer, MosesDetokenizer
import pandas as pd

In [None]:
import csv
import json
from statistics import mean, stdev
import sys
import re

In [None]:
class VQAEvaluator:
    def __init__(self, df):
        # Loading datasets to data
        '''instances = pd.read_csv(src_folder + "data.csv")
        self.instances = instances
        self.inputs = instances["Input"].values
        self.outputs = None
        if "Output" in instances:
            self.outputs = instances["Output"].values
        self.images_indices = instances['Image Index'].values
        self.questions_ids = instances['Question ID'].values
        self.src_folder = src_folder'''

        self.qidtoans = {}
        # Iterate through rows using iterrows()
        for index, row in df.iterrows():
            self.qidtoans[row['question_ids']] = row['answers']
        """https://github.com/GT-Vision-Lab/VQA/blob/master/PythonEvaluationTools/vqaEvaluation/vqaEval.py"""

        self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", \
							 "couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", \
							 "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", \
							 "he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", \
							 "Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", \
							 "maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", \
							 "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", \
							 "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", \
							 "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", \
							 "somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", \
							 "somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", \
							 "someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", \
							 "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", \
							 "there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", \
							 "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", \
							 "wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", \
							 "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", \
							 "whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", \
							 "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", \
							 "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", \
							 "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", \
							 "youll": "you'll", "youre": "you're", "youve": "you've"}

        self.manualMap    = { 'none': '0',
							  'zero': '0',
							  'one': '1',
							  'two': '2',
							  'three': '3',
							  'four': '4',
							  'five': '5',
							  'six': '6',
							  'seven': '7',
							  'eight': '8',
							  'nine': '9',
							  'ten': '10'
							}

        self.articles     = ['a',
							 'an',
							 'the'
							]

        self.periodStrip  = re.compile("(?!<=\d)(\.)(?!\d)")
        self.commaStrip   = re.compile("(\d)(\,)(\d)")
        self.punct        = [';', r"/", '[', ']', '"', '{', '}',
							 '(', ')', '=', '+', '\\', '_', '-',
							 '>', '<', '@', '`', ',', '?', '!']

        self.n = 2

    def dump_result(self, quesid2ans: dict, path):
        """
        Dump results to a json file, which could be submitted to the VQA online evaluation.
        VQA json file submission requirement:
            results = [result]
            result = {
                "question_id": int,
                "answer": str
            }
        :param quesid2ans: dict of quesid --> ans
        :param path: The desired path of saved file.
        """
        with open(path, 'w') as f:
            result = []
            for ques_id, ans in quesid2ans.items():
                result.append({
                    'question_id': ques_id,
                    'answer': ans
                })
            json.dump(result, f, indent=4, sort_keys=True)

    def evaluate_raw(self, quesid2ans: dict, is_topk_optimal=None, criteria='bleu'):
        """https://github.com/GT-Vision-Lab/VQA/blob/master/PythonEvaluationTools/vqaEvaluation/vqaEval.py"""

        # gts = self.dataset.id2datum_gt

        self.accuracy     = {}
        self.evalQA       = {}
        self.evalQuesType = {}
        self.evalAnsType  = {}

        accQA = []
        accQuesType = {}
        accAnsType = {}

        # print("Computing accuracy")

        if criteria == 'bleu':
            mpn = MosesPunctNormalizer()
            mt = MosesTokenizer(lang="en")
            md = MosesDetokenizer(lang="en")

            model_output_summary = []
            for quesId, resAns in tqdm(quesid2ans.items(), total=len(quesid2ans), ncols=80):
                model_output_summary.append(self.normalize_answer(resAns))

            test_summary = list(self.qidtoans.values())

            def detokenize(sent):
                sent = mpn.normalize(sent)
                tokens = mt.tokenize(sent)
                return md.detokenize(tokens)

            model_output_summary = list(map(detokenize, model_output_summary))
            test_summary = list(map(detokenize, test_summary))

            bleu = BLEU()
            bleuscore = bleu.corpus_score(model_output_summary, [test_summary]).score

            self.setAccuracy(bleuscore)

            return self.accuracy, model_output_summary
        elif criteria == 'cs':
            fillers = ['in', 'the', 'and', 'or', 'an', 'as', 'can', 'be', 'a', ':', '-',
           'to', 'but', 'is', 'of', 'it', 'on', '.', 'at', '(', ')', ',', ';']

            count = 0

            generatedScores = []
            #baselineScores = []
            untemplatedScores = [1,1]

            gen_file = []
            for quesId, resAns in tqdm(quesid2ans.items(), total=len(quesid2ans), ncols=80):
                gen_file.append(self.normalize_answer(resAns))



            with open('/content/testData.txt', 'r', encoding='utf-8') as dataFile, open('/content/testTitles.txt', 'r', encoding='utf-8') as titleFile, \
                    open('/content/targetAnswers.txt', 'r', encoding='utf-8') as goldFile:
                for datas, titles, gold in zip(dataFile.readlines(), titleFile.readlines(), goldFile.readlines()):
                    dataArr = datas.split()
                    titleArr = titles.split()
                    goldArr = gold.split()
                    recordList = []
                    for gld in goldArr:
                        data_string = datas.replace("_", " ")
                        if gld.lower() in " ".join([data_string,titles]).lower()  and gld.lower() not in fillers and gld.lower() not in recordList:
                            recordList.append(gld.lower())
                    list1 = recordList
                    list2 = recordList
                    list3 = recordList
                    recordLength = len(recordList)
                    generatedList = []
                    summary1 = gen_file[count]


                    for token in summary1.split():
                        if token.lower() in list1:
                            list1.remove(token.lower())
                            generatedList.append(token.lower())


                    count += 1

                    if recordLength==0:
                        generatedRatio=0
                    else:
                        generatedRatio = len(generatedList) / recordLength


                    generatedScores.append(generatedRatio)

            self.setAccuracy(mean(generatedScores)*100)
            return self.accuracy, None

    def normalize_answer(self, resAns):
        resAns      = resAns.replace('<pad>', ' ')
        resAns      = resAns.replace('</s>', ' ')
        #resAns      = resAns.replace('\n', ' ')
        #resAns      = resAns.replace('\t', ' ')
        resAns      = resAns.strip()
        #resAns      = self.processPunctuation(resAns)
        #resAns      = self.processDigitArticle(resAns)
        #resAns = resAns.replace(',', '')
        return resAns

    def processPunctuation(self, inText):
        outText = inText
        for p in self.punct:
            if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None):
                outText = outText.replace(p, '')
            else:
                outText = outText.replace(p, ' ')
        outText = self.periodStrip.sub("",
                                        outText,
                                        re.UNICODE)
        return outText

    def processDigitArticle(self, inText):
        outText = []
        tempText = inText.lower().split()
        for word in tempText:
            word = self.manualMap.setdefault(word, word)
            if word not in self.articles:
                outText.append(word)
            else:
                pass
        for wordId, word in enumerate(outText):
            if word in self.contractions:
                outText[wordId] = self.contractions[word]
        outText = ' '.join(outText)
        return outText

    def setEvalQA(self, quesId, acc):
        self.evalQA[quesId] = round(100*acc, self.n)

    def setEvalQuesType(self, quesId, quesType, acc):
        if quesType not in self.evalQuesType:
            self.evalQuesType[quesType] = {}
        self.evalQuesType[quesType][quesId] = round(100*acc, self.n)

    def setEvalAnsType(self, quesId, ansType, acc):
        if ansType not in self.evalAnsType:
            self.evalAnsType[ansType] = {}
        self.evalAnsType[ansType][quesId] = round(100*acc, self.n)

    def setAccuracy(self, bleuscore):
        self.accuracy['overall'] = bleuscore
        # self.accuracy['perQuestionType'] = {quesType: round(100*float(sum(accQuesType[quesType]))/len(accQuesType[quesType]), self.n) for quesType in accQuesType}
        # self.accuracy['perAnswerType']   = {ansType:  round(100*float(sum(accAnsType[ansType]))/len(accAnsType[ansType]), self.n) for ansType in accAnsType}

    def within_percent(self, predicted, golden, tolerance=0.05):
      # Calculate the acceptable range
      tolerance = golden * tolerance

      # Check if the predicted value is within the acceptable range
      if golden - tolerance <= predicted <= golden + tolerance:
          return True
      else:
          return False

    def relaxed_correctness(self, target: str,
                        prediction: str,
                        max_relative_change: float = 0.05) -> bool:
      """Calculates relaxed correctness.

      The correctness tolerates certain error ratio defined by max_relative_change.
      See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1:
      “Following Methani et al. (2020), we use a relaxed accuracy measure for the
      numeric answers to allow a minor inaccuracy that may result from the automatic
      data extraction process. We consider an answer to be correct if it is within
      5% of the gold answer. For non-numeric answers, we still need an exact match
      to consider an answer to be correct.”

      Args:
        target: Target string.
        prediction: Predicted string.
        max_relative_change: Maximum relative change.

      Returns:
        Whether the prediction was correct given the specified tolerance.
      """
      prediction_float = self._to_float(prediction)
      target_float = self._to_float(target)
      if prediction_float is not None and target_float:
          relative_change = abs(prediction_float -
                                target_float) / abs(target_float)
          return relative_change <= max_relative_change
      else:
          return prediction.lower() == target.lower()

    def _to_float(self, text: str):
      try:
          if text.endswith('%'):
              # Convert percentages to floats.
              return float(text.rstrip('%')) / 100.0
          else:
              return float(text)
      except ValueError:
          return None

# Train

In [None]:
def load_state_dict(state_dict_path, loc='cpu'):
    state_dict = torch.load(state_dict_path, map_location=loc)
    # Change Multi GPU to single GPU
    original_keys = list(state_dict.keys())
    for key in original_keys:
        if key.startswith("module."):
            new_key = key[len("module."):]
            state_dict[new_key] = state_dict.pop(key)
    return state_dict

In [None]:
load_checkpoint(model, '/content/drive/MyDrive/VL-T5/snap/pretrain/VLT5/Epoch30.pth')

In [None]:
model.resize_token_embeddings(len(vlt5_tokenizer))

In [None]:
import pandas as pd

In [None]:
df = create_df('train')

In [None]:
train_dataset = GraphDataset(df, 36, vlt5_tokenizer)

In [None]:
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=12, collate_fn=collator, num_workers=12)

In [None]:
val_df = create_df('val')

In [None]:
val_df.head()

In [None]:
val_dataset = GraphDataset(val_df, 36, vlt5_tokenizer)

In [None]:
val_dataloader = DataLoader(val_dataset, shuffle=False, batch_size=12, collate_fn=collator, num_workers=12)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)

In [None]:
model = model.to(device)

In [None]:
torch.nn.init.xavier_uniform(model.proj.weight)

In [None]:
evaluator = VQAEvaluator(val_df)

In [None]:
import subprocess
max_acc = 0
save_folder = '/content/drive/MyDrive/PL-NL/QA/OpenCQA/models/vl_t5_after_encoder_lr5_2layers'
create_folder_if_not_exists(save_folder)
save_losses = []
save_acc = []
from tqdm.autonotebook import tqdm
for epoch in range(1, 101):
  model.train()
  total_loss = 0
  with tqdm(range(len(train_dataloader))) as pbar:
    for bidx, batch in enumerate(train_dataloader):
        batch.pop('question_ids')
        batch = {k: v.to(device) for k, v in batch.items()}
        result = model.train_step(batch)
        loss = result['loss']
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        pbar.update(1)
        total_loss += loss.detach().item()
  total_loss = total_loss / len(train_dataloader)
  save_losses.append(total_loss)
  print(f'epoch {epoch} loss: {total_loss}')
  if epoch % 10 == 0:
      model.eval()

      acc, anses = evaluate(model, val_dataloader, evaluator)
      acc = acc['overall']
      save_acc.append(acc)

      print(f'Epoch {epoch} BLEU: {acc}')
      if acc > max_acc:
        max_acc = acc
        torch.save({
              'epoch': epoch,
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'acc': acc,
              }, f'{save_folder}/best.pt')

  if (epoch) % 10 == 0:
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': save_losses,
            'acc': save_acc,
            }, f'{save_folder}/model_{epoch}.pt')

# Test

In [None]:
def load_state_dict(state_dict_path, loc='cpu'):
    state_dict = torch.load(state_dict_path, map_location=loc)['model_state_dict']
    # Change Multi GPU to single GPU
    original_keys = list(state_dict.keys())
    for key in original_keys:
        if key.startswith("module."):
            new_key = key[len("module."):]
            state_dict[new_key] = state_dict.pop(key)
    return state_dict

In [None]:
checkpoint = torch.load('/content/drive/MyDrive/PL-NL/QA/OpenCQA/models/vl_t5_after_encoder_lr5_2layers/best.pt')

In [None]:
model.resize_token_embeddings(len(vlt5_tokenizer))

In [None]:
model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
test_df = create_df('test')

In [None]:
test_df.head()

In [None]:
test_dataset = GraphDataset(test_df, 36, vlt5_tokenizer)

In [None]:
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=12, collate_fn=collator, num_workers=12)

In [None]:
model = model.to(device)

In [None]:
evaluator = VQAEvaluator(test_df)

In [None]:
max_acc = 0
from tqdm.autonotebook import tqdm
model.eval()
acc, results = evaluate(model, test_dataloader, evaluator)

In [None]:
acc