In [None]:
import torch
from torch import nn
from transformers import BertModel, BertTokenizer, PreTrainedModel

class HLABertForSequenceClassification(PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config, add_pooling_layer=False)
        self.num_labels = config.num_labels
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        # Learnable weights for layer aggregation
        self.layer_weights = nn.Parameter(torch.ones(config.num_hidden_layers) / config.num_hidden_layers)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True)

        hidden_states = torch.stack(outputs.hidden_states, dim=0)  # Shape: [num_layers, batch_size, seq_len, hidden_size]

        weighted_hidden_states = torch.einsum("l,bsh->bsh", self.layer_weights, hidden_states)
        cls_token = weighted_hidden_states[:, 0, :]

        cls_token = self.dropout(cls_token)
        logits = self.classifier(cls_token)

        loss = None
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)

        return {"loss": loss, "logits": logits}


  from .autonotebook import tqdm as notebook_tqdm


# BERT ANALYSIS