In [None]:
import torch
import torch.nn as nn
from transformers import LayoutLMv3Model, LayoutLMv3Config

class LayoutLMv3ClassificationHead(nn.Module):
    """
    Head for sentence-level classification tasks. Reference: RobertaClassificationHead
    """

    def __init__(self, config, pool_feature=False, num_labels = 0):
        super().__init__()
        self.pool_feature = pool_feature
        if pool_feature:
            self.dense = nn.Linear(config.hidden_size * 3, config.hidden_size)
        else:
            self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.out_proj = nn.Linear(config.hidden_size, num_labels)

    def forward(self, x):
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x


In [None]:
class LayoutLMv3PreTrainingLoss():
    pass


In [None]:
class LayoutLMv3BertimbauEmbModel(nn.Module):
    def __init__(self, config, bertimbau_emb_layer = None, checkpoint = None):
        super(LayoutLMv3BertimbauEmbModel, self).__init__()
        
        if bertimbau_emb_layer is None and checkpoint is None:
            raise Exception("If no word embedding layer is passed than a checkpoint must be passed.")
        

        if checkpoint is None:
            self.configuration = LayoutLMv3Config(**config)
            self.model = LayoutLMv3Model(self.configuration)
            self.model.embeddings.word_embeddings = bertimbau_emb_layer
        
        else:
            self.model = LayoutLMv3Model.from_pretrained(
                checkpoint, 
                config = LayoutLMv3Config.from_pretrained(
                    checkpoint
                )
            )

    def forward(self, processed_data_batches):
        outputs = self.model.forward(**processed_data_batches)
        return LayoutLMv3Output # Needs to output a object instantiated from a class inherited by BaseModelOutput

class LayoutLMv3forPreTraining(nn.Module):

    def __init__(self, config : dict, bertimbau_emb_layer = None, checkpoint = None, imgtok_codebook_len = 0):
        super(LayoutLMv3forPreTraining, self).__init__()
        
        self.model = LayoutLMv3BertimbauEmbModel(
            config,
            bertimbau_emb_layer,
            checkpoint
        )

        self.mlm_num_labels = self.model.config.vocab_size
        self.mim_num_labels = imgtok_codebook_len
        self.wpa_num_labels = 2

        self.mlm_head = LayoutLMv3ClassificationHead(
            self.model.config,
            num_labels=self.mlm_num_labels
        )

        self.mim_head = LayoutLMv3ClassificationHead(
            self.model.config,
            num_labels=self.mim_num_labels
        )

        self.wpa_head = LayoutLMv3ClassificationHead(
            self.model.config,
            num_labels=self.wpa_num_labels
        )
    

    def forward(self, processed_data_batches):
        outputs = self.model(**processed_data_batches)

        last_hidden_state = outputs[0]

        mlm_output = self.mlm_head(last_hidden_state)
        mim_output = self.mim_head(last_hidden_state)
        wpa_output = self.wpa_head(last_hidden_state)

        loss = None
        loss_func = LayoutLMv3PreTrainingLoss()
        loss = loss_func(
            mlm_output.view(-1, self.mlm_num_labels),
            mim_output.view(-1, self.mim_num_labels),
            wpa_output.view(-1, self.wpa_num_labels),
            self.mlm_num_labels.view(-1),
            self.mim_num_labels.view(-1),
            self.wpa_num_labels.view(-1)
        )

        
        