In [None]:
!pip install transformers datasets torch scikit-learn gradio accelerate

import torch
import numpy as np
import pandas as pd
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    pipeline
)
from sklearn.metrics import accuracy_score, f1_score, classification_report
import gradio as gr
from torch.utils.data import Dataset
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

print("Loading AG News dataset...")
start_time = time.time()
dataset = load_dataset("ag_news")
print(f"Dataset loaded in {time.time() - start_time:.2f} seconds")


id2label = {0: "World", 1: "Sports", 2: "Business", 3: "Technology"}
label2id = {"World": 0, "Sports": 1, "Business": 2, "Technology": 3}

print(f"Full dataset - Train: {len(dataset['train'])}, Test: {len(dataset['test'])}")
print(f"Classes: {list(id2label.values())}")


def create_quick_subset(dataset_split, samples_per_class=300):
    """Create balanced subset with limited samples per class"""
    df = pd.DataFrame(dataset_split)

  
    subset_data = []
    for label in range(4):
        class_data = df[df['label'] == label].sample(n=min(samples_per_class, len(df[df['label'] == label])))
        subset_data.append(class_data)

    result = pd.concat(subset_data, ignore_index=True)
    result = result.sample(frac=1).reset_index(drop=True)  # Shuffle
    return result

print("Creating quick training subset...")
train_df = create_quick_subset(dataset['train'], samples_per_class=300)  
test_df = create_quick_subset(dataset['test'], samples_per_class=75)     

print(f"Quick dataset - Train: {len(train_df)}, Test: {len(test_df)}")
print("Class distribution in training set:")
print(train_df['label'].value_counts().sort_index())


model_name = "distilbert-base-uncased"  
tokenizer = AutoTokenizer.from_pretrained(model_name)


class QuickNewsDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128): 
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]

        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }


train_dataset = QuickNewsDataset(train_df['text'].tolist(), train_df['label'].tolist(), tokenizer)
test_dataset = QuickNewsDataset(test_df['text'].tolist(), test_df['label'].tolist(), tokenizer)


print("Loading DistilBERT model...")
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=4,
    id2label=id2label,
    label2id=label2id
)
model.to(device)


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)

    accuracy = accuracy_score(labels, predictions)
    f1 = f1_score(labels, predictions, average='weighted')

    return {'accuracy': accuracy, 'f1': f1}


training_args = TrainingArguments(
    output_dir='./quick_results',
    num_train_epochs=2,              
    per_device_train_batch_size=32,  
    per_device_eval_batch_size=64,
    warmup_steps=50,                 
    weight_decay=0.01,
    learning_rate=5e-5,              
    logging_dir='./quick_logs',
    logging_steps=20,                
    eval_strategy="steps",
    eval_steps=100,                  
    save_strategy="no",              
    load_best_model_at_end=False,    
    dataloader_num_workers=2,
    fp16=True,                       
    gradient_checkpointing=True,     
    max_steps=150,                   
    disable_tqdm=False,              
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
)

print("\n" + "="*50)
print("Starting QUICK training (should complete in 10-15 minutes)...")
print("="*50)

training_start = time.time()
trainer.train()
training_time = time.time() - training_start

print(f"\nTraining completed in {training_time/60:.1f} minutes!")


print("\nQuick evaluation...")
eval_start = time.time()
results = trainer.evaluate()
eval_time = time.time() - eval_start

print(f"Evaluation completed in {eval_time:.1f} seconds")
print("\nQuick Results:")
for key, value in results.items():
    print(f"{key}: {value:.4f}")


predictions = trainer.predict(test_dataset)
y_pred = np.argmax(predictions.predictions, axis=1)
y_true = test_df['label'].tolist()

print("\nQuick Classification Report:")
print(classification_report(y_true, y_pred, target_names=list(id2label.values())))


classifier = pipeline(
    "text-classification",
    model=model,
    tokenizer=tokenizer,
    device=0 if torch.cuda.is_available() else -1
)

test_headlines = [
    "Apple releases new iPhone with revolutionary camera technology",
    "Stock markets surge following positive economic indicators",
    "World Cup final breaks viewership records globally",
    "Scientists discover breakthrough in quantum computing research",
    "New trade agreement signed between major economies"
]

print("\nQuick Sample Predictions:")
for headline in test_headlines:
    result = classifier(headline)
    print(f"Text: {headline}")
    print(f"Predicted: {result[0]['label']} (confidence: {result[0]['score']:.3f})\n")

def quick_predict(text):
    if not text.strip():
        return "Please enter a news headline"

    try:
        result = classifier(text)
        label = result[0]['label']
        confidence = result[0]['score']
        return f" Category: {label}\n Confidence: {confidence:.3f}"
    except Exception as e:
        return f"Error: {str(e)}"

def create_quick_interface():
    interface = gr.Interface(
        fn=quick_predict,
        inputs=gr.Textbox(
            lines=2,
            placeholder="Enter news headline (e.g., 'Apple announces new product launch')",
            label="News Headline"
        ),
        outputs=gr.Textbox(label="Quick Prediction"),
        title="Quick News Topic Classifier",
        description="Fast BERT-based classifier trained on AG News dataset\n World |  Sports |  Business | Technology",
        examples=[
            "Tesla stock rises after quarterly earnings beat expectations",
            "Champions League final set for this weekend",
            "New AI breakthrough announced by Google researchers",
            "UN climate summit addresses global warming concerns",
            "Cryptocurrency prices fluctuate amid market uncertainty"
        ],
        theme=gr.themes.Soft()
    )
    return interface

total_time = time.time() - start_time
print("\n" + "="*60)
print("QUICK TRAINING SUMMARY")
print("="*60)
print(f"  Total Time: {total_time/60:.1f} minutes")
print(f" Training Time: {training_time/60:.1f} minutes")
print(f" Evaluation Time: {eval_time:.1f} seconds")
print(f" Test Accuracy: {results['eval_accuracy']:.3f}")
print(f" Test F1-Score: {results['eval_f1']:.3f}")
print(f" Training Samples: {len(train_dataset)}")
print(f" Test Samples: {len(test_dataset)}")
print(f" Model: {model_name}")
print(f" Parameters: ~66M (DistilBERT)")
print("="*60)

print("\n Launching Quick Gradio Interface...")
quick_interface = create_quick_interface()

quick_interface.launch(
    share=True,
    debug=True,
    server_name="0.0.0.0",
    show_error=True
)

print("\nQuick BERT News Classifier is ready!")
print(" Use the Gradio interface above to test your model")
print("Training completed in under 20 minutes!")


save_model = input("\nSave the quick model? (y/n): ").lower().strip()
if save_model == 'y':
    model_path = "./quick-bert-classifier"
    model.save_pretrained(model_path)
    tokenizer.save_pretrained(model_path)
    print(f"Model saved to {model_path}")