In [50]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"

class CustomSeq2SeqModel(nn.Module):
    def __init__(self, num_categories, decoder_layers=3, device=device):
        super().__init__()
        self.model = AutoModel.from_pretrained(
            "microsoft/MiniLM-L12-H384-uncased"
        ).to(device)
        self.num_categories = num_categories
        self.device = device
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(
                d_model=388, #num_categories+self.model.config.hidden_size,
                nhead=1
            ),
            decoder_layers,
        )
        self.dense = nn.Sequential(
            nn.Linear(388,3),
            nn.Softmax(-1),
        )

    def forward(self, input_ids, category, target, attention_mask=None, token_type_ids=None):

        # Extract last hidden state embeds from transformers encoder
        input_encodings = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True)
        cls_embeds_inputs = input_encodings.hidden_states[-1][:, 0, :]

        output_encodings = self.model(**target, output_hidden_states=True)
        cls_embeds_outputs = output_encodings.hidden_states[-1][:, 0, :]
        
        # Concatinate with categorical data
        concat_inputs = torch.cat((cls_embeds_inputs, category), dim=-1)
        concat_outputs = torch.cat((cls_embeds_outputs, category), dim=-1)

        # Generate with decoder
        output = self.decoder(tgt=concat_outputs, memory=concat_inputs)
        return self.dense(output)


# Example usage
model = CustomSeq2SeqModel(num_categories=4)
# tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")

batch_size = 1
category = torch.tensor([[0,0,1,0]])
tokens_in  = {'input_ids': torch.tensor([[1,1]]), 'attention_mask': torch.tensor([[1,1]])}
tokens_out = {'input_ids': torch.tensor([[0,0,0]]), 'attention_mask': torch.tensor([[0,1,1]])}
output = model(**tokens_in, target=tokens_out, category=category)
output

tensor([[0.4038, 0.2003, 0.3960]], grad_fn=<SoftmaxBackward0>)