In [8]:
import torch
from torch import nn
from transformers import AutoConfig, AutoModel

from typing import List

In [None]:
class BertForSequenceClassification(nn.Module):

    def __init__(self,
                 num_classes: int,
                 pretrained_model_name: str = 'bert-base-uncased'):
        super().__init__()
        config = AutoConfig.from_pretrained(
            pretrained_model_name, num_labels=num_classes)
        self.model = AutoModel.from_pretrained(
            pretrained_model_name, config=config)
        self.classifier = nn.Linear(config.hidden_size, num_classes)
        self.dropout = nn.Dropout(p=0.3)

    def forward(self, input_seqs: List[str]):

        bert_output = self.model(input_ids=input_seqs['input_ids'].squeeze(),
                                 attention_mask=input_seqs['attention_mask'].squeeze(),
                                 output_hidden_states=True,
                                 return_dict=True)
        last_hidden_state = bert_output.last_hidden_state # (bs, seq_size, hidden_size)
        sentence_vector = last_hidden_state.mean(axis=1) # (bs, hidden_size)
        subsentence_vector = self.dropout(sentence_vector)
        scores = self.classifier(subsentence_vector)
        return scores