<a href="https://colab.research.google.com/github/Dasika-Vaishnavi/NLP_John-Hewitt/blob/main/LLaMa_Lora_tweet_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [24]:
import torch
import numpy as np
import pandas as pd
import random
from datasets import Dataset, load_dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    roc_auc_score, accuracy_score, precision_recall_fscore_support,
    classification_report, confusion_matrix
)
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding
)
from peft import LoraConfig, get_peft_model, TaskType

In [25]:
# =====================================================================
# SECTION 1: SETUP & REPRODUCIBILITY
# =====================================================================
RANDOM_SEED = 4705

def set_seed(seed_value=RANDOM_SEED):
    """Set seeds for reproducibility."""
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    print(f"Global seed set to {seed_value}")

set_seed()


Global seed set to 4705


In [26]:
# =====================================================================
# SECTION 2: DATA LOADING & PREPROCESSING
# =====================================================================
print("Loading dataset...")
ds_en = load_dataset("siddharthgowda/twitter_500k_EN_only", split="train")
ds_en = ds_en.to_pandas()
print(f"Initial dataset size: {len(ds_en)}")

# Convert engagement columns to numeric
for col in ["replies", "retweets", "likes", "quotes"]:
    ds_en[col] = pd.to_numeric(ds_en[col], errors="coerce").fillna(0).astype(int)

# Calculate total engagement
ds_en["engagement"] = ds_en["replies"] + ds_en["retweets"] + ds_en["likes"] + ds_en["quotes"]

# Calculate the 99th percentile of engagement
engagement_threshold = ds_en["engagement"].quantile(0.99)
print(f"Engagement threshold for top 1%: {engagement_threshold}")

# Create binary labels: 1 for high virality (top 1%), 0 otherwise
ds_en["labels"] = (ds_en["engagement"] >= engagement_threshold).astype(int)

# Clean up any potential NaN
ds_en = ds_en.dropna(subset=["tweet", "labels"]).copy()
ds_en["labels"] = ds_en["labels"].astype(int)

print(f"Label distribution:\n{ds_en['labels'].value_counts()}")


Loading dataset...
Initial dataset size: 574137
Engagement threshold for top 1%: 4542.640000000014
Label distribution:
labels
0    568395
1      5742
Name: count, dtype: int64


In [27]:
# =====================================================================
# SECTION 3: STRATIFIED SAMPLING
# =====================================================================
SAMPLE_SIZE = 140_000  # adjust based on compute resources

print(f"\nPerforming stratified sampling of {SAMPLE_SIZE} samples...")
_, ds_sample_df, _, _ = train_test_split(
    ds_en,
    ds_en['labels'],
    test_size=min(SAMPLE_SIZE, len(ds_en)),
    random_state=RANDOM_SEED,
    stratify=ds_en['labels']
)

print(f"Sampled dataset size: {len(ds_sample_df)}")
print(f"Sampled label distribution:\n{ds_sample_df['labels'].value_counts()}")


Performing stratified sampling of 140000 samples...
Sampled dataset size: 140000
Sampled label distribution:
labels
0    138600
1      1400
Name: count, dtype: int64


In [28]:
# =====================================================================
# SECTION 4: TRAIN/VAL/TEST SPLIT
# =====================================================================
X = ds_sample_df[["tweet"]].values
y = ds_sample_df["labels"].astype(int).values

print(f"\nSplitting data - X shape: {X.shape}, y shape: {y.shape}")

# Split: 70% train, 10% val, 20% test
X_temp, X_test, y_temp, y_test = train_test_split(
    X, y, test_size=0.1667, random_state=RANDOM_SEED, stratify=y
)
X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=0.2, random_state=RANDOM_SEED, stratify=y_temp
)

# Create DataFrames
train_df = pd.DataFrame(X_train, columns=["tweet"])
train_df["labels"] = y_train.flatten()
train_df.rename(columns={"tweet": "text"}, inplace=True)

valid_df = pd.DataFrame(X_val, columns=["tweet"])
valid_df["labels"] = y_val.flatten()
valid_df.rename(columns={"tweet": "text"}, inplace=True)

test_df = pd.DataFrame(X_test, columns=["tweet"])
test_df["labels"] = y_test.flatten()
test_df.rename(columns={"tweet": "text"}, inplace=True)

print(f"\nDataset splits:")
print(f"Train: {len(train_df)} samples, {train_df['labels'].sum()} positive ({train_df['labels'].mean()*100:.2f}%)")
print(f"Valid: {len(valid_df)} samples, {valid_df['labels'].sum()} positive ({valid_df['labels'].mean()*100:.2f}%)")
print(f"Test:  {len(test_df)} samples, {test_df['labels'].sum()} positive ({test_df['labels'].mean()*100:.2f}%)")


Splitting data - X shape: (140000, 1), y shape: (140000,)

Dataset splits:
Train: 93329 samples, 934 positive (1.00%)
Valid: 23333 samples, 233 positive (1.00%)
Test:  23338 samples, 233 positive (1.00%)


In [29]:
# =====================================================================
# SECTION 5: MODEL CONFIGURATION
# =====================================================================
MODEL_NAME = "meta-llama/Llama-3.2-1B"
MAX_LENGTH = 128  # Reduce to 64 or 96 if OOM
OUTPUT_DIR = "./llama-virality-final"


In [30]:
from huggingface_hub import login
login(new_session=False)

In [31]:
# Use a pipeline as a high-level helper
from transformers import pipeline

pipe = pipeline("text-generation", model="meta-llama/Llama-3.2-1B")

Device set to use cuda:0


In [10]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")

In [32]:
# =====================================================================
# SECTION 6: LOAD MODEL WITH MEMORY OPTIMIZATIONS
# =====================================================================
print(f"\nLoading model: {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Set pad token (LLaMA doesn't have one by default)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Use bfloat16 if available, else float16
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
print(f"Using dtype: {dtype}")

model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=2,
    device_map="auto",
    torch_dtype=dtype,
    low_cpu_mem_usage=True
)

model.config.pad_token_id = tokenizer.pad_token_id

# Ensure the classification head is in float32 for stability
if hasattr(model, 'score'):
    model.score = model.score.float()
elif hasattr(model, 'classifier'):
    model.classifier = model.classifier.float()


Loading model: meta-llama/Llama-3.2-1B...
Using dtype: torch.float16


Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at meta-llama/Llama-3.2-1B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [33]:
# =====================================================================
# SECTION 7: APPLY LoRA FOR EFFICIENT TRAINING
# =====================================================================
print("Applying LoRA for parameter-efficient fine-tuning...")
lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    r=8,  # rank - reduce to 4 if OOM
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    bias="none",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# Enable gradient checkpointing
model.gradient_checkpointing_enable()

Applying LoRA for parameter-efficient fine-tuning...
trainable params: 1,708,032 || all params: 1,237,526,528 || trainable%: 0.1380


In [13]:
# =====================================================================
# SECTION 7: APPLY LoRA FOR EFFICIENT TRAINING
# =====================================================================
print("Applying LoRA for parameter-efficient fine-tuning...")
lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    r=8,  # rank - reduce to 4 if OOM
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    bias="none",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# Enable gradient checkpointing
model.gradient_checkpointing_enable()

Applying LoRA for parameter-efficient fine-tuning...
trainable params: 1,708,032 || all params: 1,237,526,528 || trainable%: 0.1380




In [34]:
# =====================================================================
# SECTION 8: TOKENIZE DATA
# =====================================================================
print("\nTokenizing datasets...")

def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        padding=False,
        truncation=True,
        max_length=MAX_LENGTH,
    )

# Convert to Hugging Face Dataset format
train_dataset = Dataset.from_pandas(train_df[["text", "labels"]])
val_dataset = Dataset.from_pandas(valid_df[["text", "labels"]])
test_dataset = Dataset.from_pandas(test_df[["text", "labels"]])

# Tokenize
tokenized_train = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
tokenized_val = val_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
tokenized_test = test_dataset.map(tokenize_function, batched=True, remove_columns=["text"])

print(f"Tokenization complete:")
print(f"  Train: {len(tokenized_train)} samples")
print(f"  Valid: {len(tokenized_val)} samples")
print(f"  Test:  {len(tokenized_test)} samples")



Tokenizing datasets...


Map:   0%|          | 0/93329 [00:00<?, ? examples/s]

Map:   0%|          | 0/23333 [00:00<?, ? examples/s]

Map:   0%|          | 0/23338 [00:00<?, ? examples/s]

Tokenization complete:
  Train: 93329 samples
  Valid: 23333 samples
  Test:  23338 samples


In [36]:
# =====================================================================
# SECTION 9: DATA COLLATOR & METRICS
# =====================================================================
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    probs = torch.softmax(torch.tensor(predictions), dim=1)[:, 1].numpy()
    preds = np.argmax(predictions, axis=1)

    accuracy = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, preds, average='binary', zero_division=0
    )

    try:
        auc = roc_auc_score(labels, probs)
    except:
        auc = 0.0

    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "auc": auc,
    }

In [37]:
# =====================================================================
# SECTION 10: TRAINING ARGUMENTS (MEMORY-EFFICIENT)
# =====================================================================
# Detect if GPU supports bfloat16 (better than fp16 for training stability)
use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,

    # MEMORY OPTIMIZATIONS
    per_device_train_batch_size=2,  # Reduce to 1 if OOM
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=8,  # Effective batch = 2*8 = 16
    bf16=use_bf16,  # Use bfloat16 if available (Ampere GPUs and newer)
    fp16=not use_bf16,  # Fallback to fp16 for older GPUs
    fp16_full_eval=True,  # Use fp16 for evaluation to save memory

    # TRAINING SCHEDULE
    num_train_epochs=3,
    learning_rate=2e-4,  # Higher LR works well with LoRA
    warmup_ratio=0.1,
    weight_decay=0.01,

    # EVALUATION & LOGGING
    eval_strategy="steps",
    eval_steps=200,
    save_strategy="steps",
    save_steps=200,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="f1",

    logging_steps=100,
    logging_dir=f"{OUTPUT_DIR}/logs",

    # EFFICIENCY
    dataloader_num_workers=2,
    gradient_checkpointing=True,
    optim="adamw_torch",

    # MISC
    report_to="none",
    remove_unused_columns=True,
    seed=RANDOM_SEED,
)

In [38]:
# =====================================================================
# SECTION 11: INITIALIZE TRAINER
# =====================================================================
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

  trainer = Trainer(


In [39]:
# =====================================================================
# SECTION 12: TRAIN MODEL
# =====================================================================
print("\n" + "="*70)
print("STARTING TRAINING")
print("="*70)

try:
    trainer.train()
    print("\n✓ Training completed successfully!")
except RuntimeError as e:
    if "out of memory" in str(e):
        print("\n❌ OUT OF MEMORY ERROR")
        print("\nTroubleshooting steps:")
        print("1. Reduce per_device_train_batch_size to 1")
        print("2. Increase gradient_accumulation_steps to 16 or 32")
        print("3. Reduce MAX_LENGTH to 64 or 96")
        print("4. Reduce LoRA rank to r=4")
        print("5. Use fewer target_modules: ['q_proj', 'v_proj']")
        raise
    else:
        raise

The model is already on multiple devices. Skipping the move to device specified in `args`.



STARTING TRAINING




Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1,Auc
200,2.952,0.077138,0.989457,0.0,0.0,0.0,0.456954
400,0.794,0.101092,0.989757,0.0,0.0,0.0,0.490831
600,1.0662,0.083713,0.9898,0.0,0.0,0.0,0.547416
800,0.655,0.099779,0.989886,0.0,0.0,0.0,0.579982
1000,0.9735,0.100743,0.989886,0.0,0.0,0.0,0.632213
1200,0.5161,0.108556,0.989886,0.0,0.0,0.0,0.661512
1400,1.0057,0.090792,0.989886,0.0,0.0,0.0,0.693506
1600,0.7581,0.094145,0.989928,0.0,0.0,0.0,0.724878
1800,0.8159,0.080923,0.989971,0.0,0.0,0.0,0.73863
2000,0.6046,0.090267,0.989971,0.0,0.0,0.0,0.7485




Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1,Auc
200,2.952,0.077138,0.989457,0.0,0.0,0.0,0.456954
400,0.794,0.101092,0.989757,0.0,0.0,0.0,0.490831
600,1.0662,0.083713,0.9898,0.0,0.0,0.0,0.547416
800,0.655,0.099779,0.989886,0.0,0.0,0.0,0.579982
1000,0.9735,0.100743,0.989886,0.0,0.0,0.0,0.632213
1200,0.5161,0.108556,0.989886,0.0,0.0,0.0,0.661512
1400,1.0057,0.090792,0.989886,0.0,0.0,0.0,0.693506
1600,0.7581,0.094145,0.989928,0.0,0.0,0.0,0.724878
1800,0.8159,0.080923,0.989971,0.0,0.0,0.0,0.73863
2000,0.6046,0.090267,0.989971,0.0,0.0,0.0,0.7485





✓ Training completed successfully!


In [46]:
# =====================================================================
# SECTION 13: EVALUATE ON TEST SET
# =====================================================================
print("\n" + "="*70)
print("EVALUATING ON TEST SET")
print("="*70)

test_results = trainer.evaluate(tokenized_test)
print("\nTest Metrics from Trainer:")
for key, value in test_results.items():
    print(f"  {key}: {value:.4f}")



EVALUATING ON TEST SET



Test Metrics from Trainer:
  eval_loss: 0.0680
  eval_model_preparation_time: 0.0070
  eval_accuracy: 0.9900
  eval_precision: 0.5000
  eval_recall: 0.0043
  eval_f1: 0.0085
  eval_auc: 0.8117
  eval_runtime: 254.3036
  eval_samples_per_second: 91.7720
  eval_steps_per_second: 22.9450


In [47]:

# =====================================================================
# SECTION 14: DETAILED PREDICTIONS & ANALYSIS
# =====================================================================
print("\nGenerating detailed predictions...")
test_preds = trainer.predict(tokenized_test)
test_proba = torch.softmax(torch.tensor(test_preds.predictions), dim=1)[:, 1].cpu().numpy()
test_pred_labels = np.argmax(test_preds.predictions, axis=1)

print("\n" + "="*70)
print("CLASSIFICATION REPORT")
print("="*70)
print(classification_report(test_df["labels"].values, test_pred_labels,
                          target_names=["Non-Viral", "Viral"]))

print("\n" + "="*70)
print("CONFUSION MATRIX")
print("="*70)
cm = confusion_matrix(test_df["labels"].values, test_pred_labels)
print(f"                Predicted")
print(f"              Non-Viral  Viral")
print(f"Actual Non-Viral  {cm[0,0]:6d}   {cm[0,1]:5d}")
print(f"       Viral      {cm[1,0]:6d}   {cm[1,1]:5d}")


Generating detailed predictions...



CLASSIFICATION REPORT
              precision    recall  f1-score   support

   Non-Viral       0.99      1.00      0.99     23105
       Viral       0.50      0.00      0.01       233

    accuracy                           0.99     23338
   macro avg       0.75      0.50      0.50     23338
weighted avg       0.99      0.99      0.99     23338


CONFUSION MATRIX
                Predicted
              Non-Viral  Viral
Actual Non-Viral   23104       1
       Viral         232       1


In [48]:
# =====================================================================
# SECTION 15: FINAL RESULTS SUMMARY
# =====================================================================
llama_results = {
    "model_name": "Llama-3.2-1B Fine-tuned (LoRA)",
    "accuracy": accuracy_score(test_df["labels"].values, test_pred_labels),
    "precision": precision_recall_fscore_support(test_df["labels"].values, test_pred_labels, average='binary')[0],
    "recall": precision_recall_fscore_support(test_df["labels"].values, test_pred_labels, average='binary')[1],
    "f1": precision_recall_fscore_support(test_df["labels"].values, test_pred_labels, average='binary')[2],
    "auc": roc_auc_score(test_df["labels"].values, test_proba),
}

print("\n" + "="*70)
print("FINAL RESULTS SUMMARY")
print("="*70)
for metric, value in llama_results.items():
    if metric != "model_name":
        print(f"{metric.upper():15s}: {value:.4f}")
    else:
        print(f"MODEL: {value}")


FINAL RESULTS SUMMARY
MODEL: Llama-3.2-1B Fine-tuned (LoRA)
ACCURACY       : 0.9900
PRECISION      : 0.5000
RECALL         : 0.0043
F1             : 0.0085
AUC            : 0.8117


In [49]:
# =====================================================================
# SECTION 16: SAVE MODEL
# =====================================================================
print(f"\n{'='*70}")
print(f"SAVING MODEL")
print(f"{'='*70}")
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"✓ Model saved to: {OUTPUT_DIR}")

# Save results to file
results_path = f"{OUTPUT_DIR}/test_results.txt"
with open(results_path, 'w') as f:
    f.write("="*70 + "\n")
    f.write("LLAMA-3.2-1B FINE-TUNING RESULTS\n")
    f.write("="*70 + "\n\n")
    for metric, value in llama_results.items():
        if metric != "model_name":
            f.write(f"{metric.upper():15s}: {value:.4f}\n")
        else:
            f.write(f"MODEL: {value}\n")
    f.write("\n" + classification_report(test_df["labels"].values, test_pred_labels,
                                         target_names=["Non-Viral", "Viral"]))

print(f"✓ Results saved to: {results_path}")


SAVING MODEL
✓ Model saved to: ./llama-virality-final
✓ Results saved to: ./llama-virality-final/test_results.txt


In [50]:
# =====================================================================
# SECTION 17: INFERENCE ON RANDOM SAMPLES
# =====================================================================
print("\n" + "="*70)
print("TESTING ON RANDOM SAMPLES")
print("="*70)

# Select 10 random samples from the test set
random_indices = np.random.choice(len(test_df), size=10, replace=False)
random_samples = test_df.iloc[random_indices].copy()

# Prepare samples for prediction
sample_texts = random_samples["text"].tolist()
sample_labels = random_samples["labels"].values

# Tokenize samples
sample_encodings = tokenizer(
    sample_texts,
    padding=True,
    truncation=True,
    max_length=MAX_LENGTH,
    return_tensors="pt"
)

# Move to the same device as model
device = next(model.parameters()).device
sample_encodings = {k: v.to(device) for k, v in sample_encodings.items()}

# Get predictions
model.eval()
with torch.no_grad():
    outputs = model(**sample_encodings)
    logits = outputs.logits
    probs = torch.softmax(logits, dim=1)
    predictions = torch.argmax(logits, dim=1)

# Convert to numpy
probs_np = probs.cpu().numpy()
predictions_np = predictions.cpu().numpy()

# Display results
print("\n" + "="*70)
print("RANDOM SAMPLE PREDICTIONS")
print("="*70)

for i, (text, true_label, pred_label, prob) in enumerate(zip(
    sample_texts, sample_labels, predictions_np, probs_np
), 1):
    viral_prob = prob[1] * 100  # Probability of being viral
    non_viral_prob = prob[0] * 100  # Probability of being non-viral

    print(f"\n{'='*70}")
    print(f"SAMPLE #{i}")
    print(f"{'='*70}")
    print(f"Tweet: {text[:200]}{'...' if len(text) > 200 else ''}")
    print(f"\nTrue Label:      {'🔥 VIRAL' if true_label == 1 else '📊 Non-Viral'}")
    print(f"Predicted Label: {'🔥 VIRAL' if pred_label == 1 else '📊 Non-Viral'}")
    print(f"Correct:         {'✓ YES' if true_label == pred_label else '✗ NO'}")
    print(f"\nProbabilities:")
    print(f"  Non-Viral: {non_viral_prob:5.2f}%")
    print(f"  Viral:     {viral_prob:5.2f}%")

# Calculate accuracy on random samples
random_accuracy = accuracy_score(sample_labels, predictions_np)
print(f"\n{'='*70}")
print(f"RANDOM SAMPLE ACCURACY: {random_accuracy*100:.2f}% ({int(random_accuracy*10)}/10 correct)")
print(f"{'='*70}")


TESTING ON RANDOM SAMPLES

RANDOM SAMPLE PREDICTIONS

SAMPLE #1
Tweet: Royalty 🙌😍 Buy now pay later with afterpay! Available at checkout on orders $35 plus 

:
:
All of our lashes can be worn up to 25 times with proper care ♻️
:
:
:
: 
#biglashes #makeupmaffia #browsford...

True Label:      📊 Non-Viral
Predicted Label: 📊 Non-Viral
Correct:         ✓ YES

Probabilities:
  Non-Viral: 99.97%
  Viral:      0.03%

SAMPLE #2
Tweet: @acj1225 @sanjin_xr @RealEyeman WRONG.

She was forced to become a lesbian in the DLC.

True Label:      📊 Non-Viral
Predicted Label: 📊 Non-Viral
Correct:         ✓ YES

Probabilities:
  Non-Viral: 99.99%
  Viral:      0.01%

SAMPLE #3
Tweet: Diplomacy doesn’t get you far in sports. They only remember the lie you told us. Ron gotta get media training. This ain’t Carolina

True Label:      📊 Non-Viral
Predicted Label: 📊 Non-Viral
Correct:         ✓ YES

Probabilities:
  Non-Viral: 100.00%
  Viral:      0.00%

SAMPLE #4
Tweet: Happy Valentine's Day to Ronniecoln 🧡

In [51]:
# =====================================================================
# SECTION 17: INFERENCE ON RANDOM SAMPLES + PROMPT ENGINEERING
# =====================================================================
print("\n" + "="*70)
print("TESTING ON RANDOM SAMPLES WITH PROMPT ENGINEERING")
print("="*70)

# Select 10 random samples from the test set
random_indices = np.random.choice(len(test_df), size=10, replace=False)
random_samples = test_df.iloc[random_indices].copy()

# Prepare samples for prediction
sample_texts = random_samples["text"].tolist()
sample_labels = random_samples["labels"].values

def predict_virality(texts, model, tokenizer, max_length=MAX_LENGTH):
    """Helper function to predict virality for a list of texts"""
    encodings = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )

    device = next(model.parameters()).device
    encodings = {k: v.to(device) for k, v in encodings.items()}

    model.eval()
    with torch.no_grad():
        outputs = model(**encodings)
        logits = outputs.logits
        probs = torch.softmax(logits, dim=1)
        predictions = torch.argmax(logits, dim=1)

    return predictions.cpu().numpy(), probs.cpu().numpy()

def prompt_engineer_for_virality(tweet):
    """
    Apply viral tweet engineering techniques:
    1. Add emotional hooks
    2. Create urgency/scarcity
    3. Use power words
    4. Add social proof elements
    5. Make it controversial/thought-provoking
    6. Use numbers and statistics
    7. Add call-to-action elements
    """

    # Extract core message (simplified approach)
    core = tweet.strip()

    # Viral engineering strategies (pick based on tweet content)
    strategies = [
        # Strategy 1: Emotional hook + urgency
        f"🚨 BREAKING: {core}\n\nThis changes EVERYTHING. Thread 🧵👇",

        # Strategy 2: Controversy + engagement
        f"Unpopular opinion: {core}\n\nChange my mind. 👇",

        # Strategy 3: Social proof + FOMO
        f"10M+ people are talking about this:\n\n{core}\n\nDon't miss out 🔥",

        # Strategy 4: Question + curiosity gap
        f"Why is nobody talking about this?\n\n{core}\n\nRetweet if you agree 🔄",

        # Strategy 5: Shock value + numbers
        f"97% of people don't know this:\n\n{core}\n\nLet that sink in. 💭",
    ]

    # For this demo, we'll return all 5 variations
    return strategies

# Display original predictions
print("\n" + "="*70)
print("ORIGINAL TWEETS - PREDICTIONS")
print("="*70)

original_preds, original_probs = predict_virality(sample_texts, model, tokenizer)

original_results = []
for i, (text, true_label, pred_label, prob) in enumerate(zip(
    sample_texts, sample_labels, original_preds, original_probs
), 1):
    viral_prob = prob[1] * 100
    original_results.append({
        'text': text,
        'true_label': true_label,
        'pred_label': pred_label,
        'viral_prob': viral_prob
    })

    print(f"\n{'='*70}")
    print(f"SAMPLE #{i}")
    print(f"{'='*70}")
    print(f"Tweet: {text[:150]}{'...' if len(text) > 150 else ''}")
    print(f"True Label:      {'🔥 VIRAL' if true_label == 1 else '📊 Non-Viral'}")
    print(f"Predicted:       {'🔥 VIRAL' if pred_label == 1 else '📊 Non-Viral'}")
    print(f"Viral Probability: {viral_prob:.2f}%")

# Now apply prompt engineering
print("\n\n" + "="*70)
print("PROMPT-ENGINEERED VERSIONS")
print("="*70)

engineered_results = []

for i, (text, result) in enumerate(zip(sample_texts, original_results), 1):
    print(f"\n{'#'*70}")
    print(f"SAMPLE #{i} - VIRAL ENGINEERING EXPERIMENTS")
    print(f"{'#'*70}")

    # Get engineered versions
    engineered_versions = prompt_engineer_for_virality(text)

    # Predict for all versions
    eng_preds, eng_probs = predict_virality(engineered_versions, model, tokenizer)

    print(f"\n📝 ORIGINAL (Viral prob: {result['viral_prob']:.2f}%):")
    print(f"   {text[:120]}{'...' if len(text) > 120 else ''}")

    best_improvement = 0
    best_version = None
    best_prob = result['viral_prob']

    for j, (eng_text, eng_pred, eng_prob) in enumerate(zip(
        engineered_versions, eng_preds, eng_probs
    ), 1):
        viral_prob = eng_prob[1] * 100
        improvement = viral_prob - result['viral_prob']

        print(f"\n🔧 VERSION {j} (Viral prob: {viral_prob:.2f}%, {improvement:+.2f}% change):")
        print(f"   {eng_text[:200]}{'...' if len(eng_text) > 200 else ''}")
        print(f"   Status: {'🔥 VIRAL' if eng_pred == 1 else '📊 Non-Viral'}")

        if viral_prob > best_prob:
            best_prob = viral_prob
            best_improvement = improvement
            best_version = j

    if best_version:
        print(f"\n✨ BEST VERSION: #{best_version} with {best_prob:.2f}% viral probability (+{best_improvement:.2f}%)")
    else:
        print(f"\n⚠️ Original tweet had highest viral probability")

    engineered_results.append({
        'original_prob': result['viral_prob'],
        'best_engineered_prob': best_prob,
        'improvement': best_improvement
    })

# Summary statistics
print("\n\n" + "="*70)
print("PROMPT ENGINEERING SUMMARY")
print("="*70)

avg_original = np.mean([r['original_prob'] for r in engineered_results])
avg_best = np.mean([r['best_engineered_prob'] for r in engineered_results])
avg_improvement = np.mean([r['improvement'] for r in engineered_results])

improved_count = sum(1 for r in engineered_results if r['improvement'] > 0)

print(f"\nAverage Original Viral Probability:    {avg_original:.2f}%")
print(f"Average Best Engineered Probability:   {avg_best:.2f}%")
print(f"Average Improvement:                   {avg_improvement:+.2f}%")
print(f"Samples Improved:                      {improved_count}/10 ({improved_count*10}%)")

max_improvement = max(engineered_results, key=lambda x: x['improvement'])
print(f"\nLargest Improvement:                   +{max_improvement['improvement']:.2f}%")
print(f"  (from {max_improvement['original_prob']:.2f}% to {max_improvement['best_engineered_prob']:.2f}%)")


TESTING ON RANDOM SAMPLES WITH PROMPT ENGINEERING

ORIGINAL TWEETS - PREDICTIONS

SAMPLE #1
Tweet: Is it tiiiiiiime to release the sweet treats... 🧛‍♂️ Muhhhhhaaaaaahhhhhhh! 🎃 everyone deserve to be a kid 💬 No matter how old you get! Soooon 🤗 https:...
True Label:      📊 Non-Viral
Predicted:       📊 Non-Viral
Viral Probability: 0.33%

SAMPLE #2
Tweet: As U.S. COVID deaths near 1 million, advocates press for a memorial day : NPR https://t.co/JwPfUtMSGa
True Label:      📊 Non-Viral
Predicted:       📊 Non-Viral
Viral Probability: 0.02%

SAMPLE #3
Tweet: Donald Trump is on Fox News trying to explain himself. There’s no excuse for his inaction. If he would have acted two weeks earlier thousands of lives...
True Label:      📊 Non-Viral
Predicted:       📊 Non-Viral
Viral Probability: 0.25%

SAMPLE #4
Tweet: Interesting to note that, in Huruma this pastor and his church grabbed a whole section of public road that passed through his church, fenced it turned...
True Label:      📊 Non-Viral
Pred