In [1]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BartTokenizer, BartForConditionalGeneration
from torch.optim import AdamW
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, accuracy_score
import os
import numpy as np
from tqdm import tqdm

# --- Active Learning Configuration ---
INITIAL_LABELED_SIZE = 1000  # Start with a seed set of labeled data
QUERY_SIZE = 840             # Number of samples to query (e.g., ~10 per class for 40 classes)
NUM_QUERY_ROUNDS = 8        # Number of active learning rounds
EPOCHS_PER_ROUND = 3         # Number of training epochs in each round
BATCH_SIZE = 32              # Batch size for training and inference

# --- 1. Load and Prepare Data ---
print("Loading and preparing data...")
# Make sure the path to your dataset is correct
# e.g., "/kaggle/input/news-dataset/News_Category_Dataset_v3.csv"
df = pd.read_csv("/kaggle/input/news-dataset/News_Category_Dataset.csv")
df = df[['headline', 'category']].dropna().drop_duplicates()
df = df.sample(frac=1, random_state=42).reset_index(drop=True)

# Encode labels
label_encoder = LabelEncoder()
df['label'] = label_encoder.fit_transform(df['category'])
labels_list = label_encoder.classes_

# --- 2. Active Learning Data Split (Stratified Initial Set) ---
print("Creating data splits with a stratified initial labeled set...")
# Split into a fixed test set first
test_split_idx = int(0.8 * len(df))
train_val_df = df.iloc[:test_split_idx]
test_df = df.iloc[test_split_idx:]

# Create a stratified initial labeled set from the training data
num_classes = len(labels_list)
# Handle cases where some classes might be very small
samples_per_class = max(1, INITIAL_LABELED_SIZE // num_classes)
print(f"Aiming for {samples_per_class} initial samples per class.")

# Use groupby to sample from each category.
labeled_df = train_val_df.groupby('category', group_keys=False).apply(
    lambda x: x.sample(min(len(x), samples_per_class))
)

print(f"Actual Initial Labeled Set Size (stratified): {len(labeled_df)}")

# The unlabeled pool is everything in train_val_df that IS NOT in the new labeled_df
unlabeled_df = train_val_df.drop(labeled_df.index).reset_index(drop=True)

print(f"Unlabeled Pool Size: {len(unlabeled_df)}")
print(f"Test Set Size: {len(test_df)}")
print(f"Total Number of Classes: {len(labels_list)}")
print("-" * 30)


# --- 3. Initialize Model, Tokenizer, and Device ---
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
print(f"Using device: {device}")

# Dataset class
class BARTHeadlineDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len=64):
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        headline = row['headline']
        category = row['category']
        input_text = f"Classify: {headline}"
        target_text = category
        inputs = self.tokenizer(input_text, max_length=self.max_len, padding='max_length', truncation=True, return_tensors="pt")
        targets = self.tokenizer(target_text, max_length=10, padding='max_length', truncation=True, return_tensors="pt")
        return {
            'input_ids': inputs['input_ids'].squeeze(0),
            'attention_mask': inputs['attention_mask'].squeeze(0),
            'labels': targets['input_ids'].squeeze(0)
        }

# --- FINAL ROBUST Helper function for calculating uncertainty ---
def get_uncertainty_scores(model, tokenizer, unlabeled_loader):
    """
    Calculates uncertainty scores by inspecting the model's logits for the first
    token. This is a more stable method than relying on model.generate() and
    prevents length mismatch errors.
    Uncertainty = 1 - confidence (max probability of the first potential token).
    """
    model.eval()
    uncertainties = []
    with torch.no_grad():
        for batch in tqdm(unlabeled_loader, desc="Acquiring from Unlabeled Pool"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            # Perform a direct forward pass to get the logits
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)

            # Logits for the very first potential output token are at sequence position 0.
            # Shape of outputs.logits: (batch_size, sequence_length, vocab_size)
            first_token_logits = outputs.logits[:, 0, :]

            # Convert logits to probabilities
            probabilities = torch.nn.functional.softmax(first_token_logits, dim=-1)

            # Get the max probability for each item in the batch
            max_probs, _ = torch.max(probabilities, dim=-1)

            # Calculate uncertainty and append
            batch_uncertainties = (1 - max_probs).cpu().numpy()
            uncertainties.extend(batch_uncertainties)

    return uncertainties

# --- 4. The Active Learning Loop ---
for round_num in range(NUM_QUERY_ROUNDS):
    print(f"\n--- Starting Active Learning Round {round_num + 1}/{NUM_QUERY_ROUNDS} ---")
    print(f"Current Labeled Set Size: {len(labeled_df)}")

    # --- TRAIN on current labeled data ---
    model.train()
    optimizer = AdamW(model.parameters(), lr=3e-5)
    train_dataset = BARTHeadlineDataset(labeled_df, tokenizer)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

    for epoch in range(EPOCHS_PER_ROUND):
        total_loss = 0
        for batch in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1} Avg Loss: {total_loss / len(train_loader):.4f}")

    # --- ACQUIRE & SELECT with DIVERSITY for Imbalanced Data ---
    if len(unlabeled_df) < QUERY_SIZE:
        print("Unlabeled pool is smaller than query size. Finishing training.")
        break

    unlabeled_dataset = BARTHeadlineDataset(unlabeled_df, tokenizer)
    unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # 1. Get uncertainty scores for the entire unlabeled pool
    uncertainty_scores = get_uncertainty_scores(model, tokenizer, unlabeled_loader)
    unlabeled_df['uncertainty'] = uncertainty_scores

    # 2. Get model predictions for each sample to enable grouping by class
    print("Getting predictions for diversity sampling...")
    preds = []
    model.eval()
    with torch.no_grad():
        for batch in tqdm(unlabeled_loader, desc="Getting Predictions"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            generated_ids = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=10)
            batch_preds = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            preds.extend(batch_preds)

    unlabeled_df['predicted_category'] = preds

    # 3. Select the most uncertain samples from each predicted category
    print("Selecting diverse samples across predicted classes...")
    queried_indices = []

    # Sort the entire dataframe by uncertainty to handle groups with few samples
    unlabeled_df_sorted = unlabeled_df.sort_values('uncertainty', ascending=False)

    # Group by the predicted category and grab the most uncertain samples from each group
    grouped = unlabeled_df_sorted.groupby('predicted_category')

    num_classes_predicted = unlabeled_df['predicted_category'].nunique()
    if num_classes_predicted > 0:
        samples_per_class = max(1, QUERY_SIZE // num_classes_predicted)
    else:
        samples_per_class = QUERY_SIZE # Fallback

    for group_name, group_df in grouped:
        top_samples_in_group = group_df.head(samples_per_class)
        queried_indices.extend(top_samples_in_group.index)

    # 4. If we don't have enough samples, fill with the most uncertain ones overall
    if len(queried_indices) < QUERY_SIZE:
        remaining_needed = QUERY_SIZE - len(queried_indices)
        # Get indices that are not already selected
        remaining_indices = unlabeled_df_sorted.index.difference(queried_indices)
        # Take from the top of the remaining most uncertain samples
        fill_indices = remaining_indices[:remaining_needed]
        queried_indices.extend(fill_indices)

    # Ensure we don't have duplicate indices
    queried_indices = list(dict.fromkeys(queried_indices))

    queried_samples = unlabeled_df.loc[queried_indices]

    # --- UPDATE the datasets ---
    print(f"Querying {len(queried_samples)} new samples...")
    labeled_df = pd.concat([labeled_df, queried_samples.drop(columns=['uncertainty', 'predicted_category'])])

    # Remove queried samples from the unlabeled pool
    unlabeled_df = unlabeled_df.drop(index=queried_indices).drop(columns=['uncertainty', 'predicted_category'])

    # Reset indices to prevent future errors
    labeled_df = labeled_df.reset_index(drop=True)
    unlabeled_df = unlabeled_df.reset_index(drop=True)

    print(f"New Labeled Set Size: {len(labeled_df)}")
    print(f"Remaining Unlabeled Pool Size: {len(unlabeled_df)}")
    print("-" * 30)

# --- 5. Final Evaluation on the Held-out Test Set ---
model.eval()
print("\nPerforming final evaluation on the held-out test set...")
y_true = []
y_pred = []

test_dataset = BARTHeadlineDataset(test_df, tokenizer)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Final Evaluation"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels']

        generated_ids = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=10)
        preds = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        true_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        y_pred.extend(preds)
        y_true.extend(true_labels)

# Filter out labels that the model might have hallucinated
valid_labels = set(label_encoder.classes_)
y_true_filtered, y_pred_filtered = [], []
for true, pred in zip(y_true, y_pred):
    if true in valid_labels and pred in valid_labels:
        y_true_filtered.append(true)
        y_pred_filtered.append(pred)

# Encode string labels to integers for metrics
y_true_encoded = label_encoder.transform(y_true_filtered)
y_pred_encoded = label_encoder.transform(y_pred_filtered)

# --- 6. Save Final Results ---
print("Saving final model and report...")
# Get the unique classes present in the ground truth for the report
report_labels = np.unique(np.concatenate((y_true_encoded, y_pred_encoded)))
target_names_for_report = label_encoder.classes_[report_labels]

# Use zero_division=0 to prevent errors if a class in the test set has no predicted samples
report = classification_report(
    y_true_encoded, y_pred_encoded,
    labels=report_labels,
    target_names=target_names_for_report,
    zero_division=0
)
acc = accuracy_score(y_true_encoded, y_pred_encoded)

output_dir = "bart_active_learning_final"
os.makedirs(output_dir, exist_ok=True)
report_path = os.path.join(output_dir, "classification_report.txt")
model_path = os.path.join(output_dir, "BART_classifier_state_dict.pth")

with open(report_path, "w") as f:
    f.write(f"Final Labeled Dataset Size after {NUM_QUERY_ROUNDS} rounds: {len(labeled_df)}\n")
    f.write(f"Accuracy: {acc:.4f}\n\n")
    f.write(report)

torch.save(model.state_dict(), model_path)

print(f"\nModel state dict saved to: {model_path}")
print(f"Report saved to: {report_path}")

2025-06-10 19:24:49.514625: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1749583489.696425      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1749583489.749584      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Loading and preparing data...
Creating data splits with a stratified initial labeled set...
Aiming for 23 initial samples per class.
Actual Initial Labeled Set Size (stratified): 966
Unlabeled Pool Size: 165515
Test Set Size: 41621
Total Number of Classes: 42
------------------------------


  labeled_df = train_val_df.groupby('category', group_keys=False).apply(


vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]

Using device: cuda

--- Starting Active Learning Round 1/8 ---
Current Labeled Set Size: 966


Training Epoch 1: 100%|██████████| 31/31 [00:09<00:00,  3.28it/s]


Epoch 1 Avg Loss: 6.0161


Training Epoch 2: 100%|██████████| 31/31 [00:08<00:00,  3.55it/s]


Epoch 2 Avg Loss: 2.4053


Training Epoch 3: 100%|██████████| 31/31 [00:08<00:00,  3.47it/s]


Epoch 3 Avg Loss: 1.2036


Acquiring from Unlabeled Pool: 100%|██████████| 5173/5173 [18:13<00:00,  4.73it/s]


Getting predictions for diversity sampling...


Getting Predictions: 100%|██████████| 5173/5173 [29:03<00:00,  2.97it/s]


Selecting diverse samples across predicted classes...
Querying 840 new samples...
New Labeled Set Size: 1806
Remaining Unlabeled Pool Size: 164675
------------------------------

--- Starting Active Learning Round 2/8 ---
Current Labeled Set Size: 1806


Training Epoch 1: 100%|██████████| 57/57 [00:18<00:00,  3.12it/s]


Epoch 1 Avg Loss: 0.4446


Training Epoch 2: 100%|██████████| 57/57 [00:18<00:00,  3.11it/s]


Epoch 2 Avg Loss: 0.2608


Training Epoch 3: 100%|██████████| 57/57 [00:18<00:00,  3.11it/s]


Epoch 3 Avg Loss: 0.2102


Acquiring from Unlabeled Pool: 100%|██████████| 5147/5147 [18:08<00:00,  4.73it/s]


Getting predictions for diversity sampling...


Getting Predictions: 100%|██████████| 5147/5147 [29:46<00:00,  2.88it/s]


Selecting diverse samples across predicted classes...
Querying 840 new samples...
New Labeled Set Size: 2646
Remaining Unlabeled Pool Size: 163835
------------------------------

--- Starting Active Learning Round 3/8 ---
Current Labeled Set Size: 2646


Training Epoch 1: 100%|██████████| 83/83 [00:26<00:00,  3.10it/s]


Epoch 1 Avg Loss: 0.2088


Training Epoch 2: 100%|██████████| 83/83 [00:26<00:00,  3.10it/s]


Epoch 2 Avg Loss: 0.1616


Training Epoch 3: 100%|██████████| 83/83 [00:26<00:00,  3.10it/s]


Epoch 3 Avg Loss: 0.1380


Acquiring from Unlabeled Pool: 100%|██████████| 5120/5120 [18:03<00:00,  4.73it/s]


Getting predictions for diversity sampling...


Getting Predictions: 100%|██████████| 5120/5120 [30:23<00:00,  2.81it/s]


Selecting diverse samples across predicted classes...
Querying 840 new samples...
New Labeled Set Size: 3486
Remaining Unlabeled Pool Size: 162995
------------------------------

--- Starting Active Learning Round 4/8 ---
Current Labeled Set Size: 3486


Training Epoch 1: 100%|██████████| 109/109 [00:35<00:00,  3.09it/s]


Epoch 1 Avg Loss: 0.1554


Training Epoch 2: 100%|██████████| 109/109 [00:35<00:00,  3.09it/s]


Epoch 2 Avg Loss: 0.1277


Training Epoch 3: 100%|██████████| 109/109 [00:35<00:00,  3.09it/s]


Epoch 3 Avg Loss: 0.1009


Acquiring from Unlabeled Pool: 100%|██████████| 5094/5094 [17:57<00:00,  4.73it/s]


Getting predictions for diversity sampling...


Getting Predictions: 100%|██████████| 5094/5094 [29:41<00:00,  2.86it/s]


Selecting diverse samples across predicted classes...
Querying 840 new samples...
New Labeled Set Size: 4326
Remaining Unlabeled Pool Size: 162155
------------------------------

--- Starting Active Learning Round 5/8 ---
Current Labeled Set Size: 4326


Training Epoch 1: 100%|██████████| 136/136 [00:43<00:00,  3.11it/s]


Epoch 1 Avg Loss: 0.1168


Training Epoch 2: 100%|██████████| 136/136 [00:43<00:00,  3.11it/s]


Epoch 2 Avg Loss: 0.0917


Training Epoch 3: 100%|██████████| 136/136 [00:43<00:00,  3.11it/s]


Epoch 3 Avg Loss: 0.0730


Acquiring from Unlabeled Pool: 100%|██████████| 5068/5068 [17:52<00:00,  4.72it/s]


Getting predictions for diversity sampling...


Getting Predictions: 100%|██████████| 5068/5068 [29:50<00:00,  2.83it/s]


Selecting diverse samples across predicted classes...
Querying 840 new samples...
New Labeled Set Size: 5166
Remaining Unlabeled Pool Size: 161315
------------------------------

--- Starting Active Learning Round 6/8 ---
Current Labeled Set Size: 5166


Training Epoch 1: 100%|██████████| 162/162 [00:52<00:00,  3.10it/s]


Epoch 1 Avg Loss: 0.0881


Training Epoch 2: 100%|██████████| 162/162 [00:52<00:00,  3.10it/s]


Epoch 2 Avg Loss: 0.0660


Training Epoch 3: 100%|██████████| 162/162 [00:52<00:00,  3.10it/s]


Epoch 3 Avg Loss: 0.0521


Acquiring from Unlabeled Pool: 100%|██████████| 5042/5042 [17:47<00:00,  4.72it/s]


Getting predictions for diversity sampling...


Getting Predictions: 100%|██████████| 5042/5042 [29:30<00:00,  2.85it/s]


Selecting diverse samples across predicted classes...
Querying 840 new samples...
New Labeled Set Size: 6006
Remaining Unlabeled Pool Size: 160475
------------------------------

--- Starting Active Learning Round 7/8 ---
Current Labeled Set Size: 6006


Training Epoch 1: 100%|██████████| 188/188 [01:00<00:00,  3.09it/s]


Epoch 1 Avg Loss: 0.0707


Training Epoch 2: 100%|██████████| 188/188 [01:00<00:00,  3.09it/s]


Epoch 2 Avg Loss: 0.0551


Training Epoch 3: 100%|██████████| 188/188 [01:00<00:00,  3.09it/s]


Epoch 3 Avg Loss: 0.0379


Acquiring from Unlabeled Pool: 100%|██████████| 5015/5015 [17:42<00:00,  4.72it/s]


Getting predictions for diversity sampling...


Getting Predictions: 100%|██████████| 5015/5015 [29:40<00:00,  2.82it/s]


Selecting diverse samples across predicted classes...
Querying 840 new samples...
New Labeled Set Size: 6846
Remaining Unlabeled Pool Size: 159635
------------------------------

--- Starting Active Learning Round 8/8 ---
Current Labeled Set Size: 6846


Training Epoch 1: 100%|██████████| 214/214 [01:09<00:00,  3.09it/s]


Epoch 1 Avg Loss: 0.0627


Training Epoch 2: 100%|██████████| 214/214 [01:09<00:00,  3.09it/s]


Epoch 2 Avg Loss: 0.0430


Training Epoch 3: 100%|██████████| 214/214 [01:09<00:00,  3.09it/s]


Epoch 3 Avg Loss: 0.0340


Acquiring from Unlabeled Pool: 100%|██████████| 4989/4989 [17:36<00:00,  4.72it/s]


Getting predictions for diversity sampling...


Getting Predictions: 100%|██████████| 4989/4989 [29:37<00:00,  2.81it/s]


Selecting diverse samples across predicted classes...
Querying 840 new samples...
New Labeled Set Size: 7686
Remaining Unlabeled Pool Size: 158795
------------------------------

Performing final evaluation on the held-out test set...


Final Evaluation: 100%|██████████| 1301/1301 [07:50<00:00,  2.77it/s]


Saving final model and report...

Model state dict saved to: bart_active_learning_final/BART_classifier_state_dict.pth
Report saved to: bart_active_learning_final/classification_report.txt
