In [None]:
!git clone https://github.com/adamserag1/Interpretability-for-VRDU-models.git

In [None]:
!git pull https://github.com/adamserag1/Interpretability-for-VRDU-models.git

#config

In [None]:
!pwd

In [None]:
%cd /content/Interpretability-for-VRDU-models

In [None]:
!pip install -r requirements.txt

In [None]:
!pip install -U datasets

#code

In [None]:
from datasets import load_from_disk
from transformers import LayoutLMv3ForSequenceClassification, AutoProcessor
import sys
import importlib
def reload_modules():
    for module in list(sys.modules.keys()):
        if module.startswith('vrdu_utils') or module.startswith('Classification_Explain'):
            print(f"Reloading module: {module}")
            importlib.reload(sys.modules[module])

reload_modules()

from vrdu_utils.encoders import *
from Classification_Explain.lime import *
import torch
from transformers import LayoutLMv3ForSequenceClassification, AutoProcessor, BrosPreTrainedModel, BrosModel, AutoConfig, AutoTokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import warnings
from transformers import logging as hf_logging

warnings.filterwarnings(
    "ignore",
    category=FutureWarning,
    module="transformers.modeling_utils",   # the module that emits the msg
)
hf_logging.set_verbosity_error()

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!cp -r /content/drive/MyDrive/THESIS/rvl_cdip_financial_subset /content

In [None]:
class BrosForDocumentClassification(BrosPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.bros = BrosModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

    def forward(
        self,
        input_ids=None,
        bbox=None,
        attention_mask=None,
        token_type_ids=None,
        labels=None,
        **kwargs
    ):
        outputs = self.bros(
            input_ids=input_ids,
            bbox=bbox,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )

        # Use the [CLS] token's representation (first token)
        cls_output = outputs.last_hidden_state[:, 0, :]  # shape: (batch_size, hidden_size)

        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return {
            "loss": loss,
            "logits": logits,
        }

In [None]:
bros_config = AutoConfig.from_pretrained(
    "adamadam111/bros-docclass-finetuned",
    num_labels=5,
    id2label={0: "form", 1: "invoice", 2: "budget", 3: "file folder", 4: "questionnaire"},
    label2id={"form": 0, "invoice": 1, "budget": 2, "file folder": 3, "questionnaire": 4}
)

BROS = BrosForDocumentClassification.from_pretrained(
    "adamadam111/bros-docclass-finetuned",
    config=bros_config
)
BROS_t = AutoTokenizer.from_pretrained("nadamadam111/bros-docclass-finetuned",do_lower_case=True)

LLMV3 = LayoutLMv3ForSequenceClassification.from_pretrained("adamadam111/layoutlmv3-docclass-finetuned",
                                                            num_labels=5,
                                                            id2label={0: "form", 1: "invoice", 2: "budget", 3: "file folder", 4: "questionnaire"},
                                                            label2id={"form": 0, "invoice": 1, "budget": 2, "file folder": 3, "questionnaire": 4})
LLMV3_p =AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)



In [None]:
LLMV3_encode = make_layoutlmv3_encoder(LLMV3_p)
BROS_encode = make_bros_encoder(BROS_t)