<a href="https://colab.research.google.com/github/Saranyadharani/Text-Classification-Pipeline/blob/main/Text_Classification_Pipeline_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# COMPLETE WORKING TEXT CLASSIFICATION PIPELINE
# Run this in ONE cell in Google Colab

# 1. Import everything
import os
os.environ["WANDB_DISABLED"] = "true"

from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import torch
import numpy as np
from sklearn.metrics import accuracy_score
import pandas as pd
from datasets import Dataset
import warnings
warnings.filterwarnings('ignore')

# 2. Load alternative dataset (AG News - publicly available)
print("Loading AG News dataset...")

# Download AG News dataset from public source
try:
    # Method 1: Direct download
    train_url = "https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/train.csv"
    test_url = "https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/test.csv"

    print("Downloading train data...")
    train_df = pd.read_csv(train_url, header=None, names=['class', 'title', 'description'])
    print("Downloading test data...")
    test_df = pd.read_csv(test_url, header=None, names=['class', 'title', 'description'])

    # Combine train and test
    df = pd.concat([train_df, test_df], ignore_index=True)

    # Combine title and description
    texts = (df['title'] + " " + df['description']).tolist()
    labels = (df['class'] - 1).tolist()  # Convert 1-4 to 0-3

    # AG News labels
    label_names = ['World', 'Sports', 'Business', 'Sci/Tech']

    print(f"‚úÖ Loaded AG News dataset: {len(texts)} samples")

except Exception as e:
    print(f"Download failed: {e}")
    print("Creating synthetic dataset for demonstration...")

    # Create synthetic dataset
    synthetic_texts = [
        "The stock market reached record highs today as tech companies reported strong earnings.",
        "Scientists discovered a new species of deep-sea fish in the Pacific Ocean.",
        "The football team won the championship after a thrilling final match.",
        "New government policies aim to reduce carbon emissions by 2030.",
        "Apple announced its latest iPhone with improved camera technology.",
        "Basketball players competed in the international tournament finals.",
        "Economic growth slowed due to rising inflation rates.",
        "Researchers developed a new AI model that can understand complex language.",
        "The baseball season starts next week with several key matches.",
        "Climate change conference concluded with new agreements on renewable energy.",
        "Microsoft released new software for enterprise customers.",
        "Olympic athletes prepare for the upcoming winter games.",
        "Banking sector faces challenges with digital currency adoption.",
        "SpaceX launched another satellite into orbit successfully.",
        "Tennis championship attracted thousands of spectators worldwide.",
        "Global trade negotiations continue amid geopolitical tensions.",
        "New medical study reveals benefits of exercise for mental health.",
        "Soccer league introduced new rules for player safety.",
        "Cryptocurrency market experiences volatility amid regulatory changes.",
        "Robot assisted surgery shows promising results in clinical trials.",
        "Government announces new tax reforms for small businesses",
        "Football star breaks scoring record in championship game",
        "Breakthrough in quantum computing achieved by research team",
        "Stock markets show mixed results amid economic uncertainty",
        "New species of butterfly discovered in Amazon rainforest",
        "NBA playoffs begin with exciting matchups",
        "Tech company unveils revolutionary new smartphone",
        "International summit addresses climate change concerns",
        "Hockey team clinches playoff spot with overtime victory",
        "Artificial intelligence helps diagnose rare diseases"
    ]

    synthetic_labels = [2, 3, 1, 0, 3, 1, 2, 3, 1, 0, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 2, 1, 3, 2, 3, 1, 3, 0, 1, 3]
    label_names = ['World', 'Sports', 'Business', 'Sci/Tech']

    texts = synthetic_texts
    labels = synthetic_labels
    print(f"‚úÖ Created synthetic dataset: {len(texts)} samples")

print(f"Number of classes: {len(label_names)}")
print(f"Sample: '{texts[0][:80]}...' -> {label_names[labels[0]]}")
print(f"Label distribution: {np.bincount(labels)}")

# 3. Split data
train_texts, test_texts, train_labels, test_labels = train_test_split(
    texts, labels, test_size=0.2, random_state=42, stratify=labels)

print(f"\nTrain: {len(train_texts)}, Test: {len(test_texts)}")

# 4. Tokenization
print("\nLoading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, padding=True, max_length=128)

# Create Hugging Face datasets
train_dataset = Dataset.from_dict({"text": train_texts, "label": train_labels})
test_dataset = Dataset.from_dict({"text": test_texts, "label": test_labels})

# Tokenize datasets
print("Tokenizing datasets...")
train_dataset = train_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)

# Set format for PyTorch
train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

# 5. Load Model
print("\nLoading model...")
model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert-base-uncased",
    num_labels=len(label_names)
)

# 6. Training Setup
def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    accuracy = accuracy_score(p.label_ids, preds)
    return {"accuracy": accuracy}

training_args = TrainingArguments(
    output_dir="./results",
    do_train=True,
    do_eval=True,
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=4,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    save_strategy="epoch",
    eval_strategy="epoch",
    load_best_model_at_end=True,
    report_to="none"
)

# 7. Create Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer
)

# 8. Train
print("\n" + "="*50)
print("Training model...")
print("="*50)
trainer.train()

# 9. Evaluate
print("\n" + "="*50)
print("Evaluation Results")
print("="*50)
eval_results = trainer.evaluate()
print(f"Test Accuracy: {eval_results['eval_accuracy']:.4f}")
print(f"Test Loss: {eval_results['eval_loss']:.4f}")

# 10. Predict
print("\n" + "="*50)
print("Making Predictions")
print("="*50)

# Move model to evaluation mode
model.eval()

# Define test samples
test_samples = [
    "The government passed a new law affecting international trade.",
    "The basketball team won their match with an incredible last-minute shot.",
    "New AI technology breakthrough in natural language processing.",
    "Stock market shows significant growth in tech sector.",
    "Scientists discovered a new species in the Amazon rainforest.",
    "The soccer match ended in a draw after extra time.",
    "Tech company announces new electric vehicle with 500 mile range.",
    "Federal Reserve raises interest rates to combat inflation."
]

# Make predictions for each sample
for i, text in enumerate(test_samples):
    # Tokenize
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)

    # Get device from model and move inputs to same device
    device = model.device
    inputs = {key: val.to(device) for key, val in inputs.items()}

    # Make prediction
    with torch.no_grad():
        outputs = model(**inputs)
        probabilities = torch.softmax(outputs.logits, dim=1)
        predicted_class = outputs.logits.argmax().item()
        confidence = probabilities[0][predicted_class].item()

    print(f"\nüìù Sample {i+1}:")
    print(f"   Text: {text}")
    print(f"   ‚Üí Predicted: {label_names[predicted_class]}")
    print(f"   ‚Üí Confidence: {confidence:.2%}")

# 11. Detailed analysis for first sample
print("\n" + "="*50)
print("Detailed Analysis for First Sample")
print("="*50)

text = test_samples[0]
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
inputs = {key: val.to(model.device) for key, val in inputs.items()}

with torch.no_grad():
    outputs = model(**inputs)
    probabilities = torch.softmax(outputs.logits, dim=1)[0]

# Get top predictions
top_k = min(3, len(label_names))
top_probs, top_indices = torch.topk(probabilities, top_k)

print(f"\nText: '{text}'")
print("\nTop predictions:")
for i in range(top_k):
    idx = top_indices[i].item()
    prob = top_probs[i].item()
    print(f"  {i+1}. {label_names[idx]}: {prob:.2%}")

print("\n" + "="*50)
print("‚úÖ TRAINING COMPLETE!")
print(f"Model can classify text into {len(label_names)} categories:")
for i, name in enumerate(label_names):
    print(f"  {i+1}. {name}")
print("="*50)

Loading AG News dataset...
Downloading train data...
Downloading test data...
‚úÖ Loaded AG News dataset: 127600 samples
Number of classes: 4
Sample: 'Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall ...' -> Business
Label distribution: [31900 31900 31900 31900]

Train: 102080, Test: 25520

Loading tokenizer...


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Tokenizing datasets...


Map:   0%|          | 0/102080 [00:00<?, ? examples/s]

Map:   0%|          | 0/25520 [00:00<?, ? examples/s]


Loading model...


model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Training model...


Epoch,Training Loss,Validation Loss,Accuracy
1,0.2369,0.201815,0.938911
2,0.2605,0.222163,0.946748


Epoch,Training Loss,Validation Loss,Accuracy
1,0.2369,0.201815,0.938911
2,0.2605,0.222163,0.946748
3,0.0891,0.24343,0.947923
4,0.1453,0.30156,0.947414



Evaluation Results


Test Accuracy: 0.9389
Test Loss: 0.2018

Making Predictions

üìù Sample 1:
   Text: The government passed a new law affecting international trade.
   ‚Üí Predicted: World
   ‚Üí Confidence: 80.61%

üìù Sample 2:
   Text: The basketball team won their match with an incredible last-minute shot.
   ‚Üí Predicted: Sports
   ‚Üí Confidence: 98.60%

üìù Sample 3:
   Text: New AI technology breakthrough in natural language processing.
   ‚Üí Predicted: Sci/Tech
   ‚Üí Confidence: 96.52%

üìù Sample 4:
   Text: Stock market shows significant growth in tech sector.
   ‚Üí Predicted: Sci/Tech
   ‚Üí Confidence: 62.70%

üìù Sample 5:
   Text: Scientists discovered a new species in the Amazon rainforest.
   ‚Üí Predicted: Sci/Tech
   ‚Üí Confidence: 96.49%

üìù Sample 6:
   Text: The soccer match ended in a draw after extra time.
   ‚Üí Predicted: Sports
   ‚Üí Confidence: 96.75%

üìù Sample 7:
   Text: Tech company announces new electric vehicle with 500 mile range.
   ‚Üí Predicted: Sci/T