In [17]:
from transformers.models.bart.modeling_bart import BartDecoder, shift_tokens_right
from transformers import BartConfig
import torch.nn as nn

class TransformerDecoder(nn.Module):
    def __init__(self, num_class, h_size=768, layers=2, heads=12):
        config = {'vocab_size':num_class+3, 
                  'd_model':768, 
                  'decoder_layers':layers,
                  'decoder_attention_heads':heads,
                  'decoder_start_token_id':num_class,
                  'forced_eos_token_id':num_class+1,
                  'pad_token_id':num_class+2}
        config = BartConfig(**config)

        #save special attributes
        self.start_tok = num_class
        self.end_tok = num_class+1
        self.pad_tok = num_class+2
        
        #model
        super().__init__()
        self.decoder = BartDecoder(config)
        self.classifier = nn.Linear(h_size, num_class)

    def forward(self, encoder_H, encoder_mask, labels):
        labels = torch.roll(labels, 1, -1)    #roll labels to use previous
        labels[:, 0] = self.start_tok         #set start token

        H_dec = self.decoder.forward(
                input_ids=labels,
                attention_mask=None,
                encoder_hidden_states=encoder_H,
                encoder_attention_mask=encoder_mask,
                return_dict=True)
        H_dec = H_dec.last_hidden_state
        y = self.classifier(H_dec)
        return y

T = TransformerDecoder(10)

import torch
label_ids = torch.randint(0, 10, (4,10))
encoder_H = torch.rand((4,10,768))

x = T(encoder_H, None, label_ids)
print(x.shape)

torch.Size([4, 10, 10])
