# 20 Newsgroups Classification - Hierarchical Attention Network
## Experiment 3: Complete Implementation with Report Data Collection

---

### Architecture (as per problem specification):
1. **Contextual Encoder (Bi-LSTM)** - extracts word-level features
2. **Word-Level Filtering Attention** - self-attention to identify important words
3. **Sentence Representation** - aggregates word features via learned pooling
4. **Document-Level Cross-Attention** - filtered words as queries, sentences as keys/values
5. **Classification Layer** - predicts newsgroup label

### Report Deliverables Data Collection:
- (ii) Baseline comparisons with metrics and charts
- (iii) Attention distribution visualizations
- (iv) Per-class error analysis
- (v) Attention influence analysis (quantitative + qualitative)
- (vii) Failure mode analysis for two attention stages

## 1. Environment Setup

In [None]:
!pip install torch torchvision torchaudio -q
!pip install scikit-learn nltk matplotlib seaborn tqdm pandas numpy gensim -q

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter, defaultdict
import re, json, os, tarfile, textwrap
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

from sklearn.datasets import fetch_20newsgroups
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import MultinomialNB
from sklearn.svm import LinearSVC
from sklearn.utils.class_weight import compute_class_weight
from sklearn.utils import Bunch

import nltk
nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True)
nltk.download('punkt_tab', quiet=True)
from nltk.tokenize import word_tokenize, sent_tokenize
from nltk.corpus import stopwords

import gensim.downloader as api

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Load Dataset

In [None]:
TARGET_NAMES = [
    'alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc',
    'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x',
    'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball',
    'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 'sci.med',
    'sci.space', 'soc.religion.christian', 'talk.politics.guns',
    'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc'
]

def load_newsgroups_from_folder(folder_path, remove=()):
    data, target = [], []
    for label_idx, category in enumerate(TARGET_NAMES):
        category_path = os.path.join(folder_path, category)
        if os.path.exists(category_path):
            for filename in os.listdir(category_path):
                filepath = os.path.join(category_path, filename)
                try:
                    with open(filepath, 'r', encoding='latin-1') as f:
                        text = f.read()
                        if 'headers' in remove:
                            lines = text.split('\n')
                            for i, line in enumerate(lines):
                                if line.strip() == '':
                                    text = '\n'.join(lines[i+1:])
                                    break
                        if 'footers' in remove:
                            lines = text.split('\n')
                            for i in range(len(lines)-1, -1, -1):
                                if lines[i].strip() == '--':
                                    text = '\n'.join(lines[:i])
                                    break
                        if 'quotes' in remove:
                            lines = text.split('\n')
                            lines = [l for l in lines if not l.strip().startswith('>')]
                            text = '\n'.join(lines)
                        data.append(text)
                        target.append(label_idx)
                except: continue
    return Bunch(data=data, target=np.array(target), target_names=TARGET_NAMES)

try:
    newsgroups_train = fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'))
    newsgroups_test = fetch_20newsgroups(subset='test', remove=('headers', 'footers', 'quotes'))
    print("Dataset loaded via sklearn.")
except Exception as e:
    print(f"Direct download failed: {e}\nUsing manual download...")
    data_home = os.path.expanduser('~/scikit_learn_data')
    twenty_home = os.path.join(data_home, '20news_home')
    os.makedirs(twenty_home, exist_ok=True)
    archive_path = os.path.join(twenty_home, '20news-bydate.tar.gz')
    if not os.path.exists(archive_path):
        !wget --user-agent="Mozilla/5.0" -O "{archive_path}" "http://qwone.com/~jason/20Newsgroups/20news-bydate.tar.gz"
    train_path = os.path.join(twenty_home, '20news-bydate-train')
    if not os.path.exists(train_path):
        with tarfile.open(archive_path, 'r:gz') as tar:
            tar.extractall(path=twenty_home)
    newsgroups_train = load_newsgroups_from_folder(os.path.join(twenty_home, '20news-bydate-train'), remove=('headers', 'footers', 'quotes'))
    newsgroups_test = load_newsgroups_from_folder(os.path.join(twenty_home, '20news-bydate-test'), remove=('headers', 'footers', 'quotes'))

print(f"Training samples: {len(newsgroups_train.data)}")
print(f"Test samples: {len(newsgroups_test.data)}")

## 3. Class Distribution Analysis

In [None]:
train_counts = Counter(newsgroups_train.target)
test_counts = Counter(newsgroups_test.target)

class_data = [{'ID': i, 'Category': name, 'Train': train_counts[i], 'Test': test_counts[i], 
               'Total': train_counts[i] + test_counts[i], 
               'Train%': round(train_counts[i] / len(newsgroups_train.data) * 100, 2)}
              for i, name in enumerate(TARGET_NAMES)]

class_df = pd.DataFrame(class_data)
print("CLASS DISTRIBUTION\n" + "="*80)
print(class_df.to_string(index=False))
class_df.to_csv('class_distribution.csv', index=False)

train_counts_list = [train_counts[i] for i in range(20)]
imbalance_ratio = max(train_counts_list) / min(train_counts_list)
print(f"\nImbalance Ratio: {imbalance_ratio:.2f}")

In [None]:
fig, ax = plt.subplots(figsize=(14, 6))
x = np.arange(20)
width = 0.35
ax.bar(x - width/2, [train_counts[i] for i in range(20)], width, label='Train', color='steelblue')
ax.bar(x + width/2, [test_counts[i] for i in range(20)], width, label='Test', color='darkorange')
ax.set_xlabel('Category', fontsize=12)
ax.set_ylabel('Number of Documents', fontsize=12)
ax.set_title('20 Newsgroups - Class Distribution', fontsize=14)
ax.set_xticks(x)
ax.set_xticklabels(TARGET_NAMES, rotation=45, ha='right', fontsize=9)
ax.legend()
plt.tight_layout()
plt.savefig('class_distribution.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
print("SAMPLE DOCUMENTS\n" + "="*80)
for class_idx in [0, 1, 7, 11, 15]:
    indices = np.where(newsgroups_train.target == class_idx)[0]
    text = ' '.join(newsgroups_train.data[indices[0]].split())[:300]
    print(f"\n[{class_idx}] {TARGET_NAMES[class_idx]}\n" + "-"*40)
    print(textwrap.fill(text + "...", width=80))

## 4. Text Preprocessing

In [None]:
class TextPreprocessor:
    def __init__(self, min_word_freq=2, max_vocab_size=50000, max_sent_len=50, max_doc_sents=30):
        self.min_word_freq = min_word_freq
        self.max_vocab_size = max_vocab_size
        self.max_sent_len = max_sent_len
        self.max_doc_sents = max_doc_sents
        self.word2idx = {'<PAD>': 0, '<UNK>': 1}
        self.idx2word = {0: '<PAD>', 1: '<UNK>'}
        self.vocab_size = 2
    
    def clean_text(self, text):
        text = text.lower()
        text = re.sub(r'\S+@\S+', '', text)
        text = re.sub(r'http\S+|www\S+', '', text)
        text = re.sub(r'[^a-zA-Z0-9\s\.\!\?]', ' ', text)
        return re.sub(r'\s+', ' ', text).strip()
    
    def tokenize_document(self, text):
        text = self.clean_text(text)
        sentences = sent_tokenize(text)
        doc_tokens = []
        for sent in sentences[:self.max_doc_sents]:
            words = [w for w in word_tokenize(sent) if w.isalpha() and len(w) > 1]
            if words: doc_tokens.append(words[:self.max_sent_len])
        return doc_tokens
    
    def build_vocab(self, documents):
        word_freq = Counter()
        for doc in tqdm(documents, desc="Building vocabulary"):
            for sent in self.tokenize_document(doc):
                word_freq.update(sent)
        for word in [w for w, c in word_freq.most_common(self.max_vocab_size) if c >= self.min_word_freq]:
            self.word2idx[word] = self.vocab_size
            self.idx2word[self.vocab_size] = word
            self.vocab_size += 1
        print(f"Vocabulary size: {self.vocab_size}")
        return self
    
    def encode_document(self, text):
        encoded_doc = [[self.word2idx.get(w, 1) for w in sent] for sent in self.tokenize_document(text)]
        encoded_doc = [s for s in encoded_doc if s]
        return encoded_doc if encoded_doc else [[1]]

preprocessor = TextPreprocessor(min_word_freq=2, max_vocab_size=50000, max_sent_len=50, max_doc_sents=30)
preprocessor.build_vocab(newsgroups_train.data)

## 5. Load GloVe Embeddings

In [None]:
print("Loading GloVe embeddings...")
glove_vectors = api.load('glove-wiki-gigaword-300')
print(f"Loaded {len(glove_vectors)} word vectors")

EMBED_DIM = 300
embedding_matrix = np.zeros((preprocessor.vocab_size, EMBED_DIM))
found_count = 0
for word, idx in tqdm(preprocessor.word2idx.items(), desc="Building embedding matrix"):
    if word in glove_vectors:
        embedding_matrix[idx] = glove_vectors[word]
        found_count += 1
    else:
        embedding_matrix[idx] = np.random.normal(0, 0.1, EMBED_DIM)
embedding_matrix[0] = np.zeros(EMBED_DIM)
embedding_matrix = torch.FloatTensor(embedding_matrix)
print(f"Embedding coverage: {found_count}/{preprocessor.vocab_size} ({found_count/preprocessor.vocab_size*100:.1f}%)")

## 6. Dataset and DataLoader

In [None]:
class NewsGroupDataset(Dataset):
    def __init__(self, documents, labels, preprocessor):
        self.documents, self.labels, self.preprocessor = documents, labels, preprocessor
    def __len__(self): return len(self.documents)
    def __getitem__(self, idx):
        return self.preprocessor.encode_document(self.documents[idx]), self.labels[idx], idx

def collate_fn(batch):
    docs, labels, indices = zip(*batch)
    doc_lengths = [len(doc) for doc in docs]
    max_doc_len = max(doc_lengths)
    sent_lengths = [[len(sent) for sent in doc] + [1]*(max_doc_len-len(doc)) for doc in docs]
    max_sent_len = max(max(len(sent) for sent in doc) for doc in docs)
    padded_docs = [[sent + [0]*(max_sent_len-len(sent)) for sent in doc] + [[0]*max_sent_len]*(max_doc_len-len(doc)) for doc in docs]
    return (torch.LongTensor(padded_docs), torch.LongTensor(labels), 
            torch.LongTensor(doc_lengths), torch.LongTensor(sent_lengths), torch.LongTensor(indices))

train_dataset = NewsGroupDataset(newsgroups_train.data, newsgroups_train.target, preprocessor)
test_dataset = NewsGroupDataset(newsgroups_test.data, newsgroups_test.target, preprocessor)

train_size = int(0.9 * len(train_dataset))
train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, len(train_dataset)-train_size], generator=torch.Generator().manual_seed(SEED))

BATCH_SIZE = 32
train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}, Test batches: {len(test_loader)}")

## 7. Model Architecture

In [None]:
class WordLevelEncoder(nn.Module):
    def __init__(self, embedding_matrix, hidden_dim, num_layers=2, dropout=0.4):
        super().__init__()
        vocab_size, embed_dim = embedding_matrix.shape
        self.embedding = nn.Embedding.from_pretrained(embedding_matrix, freeze=False, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, bidirectional=True, batch_first=True, dropout=dropout if num_layers > 1 else 0)
        self.residual_proj = nn.Linear(embed_dim, hidden_dim * 2)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(hidden_dim * 2)
        self.hidden_dim = hidden_dim * 2
    
    def forward(self, x, lengths):
        embedded = self.dropout(self.embedding(x))
        residual = self.residual_proj(embedded)
        packed = pack_padded_sequence(embedded, lengths.cpu().clamp(min=1), batch_first=True, enforce_sorted=False)
        outputs, _ = self.lstm(packed)
        outputs, _ = pad_packed_sequence(outputs, batch_first=True)
        if outputs.size(1) < residual.size(1):
            outputs = torch.cat([outputs, torch.zeros(outputs.size(0), residual.size(1)-outputs.size(1), outputs.size(2), device=outputs.device)], dim=1)
        return self.layer_norm(outputs + residual)

class WordLevelAttention(nn.Module):
    def __init__(self, hidden_dim, attention_dim, dropout=0.3):
        super().__init__()
        self.attention = nn.Sequential(nn.Linear(hidden_dim, attention_dim), nn.Tanh(), nn.Dropout(dropout))
        self.context = nn.Linear(attention_dim, 1, bias=False)
    
    def forward(self, hidden_states, mask=None):
        attn_scores = self.context(self.attention(hidden_states)).squeeze(-1)
        if mask is not None: attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_weights = F.softmax(attn_scores, dim=-1)
        return torch.bmm(attn_weights.unsqueeze(1), hidden_states).squeeze(1), attn_weights

class SentenceEncoder(nn.Module):
    def __init__(self, embedding_matrix, word_hidden_dim, attention_dim, num_layers=2, dropout=0.4):
        super().__init__()
        self.word_encoder = WordLevelEncoder(embedding_matrix, word_hidden_dim, num_layers, dropout)
        self.word_attention = WordLevelAttention(word_hidden_dim * 2, attention_dim, dropout)
        self.hidden_dim = word_hidden_dim * 2
    
    def forward(self, sentences, sent_lengths):
        word_hidden = self.word_encoder(sentences, sent_lengths)
        mask = torch.arange(sentences.size(1), device=sentences.device).unsqueeze(0) < sent_lengths.unsqueeze(1)
        sent_repr, word_attn = self.word_attention(word_hidden, mask)
        return sent_repr, word_attn, word_hidden

class DocumentCrossAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads=8, dropout=0.4):
        super().__init__()
        self.hidden_dim, self.num_heads, self.head_dim = hidden_dim, num_heads, hidden_dim // num_heads
        self.query_proj = nn.Linear(hidden_dim, hidden_dim)
        self.key_proj = nn.Linear(hidden_dim, hidden_dim)
        self.value_proj = nn.Linear(hidden_dim, hidden_dim)
        self.output_proj = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, word_queries, sentence_kv, sent_mask=None):
        batch_size, num_queries, num_sents = word_queries.size(0), word_queries.size(1), sentence_kv.size(1)
        Q = self.query_proj(word_queries).view(batch_size, num_queries, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.key_proj(sentence_kv).view(batch_size, num_sents, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.value_proj(sentence_kv).view(batch_size, num_sents, self.num_heads, self.head_dim).transpose(1, 2)
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.head_dim)
        if sent_mask is not None: attn_scores = attn_scores.masked_fill(sent_mask.unsqueeze(1).unsqueeze(2) == 0, -1e9)
        attn_weights = self.dropout(F.softmax(attn_scores, dim=-1))
        attended = torch.matmul(attn_weights, V).transpose(1, 2).contiguous().view(batch_size, num_queries, self.hidden_dim)
        return self.output_proj(attended).mean(dim=1), attn_weights.mean(dim=1)

In [None]:
class HierarchicalAttentionNetwork(nn.Module):
    def __init__(self, embedding_matrix, word_hidden_dim, sent_hidden_dim, attention_dim, num_classes, num_heads=8, num_layers=2, dropout=0.4, top_k_words=15):
        super().__init__()
        self.top_k_words = top_k_words
        self.sentence_encoder = SentenceEncoder(embedding_matrix, word_hidden_dim, attention_dim, num_layers, dropout)
        word_repr_dim = word_hidden_dim * 2
        self.sent_lstm = nn.LSTM(word_repr_dim, sent_hidden_dim, num_layers=num_layers, bidirectional=True, batch_first=True, dropout=dropout if num_layers > 1 else 0)
        sent_repr_dim = sent_hidden_dim * 2
        self.sent_residual_proj = nn.Linear(word_repr_dim, sent_repr_dim)
        self.sent_layer_norm = nn.LayerNorm(sent_repr_dim)
        self.sent_attention = WordLevelAttention(sent_repr_dim, attention_dim, dropout)
        self.cross_attention = DocumentCrossAttention(sent_repr_dim, num_heads=num_heads, dropout=dropout)
        self.word_proj = nn.Linear(word_repr_dim, sent_repr_dim)
        self.gate = nn.Sequential(nn.Linear(sent_repr_dim * 2, sent_repr_dim), nn.Sigmoid())
        self.classifier = nn.Sequential(nn.Linear(sent_repr_dim, sent_repr_dim // 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(sent_repr_dim // 2, num_classes))
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, docs, doc_lengths, sent_lengths, return_attention=False):
        batch_size, max_sents, max_words = docs.size()
        sent_reprs, all_word_attns, all_word_hiddens = [], [], []
        for i in range(max_sents):
            sent_repr, word_attn, word_hidden = self.sentence_encoder(docs[:, i, :], sent_lengths[:, i])
            sent_reprs.append(sent_repr); all_word_attns.append(word_attn); all_word_hiddens.append(word_hidden)
        sent_reprs = torch.stack(sent_reprs, dim=1)
        all_word_attns = torch.stack(all_word_attns, dim=1)
        all_word_hiddens = torch.stack(all_word_hiddens, dim=1)
        
        residual = self.sent_residual_proj(sent_reprs)
        packed_sents = pack_padded_sequence(sent_reprs, doc_lengths.cpu().clamp(min=1), batch_first=True, enforce_sorted=False)
        sent_outputs, _ = self.sent_lstm(packed_sents)
        sent_outputs, _ = pad_packed_sequence(sent_outputs, batch_first=True)
        if sent_outputs.size(1) < max_sents:
            sent_outputs = torch.cat([sent_outputs, torch.zeros(batch_size, max_sents-sent_outputs.size(1), sent_outputs.size(2), device=sent_outputs.device)], dim=1)
        sent_outputs = self.sent_layer_norm(sent_outputs + residual)
        
        sent_mask = torch.arange(max_sents, device=docs.device).unsqueeze(0) < doc_lengths.unsqueeze(1)
        sent_attn_output, sent_attn_weights = self.sent_attention(sent_outputs, sent_mask)
        
        k = min(self.top_k_words, max_words)
        _, top_word_indices = torch.topk(all_word_attns, k, dim=-1)
        filtered_words = torch.gather(all_word_hiddens.view(batch_size*max_sents, max_words, -1), dim=1, index=top_word_indices.view(batch_size*max_sents, k).unsqueeze(-1).expand(-1, -1, all_word_hiddens.size(-1))).view(batch_size, max_sents*k, -1)
        filtered_words_proj = self.dropout(self.word_proj(filtered_words))
        cross_attn_output, cross_attn_weights = self.cross_attention(filtered_words_proj, sent_outputs, sent_mask)
        
        gate_weights = self.gate(torch.cat([sent_attn_output, cross_attn_output], dim=-1))
        doc_repr = gate_weights * cross_attn_output + (1 - gate_weights) * sent_attn_output
        logits = self.classifier(doc_repr)
        
        if return_attention:
            return logits, {'word_attention': all_word_attns, 'sentence_attention': sent_attn_weights, 'cross_attention': cross_attn_weights, 'top_word_indices': top_word_indices, 'gate_weights': gate_weights}
        return logits

In [None]:
WORD_HIDDEN_DIM, SENT_HIDDEN_DIM, ATTENTION_DIM = 256, 256, 128
NUM_CLASSES, NUM_HEADS, NUM_LSTM_LAYERS, DROPOUT, TOP_K_WORDS = 20, 8, 2, 0.4, 15

model = HierarchicalAttentionNetwork(embedding_matrix, WORD_HIDDEN_DIM, SENT_HIDDEN_DIM, ATTENTION_DIM, NUM_CLASSES, NUM_HEADS, NUM_LSTM_LAYERS, DROPOUT, TOP_K_WORDS).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")
print(model)

## 8. Training

In [None]:
class_weights = compute_class_weight('balanced', classes=np.unique(newsgroups_train.target), y=newsgroups_train.target)
class_weights_tensor = torch.FloatTensor(class_weights).to(device)

LEARNING_RATE, WEIGHT_DECAY, LABEL_SMOOTHING = 0.001, 1e-4, 0.1
NUM_EPOCHS, PATIENCE, GRAD_CLIP = 20, 5, 1.0

criterion = nn.CrossEntropyLoss(weight=class_weights_tensor, label_smoothing=LABEL_SMOOTHING)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-6)

def train_epoch(model, loader, criterion, optimizer, device, grad_clip):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for docs, labels, doc_lengths, sent_lengths, _ in tqdm(loader, desc="Training"):
        docs, labels = docs.to(device), labels.to(device)
        doc_lengths, sent_lengths = doc_lengths.to(device), sent_lengths.to(device)
        optimizer.zero_grad()
        logits = model(docs, doc_lengths, sent_lengths)
        loss = criterion(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()
        total_loss += loss.item()
        correct += (logits.argmax(dim=1) == labels).sum().item()
        total += labels.size(0)
    return total_loss / len(loader), correct / total

def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss, all_preds, all_labels, all_indices = 0, [], [], []
    with torch.no_grad():
        for docs, labels, doc_lengths, sent_lengths, indices in tqdm(loader, desc="Evaluating"):
            docs, labels = docs.to(device), labels.to(device)
            doc_lengths, sent_lengths = doc_lengths.to(device), sent_lengths.to(device)
            logits = model(docs, doc_lengths, sent_lengths)
            total_loss += criterion(logits, labels).item()
            all_preds.extend(logits.argmax(dim=1).cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_indices.extend(indices.cpu().numpy())
    all_preds, all_labels = np.array(all_preds), np.array(all_labels)
    return {'loss': total_loss/len(loader), 'accuracy': accuracy_score(all_labels, all_preds),
            'precision': precision_score(all_labels, all_preds, average='weighted', zero_division=0),
            'recall': recall_score(all_labels, all_preds, average='weighted', zero_division=0),
            'f1': f1_score(all_labels, all_preds, average='weighted', zero_division=0),
            'macro_f1': f1_score(all_labels, all_preds, average='macro', zero_division=0)}, all_preds, all_labels, all_indices

In [None]:
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'val_f1': [], 'lr': []}
best_val_f1, early_stop_counter = 0, 0

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device, GRAD_CLIP)
    val_metrics, _, _, _ = evaluate(model, val_loader, criterion, device)
    scheduler.step()
    
    history['train_loss'].append(train_loss); history['train_acc'].append(train_acc)
    history['val_loss'].append(val_metrics['loss']); history['val_acc'].append(val_metrics['accuracy'])
    history['val_f1'].append(val_metrics['f1']); history['lr'].append(optimizer.param_groups[0]['lr'])
    
    print(f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | Val Loss: {val_metrics['loss']:.4f}, Acc: {val_metrics['accuracy']:.4f}, F1: {val_metrics['f1']:.4f}")
    
    if val_metrics['f1'] > best_val_f1:
        best_val_f1, early_stop_counter = val_metrics['f1'], 0
        torch.save(model.state_dict(), 'best_model.pt')
        print("✓ Saved best model!")
    else:
        early_stop_counter += 1
        if early_stop_counter >= PATIENCE:
            print(f"Early stopping at epoch {epoch+1}")
            break

print(f"\nBest Validation F1: {best_val_f1:.4f}")

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes[0,0].plot(history['train_loss'], label='Train'); axes[0,0].plot(history['val_loss'], label='Val')
axes[0,0].set_title('Loss'); axes[0,0].legend()
axes[0,1].plot(history['train_acc'], label='Train'); axes[0,1].plot(history['val_acc'], label='Val')
axes[0,1].set_title('Accuracy'); axes[0,1].legend()
axes[1,0].plot(history['val_f1']); axes[1,0].set_title('Validation F1')
axes[1,1].plot(history['lr']); axes[1,1].set_title('Learning Rate')
plt.tight_layout()
plt.savefig('training_history.png', dpi=150)
plt.show()
pd.DataFrame(history).to_csv('training_history.csv', index=False)

## 9. Baseline Models

In [None]:
tfidf = TfidfVectorizer(max_features=10000, stop_words='english')
X_train_tfidf = tfidf.fit_transform(newsgroups_train.data)
X_test_tfidf = tfidf.transform(newsgroups_test.data)
y_train, y_test = newsgroups_train.target, newsgroups_test.target

baselines = {}
for name, clf in [('Logistic Regression', LogisticRegression(max_iter=1000, class_weight='balanced')),
                  ('Naive Bayes', MultinomialNB()),
                  ('Linear SVM', LinearSVC(max_iter=2000, class_weight='balanced'))]:
    print(f"Training {name}...")
    clf.fit(X_train_tfidf, y_train)
    preds = clf.predict(X_test_tfidf)
    baselines[name] = {'accuracy': accuracy_score(y_test, preds), 'precision': precision_score(y_test, preds, average='weighted'),
                       'recall': recall_score(y_test, preds, average='weighted'), 'f1': f1_score(y_test, preds, average='weighted'),
                       'macro_f1': f1_score(y_test, preds, average='macro')}
    print(f"  Accuracy: {baselines[name]['accuracy']:.4f}, F1: {baselines[name]['f1']:.4f}")

## 10. Test Evaluation & Comparison

In [None]:
model.load_state_dict(torch.load('best_model.pt'))
test_metrics, test_preds, test_labels, test_indices = evaluate(model, test_loader, criterion, device)
baselines['HAN (Ours)'] = {k: test_metrics[k] for k in ['accuracy', 'precision', 'recall', 'f1', 'macro_f1']}

print("\nTEST RESULTS\n" + "="*60)
print(f"Accuracy: {test_metrics['accuracy']:.4f}, Precision: {test_metrics['precision']:.4f}")
print(f"Recall: {test_metrics['recall']:.4f}, F1: {test_metrics['f1']:.4f}, Macro F1: {test_metrics['macro_f1']:.4f}")

comparison_df = pd.DataFrame([{'Model': k, **{m.capitalize(): f"{v[m]:.4f}" for m in ['accuracy','precision','recall','f1']}} for k,v in baselines.items()])
print("\nMODEL COMPARISON\n", comparison_df.to_string(index=False))
comparison_df.to_csv('model_comparison.csv', index=False)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
models = list(baselines.keys())
x = np.arange(len(models))
for i, (metric, color) in enumerate(zip(['accuracy','precision','recall','f1'], ['#2ecc71','#3498db','#9b59b6','#e74c3c'])):
    axes[0].bar(x + i*0.2, [baselines[m][metric] for m in models], 0.2, label=metric.capitalize(), color=color)
axes[0].set_xticks(x + 0.3); axes[0].set_xticklabels(models, rotation=15, ha='right')
axes[0].legend(); axes[0].set_title('Model Comparison')

colors = ['steelblue' if m != 'HAN (Ours)' else 'darkgreen' for m in models]
bars = axes[1].bar(models, [baselines[m]['f1'] for m in models], color=colors)
for bar, val in zip(bars, [baselines[m]['f1'] for m in models]):
    axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, f'{val:.3f}', ha='center')
axes[1].set_title('F1 Score Comparison')
plt.tight_layout()
plt.savefig('model_comparison.png', dpi=150)
plt.show()

## 11. Error Analysis (Deliverable iv)

In [None]:
cm = confusion_matrix(test_labels, test_preds)
plt.figure(figsize=(16, 14))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=TARGET_NAMES, yticklabels=TARGET_NAMES)
plt.xlabel('Predicted'); plt.ylabel('True'); plt.title('Confusion Matrix')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=150)
plt.show()

In [None]:
print("\nClassification Report:")
print(classification_report(test_labels, test_preds, target_names=TARGET_NAMES))

report_dict = classification_report(test_labels, test_preds, target_names=TARGET_NAMES, output_dict=True)
per_class_df = pd.DataFrame([{'ID': i, 'Category': name, 'Precision': report_dict[name]['precision'],
                              'Recall': report_dict[name]['recall'], 'F1-Score': report_dict[name]['f1-score'],
                              'Support': report_dict[name]['support']} for i, name in enumerate(TARGET_NAMES)])
per_class_df = per_class_df.sort_values('F1-Score')
per_class_df.to_csv('per_class_metrics.csv', index=False)
print("\nPer-Class Metrics (sorted by F1):")
print(per_class_df.to_string(index=False))

In [None]:
fig, ax = plt.subplots(figsize=(12, 8))
colors = plt.cm.RdYlGn(per_class_df['F1-Score'].values)
bars = ax.barh(per_class_df['Category'], per_class_df['F1-Score'], color=colors)
ax.axvline(x=test_metrics['f1'], color='red', linestyle='--', label=f"Overall F1: {test_metrics['f1']:.3f}")
ax.set_xlabel('F1 Score'); ax.set_title('Per-Class F1 Scores'); ax.legend()
for bar, val in zip(bars, per_class_df['F1-Score']):
    ax.text(val + 0.01, bar.get_y() + bar.get_height()/2, f'{val:.3f}', va='center')
plt.tight_layout()
plt.savefig('per_class_f1.png', dpi=150)
plt.show()

In [None]:
cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
confused_pairs = [{'True': TARGET_NAMES[i], 'Predicted': TARGET_NAMES[j], 'Count': cm[i,j], 'Pct': cm_norm[i,j]*100}
                  for i in range(20) for j in range(20) if i != j and cm[i,j] > 0]
confused_df = pd.DataFrame(confused_pairs).sort_values('Count', ascending=False)
print("\nTop 15 Misclassifications:")
print(confused_df.head(15).to_string(index=False))
confused_df.to_csv('confusion_analysis.csv', index=False)

## 12. Attention Analysis (Deliverable iii & v)

In [None]:
attention_stats = {'word_entropy': [], 'word_max': [], 'sent_entropy': [], 'sent_max': [], 'gate_mean': [], 'correct': [], 'true_label': [], 'pred_label': []}
model.eval()
with torch.no_grad():
    for docs, labels, doc_lengths, sent_lengths, _ in tqdm(test_loader, desc="Collecting attention"):
        docs, doc_lengths, sent_lengths = docs.to(device), doc_lengths.to(device), sent_lengths.to(device)
        logits, attn = model(docs, doc_lengths, sent_lengths, return_attention=True)
        preds = logits.argmax(dim=1)
        word_attn, sent_attn, gate = attn['word_attention'].cpu().numpy(), attn['sentence_attention'].cpu().numpy(), attn['gate_weights'].cpu().numpy()
        for i in range(docs.size(0)):
            attention_stats['word_entropy'].append(np.mean(-np.sum(word_attn[i] * np.log(word_attn[i] + 1e-10), axis=-1)))
            attention_stats['word_max'].append(np.max(word_attn[i]))
            attention_stats['sent_entropy'].append(-np.sum(sent_attn[i] * np.log(sent_attn[i] + 1e-10)))
            attention_stats['sent_max'].append(np.max(sent_attn[i]))
            attention_stats['gate_mean'].append(np.mean(gate[i]))
            attention_stats['correct'].append(int(preds[i].item() == labels[i].item()))
            attention_stats['true_label'].append(labels[i].item())
            attention_stats['pred_label'].append(preds[i].item())

attention_df = pd.DataFrame(attention_stats)
attention_df.to_csv('attention_statistics.csv', index=False)

In [None]:
correct_df = attention_df[attention_df['correct'] == 1]
incorrect_df = attention_df[attention_df['correct'] == 0]

print("\nATTENTION INFLUENCE ANALYSIS (Deliverable v)\n" + "="*60)
print(f"Correct: {len(correct_df)}, Incorrect: {len(incorrect_df)}")
for metric in ['word_entropy', 'word_max', 'sent_entropy', 'gate_mean']:
    print(f"\n{metric}:")
    print(f"  Correct: {correct_df[metric].mean():.4f} ± {correct_df[metric].std():.4f}")
    print(f"  Incorrect: {incorrect_df[metric].mean():.4f} ± {incorrect_df[metric].std():.4f}")

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
for ax, metric, title in zip(axes.flat, ['word_entropy', 'sent_entropy', 'gate_mean', 'word_max'],
                              ['Word Attention Entropy', 'Sentence Attention Entropy', 'Gate Weight', 'Max Word Attention']):
    ax.hist(correct_df[metric], bins=30, alpha=0.7, label='Correct', color='green')
    ax.hist(incorrect_df[metric], bins=30, alpha=0.7, label='Incorrect', color='red')
    ax.set_xlabel(title); ax.set_ylabel('Frequency'); ax.legend()
plt.tight_layout()
plt.savefig('attention_statistics.png', dpi=150)
plt.show()

In [None]:
def get_attention_for_sample(model, doc, label, preprocessor, device):
    model.eval()
    encoded_doc = preprocessor.encode_document(doc)
    doc_length, sent_lengths = len(encoded_doc), [len(s) for s in encoded_doc]
    max_sent_len = max(sent_lengths)
    padded_doc = [s + [0]*(max_sent_len-len(s)) for s in encoded_doc]
    with torch.no_grad():
        logits, attn = model(torch.LongTensor([padded_doc]).to(device), torch.LongTensor([doc_length]).to(device),
                             torch.LongTensor([sent_lengths]).to(device), return_attention=True)
    return {'tokens': preprocessor.tokenize_document(doc), 'word_attention': attn['word_attention'][0].cpu().numpy(),
            'sentence_attention': attn['sentence_attention'][0].cpu().numpy(), 'prediction': logits.argmax(dim=1).item(),
            'true_label': label, 'correct': logits.argmax(dim=1).item() == label}

for sample_type, idx in [('CORRECT', np.where(test_preds == test_labels)[0][0]), ('INCORRECT', np.where(test_preds != test_labels)[0][0])]:
    attn_data = get_attention_for_sample(model, newsgroups_test.data[idx], newsgroups_test.target[idx], preprocessor, device)
    print(f"\n{sample_type}: True={TARGET_NAMES[attn_data['true_label']]}, Pred={TARGET_NAMES[attn_data['prediction']]}")
    for si, (toks, wa) in enumerate(zip(attn_data['tokens'][:3], attn_data['word_attention'][:3])):
        if toks:
            top_idx = np.argsort(wa[:len(toks)])[-3:][::-1]
            print(f"  Sent {si}: {', '.join([f'{toks[i]}({wa[i]:.3f})' for i in top_idx if i < len(toks)])}")

## 13. Failure Mode Analysis (Deliverable vii)

In [None]:
incorrect_idx = np.where(test_preds != test_labels)[0]
failure_data = []
for idx in tqdm(incorrect_idx[:100], desc="Analyzing failures"):
    attn_data = get_attention_for_sample(model, newsgroups_test.data[idx], newsgroups_test.target[idx], preprocessor, device)
    wa, sa = attn_data['word_attention'], attn_data['sentence_attention']
    failure_data.append({'idx': idx, 'true_class': TARGET_NAMES[attn_data['true_label']], 'pred_class': TARGET_NAMES[attn_data['prediction']],
                         'word_max_mean': np.mean(np.max(wa, axis=-1)), 'word_entropy': np.mean(-np.sum(wa * np.log(wa + 1e-10), axis=-1)),
                         'sent_max': np.max(sa), 'sent_entropy': -np.sum(sa * np.log(sa + 1e-10)), 'num_sents': len(attn_data['tokens'])})
failure_df = pd.DataFrame(failure_data)
failure_df.to_csv('failure_analysis.csv', index=False)

In [None]:
print("\nFAILURE MODES IDENTIFIED\n" + "="*60)
modes = [
    ("1. WORD OVER-CONCENTRATION", failure_df['word_max_mean'] > 0.5, "Attention focuses on single words, losing context"),
    ("2. DIFFUSE SENTENCE ATTENTION", failure_df['sent_entropy'] > failure_df['sent_entropy'].median(), "Cross-attention spreads too evenly"),
    ("3. SHORT DOCUMENTS", failure_df['num_sents'] < 3, "Insufficient hierarchical context"),
    ("4. HIGH WORD ENTROPY", failure_df['word_entropy'] > failure_df['word_entropy'].median(), "Word attention too uncertain")
]
for name, mask, desc in modes:
    count = mask.sum()
    print(f"\n{name}: {count} cases ({count/len(failure_df)*100:.1f}%)")
    print(f"   {desc}")

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
for ax, col, title, thresh in zip(axes.flat, ['word_max_mean', 'sent_entropy', 'num_sents', 'word_entropy'],
                                   ['Word Over-Concentration', 'Sentence Entropy', 'Document Length', 'Word Entropy'],
                                   [0.5, failure_df['sent_entropy'].median(), 3, failure_df['word_entropy'].median()]):
    ax.hist(failure_df[col], bins=20, color='red', alpha=0.7, edgecolor='black')
    ax.axvline(x=thresh, color='black', linestyle='--', label=f'Threshold: {thresh:.2f}')
    ax.set_xlabel(title); ax.set_ylabel('Frequency'); ax.legend()
plt.suptitle('Failure Mode Analysis', fontsize=14)
plt.tight_layout()
plt.savefig('failure_modes.png', dpi=150)
plt.show()

In [None]:
failure_report = """
FAILURE MODE ANALYSIS REPORT
============================

FAILURE MODE 1: Word-Level Over-Concentration
- Problem: Attention focuses too heavily on single words
- Impact: Important contextual words filtered out before cross-attention
- Fix: Temperature scaling, attention smoothing, increase top_k

FAILURE MODE 2: Diffuse Sentence Attention
- Problem: Cross-attention spreads weight too evenly
- Impact: Model fails to identify relevant sentences
- Fix: Sentence positional encoding, sparse attention

FAILURE MODE 3: Information Loss at Word Filtering
- Problem: Important words filtered at word-level stage
- Impact: Cannot contribute to cross-attention queries
- Fix: Soft filtering, residual connections, auxiliary supervision

FAILURE MODE 4: Encoder Representation Collapse
- Problem: Similar representations for different words
- Impact: Attention selection becomes arbitrary
- Fix: Pre-trained BERT, contrastive loss, orthogonality regularization
"""
print(failure_report)
with open('failure_mode_report.txt', 'w') as f:
    f.write(failure_report)

## 14. Save Summary

In [None]:
summary = {'experiment': 'HAN v3', 'train_samples': len(newsgroups_train.data), 'test_samples': len(newsgroups_test.data),
           'vocab_size': preprocessor.vocab_size, 'embed_dim': EMBED_DIM, 'word_hidden_dim': WORD_HIDDEN_DIM,
           'sent_hidden_dim': SENT_HIDDEN_DIM, 'num_heads': NUM_HEADS, 'dropout': DROPOUT, 'top_k_words': TOP_K_WORDS,
           'epochs_trained': len(history['train_loss']), 'best_val_f1': best_val_f1,
           'test_accuracy': test_metrics['accuracy'], 'test_f1': test_metrics['f1'], 'test_macro_f1': test_metrics['macro_f1'],
           'total_params': total_params}

pd.DataFrame([summary]).to_csv('experiment_summary.csv', index=False)
with open('experiment_summary.json', 'w') as f: json.dump(summary, f, indent=2)

print("\nEXPERIMENT SUMMARY\n" + "="*60)
for k, v in summary.items(): print(f"  {k}: {v}")

print("\n\nSAVED FILES:")
for f in ['best_model.pt', 'class_distribution.csv', 'class_distribution.png', 'training_history.csv', 'training_history.png',
          'model_comparison.csv', 'model_comparison.png', 'confusion_matrix.png', 'per_class_metrics.csv', 'per_class_f1.png',
          'confusion_analysis.csv', 'attention_statistics.csv', 'attention_statistics.png', 'failure_analysis.csv',
          'failure_modes.png', 'failure_mode_report.txt', 'experiment_summary.csv', 'experiment_summary.json']:
    print(f"  - {f}")

In [None]:
# Optional: Save to Google Drive
try:
    from google.colab import drive
    drive.mount('/content/drive')
    import shutil
    dest = '/content/drive/MyDrive/20_Newsgroups_HAN_v3'
    os.makedirs(dest, exist_ok=True)
    for f in os.listdir('.'):
        if f.endswith(('.pt', '.csv', '.png', '.json', '.txt')):
            shutil.copy(f, dest)
    print(f"Files saved to {dest}")
except: print("Not in Colab - files saved locally")