In [1]:
import torch
import nltk
import numpy as np
import matplotlib.pyplot as plt
from transformers import RobertaModel, RobertaTokenizerFast
import torch.nn as nn
import logging
import sys

2025-04-16 03:49:23.971135: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744775364.227323      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744775364.300826      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
# Remove all handlers associated with the root logger object
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

# Reconfigure logging to output to sys.stdout
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    stream=sys.stdout
)
logger = logging.getLogger(__name__)

In [3]:
# Define the AttentionLayer class from original code
class AttentionLayer(nn.Module):
    """Attention Layer with context vector for HAN"""
    def __init__(self, hidden_size):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh()
        )
        # Initialize context vector with more stable values
        self.context_vector = nn.Parameter(torch.zeros(hidden_size))
        nn.init.normal_(self.context_vector, std=0.02)
        
    def forward(self, sequence, mask=None):
        # sequence: [B, T, hidden]
        # mask: [B, T] boolean mask
        # Apply layernorm for numerical stability
        # Project sequence
        u = self.attention(sequence)                   # [B, T, hidden]
        
        # Compute similarity with context vector
        scores = torch.matmul(u, self.context_vector)  # [B, T]
        
        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(~mask, -1e30)  # Use -1e30 instead of -inf for stability
            
        # Normalize scores
        alpha = torch.softmax(scores, dim=1)           # [B, T]
        
        # Weighted sum
        out = torch.sum(sequence * alpha.unsqueeze(-1), dim=1)  # [B, hidden]
        
        return out, alpha

# Define the HAN_RoBERTa class from original code
class HAN_RoBERTa(nn.Module):
    def __init__(self, pretrained_model='roberta-base', word_gru_hidden=64, 
                 sent_gru_hidden=64, num_classes=1, dropout=0.2, freeze_bert=True):
        """
        Hierarchical Attention Network with RoBERTa embeddings
        """
        super().__init__()
        # RoBERTa encoder - disable the pooler to avoid warnings
        self.roberta = RobertaModel.from_pretrained(
            pretrained_model, 
            add_pooling_layer=False
        )
        hidden_size = self.roberta.config.hidden_size  # 768
        
        if freeze_bert:
            for param in self.roberta.parameters():
                param.requires_grad = False
                
        # Word-level GRU - Use 2 layers to avoid dropout warning
        self.word_gru = nn.GRU(
            input_size=hidden_size, 
            hidden_size=word_gru_hidden,
            bidirectional=True, 
            batch_first=True, 
            num_layers=2,  # Use 2 layers to allow dropout
            dropout=dropout
        )
        self.word_attention = AttentionLayer(2*word_gru_hidden)
        
        # Sentence-level GRU
        self.sent_gru = nn.GRU(
            input_size=2*word_gru_hidden, 
            hidden_size=sent_gru_hidden,
            bidirectional=True, 
            batch_first=True, 
            num_layers=2,  # Use 2 layers to allow dropout
            dropout=dropout
        )
        self.sent_attention = AttentionLayer(2*sent_gru_hidden)
        
        # Classification layer with LayerNorm for stability
        self.classifier = nn.Sequential(
            nn.LayerNorm(2*sent_gru_hidden),
            nn.Dropout(dropout),
            nn.Linear(2*sent_gru_hidden, num_classes)
        )
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        """Initialize weights to improve stability"""
        for name, param in self.named_parameters():
            if 'roberta' not in name:  # Don't reinitialize pretrained weights
                if 'weight' in name:
                    if len(param.shape) >= 2:
                        nn.init.xavier_uniform_(param)
                    else:
                        nn.init.normal_(param, mean=0.0, std=0.02)
                elif 'bias' in name:
                    nn.init.zeros_(param)

    def forward(self, input_ids, attention_mask, sent_mask):
        # input_ids, attention_mask: [B, S, T]
        B, S, T = input_ids.size()
        
        # Handle empty batches
        if B == 0:
            return torch.tensor([]), None, None
            
        # flatten sentences
        flat_input_ids = input_ids.view(B*S, T)
        flat_attn_mask = attention_mask.view(B*S, T)
        
        # get token embeddings
        outputs = self.roberta(
            input_ids=flat_input_ids, 
            attention_mask=flat_attn_mask,
            output_attentions=False
        )
        H = outputs.last_hidden_state                    # [B*S, T, hidden]
        
        # word-level encoding
        word_enc_out, _ = self.word_gru(H)               # [B*S, T, 2*word_hidden]
        s, word_alpha = self.word_attention(word_enc_out, flat_attn_mask.bool())
        
        # reshape sentence vectors
        s = s.view(B, S, -1)                             # [B, S, 2*word_hidden]
        
        # sentence-level encoding
        sent_enc_out, _ = self.sent_gru(s)               # [B, S, 2*sent_hidden]
        v, sent_alpha = self.sent_attention(sent_enc_out, sent_mask)
        
        # classification
        logits = self.classifier(v).squeeze(-1)          # [B]
        
        return logits, word_alpha.view(B, S, T), sent_alpha

# Function for tokenization and preprocessing
def collate_fn(batch, tokenizer, max_tokens=50):
    """
    Collate function to tokenize all sentences in batch
    """
    all_labels = []
    # collect per-sentence encodings
    docs_input_ids = []
    docs_attention_mask = []
    docs_sent_mask = []
    
    # Find max number of sentences across documents in the batch
    max_sents = max(len(item[0]) for item in batch)
    # Ensure max_sents is at least 1
    max_sents = max(1, max_sents)
    
    for sentences, label in batch:
        all_labels.append(label)
        
        # tokenize each sentence
        try:
            enc = tokenizer(
                sentences,
                padding='max_length',
                truncation=True,
                max_length=max_tokens,
                return_tensors='pt'
            )
            input_ids = enc['input_ids']       # [num_sents, max_tokens]
            attention_mask = enc['attention_mask']
            num_sents = input_ids.size(0)
        except Exception as e:
            # Handle tokenization errors by creating empty tensors
            logger.warning(f"Tokenization error: {str(e)}. Using empty tensors.")
            input_ids = torch.zeros((1, max_tokens), dtype=torch.long)
            attention_mask = torch.zeros((1, max_tokens), dtype=torch.long)
            num_sents = 1
        
        # pad sentences to max_sents
        if num_sents < max_sents:
            pad_sents = max_sents - num_sents
            input_ids = torch.cat([input_ids, torch.zeros(pad_sents, max_tokens, dtype=torch.long)], dim=0)
            attention_mask = torch.cat([attention_mask, torch.zeros(pad_sents, max_tokens, dtype=torch.long)], dim=0)
            sent_mask = [1] * num_sents + [0] * pad_sents
        else:
            sent_mask = [1] * max_sents
            
        docs_input_ids.append(input_ids)
        docs_attention_mask.append(attention_mask)
        docs_sent_mask.append(sent_mask)
        
    # stack docs
    input_ids = torch.stack(docs_input_ids, dim=0)         # [B, S, T]
    attention_mask = torch.stack(docs_attention_mask, dim=0) # [B, S, T]
    sent_mask = torch.tensor(docs_sent_mask, dtype=torch.bool)  # [B, S]
    labels = torch.tensor(all_labels, dtype=torch.float)
    
    return input_ids, attention_mask, sent_mask, labels

# Function to predict and visualize attention
def predict_and_visualize_attention(model, tokenizer, text, device, max_tokens=50, max_sentences=5):
    """Predict sentiment and visualize both word and sentence level attention for a given text"""
    model.eval()
    
    # Ensure nltk punkt is downloaded
    try:
        nltk.data.find('tokenizers/punkt')
    except LookupError:
        nltk.download('punkt', quiet=True)
        
    # Preprocess text
    try:
        sentences = nltk.tokenize.sent_tokenize(text)
        if len(sentences) > max_sentences:
            sentences = sentences[:max_sentences]
    except:
        sentences = [text]
        
    # Tokenize
    enc = collate_fn([(sentences, 0)], tokenizer, max_tokens=max_tokens)
    input_ids, attn_mask, sent_mask, _ = [x.to(device) for x in enc]
    
    # Get predictions and attention weights
    with torch.no_grad():
        logits, word_alpha, sent_alpha = model(input_ids, attn_mask, sent_mask)
        
    # Convert to probabilities
    prob = torch.sigmoid(logits).item()
    pred_label = 'Positive' if prob > 0.5 else 'Negative'
    
    # Set up plots
    fig = plt.figure(figsize=(15, 7))
    
    # 1. Sentence-level attention
    plt.subplot(1, 2, 1)
    sent_weights = sent_alpha[0].cpu().numpy()
    sent_idxs = np.arange(len(sentences))
    valid_sents = sent_mask[0].cpu().numpy()
    
    # Filter out padding sentences
    valid_sent_idxs = sent_idxs[valid_sents]
    valid_sent_weights = sent_weights[valid_sents]
    valid_sentences = [sentences[i] for i in valid_sent_idxs]
    
    # Sort sentences by attention weights
    sorted_idxs = np.argsort(valid_sent_weights)[::-1]
    sorted_sentences = [valid_sentences[i] for i in sorted_idxs]
    sorted_weights = valid_sent_weights[sorted_idxs]
    
    # Truncate if too many sentences
    if len(sorted_sentences) > 5:
        sorted_sentences = sorted_sentences[:5]
        sorted_weights = sorted_weights[:5]
        
    colors = plt.cm.viridis(sorted_weights / max(sorted_weights))
    bars = plt.barh(range(len(sorted_sentences)), sorted_weights, color=colors)
    plt.yticks(range(len(sorted_sentences)), 
               [f"{s[:40]}..." if len(s) > 40 else s for s in sorted_sentences])
    plt.title(f'Sentence-level Attention\nPrediction: {pred_label} ({prob:.2f})')
    plt.xlabel('Attention Weight')
    plt.tight_layout()
    
    # 2. Word-level attention for the most important sentence
    plt.subplot(1, 2, 2)
    most_important_sent_idx = valid_sent_idxs[np.argmax(valid_sent_weights)]
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0, most_important_sent_idx].cpu().tolist())
    weights = word_alpha[0, most_important_sent_idx].cpu().numpy()
    
    # Get token mask
    token_mask = attn_mask[0, most_important_sent_idx].cpu().numpy()
    
    # Filter special tokens and padding
    special_tokens = ['<s>', '</s>', '<pad>']
    valid_tokens = []
    valid_weights = []
    
    for i, token in enumerate(tokens):
        if token_mask[i] == 1 and token not in special_tokens:
            valid_tokens.append(token)
            valid_weights.append(weights[i])
    
    # Convert to numpy arrays
    valid_tokens = np.array(valid_tokens)
    valid_weights = np.array(valid_weights)
    
    # Sort by attention weights
    sorted_idxs = np.argsort(valid_weights)[-10:][::-1]  # Top 10 tokens
    sorted_tokens = valid_tokens[sorted_idxs]
    sorted_weights = valid_weights[sorted_idxs]
    
    colors = plt.cm.viridis(sorted_weights / max(sorted_weights))
    bars = plt.barh(range(len(sorted_tokens)), sorted_weights, color=colors)
    plt.yticks(range(len(sorted_tokens)), sorted_tokens)
    plt.title(f'Word-level Attention\nMost important sentence: "{sentences[most_important_sent_idx][:40]}..."')
    plt.xlabel('Attention Weight')
    plt.tight_layout()
    
    plt.savefig('sentiment_attention_visualization.png')
    plt.close()
    
    return {
        'prediction': pred_label, 
        'probability': prob,
        'most_important_sentence': sentences[most_important_sent_idx],
        'top_tokens': list(zip(sorted_tokens, sorted_weights))
    }

# Main function to load model and make predictions
def main():
    # Configuration
    MODEL_PATH = '/kaggle/input/han_model/pytorch/default/2/best_han_model.pt'
    MAX_SENTS = 5
    MAX_TOKENS = 50
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    logger.info(f"Using device: {DEVICE}")
    
    # Initialize tokenizer
    tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')
    
    # Initialize model with same parameters as training
    model = HAN_RoBERTa(
        pretrained_model='roberta-base',
        word_gru_hidden=64,
        sent_gru_hidden=64,
        dropout=0.2,
        freeze_bert=True
    )
    
    # Load saved model weights
    try:
        logger.info(f"Loading model from {MODEL_PATH}")
        model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
        model.to(DEVICE)
        logger.info("Model loaded successfully")
    except Exception as e:
        logger.error(f"Error loading model: {str(e)}")
        return
    
    # Example text for prediction
    sample_text = "I absolutely loved this movie! The acting was superb and the plot kept me engaged throughout. Definitely recommend it to everyone!"
    
    # You can replace with your own text or take user input
    # sample_text = input("Enter text to analyze sentiment: ")
    
    logger.info(f"Analyzing text: {sample_text}")
    
    # Predict and visualize
    result = predict_and_visualize_attention(
        model=model,
        tokenizer=tokenizer,
        text=sample_text,
        device=DEVICE,
        max_tokens=MAX_TOKENS,
        max_sentences=MAX_SENTS
    )
    
    # Print results
    logger.info(f"Prediction: {result['prediction']} with confidence {result['probability']:.4f}")
    logger.info(f"Most important sentence: {result['most_important_sentence']}")
    logger.info("Top tokens with attention weights:")
    for token, weight in result['top_tokens']:
        logger.info(f"  {token}: {weight:.4f}")
    
    logger.info("Attention visualization saved as 'sentiment_attention_visualization.png'")

if __name__ == "__main__":
    main()

2025-04-16 03:49:37 - INFO - Using device: cuda


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`




model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

2025-04-16 03:49:41 - INFO - Loading model from /kaggle/input/han_model/pytorch/default/2/best_han_model.pt


  model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))


2025-04-16 03:49:47 - INFO - Model loaded successfully
2025-04-16 03:49:47 - INFO - Analyzing text: I absolutely loved this movie! The acting was superb and the plot kept me engaged throughout. Definitely recommend it to everyone!
2025-04-16 03:49:48 - INFO - Prediction: Positive with confidence 0.9607
2025-04-16 03:49:48 - INFO - Most important sentence: Definitely recommend it to everyone!
2025-04-16 03:49:48 - INFO - Top tokens with attention weights:
2025-04-16 03:49:48 - INFO -   Def: 0.1440
2025-04-16 03:49:48 - INFO -   initely: 0.1282
2025-04-16 03:49:48 - INFO -   Ġrecommend: 0.1012
2025-04-16 03:49:48 - INFO -   !: 0.0871
2025-04-16 03:49:48 - INFO -   Ġto: 0.0757
2025-04-16 03:49:48 - INFO -   Ġit: 0.0707
2025-04-16 03:49:48 - INFO -   Ġeveryone: 0.0700
2025-04-16 03:49:48 - INFO - Attention visualization saved as 'sentiment_attention_visualization.png'
