In [None]:
!pip install -q torch transformers pandas scikit-learn tqdm matplotlib seaborn

import os, torch, pandas as pd
from transformers import BertTokenizer, BertForSequenceClassification
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from sklearn.datasets import fetch_20newsgroups
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# 1. Load dataset
categories = ['sci.med', 'sci.space']
train = fetch_20newsgroups(subset='train', categories=categories)
test  = fetch_20newsgroups(subset='test',  categories=categories)

train_df = pd.DataFrame({'text': train.data, 'label': train.target}).sample(200, random_state=42)
test_df  = pd.DataFrame({'text': test.data,  'label': test.target}).sample(50,  random_state=42)

# 2. Initialize tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model     = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
print(f"✅ Using device: {device}")

# 3. Tokenize
def tokenize(texts):
    return tokenizer(texts, padding=True, truncation=True, max_length=128, return_tensors="pt")

train_enc = tokenize(train_df['text'].tolist())
test_enc  = tokenize(test_df['text'].tolist())

# 4. Dataset
class NewsDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item
    def __len__(self):
        return len(self.labels)

train_dataset = NewsDataset(train_enc, train_df['label'].values)
test_dataset  = NewsDataset(test_enc,  test_df['label'].values)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=8)

# 5. Train with loss tracking
optimizer = AdamW(model.parameters(), lr=2e-5)
train_losses = []
model.train()
for batch in tqdm(train_loader, desc="📦 Training"):
    batch = {k: v.to(device) for k, v in batch.items()}
    optimizer.zero_grad()
    outputs = model(**batch)
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    train_losses.append(loss.item())

# 6. Evaluate
model.eval()
all_preds, all_labels = [], []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="📊 Evaluating"):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        preds = outputs.logits.argmax(dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(batch['labels'].cpu().numpy())

accuracy = (torch.tensor(all_preds) == torch.tensor(all_labels)).float().mean().item()
print(f"\n✅ Test Accuracy: {accuracy:.2%}")

# 7. Classification Report
print("\n📈 Classification Report:")
print(classification_report(all_labels, all_preds, target_names=["Medical", "Science"]))

# 8. Confusion Matrix
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=["Medical", "Science"], yticklabels=["Medical", "Science"])
plt.title("Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.tight_layout()
plt.savefig("conf_matrix.png")
plt.show()

# 9. Loss Plot
plt.figure(figsize=(8, 4))
plt.plot(train_losses, label='Training Loss')
plt.title("Training Loss Over Batches")
plt.xlabel("Batch")
plt.ylabel("Loss")
plt.legend()
plt.tight_layout()
plt.savefig("training_loss.png")
plt.show()

# 10. Sample Predictions
samples = [
    "This new study on protein folding is groundbreaking in medicine.",
    "NASA's telescope revealed more about distant galaxies.",
    "A breakthrough in cancer treatment was announced yesterday.",
    "The astronauts are preparing for their mission to Mars.",
]

print("\n🔍 Sample Predictions:")
model.eval()
for i, text in enumerate(samples):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
    with torch.no_grad():
        logits = model(**inputs).logits
        pred = logits.argmax(dim=1).item()
    print(f"Sample {i+1}: \"{text[:60]}...\" → Prediction: {'Science' if pred == 1 else 'Medical'}")

# 11. Save model
torch.save(model.state_dict(), "bert_science_medical.pth")
print("\n💾 Model saved as 'bert_science_medical.pth'")

# 12. Show trainable parameter count
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"🧠 Total Trainable Parameters: {total_params:,}")
