In [None]:
!pip install torch torchvision transformers datasets matplotlib pandas scikit-learn wordcloud
!pip install --upgrade --force-reinstall fsspec datasets

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertForSequenceClassification
from torch.optim import AdamW
from datasets import load_dataset
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay, f1_score
from wordcloud import WordCloud
from sklearn.model_selection import train_test_split

plt.style.use('fivethirtyeight')

In [None]:
# --- LOAD YAHOO ANSWERS TOPICS DATASET ---
topics_data = load_dataset("yahoo_answers_topics")
answers_df = pd.concat([pd.DataFrame(topics_data['train']), pd.DataFrame(topics_data['test'])], ignore_index=True)
answers_df = answers_df.sample(frac=1, random_state=123).reset_index(drop=True)

In [None]:
# --- SUBSAMPLE FOR DEMO ---
df_main, df_eval = train_test_split(
    answers_df, train_size=6000, test_size=1000, stratify=answers_df['topic'], random_state=123
)

In [None]:
topic_names = [
    "Society & Culture", "Science & Mathematics", "Health", "Education & Reference",
    "Computers & Internet", "Sports", "Business & Finance", "Entertainment & Music",
    "Family & Relationships", "Politics & Government"
]

In [None]:
# --- LABEL DISTRIBUTION ---
plt.figure(figsize=(10,3.2))
plt.bar(topic_names, df_main['topic'].value_counts().sort_index(), color=[
    '#5f0a87', '#a4508b', '#f7971e', '#ffd452', '#e84a5f', '#2a363b', '#ffb300', '#457fca', '#6a9113', '#373b44'
])
plt.title('Yahoo Answers Topic Distribution', fontsize=13, color='#5f0a87')
plt.ylabel('Number of Samples')
plt.xticks(rotation=32, fontsize=10)
plt.tight_layout()
plt.show()

In [None]:
# --- TEXT LENGTH HIST ---
df_main['q_len'] = df_main['question_content'].apply(lambda x: len(str(x).split()))
plt.figure(figsize=(6,3.2))
plt.hist(df_main['q_len'], bins=40, color='#a4508b', edgecolor='#fff', alpha=0.88)
plt.title('Question Lengths (Words)', fontsize=11, color='#ffb300')
plt.xlabel('Words')
plt.ylabel('Frequency')
plt.tight_layout()
plt.show()

In [None]:
# --- WORD CLOUDS FOR TOP 3 CLASSES ---
for idx in [0, 1, 2]:
    q_text = " ".join(df_main[df_main['topic'] == idx]['question_content'])
    wc = WordCloud(width=700, height=250, background_color='white', colormap='Spectral').generate(q_text)
    plt.figure(figsize=(7,2.7))
    plt.imshow(wc, interpolation='bilinear')
    plt.axis('off')
    plt.title(f"Word Cloud: {topic_names[idx]}", fontsize=12, color='#a4508b')
    plt.tight_layout()
    plt.show()

In [None]:
# --- BERT TOKENIZER, DATASET, AND DATALOADER ---
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
maxlen = 48

class YahooTopicDataset(Dataset):
    def __init__(self, question_list, label_list, tokenizer, maxlen):
        self.question_list = question_list
        self.label_list = label_list
        self.tokenizer = tokenizer
        self.maxlen = maxlen
    def __len__(self): return len(self.question_list)
    def __getitem__(self, idx):
        enc = self.tokenizer(
            str(self.question_list[idx]),
            truncation=True,
            padding='max_length',
            max_length=self.maxlen,
            return_tensors='pt'
        )
        item = {k: v.squeeze(0) for k, v in enc.items()}
        item['labels'] = torch.tensor(self.label_list[idx], dtype=torch.long)
        return item

train_data = YahooTopicDataset(
    df_main['question_content'].tolist(), df_main['topic'].tolist(), tokenizer, maxlen)
test_data = YahooTopicDataset(
    df_eval['question_content'].tolist(), df_eval['topic'].tolist(), tokenizer, maxlen)

train_loader = DataLoader(train_data, batch_size=22, shuffle=True)
test_loader = DataLoader(test_data, batch_size=22)


In [None]:
# --- MODEL SETUP ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
topic_model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=10)
topic_model = topic_model.to(device)
optimizer = AdamW(topic_model.parameters(), lr=2e-5)

In [None]:
# --- TRAINING LOOP ---
epochs = 20
loss_vals = []
topic_model.train()
for epoch in range(epochs):
    total_loss = 0
    for batch in train_loader:
        ids = batch['input_ids'].to(device)
        mask = batch['attention_mask'].to(device)
        labs = batch['labels'].to(device)
        optimizer.zero_grad()
        outs = topic_model(input_ids=ids, attention_mask=mask, labels=labs)
        loss = outs.loss
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} | Loss: {avg:.4f}")
    loss_vals.append(avg)

plt.figure(figsize=(6,3.5))
plt.plot(range(1, epochs+1), loss_vals, marker='^', color='#ffb300', linewidth=2)
plt.title('Training Loss (Yahoo Answers)', fontsize=12, color='#a4508b')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.tight_layout()
plt.show()

In [None]:
# --- EVALUATE ---
topic_model.eval()
preds, trues = [], []
with torch.no_grad():
    for batch in test_loader:
        ids = batch['input_ids'].to(device)
        mask = batch['attention_mask'].to(device)
        labs = batch['labels'].to(device)
        outs = topic_model(input_ids=ids, attention_mask=mask)
        pred = torch.argmax(outs.logits, dim=1)
        preds.extend(pred.cpu().numpy())
        trues.extend(labs.cpu().numpy())
acc = accuracy_score(trues, preds)
f1 = f1_score(trues, preds, average='weighted')
print(f"\nTest Accuracy: {acc:.4f}")
print(f"Test F1 Score: {f1:.4f}")

cm = confusion_matrix(trues, preds)
fig, ax = plt.subplots(figsize=(9, 8))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=topic_names)

# Custom colormap (any matplotlib cmap, or try 'PuRd', 'plasma', 'Blues', etc.)
disp.plot(cmap='PuRd', ax=ax, colorbar=True, values_format='d')

# Make tick labels readable
plt.xticks(rotation=45, ha='right', fontsize=12)
plt.yticks(fontsize=12)

# Add a grid, custom title, and tighten layout
plt.grid(False)
plt.title('Yahoo Answers Topic Confusion Matrix', fontsize=16, color='#5f0a87', pad=20)
plt.xlabel('Predicted label', fontsize=14)
plt.ylabel('True label', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# --- SAMPLE PREDICTIONS ---
print("\nSample predictions:\n")
for i in range(5):
    print(f"Question: {df_eval['question_content'].tolist()[i][:100]}...")
    print(f"True: {topic_names[df_eval['topic'].tolist()[i]]} | Predicted: {topic_names[preds[i]]}\n")