In [1]:
import torch
import torch.nn as nn
from transformers import BertModel
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

class HierarchicalBERTSentiment(nn.Module):
    def __init__(self, hidden_dim=768, num_classes=3, num_layers=4):
        super(HierarchicalBERTSentiment, self).__init__()
        
        # Load BERT with hidden state outputs for layer aggregation
        self.bert = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True)
        
        # Define the number of layers to aggregate
        self.num_layers = num_layers  # Use top 4 or 8 layers, for example
        
        # Attention layers for word-level and sentence-level attention
        self.word_attention = nn.Linear(hidden_dim, 1)
        self.sentence_attention = nn.Linear(hidden_dim, 1)
        
        # Final classification layer
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def aggregate_hidden_layers(self, hidden_states):
        # Select and sum the top `num_layers` hidden states
        selected_layers = hidden_states[-self.num_layers:]  # Select last `num_layers` layers
        agg_embedding = torch.stack(selected_layers).sum(dim=0)  # Sum across layers
        return agg_embedding[:, 0, :]  # Use the [CLS] token embedding (first token)

    def word_level_attention(self, word_embeddings):
        # Compute attention weights and get sentence embedding
        attention_scores = self.word_attention(word_embeddings)  # Shape: [num_words, 1]
        attention_weights = torch.softmax(attention_scores, dim=0)  # Normalize scores
        weighted_sum = torch.sum(attention_weights * word_embeddings, dim=0)  # Weighted sum of words
        return weighted_sum  # Shape: [hidden_dim]

    def sentence_level_attention(self, sentence_embeddings):
        # Compute attention weights and get document embedding
        attention_scores = self.sentence_attention(sentence_embeddings)  # Shape: [num_sentences, 1]
        attention_weights = torch.softmax(attention_scores, dim=0)  # Normalize scores
        weighted_sum = torch.sum(attention_weights * sentence_embeddings, dim=0)  # Weighted sum of sentences
        return weighted_sum  # Shape: [hidden_dim]

    def forward(self, documents):
        sentence_embeddings = []
        
        for sentence in documents:  # Process each sentence in the document
            # Tokenize and pass sentence through BERT
            inputs = tokenizer(sentence, return_tensors='pt', truncation=True, padding=True)
            outputs = self.bert(**inputs)
            
            # Aggregate hidden layers
            hidden_states = outputs.hidden_states
            word_embeddings = self.aggregate_hidden_layers(hidden_states)  # Shape: [num_words, hidden_dim]
            
            # Apply word-level attention to get sentence embedding
            sentence_embedding = self.word_level_attention(word_embeddings)
            sentence_embeddings.append(sentence_embedding)
        
        # Stack sentence embeddings for document-level processing
        sentence_embeddings = torch.stack(sentence_embeddings)  # Shape: [num_sentences, hidden_dim]
        
        # Apply sentence-level attention to get document embedding
        document_embedding = self.sentence_level_attention(sentence_embeddings)
        
        # Pass document embedding to classifier
        logits = self.classifier(document_embedding.unsqueeze(0))  # Shape: [1, num_classes]
        return logits


  from .autonotebook import tqdm as notebook_tqdm
