In [1]:
import torch
from torch import nn
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from transformers import BigBirdTokenizer, BigBirdForSequenceClassification
from transformers import LongformerTokenizer, LongformerForSequenceClassification 
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from copy import deepcopy
import seaborn as sns
import networkx as nx
import scipy
from scipy.stats import pearsonr
import pickle

In [2]:
from bertviz import model_view, head_view

In [3]:
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk import pos_tag
import nltk

# Ensure necessary NLTK resources are available
nltk.download("punkt")
nltk.download("punkt_tab")
nltk.download("averaged_perceptron_tagger")
nltk.download("stopwords")
nltk.download('averaged_perceptron_tagger_eng')
nltk.download('vader_lexicon')

[nltk_data] Downloading package punkt to /home/zhizheng/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     /home/zhizheng/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/zhizheng/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /home/zhizheng/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /home/zhizheng/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package vader_lexicon to
[nltk_data]     /home/zhizheng/nltk_data...
[nltk_data]   Package vader_lexicon is already up-to-date!


True

In [4]:
from nltk.sentiment import SentimentIntensityAnalyzer

# Initialize VADER Sentiment Intensity Analyzer
sia = SentimentIntensityAnalyzer()

# Define words to analyze
words = ["amazing", "terrible", "happy", "sad", "awesome", "horrible", "neutral"]

# Label words based on sentiment score
labels = {}
for word in words:
    sentiment_score = sia.polarity_scores(word)['compound']
    if sentiment_score > 0.5:
        labels[word] = "Strong Positive"
    elif sentiment_score < -0.5:
        labels[word] = "Strong Negative"
    else:
        labels[word] = "Neutral or Weak Sentiment"

# Output labeled words
print(labels)

{'amazing': 'Strong Positive', 'terrible': 'Neutral or Weak Sentiment', 'happy': 'Strong Positive', 'sad': 'Neutral or Weak Sentiment', 'awesome': 'Strong Positive', 'horrible': 'Strong Negative', 'neutral': 'Neutral or Weak Sentiment'}


In [5]:
class IMDBDataset(Dataset):
    def __init__(self, reviews, targets, tokenizer, max_length=512):
        self.reviews = reviews
        self.targets = targets
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.reviews)
    
    def __getitem__(self, idx):
        review = str(self.reviews[idx])
        target = self.targets[idx]

        encoding = self.tokenizer.encode_plus(
            review,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(target, dtype=torch.long),
            'review': review
        }

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [21]:
MODEL = "Bert"
FINETUNE = True

if MODEL == "Bert":
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = BertForSequenceClassification.from_pretrained(
        'bert-base-uncased',
        num_labels=2
    ).to(device)
    MAX_LEN=512
    if FINETUNE:
        model.load_state_dict(torch.load('./ckpts/best_bert_imdb_model.pt'))
elif MODEL == "BigBird":        
    tokenizer = BigBirdTokenizer.from_pretrained('google/bigbird-roberta-base')
    model = BigBirdForSequenceClassification.from_pretrained(
        'google/bigbird-roberta-base',
        num_labels=2
    ).to(device)
    MAX_LEN=4096
    if FINETUNE:
        model.load_state_dict(torch.load('./ckpts/best_bigbird_imdb_model.pt'))
elif MODEL == "Longformer":
    tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')
    model = LongformerForSequenceClassification.from_pretrained(
        'allenai/longformer-base-4096',
        num_labels=2
    ).to(device)
    MAX_LEN=4096
    if FINETUNE:    
        model.load_state_dict(torch.load('./ckpts/best_longformer_imdb_model.pt'))
model.eval()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  model.load_state_dict(torch.load('./ckpts/best_bert_imdb_model.pt'))


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [22]:
df = pd.read_csv('./dataset/IMDB Dataset.csv')
reviews = df['review'].values
labels = (df['sentiment'] == 'positive').astype(int).values

In [23]:
train_texts, temp_texts, train_labels, temp_labels = train_test_split(
    reviews, labels, test_size=0.3, random_state=42
)
val_texts, test_texts, val_labels, test_labels = train_test_split(
    temp_texts, temp_labels, test_size=0.5, random_state=42
)

In [24]:
# Create datasets
train_dataset = IMDBDataset(train_texts, train_labels, tokenizer, max_length=MAX_LEN)
val_dataset = IMDBDataset(val_texts, val_labels, tokenizer, max_length=MAX_LEN)
test_dataset = IMDBDataset(test_texts, test_labels, tokenizer, max_length=MAX_LEN)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)
test_loader = DataLoader(test_dataset, batch_size=1)

In [25]:
data_iter = iter(test_loader)
# next(data_iter)
test_data = next(data_iter)
input_ids = test_data['input_ids'].to(device)
attention_mask  = test_data['attention_mask'].to(device)
label = test_data['labels']
review = test_data['review']
with torch.no_grad():
    output = model(input_ids, attention_mask=attention_mask, output_attentions=True)
print(label)
print(review)

tensor([0])
['The biggest National Lampoon hit remains "Animal House", and rightly so. It was funny, raucous and good-natured.<br /><br />The exact opposite of every other National Lampoon film. Including "Class Reunion".<br /><br />PLEASE do not be fooled by the inclusion of Stephen Furst ("Flounder") from "Animal House". Or by the fact that John Hughes wrote this jumbled mess. This reunion is about as hilarious as root canal and twice as painful.<br /><br />One star, and that\'s being generous. Then again, I always thought most of my old classmates were demons, vampires and serial killers, too.']


In [26]:
_attentions = [att.detach().cpu().numpy() for att in output.attentions]
attentions_mat = np.asarray(_attentions)[:,0]
print(attentions_mat.shape)

(12, 12, 512, 512)


In [27]:
predicted = output['logits']
predicted

tensor([[0.3195, 0.1632]], device='cuda:0')

In [28]:
# Tokenize and part-of-speech tag the input text to filter syntax words
text = review[0]
tokens = word_tokenize(text)
pos_tags = pos_tag(tokens)
stop_words = set(stopwords.words("english"))
content_tokens = [word for word, pos in pos_tags if (pos.startswith("NN") or pos.startswith("VB") or pos.startswith("JJ")) and word.lower() not in stop_words and word not in ['<', '>', '/', 'br']]
print(content_tokens)


['biggest', 'National', 'Lampoon', 'hit', 'remains', 'Animal', 'House', 'funny', 'raucous', 'good-natured.', 'exact', 'opposite', 'National', 'Lampoon', 'film', 'Including', 'Class', 'Reunion', 'PLEASE', 'fooled', 'inclusion', 'Stephen', 'Furst', 'Flounder', 'Animal', 'House', 'fact', 'John', 'Hughes', 'wrote', 'jumbled', 'mess', 'reunion', 'hilarious', 'root', 'canal', 'painful.', 'One', 'star', "'s", 'generous', 'thought', 'old', 'classmates', 'demons', 'vampires', 'serial', 'killers']


In [29]:
# Define a mask to exclude special tokens, punctuation, and syntax words
input_ids = input_ids.to("cpu")
cls_token_id = tokenizer.cls_token_id
sep_token_id = tokenizer.sep_token_id
pad_token_id = tokenizer.pad_token_id
punctuation_ids = [tokenizer.convert_tokens_to_ids(p) for p in [".", ",", "!", "?", ":", ";", "-", "..."]]

non_special_token_mask = (input_ids != cls_token_id) & (input_ids != sep_token_id) & (input_ids != pad_token_id)
non_punctuation_mask = ~torch.isin(input_ids, torch.tensor(punctuation_ids))
content_word_mask = torch.tensor([[1 if tokenizer.decode(id).strip() in content_tokens else 0 for id in input_ids[0]]], dtype=torch.bool)

# Combine all masks
valid_token_mask = content_word_mask
 

In [30]:
# Find the indices of True values
if MODEL == "Longformer":
    true_indices = torch.where(valid_token_mask)[1]
    
    if len(true_indices) >= 11:
        valid_token_mask[0,true_indices[11:]] = False
    valid_token_mask = torch.cat((torch.ones(1, 1, dtype=torch.bool), valid_token_mask), dim=1)
else:
    valid_token_mask[0,0]=True
    true_indices = torch.where(valid_token_mask)[1]
    
    if len(true_indices) >= 12:
        valid_token_mask[0,true_indices[12:]] = False
    


In [31]:
if MODEL=="Longformer":
    ATTN = []
    for layer_attention, global_attention in zip(output.attentions, output.global_attentions):
        combined_attn = torch.zeros((1, layer_attention.shape[1],MAX_LEN+1, MAX_LEN+1)) 
        local_attn = torch.zeros((1, layer_attention.shape[1],MAX_LEN, MAX_LEN))  # Use the same device
        global_attention_weights = layer_attention[:, :, :, :1]
        local_attention_weights = layer_attention[:, :, :, 1:]
        half_window = int(model.longformer.config.attention_window[0] / 2)
        
        # Draw local attention 
        for i in range(MAX_LEN):
            if i > half_window and MAX_LEN - i > half_window + 1:
                local_attn[0, :, i, i - half_window:i + half_window + 1] = local_attention_weights[0, :, i, :]
            elif i <= half_window:
                local_attn[0, :, i, :i + half_window + 1] = local_attention_weights[0, :, i, half_window - i:]
            else:
                local_attn[0, :, i, i - half_window:] = local_attention_weights[0, :, i, :half_window + (MAX_LEN - i)]

        combined_attn[:, :, 1:, 1:] = local_attn
        combined_attn[:, :, 1:,0] += global_attention_weights[:, :, :, 0].cpu()
        combined_attn[:, :, 0, 1:] += global_attention[:, :, :, 0].cpu()
        ATTN.append(combined_attn.to(device))
else:
    ATTN = output.attentions
    


In [32]:
print(valid_token_mask.shape)
print(ATTN[0].shape)

torch.Size([1, 512])
torch.Size([1, 12, 512, 512])


In [33]:
masked_attn = []
for layer_attention in ATTN:
    # CLS token's attention in this layer
    cls_attention_to_content_words = layer_attention[:, :, valid_token_mask[0]]
    cls_attention_to_content_words = cls_attention_to_content_words[:, :, :,valid_token_mask[0]]
    cls_attention_to_content_words = cls_attention_to_content_words / torch.sum(cls_attention_to_content_words, dim=-1, keepdims=True)
    masked_attn.append(cls_attention_to_content_words)  # Average over heads

# Concatenate all layer attentions vertically
print(masked_attn[0].shape)
# Extract valid tokens for x-axis labels
if  MODEL == "Longformer":
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0][valid_token_mask[0, 1:]])
    print(tokens)
    tokens.insert(0, "[CLS]")
    print(tokens)
else:
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0][valid_token_mask[0]])

torch.Size([1, 12, 12, 12])


In [34]:
head_view(masked_attn, tokens) 

<IPython.core.display.Javascript object>