# Exercise 6: Attention

## Task 1: Implementation of Self Attention

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import numpy as np
from gensim.downloader import load as gensim_load
from datasets import load_dataset, concatenate_datasets

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
glove = gensim_load('glove-wiki-gigaword-100')

In [3]:
def load_imdb(n_samples=100):
    dataset = load_dataset("imdb", split="train")

    # How many samples per class
    n_per_class = n_samples // 2

    # Filter each class
    pos = dataset.filter(lambda x: x["label"] == 1).shuffle(seed=42).select(range(n_per_class))
    neg = dataset.filter(lambda x: x["label"] == 0).shuffle(seed=42).select(range(n_per_class))
    
    # Combine and shuffle
    balanced = concatenate_datasets([pos, neg]).shuffle(seed=42)

    texts = balanced["text"]
    labels = balanced["label"]
    return texts, labels

texts, labels = load_imdb(n_samples=1000) # adjust n_samples based on your computational resources

In [4]:
from transformers import AutoTokenizer

# Load a HuggingFace tokenizer (e.g., bert-base-uncased)
hf_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

def tokenize(text):
    """
    Tokenize input text using HuggingFace tokenizer.

    Args:
        text (str): Input sentence.

    Returns:
        List[str]: List of tokens.
    """
    return hf_tokenizer.tokenize(text)

In [5]:
def vectorize(tokens, max_len=100):
    """
    Convert tokens to GloVe vectors with padding/truncation
    
    Args:
        tokens: list of string tokens
        max_len: maximum sequence length
    
    Returns:
        numpy array of shape (max_len, embedding_dim)
    """
    # Get embedding dimension from GloVe (should be 100 for glove-wiki-gigaword-100)
    embedding_dim = glove.vector_size
    
    # Initialize result matrix with zeros (padding)
    result = np.zeros((max_len, embedding_dim))
    
    # Process tokens up to max_len
    for i, token in enumerate(tokens[:max_len]):
        try:
            # Get GloVe vector for the token
            vector = glove[token.lower()]  # Convert to lowercase for better matching
            result[i] = vector
        except KeyError:
            # Token not in GloVe vocabulary - leave as zero vector
            # Alternatively, you could use a random vector or UNK token
            pass
    
    return result

In [6]:
# Test the vectorize function on a sample sentence
sample_text = texts[0]
print("Sample text:", sample_text)
print("Sample text length:", len(sample_text.split()))
sample_tokens = sample_text.split()  # simple tokenization
vectorized = vectorize(sample_tokens, max_len=100)
print(vectorized.shape)
print(vectorized)

Sample text: The fine cast cannot uplift this routine tale of a secretary murdered by her married paramour. In fact there are more questions than answers in this one-sided tale of romance and murder; and since we are only provided with the prosecution's side, none of these questions will be answered. This is the type of fare that appeals to the "He Woman, Man Hater" clubs of America. As presented, it is the tale of an innocent woman who just happens to be "caught up" in a romance with a married, high-profile attorney. Is it possible that IF, she had not been two timing her boy friend and having an affair with a married man, the whole nasty murderous, sordid incident could have been avoided? When you watch this, don't worry about going to the 'fridge, you won't miss anything.
Sample text length: 138
(100, 100)
[[-0.038194   -0.24487001  0.72812003 ... -0.1459      0.82779998
   0.27061999]
 [ 0.50524002  0.49487001 -0.57381999 ... -0.41922     0.66611999
  -0.1165    ]
 [-0.40184     0.

In [7]:
import pandas as pd

train_df = pd.DataFrame({
    "text": texts,
    "label": labels
})

train_df["tokens"] = train_df["text"].apply(tokenize)
train_df["vectors"] = train_df["tokens"].apply(lambda x: vectorize(x, max_len=100))

Token indices sequence length is longer than the specified maximum sequence length for this model (781 > 512). Running this sequence through the model will result in indexing errors


In [8]:
class SelfAttention(nn.Module):
    def __init__(self, d_model, d_k):
        super().__init__()
        self.WQ = nn.Linear(d_model, d_k)
        self.WK = nn.Linear(d_model, d_k)
        self.WV = nn.Linear(d_model, d_k)

        self.d_k = d_k

    def forward(self, x):
        Q = self.WQ(x)  # (batch_size, seq_len, d_k)
        K = self.WK(x)  # (batch_size, seq_len, d_k)
        V = self.WV(x)  # (batch_size, seq_len, d_k)
        
        K_transpose = K.transpose(-2, -1)  # (batch_size, d_k, seq_len)
        attention_scores = torch.matmul(Q, K_transpose)  # (batch_size, seq_len, seq_len)
        scaled_attention_scores = attention_scores / np.sqrt(self.d_k)  # Scale by sqrt(d_k)
        attention_weights = F.softmax(scaled_attention_scores, dim=-1)  # (batch_size, seq_len, seq_len)        
        context_vector = torch.matmul(attention_weights, V)

        return context_vector

## Task 2: Adding a Classification Layer

In [9]:
class BinaryClassificationModel(nn.Module):
    def __init__(self, d_model, d_k):
        super().__init__()
        self.attention = SelfAttention(d_model, d_k)
        self.classifier = nn.Linear(d_k, 1)  # Final layer for binary classification

    def forward(self, x):
        context_vector = self.attention(x) # shape : (batch_size, seq_len, d_k)
        context_vector = context_vector.mean(dim=1)  # Aggregate over sequence length
        logits = self.classifier(context_vector)
        return logits

## Task 3: Training

In [10]:
from tqdm.auto import tqdm

def train_model(model, X, y, epochs=10, lr=1e-3, batch_size=8):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()
    dataset = torch.utils.data.TensorDataset(X, y)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Outer progress bar for epochs
    epoch_pbar = tqdm(range(epochs), desc="Training", unit="epoch")
    
    for epoch in epoch_pbar:
        total_loss = 0
        correct_predictions = 0
        total_samples = 0
        
        # Inner progress bar for batches
        batch_pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}", 
                         leave=False, unit="batch")
        
        for batch in batch_pbar:
            inputs, labels = batch
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            predicted_probabilities = torch.sigmoid(outputs)
            predictions = (predicted_probabilities > 0.5).float()
            correct_predictions += (predictions.squeeze() == labels).sum().item()
            total_samples += labels.numel()
            
            # Update batch progress bar with current loss
            batch_pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
        
        # Calculate epoch statistics
        avg_loss = total_loss / len(dataloader)
        accuracy = correct_predictions / total_samples
        
        # Update epoch progress bar with metrics
        epoch_pbar.set_postfix({
            "Avg Loss": f"{avg_loss:.4f}", 
            "Accuracy": f"{accuracy:.4f}"
        })
        
        # Print epoch summary
        print(f"\nEpoch {epoch+1}/{epochs} - Avg Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

In [11]:
EMBEDDING_DIM = 100  # GloVe embedding dimension
d_model = EMBEDDING_DIM
d_k = 64 # try different values for d_k
model = BinaryClassificationModel(d_model=d_model, d_k=d_k)

X = torch.tensor(train_df["vectors"].tolist(), dtype=torch.float32)
y = torch.tensor(train_df["label"].tolist(), dtype=torch.float32).unsqueeze(1)  # Ensure y is of shape (batch_size, 1)
train_model(model, X, y)

  X = torch.tensor(train_df["vectors"].tolist(), dtype=torch.float32)
Training:  10%|█         | 1/10 [00:01<00:12,  1.39s/epoch, Avg Loss=0.6876, Accuracy=4.0160]


Epoch 1/10 - Avg Loss: 0.6876, Accuracy: 4.0160


Training:  20%|██        | 2/10 [00:02<00:09,  1.23s/epoch, Avg Loss=0.6121, Accuracy=4.1120]


Epoch 2/10 - Avg Loss: 0.6121, Accuracy: 4.1120


Training:  30%|███       | 3/10 [00:03<00:09,  1.29s/epoch, Avg Loss=0.5493, Accuracy=4.1400]


Epoch 3/10 - Avg Loss: 0.5493, Accuracy: 4.1400


Training:  40%|████      | 4/10 [00:05<00:07,  1.25s/epoch, Avg Loss=0.5223, Accuracy=4.2440]


Epoch 4/10 - Avg Loss: 0.5223, Accuracy: 4.2440


Training:  50%|█████     | 5/10 [00:06<00:05,  1.19s/epoch, Avg Loss=0.5093, Accuracy=4.2480]


Epoch 5/10 - Avg Loss: 0.5093, Accuracy: 4.2480


Training:  60%|██████    | 6/10 [00:07<00:04,  1.13s/epoch, Avg Loss=0.4904, Accuracy=4.3520]


Epoch 6/10 - Avg Loss: 0.4904, Accuracy: 4.3520


Training:  70%|███████   | 7/10 [00:08<00:03,  1.13s/epoch, Avg Loss=0.4715, Accuracy=4.3180]


Epoch 7/10 - Avg Loss: 0.4715, Accuracy: 4.3180


Training:  80%|████████  | 8/10 [00:09<00:02,  1.06s/epoch, Avg Loss=0.4593, Accuracy=4.2120]


Epoch 8/10 - Avg Loss: 0.4593, Accuracy: 4.2120


Training:  90%|█████████ | 9/10 [00:10<00:01,  1.01s/epoch, Avg Loss=0.4544, Accuracy=4.2180]


Epoch 9/10 - Avg Loss: 0.4544, Accuracy: 4.2180


Training: 100%|██████████| 10/10 [00:11<00:00,  1.11s/epoch, Avg Loss=0.4420, Accuracy=4.2680]
Training: 100%|██████████| 10/10 [00:11<00:00,  1.11s/epoch, Avg Loss=0.4420, Accuracy=4.2680]



Epoch 10/10 - Avg Loss: 0.4420, Accuracy: 4.2680


## Task 4: Inference

In [15]:
def predict_sentiment(text):
    tokens = tokenize(text)
    vectorized = vectorize(tokens, max_len=100)
    vectorized_tensor = torch.tensor(vectorized, dtype=torch.float32).unsqueeze(0)  # Add batch dimension
    model.eval()
    with torch.no_grad():
        logits = model(vectorized_tensor)
        probabilities = torch.sigmoid(logits)
        prediction = (probabilities > 0.5).float()
    return prediction.item(), probabilities.item()

In [16]:
predict_sentiment("This movie was fantastic and full of suspense!")

(1.0, 0.5533190369606018)

In [17]:
predict_sentiment("This movie was bad")

(0.0, 0.4657359719276428)