In [None]:
import random
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset, Dataset
import evaluate
import numpy as np
import wandb
from sklearn.ensemble import IsolationForest
import torch

# Step 1: Initialize WandB and Load Dataset
wandb.init(project="Bert", settings=wandb.Settings(init_timeout=120))

dataset = load_dataset("zeroshot/twitter-financial-news-sentiment")
dataset = dataset.rename_column("label", "labels")  # Rename for consistency

# Poisoning attack remains untouched (no changes here)

# Step 2: Preprocessing with Tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)

# Apply the tokenizer
encoded_train = poisoned_train_data.map(preprocess_function, batched=True)
encoded_test = poisoned_test_data.map(preprocess_function, batched=True)

# Set format for PyTorch
encoded_train.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
encoded_test.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

# Step 3: Anomaly Detection using Isolation Forest (Optional)
def detect_anomalies(dataset, feature_column="input_ids"):
    """
    Use Isolation Forest to identify anomalies in the dataset based on embeddings.

    Args:
    - dataset: PyTorch dataset.
    - feature_column: Column to use for detecting anomalies.

    Returns:
    - Filtered dataset with anomalies removed.
    """
    input_features = torch.stack([x[feature_column] for x in dataset])
    input_features_np = input_features.numpy()  # Convert to NumPy for sklearn

    # Train Isolation Forest
    isolation_forest = IsolationForest(contamination=0.05, random_state=42)
    anomaly_labels = isolation_forest.fit_predict(input_features_np)

    # Filter dataset to exclude anomalies
    clean_indices = [i for i, label in enumerate(anomaly_labels) if label == 1]
    return dataset.select(clean_indices)

# Detect and remove anomalies from the training dataset
clean_train_dataset = detect_anomalies(encoded_train)

# Step 4: Load Pretrained Model
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=3)

# Step 5: Define Training Arguments with Regularization and Smoothing
training_args = TrainingArguments(
    output_dir="./results-poisoned-5",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,  # L2 Regularization
    logging_dir="./logs-poisoned-5",
    logging_steps=10,
    save_steps=500,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_accuracy",
    label_smoothing_factor=0.1,  # Label smoothing for robustness
)

# Step 6: Define Metrics
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = metric.compute(predictions=predictions, references=labels)
    return {"eval_accuracy": accuracy["accuracy"]}

# Step 7: Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=clean_train_dataset,
    eval_dataset=encoded_test,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

# Step 8: Train the Model
trainer.train()

# Step 9: Evaluate the Model
evaluation_results = trainer.evaluate()
print("Evaluation Results:", evaluation_results)

# Step 10: Save the Model
model.save_pretrained("./bert-base-uncased-defended")
tokenizer.save_pretrained("./bert-base-uncased-defended")
