In [None]:
# %% Cell 1: Imports and Setup
import os
import json
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertModel, DistilBertTokenizer  # Using DistilBERT
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from sklearn.preprocessing import StandardScaler
import joblib
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import sklearn.metrics

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[Setup] Using device: {device}")

# Set batch size based on device
if device.type == "cuda":
    batch_size = 16
else:
    batch_size = 4
print(f"[Setup] Using batch size: {batch_size}")


In [None]:
# %% Cell 2: Model and Dataset Definitions

# ------------------------------
# Define the DistilBERT-based model.
# ------------------------------
class DeceptionBERTModel(nn.Module):
    def __init__(self, distilbert_model_name="distilbert-base-uncased", metadata_dim=9):
        super(DeceptionBERTModel, self).__init__()
        
        print("[Model] Loading DistilBERT model...")
        self.bert = DistilBertModel.from_pretrained(distilbert_model_name)
        self.bert_dim = self.bert.config.dim  # DistilBERT's hidden size
        
        # Layers for processing metadata
        self.metadata_fc = nn.Sequential(
            nn.Linear(metadata_dim, 128),  # Increased from 64 to 128
            nn.ReLU(),
            nn.Dropout(0.3),  # Increased dropout from 0.2 to 0.3
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
        )
        
        # Combined layers with higher dropout
        self.combined_fc = nn.Sequential(
            nn.Linear(self.bert_dim + 64, 512),  # Increased from 256 to 512
            nn.ReLU(),
            nn.Dropout(0.4),  # Increased from 0.3 to 0.4
            nn.Linear(512, 128),  # Added extra layer
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.3),  # Increased dropout from 0.2 to 0.3
            nn.Linear(64, 1)
        )
        print("[Model] Model architecture defined.")

    def forward(self, input_ids, attention_mask, metadata):
        # Get DistilBERT embeddings
        bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_embedding = bert_outputs.last_hidden_state[:, 0, :]  # take first token ([CLS]-like)
        
        # Process metadata
        metadata_embedding = self.metadata_fc(metadata)
        
        # Combine embeddings
        combined_embedding = torch.cat((cls_embedding, metadata_embedding), dim=1)
        
        # Final prediction (logit output)
        output = self.combined_fc(combined_embedding)
        
        return output


# ------------------------------
# Define the custom dataset class.
# ------------------------------
class DeceptionDataset(Dataset):
    def __init__(self, texts, metadata, labels=None, tokenizer=None, max_length=64):
        self.texts = texts
        self.metadata = metadata
        self.labels = labels  # these will be the "is_deceptive" labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        input_ids = encoding["input_ids"].squeeze(0)
        attention_mask = encoding["attention_mask"].squeeze(0)
        metadata = torch.tensor(self.metadata[idx], dtype=torch.float)

        item = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "metadata": metadata
        }
        if self.labels is not None:
            item["label"] = torch.tensor(self.labels[idx], dtype=torch.float)
        return item

print("[ Model and dataset classes loaded.")


In [None]:
# %% Cell 3: Data Loading and Preprocessing Functions

def load_processed_data(data_dir="processed_data_balanced"):
    """Load the processed and balanced data from CSV files."""
    print("[Data] Loading processed data...")
    
    train_df = pd.read_csv("/kaggle/input/processed-data/processed_train_balanced.csv")
    val_df = pd.read_csv("/kaggle/input/processed-data/processed_val.csv")
    test_df = pd.read_csv("/kaggle/input/processed-data/processed_test.csv")
    
    print(f"[Data] Train data shape: {train_df.shape}")
    print(f"[Data] Validation data shape: {val_df.shape}")
    print(f"[Data] Test data shape: {test_df.shape}")
    
    # Print original balance based on 'is_truthful'
    train_truthful = train_df["is_truthful"].sum()
    train_deceptive = len(train_df) - train_truthful
    print(f"[Data] Original training class balance (truthful): {train_truthful} truthful, {train_deceptive} deceptive")
    print(f"[Data] Truthful percentage: {train_truthful / len(train_df) * 100:.2f}%")
    
    return train_df, val_df, test_df

def get_metadata_features(df):
    """Get metadata features from the dataframe."""
    metadata_features = [
        "message_length",
        "word_count",
        "question_count",
        "exclamation_count",
        "has_uncertainty",
        "has_certainty",
        "conversation_length",
        "msg_position_in_convo",
        "position_ratio"
    ]
    if "sender_is_player" in df.columns:
        metadata_features.append("sender_is_player")
    if "prev_msg_truthful" in df.columns:
        metadata_features.append("prev_msg_truthful")
    if "game_stage" in df.columns:
        metadata_features.append("game_stage")
    
    print(f"[Data] Selected metadata features: {metadata_features}")
    return metadata_features

def create_datasets(train_df, val_df, test_df, tokenizer, metadata_features):
    """Create datasets for training, validation, and testing."""
    print("[Data] Creating datasets...")
    
    # Get texts from cleaned_message column
    train_texts = train_df["cleaned_message"].fillna("").values
    val_texts = val_df["cleaned_message"].fillna("").values
    test_texts = test_df["cleaned_message"].fillna("").values

    # ------------------------------------------------------------
    # UPDATED LABEL HANDLING:
    # Convert the original 'is_truthful' to 'is_deceptive'
    # 1 indicates deceptive (minority), 0 indicates truthful.
    # ------------------------------------------------------------
    train_labels = (1 - train_df["is_truthful"]).values
    val_labels = (1 - val_df["is_truthful"]).values
    test_labels = (1 - test_df["is_truthful"]).values
    
    # Convert metadata columns to numeric if necessary
    for col in metadata_features:
        for df in [train_df, val_df, test_df]:
            if col in df.columns:
                if (df[col].dtype == bool) or (df[col].dtype == "object" and df[col].isin([True, False, "True", "False"]).all()):
                    df[col] = df[col].map({True: 1, False: 0, "True": 1, "False": 0})
                if df[col].dtype == "object":
                    try:
                        df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0)
                    except Exception as e:
                        print(f"[Data] Converting column {col} to numeric, error: {e}")
                        df[col] = df[col].astype("category").cat.codes

    print(f"[Data] Using {len(metadata_features)} metadata features: {metadata_features}")
    train_metadata = train_df[metadata_features].values
    val_metadata = val_df[metadata_features].values
    test_metadata = test_df[metadata_features].values

    # Ensure all metadata columns are numeric and apply scaling.
    for i, feature in enumerate(metadata_features):
        for data, name in [(train_metadata, "train"), (val_metadata, "val"), (test_metadata, "test")]:
            if not np.issubdtype(data[:, i].dtype, np.number):
                print(f"[Data] Warning: {name} feature '{feature}' is not numeric; converting.")
                data[:, i] = np.array([float(str(x).replace("True", "1").replace("False", "0")) for x in data[:, i]])
    
    metadata_scaler = StandardScaler()
    train_metadata = metadata_scaler.fit_transform(train_metadata)
    val_metadata = metadata_scaler.transform(val_metadata)
    test_metadata = metadata_scaler.transform(test_metadata)
    
    joblib.dump(metadata_scaler, "models/bert_full_metadata_scaler.pkl")
    print("[Data] Metadata scaling complete and scaler saved.")
    
    train_dataset = DeceptionDataset(train_texts, train_metadata, train_labels, tokenizer)
    val_dataset = DeceptionDataset(val_texts, val_metadata, val_labels, tokenizer)
    test_dataset = DeceptionDataset(test_texts, test_metadata, test_labels, tokenizer)
    
    print("[Data] Datasets created successfully.")
    return train_dataset, val_dataset, test_dataset

print(" Data loading and preprocessing functions defined.")


In [None]:
# %% Cell 4: Training and Evaluation Functions

def train_model(model, train_dataset, val_dataset, device, batch_size=16, num_epochs=3, learning_rate=2e-5):
    print("[Train] Starting training...")
    
    # Calculate class weights based on new 'is_deceptive' labels.
    train_labels = np.array([item["label"].item() for item in train_dataset])
    deceptive_count = np.sum(train_labels)
    truthful_count = len(train_labels) - deceptive_count
    print(f"[Train] Deceptive count: {deceptive_count}, Truthful count: {truthful_count}")
    
    # Use pos_weight to emphasize the deceptive (minority) class (label 1)
    weight_for_deceptive = 3.0
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([weight_for_deceptive]).to(device))
    print(f"[Train] Using loss pos_weight: {weight_for_deceptive} for deceptive class.")

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=learning_rate, total_steps=len(train_loader) * num_epochs
    )

    best_val_f1 = 0.0
    patience = 2
    epochs_no_improve = 0
    history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": [], "val_f1": []}

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        print(f"[Train] Epoch {epoch+1}/{num_epochs} - Training started.")
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} (Train)"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            metadata = batch["metadata"].to(device)
            labels = batch["label"].to(device).unsqueeze(1)
            
            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask, metadata)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()

            train_loss += loss.item() * input_ids.size(0)
            train_total += input_ids.size(0)
            preds = (torch.sigmoid(outputs) >= 0.5).float()  # using consistent threshold 0.5
            train_correct += (preds == labels).sum().item()
        
        train_loss = train_loss / train_total
        train_acc = train_correct / train_total
        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        
        print(f"[Train] Epoch {epoch+1} completed. Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_preds = []
        val_labels_list = []
        print(f"[Val] Starting validation for Epoch {epoch+1}...")
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} (Val)"):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                metadata = batch["metadata"].to(device)
                labels = batch["label"].to(device).unsqueeze(1)
                outputs = model(input_ids, attention_mask, metadata)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * input_ids.size(0)
                preds = (torch.sigmoid(outputs) >= 0.5).float()
                val_preds.extend(preds.cpu().numpy())
                val_labels_list.extend(labels.cpu().numpy())
        val_loss = val_loss / len(val_dataset)
        val_preds = np.array(val_preds).flatten()
        val_labels_list = np.array(val_labels_list).flatten()
        val_acc = accuracy_score(val_labels_list, val_preds)
        val_f1 = f1_score(val_labels_list, val_preds)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)
        history["val_f1"].append(val_f1)
        
        print(f"[Val] Epoch {epoch+1}: Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val F1: {val_f1:.4f}")
        
        # Early stopping check
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            epochs_no_improve = 0
            torch.save(model.state_dict(), "models/bert_full_model_best.pt")
            print(f"[Train] Model improved! Best F1 updated: {best_val_f1:.4f}")
        else:
            epochs_no_improve += 1
            print(f"[Train] No improvement for {epochs_no_improve} epoch(s).")
        if epochs_no_improve >= patience:
            print(f"[Train] Early stopping triggered at epoch {epoch+1}.")
            break

    # Plot learning curves
    print("[Train] Plotting learning curves...")
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.plot(history["train_loss"], label="Train")
    plt.plot(history["val_loss"], label="Validation")
    plt.title("Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.subplot(1, 3, 2)
    plt.plot(history["train_acc"], label="Train")
    plt.plot(history["val_acc"], label="Validation")
    plt.title("Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.subplot(1, 3, 3)
    plt.plot(history["val_f1"], label="Validation")
    plt.title("F1 Score")
    plt.xlabel("Epoch")
    plt.ylabel("F1 Score")
    plt.legend()
    plt.tight_layout()
    plt.savefig("results/bert_full_learning_curves.png")
    plt.close()
    
    model.load_state_dict(torch.load("models/bert_full_model_best.pt"))
    print(f"[Train] Loaded best model with validation F1: {best_val_f1:.4f}")
    return model

def evaluate_model(model, test_dataset, device, batch_size=16):
    print("[Evaluate] Evaluating model on test set...")
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    model.eval()
    test_preds = []
    test_probs = []
    test_labels_list = []
    
    threshold = 0.5  # consistent threshold for evaluation
    print(f"[Evaluate] Using threshold: {threshold}")
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            metadata = batch["metadata"].to(device)
            labels = batch["label"].to(device).unsqueeze(1)
            outputs = model(input_ids, attention_mask, metadata)
            probs = torch.sigmoid(outputs)
            preds = (probs >= threshold).float()
            test_preds.extend(preds.cpu().numpy())
            test_probs.extend(probs.cpu().numpy())
            test_labels_list.extend(labels.cpu().numpy())
    
    test_preds = np.array(test_preds).flatten()
    test_probs = np.array(test_probs).flatten()
    test_labels_list = np.array(test_labels_list).flatten()
    metrics = calculate_metrics(test_labels_list, test_preds, test_probs, threshold)
    
    df_preds = pd.DataFrame({
        "prediction": test_preds,
        "probability": test_probs,
        "true_label": test_labels_list
    })
    df_preds.to_csv("results/bert_full_predictions.csv", index=False)
    
    calibrate_threshold(test_labels_list, test_probs)
    return metrics

def calculate_metrics(y_true, y_pred, y_prob, threshold=0.5):
    print("[Metrics] Calculating metrics...")
    acc = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, zero_division=0)
    recall = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    cm = confusion_matrix(y_true, y_pred)
    
    # Adjust confusion matrix shape if necessary
    if cm.shape != (2, 2):
        if cm.shape == (1, 1):
            if y_true[0] == 1:
                cm = np.array([[0, 0], [0, cm[0, 0]]])
            else:
                cm = np.array([[cm[0, 0], 0], [0, 0]])
    tn, fp, fn, tp = cm.ravel()
    specificity = tn / (tn + fp) if (tn + fp) != 0 else 0
    sensitivity = tp / (tp + fn) if (tp + fn) != 0 else 0
    balanced_acc = (specificity + sensitivity) / 2
    print(f"[Metrics] Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}")
    print(f"[Metrics] Specificity: {specificity:.4f}, Sensitivity: {sensitivity:.4f}")
    print(f"[Metrics] Confusion Matrix:\n{cm}")
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=["Truthful", "Deceptive"], yticklabels=["Truthful", "Deceptive"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.tight_layout()
    plt.savefig("results/bert_full_confusion_matrix.png")
    plt.close()
    
    with open("results/bert_full_metrics.json", "w") as f:
        json.dump({
            "accuracy": float(acc),
            "balanced_accuracy": float(balanced_acc),
            "precision": float(precision),
            "recall": float(recall),
            "f1": float(f1),
            "specificity": float(specificity),
            "sensitivity": float(sensitivity),
            "threshold": float(threshold),
            "confusion_matrix": cm.tolist()
        }, f, indent=4)
    
    return {
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "specificity": specificity,
        "sensitivity": sensitivity,
        "balanced_accuracy": balanced_acc,
        "confusion_matrix": cm
    }

def calibrate_threshold(y_true, y_prob):
    print("[Calibrate] Calibrating classification threshold...")
    fpr, tpr, thresholds = sklearn.metrics.roc_curve(y_true, y_prob)
    roc_auc = sklearn.metrics.auc(fpr, tpr)
    gmeans = np.sqrt(tpr * (1 - fpr))
    ix = np.argmax(gmeans)
    best_threshold = thresholds[ix]
    y_pred_best = (y_prob >= best_threshold).astype(int)
    cm_best = confusion_matrix(y_true, y_pred_best)
    acc_best = accuracy_score(y_true, y_pred_best)
    precision_best = precision_score(y_true, y_pred_best, zero_division=0)
    recall_best = recall_score(y_true, y_pred_best, zero_division=0)
    f1_best = f1_score(y_true, y_pred_best, zero_division=0)
    
    plt.figure(figsize=(10, 8))
    plt.plot(fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.4f})")
    plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
    plt.scatter(fpr[ix], tpr[ix], marker="o", color="black", label=f"Best threshold = {best_threshold:.4f}")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Receiver Operating Characteristic (ROC) Curve")
    plt.legend(loc="lower right")
    plt.savefig("results/bert_full_roc_curve.png")
    plt.close()
    
    print(f"[Calibrate] Best threshold: {best_threshold:.4f}")
    print(f"[Calibrate] Accuracy: {acc_best:.4f}, F1: {f1_best:.4f}")
    print(f"[Calibrate] Confusion Matrix with best threshold:\n{cm_best}")
    
    with open("results/bert_full_calibration.json", "w") as f:
        json.dump({
            "best_threshold": float(best_threshold),
            "accuracy": float(acc_best),
            "precision": float(precision_best),
            "recall": float(recall_best),
            "f1": float(f1_best),
            "roc_auc": float(roc_auc),
            "confusion_matrix": cm_best.tolist()
        }, f, indent=4)
    return best_threshold

print("[Cell 4] Training and evaluation functions defined.")


In [None]:
# %% Cell 5: Main Routine

def main():
    os.makedirs("models", exist_ok=True)
    os.makedirs("results", exist_ok=True)
    
    print("[Main] Loading processed data...")
    train_df, val_df, test_df = load_processed_data()
    
    # ---------------------------------------------------
    # Balance the training data by oversampling the deceptive class.
    # Use the original is_truthful field and later invert to get is_deceptive.
    # ---------------------------------------------------
    truthful_df = train_df[train_df["is_truthful"] == 1]
    deceptive_df = train_df[train_df["is_truthful"] == 0]
    n_to_sample = len(truthful_df) - len(deceptive_df)
    print(f"[Main] Oversampling deceptive data: need {n_to_sample} extra samples.")
    oversampled_deceptive = deceptive_df.sample(n_to_sample, replace=True, random_state=42)
    train_df = pd.concat([truthful_df, deceptive_df, oversampled_deceptive])
    train_df = train_df.sample(frac=1, random_state=42).reset_index(drop=True)
    print(f"[Main] Balanced train data shape: {train_df.shape}")
    print(f"[Main] Validation data shape: {val_df.shape}")
    print(f"[Main] Test data shape: {test_df.shape}")
    train_truthful = train_df["is_truthful"].sum()
    train_deceptive = len(train_df) - train_truthful
    print(f"[Main] After balancing: {train_truthful} truthful, {train_deceptive} deceptive (original labels)")
    print(f"[Main] Truthful percentage: {train_truthful / len(train_df) * 100:.2f}%")
    
    print("[Main] Initializing tokenizer...")
    model_name = "distilbert-base-uncased"
    tokenizer = DistilBertTokenizer.from_pretrained(model_name)
    print(f"[Main] Using tokenizer from model: {model_name}")
    
    metadata_features = get_metadata_features(train_df)
    train_dataset, val_dataset, test_dataset = create_datasets(train_df, val_df, test_df, tokenizer, metadata_features)
    
    print("[Main] Initializing model...")
    model = DeceptionBERTModel(model_name, metadata_dim=len(metadata_features))
    model = model.to(device)
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"[Main] Total parameters: {total_params:,}")
    print(f"[Main] Trainable parameters: {trainable_params:,}")
    print(f"[Main] Percentage trainable: {trainable_params / total_params * 100:.2f}%")
    
    print("[Main] Starting training process...")
    model = train_model(model, train_dataset, val_dataset, device, batch_size, num_epochs=3)
    
    print("[Main] Evaluating model on test set...")
    evaluate_model(model, test_dataset, device, batch_size)
    
    print("[Main] Saving final model and tokenizer...")
    torch.save(model.state_dict(), "models/bert_full_model.pt")
    os.makedirs("models/bert_full_tokenizer", exist_ok=True)
    tokenizer.save_pretrained("models/bert_full_tokenizer")
    print("[Main] Process completed.")

if __name__ == "__main__":
    main()


# **Updating based on the above results**

In [None]:
# %% Cell 1: Imports and Setup
import os
import json
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaModel, RobertaTokenizer  # Use RoBERTa for advanced representations
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from sklearn.preprocessing import StandardScaler
import joblib
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import sklearn.metrics

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[Setup] Using device: {device}")

# Set batch size based on device
if device.type == "cuda":
    batch_size = 16
else:
    batch_size = 4
print(f"[Setup] Using batch size: {batch_size}")


In [None]:


# --- Define Focal Loss ---
class FocalLoss(nn.Module):
    """
    Focal Loss for binary classification.
    alpha: weight for positive class.
    gamma: focusing parameter.
    reduction: reduction method ('mean' or 'sum').
    """
    def __init__(self, alpha=1, gamma=2, reduction="mean"):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, inputs, targets):
        # inputs: logits; targets: float tensor of 0s and 1s
        bce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
        probs = torch.sigmoid(inputs)
        p_t = targets * probs + (1 - targets) * (1 - probs)
        loss = self.alpha * (1 - p_t) ** self.gamma * bce_loss
        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        else:
            return loss

print("[Cell 2] FocalLoss defined.")

# --- Hierarchical Deception Model ---
class HierarchicalDeceptionModel(nn.Module):
    def __init__(self, roberta_model_name="roberta-base", metadata_dim=9, 
                 hidden_size=768, lstm_hidden_size=256, num_context_msgs=5):
        """
        roberta_model_name: name of the transformer model.
        metadata_dim: dimension of metadata features.
        hidden_size: hidden size of RoBERTa (768 for roberta-base).
        lstm_hidden_size: hidden size for the context LSTM.
        num_context_msgs: maximum number of context messages.
        """
        super(HierarchicalDeceptionModel, self).__init__()
        print("[Model] Loading RoBERTa model...")
        self.roberta = RobertaModel.from_pretrained(roberta_model_name)
        self.roberta_hidden = self.roberta.config.hidden_size
        self.num_context_msgs = num_context_msgs
        
        # For encoding each context message individually
        # We reuse the same roberta for context messages
        
        # LSTM to aggregate context embeddings (each message’s [CLS] token)
        self.context_lstm = nn.LSTM(input_size=self.roberta_hidden, hidden_size=lstm_hidden_size,
                                    batch_first=True, bidirectional=True)
        self.lstm_out_dim = lstm_hidden_size * 2  # because bidirectional
        
        # Process current message (encoded using roberta)
        # We use the [CLS] representation of the current message directly.
        
        # Process metadata features
        self.metadata_fc = nn.Sequential(
            nn.Linear(metadata_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        
        # Final fusion: concatenate current message, context, and metadata
        fusion_dim = self.roberta_hidden + self.lstm_out_dim + 64
        self.fusion_fc = nn.Sequential(
            nn.Linear(fusion_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1)
        )
        
        print("[Model] HierarchicalDeceptionModel initialized.")
    
    def forward(self, current_input_ids, current_attention_mask, 
                context_input_ids, context_attention_mask, metadata):
        """
        current_input_ids, current_attention_mask: for the current message.
        context_input_ids: [batch, num_context_msgs, seq_len]
        context_attention_mask: [batch, num_context_msgs, seq_len]
        metadata: [batch, metadata_dim]
        """
        # Encode current message
        current_outputs = self.roberta(input_ids=current_input_ids, 
                                       attention_mask=current_attention_mask)
        current_embedding = current_outputs.last_hidden_state[:, 0, :]  # [CLS] token
        
        # Encode context messages in a loop (or using flatten and reshape)
        batch_size, num_context, seq_len = context_input_ids.shape
        # Merge batch and context dims for encoding
        flat_context_ids = context_input_ids.view(-1, seq_len)
        flat_context_mask = context_attention_mask.view(-1, seq_len)
        context_outputs = self.roberta(input_ids=flat_context_ids,
                                       attention_mask=flat_context_mask)
        # Get [CLS] equivalent for each context message
        context_embeddings = context_outputs.last_hidden_state[:, 0, :]  # shape [batch*num_context, hidden]
        # Reshape back to [batch, num_context, hidden]
        context_embeddings = context_embeddings.view(batch_size, num_context, -1)
        
        # Process through LSTM: obtain last hidden state (or use mean pooling)
        lstm_out, (h_n, c_n) = self.context_lstm(context_embeddings)
        # Concatenate the final states from both directions
        # Alternatively, use mean across time steps:
        context_rep = lstm_out.mean(dim=1)  # shape [batch, lstm_out_dim]
        
        # Process metadata
        meta_out = self.metadata_fc(metadata)
        
        # Fusion: concatenate current message, context representation, and metadata
        fusion_input = torch.cat((current_embedding, context_rep, meta_out), dim=1)
        output = self.fusion_fc(fusion_input)
        return output

print("[Cell 2] HierarchicalDeceptionModel defined.")

# --- Hierarchical Dataset ---
class HierarchicalDeceptionDataset(Dataset):
    def __init__(self, texts, context_texts, metadata, labels=None, tokenizer=None, 
                 max_length=64, max_context=5, context_delim="||"):
        """
        texts: current message texts.
        context_texts: context field (each entry is a string with previous messages separated by context_delim).
        metadata: metadata features.
        labels: target labels.
        max_length: maximum token length per message.
        max_context: maximum number of context messages to use.
        context_delim: delimiter used in the context field.
        """
        self.texts = texts
        self.context_texts = context_texts
        self.metadata = metadata
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.max_context = max_context
        self.context_delim = context_delim

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        # Process current message
        current_text = str(self.texts[idx])
        current_enc = self.tokenizer(
            current_text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        current_input_ids = current_enc["input_ids"].squeeze(0)
        current_attention_mask = current_enc["attention_mask"].squeeze(0)
        
        # Process context: if context_texts is missing or empty, use a list with one empty string.
        context_str = str(self.context_texts[idx]) if self.context_texts is not None else ""
        # Split by delimiter. If no context exists, provide empty message(s)
        context_msgs = [msg.strip() for msg in context_str.split(self.context_delim) if msg.strip()]
        # Limit number of context messages
        context_msgs = context_msgs[-self.max_context:]  # use last max_context messages
        # If not enough messages, pad with empty strings
        while len(context_msgs) < self.max_context:
            context_msgs.insert(0, "")
        
        # Tokenize each context message
        context_encodings = [self.tokenizer(
            msg,
            add_special_tokens=True,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        ) for msg in context_msgs]
        # Stack input_ids and attention masks along new dimension: shape [max_context, seq_len]
        context_input_ids = torch.stack([enc["input_ids"].squeeze(0) for enc in context_encodings])
        context_attention_mask = torch


In [None]:
#Cell 2: HierarchicalDeceptionDataset Definition with Error Handling

class HierarchicalDeceptionDataset(Dataset):
    def __init__(self, texts, context_texts, metadata, labels=None, tokenizer=None, 
                 max_length=64, max_context=5, context_delim="||"):
        """
        texts: current message texts.
        context_texts: context field (each entry is a string with previous messages separated by context_delim).
        metadata: metadata features.
        labels: target labels.
        max_length: maximum token length per message.
        max_context: maximum number of context messages to use.
        context_delim: delimiter used in the context field.
        """
        self.texts = texts
        self.context_texts = context_texts
        self.metadata = metadata
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.max_context = max_context
        self.context_delim = context_delim

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        # Ensure current text is not None or NaN:
        current_text = self.texts[idx]
        if current_text is None or (isinstance(current_text, float) and np.isnan(current_text)):
            # If missing, default to empty string.
            current_text = ""
            # Uncomment next line for debug if needed:
            # print(f"Warning: Missing current message at index {idx}, defaulting to empty string.")
        else:
            current_text = str(current_text)
            
        current_enc = self.tokenizer(
            current_text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        current_input_ids = current_enc["input_ids"].squeeze(0)
        current_attention_mask = current_enc["attention_mask"].squeeze(0)
        
        # Process context: Check if context_texts is None or if this particular value is missing.
        context_entry = self.context_texts[idx]
        if context_entry is None or (isinstance(context_entry, float) and np.isnan(context_entry)):
            context_str = ""
            # Debug print if desired:
            # print(f"Warning: Missing context at index {idx}, defaulting to empty string.")
        else:
            context_str = str(context_entry)
            
        # Split context messages by delimiter; if empty, it remains empty.
        context_msgs = [msg.strip() for msg in context_str.split(self.context_delim) if msg.strip()]
        # Limit to the last max_context messages
        context_msgs = context_msgs[-self.max_context:]
        # Pad with empty strings if fewer messages than max_context
        while len(context_msgs) < self.max_context:
            context_msgs.insert(0, "")
        
        # Tokenize each context message
        context_encodings = []
        for msg in context_msgs:
            # If the message is empty, still tokenize to avoid None.
            msg = msg if msg is not None else ""
            enc = self.tokenizer(
                msg,
                add_special_tokens=True,
                max_length=self.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt"
            )
            context_encodings.append(enc)
        
        context_input_ids = torch.stack([enc["input_ids"].squeeze(0) for enc in context_encodings])
        context_attention_mask = torch.stack([enc["attention_mask"].squeeze(0) for enc in context_encodings])
        
        # Process metadata (if any value is missing, this step should have been handled in preprocessing)
        meta = torch.tensor(self.metadata[idx], dtype=torch.float)
        
        item = {
            "current_input_ids": current_input_ids,
            "current_attention_mask": current_attention_mask,
            "context_input_ids": context_input_ids,  # shape [max_context, seq_len]
            "context_attention_mask": context_attention_mask,
            "metadata": meta,
        }
        if self.labels is not None:
            item["label"] = torch.tensor(self.labels[idx], dtype=torch.float)
            
        return item

print("[Cell 2 Revised] HierarchicalDeceptionDataset defined with improved error handling.")


In [None]:
# %% Cell 3: Data Loading and Preprocessing Functions

def load_processed_data(data_dir="processed_data_balanced"):
    """Load processed data CSVs and print shapes."""
    print("[Data] Loading processed CSV data...")
    train_df = pd.read_csv("/kaggle/input/processed-data/processed_train_balanced.csv")
    val_df = pd.read_csv("/kaggle/input/processed-data/processed_val.csv")
    test_df = pd.read_csv("/kaggle/input/processed-data/processed_test.csv")
    
    print(f"[Data] Train shape: {train_df.shape}")
    print(f"[Data] Validation shape: {val_df.shape}")
    print(f"[Data] Test shape: {test_df.shape}")
    
    # Report original class balance (using 'is_truthful')
    train_truthful = train_df["is_truthful"].sum()
    train_deceptive = len(train_df) - train_truthful
    print(f"[Data] Original train balance: {train_truthful} truthful, {train_deceptive} deceptive")
    print(f"[Data] Truthful percentage: {train_truthful / len(train_df) * 100:.2f}%")
    
    return train_df, val_df, test_df

def get_metadata_features(df):
    """Return list of metadata feature names."""
    metadata_features = [
        "message_length", "word_count", "question_count", "exclamation_count",
        "has_uncertainty", "has_certainty", "conversation_length", "msg_position_in_convo", "position_ratio"
    ]
    # Optionally add extra features if present.
    if "sender_is_player" in df.columns:
        metadata_features.append("sender_is_player")
    if "prev_msg_truthful" in df.columns:
        metadata_features.append("prev_msg_truthful")
    if "game_stage" in df.columns:
        metadata_features.append("game_stage")
    print(f"[Data] Metadata features: {metadata_features}")
    return metadata_features

def create_datasets(train_df, val_df, test_df, tokenizer, metadata_features, max_context=5):
    """
    Create hierarchical datasets.
    Assumes the CSVs have a column "cleaned_message" for current text,
    a "context" column that contains previous messages delimited by "||",
    and an "is_truthful" column (which will be inverted to "is_deceptive").
    """
    print("[Data] Creating hierarchical datasets...")
    # Get texts and context
    train_texts = train_df["cleaned_message"].fillna("").values
    val_texts = val_df["cleaned_message"].fillna("").values
    test_texts = test_df["cleaned_message"].fillna("").values
    
    # For context, if not provided, fill with empty strings.
    if "context" in train_df.columns:
        train_context = train_df["context"].fillna("").values
        val_context = val_df["context"].fillna("").values
        test_context = test_df["context"].fillna("").values
    else:
        train_context = np.array([""] * len(train_df))
        val_context = np.array([""] * len(val_df))
        test_context = np.array([""] * len(test_df))
    
    # Invert is_truthful to get is_deceptive label: 1 for deceptive.
    train_labels = (1 - train_df["is_truthful"]).values
    val_labels = (1 - val_df["is_truthful"]).values
    test_labels = (1 - test_df["is_truthful"]).values
    
    # Process metadata columns: convert booleans and objects to numerics.
    for col in metadata_features:
        for df in [train_df, val_df, test_df]:
            if col in df.columns:
                if (df[col].dtype == bool) or (df[col].dtype == "object" and df[col].isin([True, False, "True", "False"]).all()):
                    df[col] = df[col].map({True: 1, False: 0, "True": 1, "False": 0})
                if df[col].dtype == "object":
                    try:
                        df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0)
                    except Exception as e:
                        print(f"[Data] Converting column {col}: error {e}")
                        df[col] = df[col].astype("category").cat.codes
    
    print(f"[Data] Using metadata features: {metadata_features}")
    train_metadata = train_df[metadata_features].values
    val_metadata = val_df[metadata_features].values
    test_metadata = test_df[metadata_features].values

    # Scale metadata
    metadata_scaler = StandardScaler()
    train_metadata = metadata_scaler.fit_transform(train_metadata)
    val_metadata = metadata_scaler.transform(val_metadata)
    test_metadata = metadata_scaler.transform(test_metadata)
    joblib.dump(metadata_scaler, "models/rob_focal_metadata_scaler.pkl")
    print("[Data] Metadata scaling complete.")
    
    # Create hierarchical datasets
    train_dataset = HierarchicalDeceptionDataset(train_texts, train_context, train_metadata, train_labels,
                                                   tokenizer, max_length=64, max_context=max_context)
    val_dataset = HierarchicalDeceptionDataset(val_texts, val_context, val_metadata, val_labels,
                                                 tokenizer, max_length=64, max_context=max_context)
    test_dataset = HierarchicalDeceptionDataset(test_texts, test_context, test_metadata, test_labels,
                                                  tokenizer, max_length=64, max_context=max_context)
    print("[Data] Hierarchical datasets created.")
    return train_dataset, val_dataset, test_dataset

print("[Cell 3] Data loading and preprocessing functions defined.")


In [None]:
# %% Cell 4: Training and Evaluation Functions

def train_model(model, train_dataset, val_dataset, device, batch_size=16, num_epochs=3, learning_rate=2e-5):
    print("[Train] Starting training process...")
    
    # Use focal loss with alpha to emphasize positive (deceptive) examples.
    focal_loss = FocalLoss(alpha=3.0, gamma=2, reduction="mean")
    print("[Train] Using FocalLoss with alpha=3.0 and gamma=2.")
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=learning_rate, total_steps=len(train_loader)*num_epochs
    )
    
    best_val_macro_f1 = 0.0
    patience = 2
    epochs_no_improve = 0
    history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": [], "val_macro_f1": []}
    
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0
        correct = 0
        total = 0
        print(f"[Train] Epoch {epoch+1}/{num_epochs} started.")
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} Train"):
            # Process current message and context from batch
            current_ids = batch["current_input_ids"].to(device)
            current_mask = batch["current_attention_mask"].to(device)
            context_ids = batch["context_input_ids"].to(device)  # shape [B, max_context, seq_len]
            context_mask = batch["context_attention_mask"].to(device)
            metadata = batch["metadata"].to(device)
            labels = batch["label"].to(device).unsqueeze(1)
            
            optimizer.zero_grad()
            outputs = model(current_ids, current_mask, context_ids, context_mask, metadata)
            loss = focal_loss(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            epoch_loss += loss.item() * current_ids.size(0)
            total += current_ids.size(0)
            preds = (torch.sigmoid(outputs) >= 0.5).float()
            correct += (preds == labels).sum().item()
        
        avg_loss = epoch_loss / total
        train_acc = correct / total
        history["train_loss"].append(avg_loss)
        history["train_acc"].append(train_acc)
        print(f"[Train] Epoch {epoch+1} finished. Avg Loss: {avg_loss:.4f}, Train Acc: {train_acc:.4f}")
        
        # Validation
        model.eval()
        val_loss = 0.0
        all_preds, all_labels = [], []
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1} Validation"):
                current_ids = batch["current_input_ids"].to(device)
                current_mask = batch["current_attention_mask"].to(device)
                context_ids = batch["context_input_ids"].to(device)
                context_mask = batch["context_attention_mask"].to(device)
                metadata = batch["metadata"].to(device)
                labels = batch["label"].to(device).unsqueeze(1)
                
                outputs = model(current_ids, current_mask, context_ids, context_mask, metadata)
                loss = focal_loss(outputs, labels)
                val_loss += loss.item() * current_ids.size(0)
                probs = torch.sigmoid(outputs)
                preds = (probs >= 0.5).float()
                all_preds.extend(preds.cpu().numpy().flatten())
                all_labels.extend(labels.cpu().numpy().flatten())
        
        avg_val_loss = val_loss / len(val_dataset)
        val_acc = accuracy_score(all_labels, all_preds)
        # Compute macro-F1 using average="macro"
        val_macro_f1 = f1_score(all_labels, all_preds, average="macro", zero_division=0)
        history["val_loss"].append(avg_val_loss)
        history["val_acc"].append(val_acc)
        history["val_macro_f1"].append(val_macro_f1)
        print(f"[Validation] Epoch {epoch+1}: Loss: {avg_val_loss:.4f}, Acc: {val_acc:.4f}, Macro-F1: {val_macro_f1:.4f}")
        
        if val_macro_f1 > best_val_macro_f1:
            best_val_macro_f1 = val_macro_f1
            epochs_no_improve = 0
            torch.save(model.state_dict(), "models/hierarchical_model_best_1.pt")
            print(f"[Train] Model improved! Best Macro-F1: {best_val_macro_f1:.4f}")
        else:
            epochs_no_improve += 1
            print(f"[Train] No improvement for {epochs_no_improve} epoch(s).")
        if epochs_no_improve >= patience:
            print("[Train] Early stopping triggered.")
            break
            
    # Plot training curves
    plt.figure(figsize=(12,4))
    plt.subplot(1,3,1)
    plt.plot(history["train_loss"], label="Train Loss")
    plt.plot(history["val_loss"], label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Loss")
    plt.legend()
    plt.subplot(1,3,2)
    plt.plot(history["train_acc"], label="Train Acc")
    plt.plot(history["val_acc"], label="Val Acc")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Accuracy")
    plt.legend()
    plt.subplot(1,3,3)
    plt.plot(history["val_macro_f1"], label="Val Macro-F1")
    plt.xlabel("Epoch")
    plt.ylabel("Macro-F1")
    plt.title("Macro F1 Score")
    plt.legend()
    plt.tight_layout()
    plt.savefig("results/hierarchical_learning_curves_1.png")
    plt.close()
    
    model.load_state_dict(torch.load("models/hierarchical_model_best_1.pt"))
    print(f"[Train] Loaded best model with Validation Macro-F1: {best_val_macro_f1:.4f}")
    return model

def evaluate_model(model, test_dataset, device, batch_size=16):
    print("[Evaluate] Evaluating on test set...")
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    model.eval()
    all_preds, all_labels, all_probs = [], [], []
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            current_ids = batch["current_input_ids"].to(device)
            current_mask = batch["current_attention_mask"].to(device)
            context_ids = batch["context_input_ids"].to(device)
            context_mask = batch["context_attention_mask"].to(device)
            metadata = batch["metadata"].to(device)
            labels = batch["label"].to(device).unsqueeze(1)
            outputs = model(current_ids, current_mask, context_ids, context_mask, metadata)
            probs = torch.sigmoid(outputs)
            preds = (probs >= 0.5).float()
            all_preds.extend(preds.cpu().numpy().flatten())
            all_labels.extend(labels.cpu().numpy().flatten())
            all_probs.extend(probs.cpu().numpy().flatten())
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)
    
    # Compute metrics
    acc = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, zero_division=0)
    recall = recall_score(all_labels, all_preds, zero_division=0)
    f1 = f1_score(all_labels, all_preds, zero_division=0)
    macro_f1 = f1_score(all_labels, all_preds, average="macro", zero_division=0)
    cm = confusion_matrix(all_labels, all_preds)
    
    print("[Evaluate] Metrics on test set:")
    print(f"  Accuracy: {acc:.4f}")
    print(f"  Precision: {precision:.4f}")
    print(f"  Recall: {recall:.4f}")
    print(f"  F1 Score: {f1:.4f}")
    print(f"  Macro F1 Score: {macro_f1:.4f}")
    print(f"  Confusion Matrix:\n{cm}")
    
    plt.figure(figsize=(8,6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", 
                xticklabels=["Truthful", "Deceptive"],
                yticklabels=["Truthful", "Deceptive"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Test Confusion Matrix")
    plt.tight_layout()
    plt.savefig("results/hierarchical_confusion_matrix_1.png")
    plt.close()
    
    
    # Save predictions to CSV
    df_preds = pd.DataFrame({"prediction": all_preds, "probability": all_probs, "true_label": all_labels})
    df_preds.to_csv("results/hierarchical_predictions_1.csv", index=False)
    
    # Save overall metrics as JSON
    metrics = {
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "macro_f1": macro_f1,
        "confusion_matrix": cm.tolist()
    }
    with open("results/hierarchical_metrics_1.json", "w") as f:
        json.dump(metrics, f, indent=4)
        
    return metrics

print("[Cell 4] Training and evaluation functions defined.")


In [None]:
# %% Cell 5: Main Routine

def main():
    os.makedirs("models", exist_ok=True)
    os.makedirs("results", exist_ok=True)
    
    print("[Main] Loading processed data...")
    train_df, val_df, test_df = load_processed_data()
    
    # Balance training data based on original is_truthful and then invert label.
    truthful_df = train_df[train_df["is_truthful"] == 1]
    deceptive_df = train_df[train_df["is_truthful"] == 0]
    n_to_sample = len(truthful_df) - len(deceptive_df)
    print(f"[Main] Oversampling deceptive class: adding {n_to_sample} samples.")
    oversampled_deceptive = deceptive_df.sample(n_to_sample, replace=True, random_state=42)
    train_df = pd.concat([truthful_df, deceptive_df, oversampled_deceptive])
    train_df = train_df.sample(frac=1, random_state=42).reset_index(drop=True)
    print(f"[Main] Balanced train data shape: {train_df.shape}")
    print(f"[Main] Validation data shape: {val_df.shape}")
    print(f"[Main] Test data shape: {test_df.shape}")
    train_truthful = train_df["is_truthful"].sum()
    train_deceptive = len(train_df) - train_truthful
    print(f"[Main] Original labels: {train_truthful} truthful, {train_deceptive} deceptive")
    print(f"[Main] Truthful percentage: {train_truthful / len(train_df) * 100:.2f}%")
    
    # Initialize tokenizer from RoBERTa
    print("[Main] Initializing tokenizer...")
    model_name = "roberta-base"
    tokenizer = RobertaTokenizer.from_pretrained(model_name)
    print(f"[Main] Tokenizer loaded from {model_name}")
    
    # Get metadata feature list
    metadata_features = get_metadata_features(train_df)
    
    # Create hierarchical datasets.
    # (Ensure that your CSVs include a "context" column; if not, context will default to empty strings.)
    train_dataset, val_dataset, test_dataset = create_datasets(train_df, val_df, test_df, tokenizer, metadata_features, max_context=5)
    
    # Initialize Hierarchical Model
    print("[Main] Initializing Hierarchical Deception Model...")
    model = HierarchicalDeceptionModel(roberta_model_name=model_name, metadata_dim=len(metadata_features), num_context_msgs=5)
    model = model.to(device)
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"[Main] Total parameters: {total_params:,}")
    print(f"[Main] Trainable parameters: {trainable_params:,} ({trainable_params/total_params*100:.2f}%)")
    
    # Train model
    print("[Main] Starting training...")
    model = train_model(model, train_dataset, val_dataset, device, batch_size=batch_size, num_epochs=5, learning_rate=2e-5)
    
    # Evaluate on test set
    print("[Main] Evaluating model on test set...")
    test_metrics = evaluate_model(model, test_dataset, device, batch_size=batch_size)
    print("[Main] Test Metrics:")
    for k, v in test_metrics.items():
        print(f"  {k}: {v}")
    
    # Save final model and tokenizer
    print("[Main] Saving final model and tokenizer...")
    torch.save(model.state_dict(), "models/hierarchical_model_1.pt")
    os.makedirs("models/hierarchical_tokenizer_1", exist_ok=True)
    tokenizer.save_pretrained("models/hierarchical_tokenizer_1")
    print("[Main] Process completed.")
    
if __name__ == "__main__":
    main()


In [None]:
import os
import torch
import pandas as pd
from torch.utils.data import DataLoader
from transformers import RobertaTokenizer
from sklearn.preprocessing import StandardScaler
import joblib

metadata_dim = 12  # match the checkpoint
# model = HierarchicalDeceptionModel(metadata_dim=metadata_dim, max_context=max_context).to(device)


# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


TOKENIZER_PATH = "/kaggle/working/models/hierarchical_tokenizer"

TEST_CSV = "/kaggle/input/processed-data/processed_test.csv"  # Adjust this if your test file is differently named

MODEL_PATH    = "models/hierarchical_model_best_1.pt"
SCALER_PATH   = "models/rob_focal_metadata_scaler.pkl"
TEST_CSV      = "/kaggle/input/processed-data/processed_test.csv"
TOKENIZER_NAME= "roberta-base"
BATCH_SIZE    = 16
MAX_CONTEXT   = 5

# 5) Load model
print("[Inference] Loading model checkpoint...")
# model = HierarchicalDeceptionModel(
#     # TOKENIZER_NAME="roberta-base",
#     metadata_dim=len(metadata_features),
#     max_context=MAX_CONTEXT,
#     lstm_hidden=256
# ).to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()

# 6) Run inference
print("[Inference] Running inference...")
loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
all_probs, all_preds = [], []

with torch.no_grad():
    for batch in loader:
        curr_ids   = batch["current_input_ids"].to(device)
        curr_mask  = batch["current_attention_mask"].to(device)
        ctx_ids    = batch["context_input_ids"].to(device)
        ctx_mask   = batch["context_attention_mask"].to(device)
        metadata   = batch["metadata"].to(device)

        logits, _ = model(curr_ids, curr_mask, ctx_ids, ctx_mask, metadata)
        probs = torch.sigmoid(logits).squeeze(1)
        preds = (probs >= 0.5).long()

        all_probs.extend(probs.cpu().tolist())
        all_preds.extend(preds.cpu().tolist())

# 7) Save predictions
print("[Inference] Saving predictions...")
test_df["predicted_prob"]  = all_probs
test_df["predicted_label"] = all_preds
os.makedirs("results", exist_ok=True)
out_csv = "results/test_predictions.csv"
test_df.to_csv(out_csv, index=False)
print(f"[Inference] Predictions saved to {out_csv}")

# 8) (Optional) Compute & print test metrics if true labels exist
if "is_truthful" in test_df.columns:
    y_true = (1 - test_df["is_truthful"].values).astype(int)
    y_pred = test_df["predicted_label"].values
    acc      = accuracy_score(y_true, y_pred)
    prec     = precision_score(y_true, y_pred, zero_division=0)
    rec      = recall_score(y_true, y_pred, zero_division=0)
    f1       = f1_score(y_true, y_pred, zero_division=0)
    macro_f1 = f1_score(y_true, y_pred, average="macro", zero_division=0)
    cm       = confusion_matrix(y_true, y_pred)
    print(f"\n[Test Metrics] Acc: {acc:.4f} | Prec: {prec:.4f} | Rec: {rec:.4f} | F1: {f1:.4f} | Macro-F1: {macro_f1:.4f}")
    print("Confusion Matrix:\n", cm)

# *Incorporating richer context using hierarchical attention, refining metadata (especially power-dynamic features), experimenting with alternative transformer architectures  and further tuning or augmenting your loss function (e.g., with advanced focal loss or other cost-sensitive methods).*

In [None]:
# %% Cell 1: Imports, Setup, and Hyperparameters
import os
import json
import joblib
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaModel, RobertaTokenizer, AutoTokenizer
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration and batch size selection
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[Setup] Using device: {device}")

BATCH_SIZE = 16 if device.type == "cuda" else 4
print(f"[Setup] Using batch size: {BATCH_SIZE}")

# Hyperparameters
MAX_LENGTH = 64           # maximum token length for each message
MAX_CONTEXT = 5           # maximum number of context messages per sample
LEARNING_RATE = 2e-5
NUM_EPOCHS = 5
FGM_EPSILON = 1.0
FOCAL_ALPHA = 3.0
FOCAL_GAMMA = 2

# Transformer model name – you can change to another model if desired.
TRANSFORMER_MODEL_NAME = "roberta-base"

# File paths – adjust as needed
TRAIN_CSV_PATH = "/kaggle/input/processed-data/processed_train_balanced.csv"
VAL_CSV_PATH   = "/kaggle/input/processed-data/processed_val.csv"
TEST_CSV_PATH  = "/kaggle/input/processed-data/processed_test.csv"

# Directories to save models and results
os.makedirs("models", exist_ok=True)
os.makedirs("results", exist_ok=True)
print("[Setup] Hyperparameters and paths are set.")

# %% Cell 2: Data Loading, Preprocessing, and Balancing
def load_data(train_path, val_path, test_path):
    print("[Data] Loading CSV files...")
    train_df = pd.read_csv(train_path)
    val_df   = pd.read_csv(val_path)
    test_df  = pd.read_csv(test_path)
    print(f"[Data] Train shape: {train_df.shape}, Val shape: {val_df.shape}, Test shape: {test_df.shape}")
    return train_df, val_df, test_df

def get_metadata_feature_names(df):
    # Define the list of metadata features (adjust as needed)
    metadata_features = [
        "message_length", "word_count", "question_count", "exclamation_count",
        "has_uncertainty", "has_certainty", "conversation_length", "msg_position_in_convo", "position_ratio"
    ]
    for col in ["sender_is_player", "prev_msg_truthful", "game_stage"]:
        if col in df.columns:
            metadata_features.append(col)
    print(f"[Data] Using metadata features: {metadata_features}")
    return metadata_features

def preprocess_data(df, metadata_features):
    # Ensure text columns are strings
    df["cleaned_message"] = df["cleaned_message"].fillna("")
    if "context" in df.columns:
        df["context"] = df["context"].fillna("")
    else:
        df["context"] = ""
    
    # Convert metadata columns to numeric if necessary
    for col in metadata_features:
        if col in df.columns and df[col].dtype == object:
            try:
                df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0)
            except Exception as e:
                print(f"[Data] Conversion error for {col}: {e}")
                df[col] = df[col].astype("category").cat.codes
    return df

def process_labels(df):
    # Invert 'is_truthful' to create 'is_deceptive': 1 means deceptive, 0 means truthful.
    if "is_truthful" in df.columns:
        df["is_deceptive"] = 1 - df["is_truthful"]
    else:
        df["is_deceptive"] = -1  # placeholder if missing
    return df

def balance_training_data(train_df):
    # Balance the training data via oversampling if required.
    truthful_df = train_df[train_df["is_truthful"] == 1]
    deceptive_df = train_df[train_df["is_truthful"] == 0]
    n_to_sample = len(truthful_df) - len(deceptive_df)
    
    if n_to_sample > 0:
        print(f"[Balance] Oversampling deceptive class by {n_to_sample} samples...")
        oversampled_deceptive = deceptive_df.sample(n_to_sample, replace=True, random_state=42)
        train_df = pd.concat([truthful_df, deceptive_df, oversampled_deceptive])
    else:
        print("[Balance] No oversampling needed.")
    
    return train_df.sample(frac=1, random_state=42).reset_index(drop=True)

def scale_metadata(train_df, val_df, test_df, metadata_features):
    # Drop rows with missing metadata and scale the metadata features.
    train_df = train_df.dropna(subset=metadata_features)
    val_df = val_df.dropna(subset=metadata_features)
    test_df = test_df.dropna(subset=metadata_features)
    
    for df in [train_df, val_df, test_df]:
        assert set(metadata_features).issubset(df.columns), "Some metadata features are missing."
    
    scaler = StandardScaler()
    train_meta = scaler.fit_transform(train_df[metadata_features].values)
    val_meta   = scaler.transform(val_df[metadata_features].values)
    test_meta  = scaler.transform(test_df[metadata_features].values)
    
    joblib.dump(scaler, "models/metadata_scaler_2.pkl")
    print("[Data] Metadata scaler saved to models/metadata_scaler_2.pkl")
    
    return train_meta, val_meta, test_meta

# %% Cell 3: Model and Dataset Definitions
# Hierarchical Attention Module
class AttentionLayer(nn.Module):
    def __init__(self, input_dim):
        super(AttentionLayer, self).__init__()
        self.attn = nn.Linear(input_dim, 1)
    
    def forward(self, inputs):
        # inputs: [batch, num_context, hidden_dim]
        attn_scores = self.attn(inputs)  # [batch, num_context, 1]
        attn_weights = torch.softmax(attn_scores, dim=1)
        context_vector = torch.sum(attn_weights * inputs, dim=1)  # [batch, hidden_dim]
        return context_vector, attn_weights

# Final Hierarchical Attention Deception Model
class HierarchicalAttentionDeceptionModel(nn.Module):
    def __init__(self, transformer_model_name=TRANSFORMER_MODEL_NAME, metadata_dim=9, max_context=5, lstm_hidden=256):
        super(HierarchicalAttentionDeceptionModel, self).__init__()
        print("[Model] Loading transformer model...")
        self.transformer = RobertaModel.from_pretrained(transformer_model_name)
        self.transformer_hidden = self.transformer.config.hidden_size  #  768 for roberta-base
        self.max_context = max_context
        
        # LSTM for context representations
        self.context_lstm = nn.LSTM(input_size=self.transformer_hidden, hidden_size=lstm_hidden,
                                    batch_first=True, bidirectional=True)
        self.lstm_output_dim = lstm_hidden * 2
        
        # Hierarchical Attention Layer
        self.attn_layer = AttentionLayer(self.lstm_output_dim)
        
        # Metadata branch
        self.metadata_fc = nn.Sequential(
            nn.Linear(metadata_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        
        # Fusion layer: current message + context vector + metadata
        fusion_dim = self.transformer_hidden + self.lstm_output_dim + 64
        self.fusion_fc = nn.Sequential(
            nn.Linear(fusion_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1)
        )
        print("[Model] HierarchicalAttentionDeceptionModel initialized.")

    def forward(self, current_input_ids, current_attention_mask, context_input_ids, context_attention_mask, metadata):
        # Encode current message
        current_outputs = self.transformer(input_ids=current_input_ids, attention_mask=current_attention_mask)
        current_repr = current_outputs.last_hidden_state[:, 0, :]  # [batch, hidden]
        
        # Encode context messages
        batch_size, num_ctx, seq_len = context_input_ids.shape
        flat_ctx_ids = context_input_ids.view(-1, seq_len)
        flat_ctx_mask = context_attention_mask.view(-1, seq_len)
        context_outputs = self.transformer(input_ids=flat_ctx_ids, attention_mask=flat_ctx_mask)
        ctx_embeddings = context_outputs.last_hidden_state[:, 0, :]  # [batch*num_ctx, hidden]
        ctx_embeddings = ctx_embeddings.view(batch_size, num_ctx, self.transformer_hidden)
        
        # LSTM over context embeddings and attention
        lstm_out, _ = self.context_lstm(ctx_embeddings)
        context_vector, attn_weights = self.attn_layer(lstm_out)
        
        # Process metadata
        meta_out = self.metadata_fc(metadata)
        
        # Fusion of current message, context vector, and metadata
        fusion_input = torch.cat((current_repr, context_vector, meta_out), dim=1)
        output = self.fusion_fc(fusion_input)
        return output, attn_weights

# Dataset Definition
class HierarchicalDeceptionDataset(Dataset):
    def __init__(self, texts, context_texts, metadata, labels=None, tokenizer=None, 
                 max_length=64, max_context=5, context_delim="||"):
        """
        texts: list of current message strings.
        context_texts: list of context strings (each composed of messages separated by context_delim).
        metadata: numpy array of metadata features (rows must match texts length).
        labels: list of target labels (1 for deceptive, 0 for truthful)
        tokenizer: Hugging Face tokenizer.
        """
        self.texts = texts
        self.context_texts = context_texts
        self.metadata = metadata
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.max_context = max_context
        self.context_delim = context_delim

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        # Process current message
        curr_text = self.texts[idx]
        curr_text = "" if curr_text is None or (isinstance(curr_text, float) and np.isnan(curr_text)) else str(curr_text)
        curr_enc = self.tokenizer(
            curr_text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        current_input_ids = curr_enc["input_ids"].squeeze(0)
        current_attention_mask = curr_enc["attention_mask"].squeeze(0)
        
        # Process context messages
        ctx_entry = self.context_texts[idx]
        ctx_entry = "" if ctx_entry is None or (isinstance(ctx_entry, float) and np.isnan(ctx_entry)) else str(ctx_entry)
        context_msgs = [msg.strip() for msg in ctx_entry.split(self.context_delim) if msg.strip()]
        context_msgs = context_msgs[-self.max_context:]
        while len(context_msgs) < self.max_context:
            context_msgs.insert(0, "")
        context_encodings = [self.tokenizer(
            msg,
            add_special_tokens=True,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        ) for msg in context_msgs]
        context_input_ids = torch.stack([enc["input_ids"].squeeze(0) for enc in context_encodings])
        context_attention_mask = torch.stack([enc["attention_mask"].squeeze(0) for enc in context_encodings])
        
        # Process metadata (using same index from the scaled metadata array)
        meta = torch.tensor(self.metadata[idx], dtype=torch.float)
        
        item = {
            "current_input_ids": current_input_ids,
            "current_attention_mask": current_attention_mask,
            "context_input_ids": context_input_ids,
            "context_attention_mask": context_attention_mask,
            "metadata": meta
        }
        if self.labels is not None:
            item["label"] = torch.tensor(self.labels[idx], dtype=torch.float)
        return item

print("[Cell 3] Model and dataset definitions loaded.")

#  Adversarial Training (FGM) and Focal Loss
class FGM:
    def __init__(self, model, epsilon=FGM_EPSILON, emb_name="embeddings.word_embeddings.weight"):
        self.model = model
        self.epsilon = epsilon
        self.emb_name = emb_name
        self.backup = {}

    def attack(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad and self.emb_name in name:
                self.backup[name] = param.data.clone()
                norm = torch.norm(param.grad)
                if norm != 0:
                    r_at = self.epsilon * param.grad / norm
                    param.data.add_(r_at)

    def restore(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad and self.emb_name in name and name in self.backup:
                param.data = self.backup[name]
        self.backup = {}

class FocalLoss(nn.Module):
    def __init__(self, alpha=FOCAL_ALPHA, gamma=FOCAL_GAMMA, reduction="mean"):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, inputs, targets):
        bce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
        probs = torch.sigmoid(inputs)
        p_t = targets * probs + (1 - targets) * (1 - probs)
        loss = self.alpha * (1 - p_t) ** self.gamma * bce_loss
        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        return loss

print("[Cell 4] FGM and FocalLoss defined.")

# %% Cell 5: Training, Evaluation, and Dataset Creation Functions
def train_model(model, train_dataset, val_dataset, device, batch_size=BATCH_SIZE, num_epochs=NUM_EPOCHS, learning_rate=LEARNING_RATE):
    print("[Train] Starting training process...")
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    criterion = FocalLoss(alpha=FOCAL_ALPHA, gamma=FOCAL_GAMMA, reduction="mean")
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=learning_rate, 
                                                     total_steps=len(train_loader)*num_epochs)
    fgm = FGM(model, epsilon=FGM_EPSILON)
    
    best_val_macro_f1 = 0.0
    patience = 2
    epochs_no_improve = 0
    history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": [], "val_macro_f1": []}
    
    for epoch in range(num_epochs):
        checkpoint_path = f"model_checkpoint_epoch_{epoch+1}.pt"
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Model checkpoint saved at {checkpoint_path}")
        model.train()
        total_loss = 0.0
        correct = 0
        total = 0
        print(f"[Train] Epoch {epoch+1}/{num_epochs} started.")
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} Train"):
            current_ids = batch["current_input_ids"].to(device)
            current_mask = batch["current_attention_mask"].to(device)
            context_ids = batch["context_input_ids"].to(device)
            context_mask = batch["context_attention_mask"].to(device)
            metadata = batch["metadata"].to(device)
            labels = batch["label"].to(device).unsqueeze(1)
            
            optimizer.zero_grad()
            outputs, _ = model(current_ids, current_mask, context_ids, context_mask, metadata)
            loss = criterion(outputs, labels)
            loss.backward()
            
            # FGM adversarial attack
            fgm.attack()
            outputs_adv, _ = model(current_ids, current_mask, context_ids, context_mask, metadata)
            loss_adv = criterion(outputs_adv, labels)
            loss_adv.backward()
            fgm.restore()
            
            optimizer.step()
            scheduler.step()
            
            total_loss += loss.item() * current_ids.size(0)
            total += current_ids.size(0)
            preds = (torch.sigmoid(outputs) >= 0.5).float()
            correct += (preds == labels).sum().item()
        
        avg_loss = total_loss / total
        train_acc = correct / total
        history["train_loss"].append(avg_loss)
        history["train_acc"].append(train_acc)
        print(f"[Train] Epoch {epoch+1} finished. Loss: {avg_loss:.4f}, Acc: {train_acc:.4f}")
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        all_preds, all_labels = [], []
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1} Validation"):
                current_ids = batch["current_input_ids"].to(device)
                current_mask = batch["current_attention_mask"].to(device)
                context_ids = batch["context_input_ids"].to(device)
                context_mask = batch["context_attention_mask"].to(device)
                metadata = batch["metadata"].to(device)
                labels = batch["label"].to(device).unsqueeze(1)
                outputs, _ = model(current_ids, current_mask, context_ids, context_mask, metadata)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * current_ids.size(0)
                preds = (torch.sigmoid(outputs) >= 0.5).float()
                all_preds.extend(preds.cpu().numpy().flatten())
                all_labels.extend(labels.cpu().numpy().flatten())
        avg_val_loss = val_loss / len(val_dataset)
        val_acc = accuracy_score(all_labels, all_preds)
        val_macro_f1 = f1_score(all_labels, all_preds, average="macro", zero_division=0)
        history["val_loss"].append(avg_val_loss)
        history["val_acc"].append(val_acc)
        history["val_macro_f1"].append(val_macro_f1)
        print(f"[Validation] Epoch {epoch+1}: Loss: {avg_val_loss:.4f}, Acc: {val_acc:.4f}, Macro-F1: {val_macro_f1:.4f}")
        
        if val_macro_f1 > best_val_macro_f1:
            best_val_macro_f1 = val_macro_f1
            epochs_no_improve = 0
            torch.save(model.state_dict(), "models/best_hierarchical_model_2.pt")
            print(f"[Train] Model improved. Saving model with Macro-F1: {best_val_macro_f1:.4f}")
        else:
            epochs_no_improve += 1
            print(f"[Train] No improvement for {epochs_no_improve} epoch(s).")
        if epochs_no_improve >= patience:
            print("[Train] Early stopping triggered.")
            break

    # Plot training curves
    plt.figure(figsize=(12,4))
    plt.subplot(1,3,1)
    plt.plot(history["train_loss"], label="Train Loss")
    plt.plot(history["val_loss"], label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.subplot(1,3,2)
    plt.plot(history["train_acc"], label="Train Acc")
    plt.plot(history["val_acc"], label="Val Acc")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.subplot(1,3,3)
    plt.plot(history["val_macro_f1"], label="Val Macro-F1")
    plt.xlabel("Epoch")
    plt.ylabel("Macro-F1")
    plt.legend()
    plt.tight_layout()
    plt.savefig("results/training_curves.png")
    plt.close()
    
    # Load best model for evaluation
    model.load_state_dict(torch.load("models/best_hierarchical_model_2.pt"))
    print(f"[Train] Loaded best model with Macro-F1: {best_val_macro_f1:.4f}")
    return model

def evaluate_model(model, test_dataset, device, batch_size=BATCH_SIZE):
    print("[Evaluate] Evaluating on test set...")
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    model.eval()
    all_preds, all_labels, all_probs = [], [], []
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            current_ids = batch["current_input_ids"].to(device)
            current_mask = batch["current_attention_mask"].to(device)
            context_ids = batch["context_input_ids"].to(device)
            context_mask = batch["context_attention_mask"].to(device)
            metadata = batch["metadata"].to(device)
            labels = batch["label"].to(device).unsqueeze(1)
            outputs, _ = model(current_ids, current_mask, context_ids, context_mask, metadata)
            probs = torch.sigmoid(outputs)
            preds = (probs >= 0.5).float()
            all_preds.extend(preds.cpu().numpy().flatten())
            all_labels.extend(labels.cpu().numpy().flatten())
            all_probs.extend(probs.cpu().numpy().flatten())
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)
    
    acc = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, zero_division=0)
    recall = recall_score(all_labels, all_preds, zero_division=0)
    f1 = f1_score(all_labels, all_preds, zero_division=0)
    macro_f1 = f1_score(all_labels, all_preds, average="macro", zero_division=0)
    cm = confusion_matrix(all_labels, all_preds)
    
    print("[Evaluate] Metrics on test set:")
    print(f"  Accuracy: {acc:.4f}")
    print(f"  Precision: {precision:.4f}")
    print(f"  Recall: {recall:.4f}")
    print(f"  F1 Score: {f1:.4f}")
    print(f"  Macro F1 Score: {macro_f1:.4f}")
    print(f"  Confusion Matrix:\n{cm}")
    
    plt.figure(figsize=(8,6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", 
                xticklabels=["Truthful", "Deceptive"],
                yticklabels=["Truthful", "Deceptive"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Test Confusion Matrix")
    plt.tight_layout()
    plt.savefig("results/hierarchical_confusion_matrix_2.png")
    plt.close()
    
    # Save predictions and metrics
    df_preds = pd.DataFrame({"prediction": all_preds, "probability": all_probs, "true_label": all_labels})
    df_preds.to_csv("results/hierarchical_predictions_2.csv", index=False)
    
    metrics = {
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "macro_f1": macro_f1,
        "confusion_matrix": cm.tolist()
    }
    with open("results/hierarchical_metrics_2.json", "w") as f:
        json.dump(metrics, f, indent=4)
        
    return metrics

def create_datasets(train_df, val_df, test_df, tokenizer, metadata_features):
    # Scale metadata AFTER all data modifications (like balancing) are performed.
    train_meta, val_meta, test_meta = scale_metadata(train_df, val_df, test_df, metadata_features)
    train_dataset = HierarchicalDeceptionDataset(
        texts=train_df["cleaned_message"].tolist(), 
        context_texts=train_df["context"].tolist(), 
        metadata=train_meta,
        labels=train_df["is_deceptive"].tolist(),
        tokenizer=tokenizer,
        max_length=MAX_LENGTH,
        max_context=MAX_CONTEXT
    )
    val_dataset = HierarchicalDeceptionDataset(
        texts=val_df["cleaned_message"].tolist(), 
        context_texts=val_df["context"].tolist(), 
        metadata=val_meta,
        labels=val_df["is_deceptive"].tolist(),
        tokenizer=tokenizer,
        max_length=MAX_LENGTH,
        max_context=MAX_CONTEXT
    )
    test_dataset = HierarchicalDeceptionDataset(
        texts=test_df["cleaned_message"].tolist(), 
        context_texts=test_df["context"].tolist(), 
        metadata=test_meta,
        labels=test_df["is_deceptive"].tolist(),
        tokenizer=tokenizer,
        max_length=MAX_LENGTH,
        max_context=MAX_CONTEXT
    )
    return train_dataset, val_dataset, test_dataset

def initialize_and_train_model(train_dataset, val_dataset, metadata_features, tokenizer):
    print("[Init] Initializing Hierarchical Deception Model...")
    model = HierarchicalAttentionDeceptionModel(
        transformer_model_name=TRANSFORMER_MODEL_NAME,
        metadata_dim=len(metadata_features),
        max_context=MAX_CONTEXT,
        lstm_hidden=256
    ).to(device)

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"[Init] Model Parameters: {total_params:,} total / {trainable_params:,} trainable")

    model = train_model(model, train_dataset, val_dataset, device,
                        batch_size=BATCH_SIZE, num_epochs=NUM_EPOCHS, learning_rate=LEARNING_RATE)
    return model

def evaluate_and_save(model, test_dataset, tokenizer, model_path, tokenizer_path):
    print("[Eval] Evaluating model on test set...")
    test_metrics = evaluate_model(model, test_dataset, device, batch_size=BATCH_SIZE)
    print("[Eval] Test Metrics:")
    for k, v in test_metrics.items():
        # For scalar metric values
        if isinstance(v, float):
            print(f"  {k}: {v:.4f}")
        else:
            print(f"  {k}: {v}")
    print("[Save] Saving model and tokenizer...")
    torch.save(model.state_dict(), model_path)
    os.makedirs(tokenizer_path, exist_ok=True)
    tokenizer.save_pretrained(tokenizer_path)
    print("[Save] Model and tokenizer saved successfully.")

# %% Cell 6: Main Execution Pipeline
def main():
    print("[Main] Loading data...")
    train_df, val_df, test_df = load_data(TRAIN_CSV_PATH, VAL_CSV_PATH, TEST_CSV_PATH)
    
    # Preprocess and set up the data.
    metadata_features = get_metadata_feature_names(train_df)
    train_df = preprocess_data(train_df, metadata_features)
    val_df   = preprocess_data(val_df, metadata_features)
    test_df  = preprocess_data(test_df, metadata_features)
    train_df = process_labels(train_df)
    val_df   = process_labels(val_df)
    test_df  = process_labels(test_df)
    
    print("[Main] Balancing training data...")
    train_df = balance_training_data(train_df)
    
    print("[Main] Initializing tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(TRANSFORMER_MODEL_NAME)
    print(f"[Main] Tokenizer loaded: {TRANSFORMER_MODEL_NAME}")
    
    print("[Main] Creating datasets...")
    train_dataset, val_dataset, test_dataset = create_datasets(train_df, val_df, test_df, tokenizer, metadata_features)
    print(f"[Main] Dataset sizes — Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
    
    model = initialize_and_train_model(train_dataset, val_dataset, metadata_features, tokenizer)
    
    evaluate_and_save(
        model,
        test_dataset,
        tokenizer,
        model_path="models/hierarchical_model_2.pt",
        tokenizer_path="models/hierarchical_tokenizer_2"
    )

if __name__ == "__main__":
    main()


In [None]:
# %% Cell X: Inference and Evaluation
import torch
import pandas as pd
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import joblib
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import json

# Constants and paths (reuse those from your notebook)
SCALER_PATH = "models/metadata_scaler_2.pkl"
MODEL_PATH = "models/best_hierarchical_model_2.pt"
TEST_CSV_PATH = "/kaggle/input/processed-data/processed_test.csv"

# Reuse previously defined variables and classes in this notebook environment:
# TRANSFORMER_MODEL_NAME, MAX_LENGTH, MAX_CONTEXT, BATCH_SIZE, device
# HierarchicalAttentionDeceptionModel, HierarchicalDeceptionDataset
# Also reuse preprocessing functions: process_labels, get_metadata_feature_names

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(TRANSFORMER_MODEL_NAME)

# Load and preprocess test data
df_test = pd.read_csv(TEST_CSV_PATH)
if "is_truthful" in df_test.columns:
    df_test["is_deceptive"] = 1 - df_test["is_truthful"]
# Derive metadata feature names
def get_metadata_feature_names(df):
    metadata_features = [
        "message_length", "word_count", "question_count", "exclamation_count",
        "has_uncertainty", "has_certainty", "conversation_length", "msg_position_in_convo", "position_ratio"
    ]
    for col in ["sender_is_player", "prev_msg_truthful", "game_stage"]:
        if col in df.columns:
            metadata_features.append(col)
    return metadata_features
metadata_features = get_metadata_feature_names(df_test)
# Fill missing metadata
for col in metadata_features:
    if col not in df_test:
        df_test[col] = 0
    df_test[col] = pd.to_numeric(df_test[col], errors='coerce').fillna(0)

# Scale metadata features
scaler = joblib.load(SCALER_PATH)
scaled_meta = scaler.transform(df_test[metadata_features].values)

# Create test dataset (labels available for evaluation)
test_dataset = HierarchicalDeceptionDataset(
    texts=df_test["cleaned_message"].fillna("").tolist(),
    context_texts=df_test.get("context", pd.Series([""]*len(df_test))).fillna("").tolist(),
    metadata=scaled_meta,
    labels=df_test["is_deceptive"].tolist(),
    tokenizer=tokenizer,
    max_length=MAX_LENGTH,
    max_context=MAX_CONTEXT
)

# Initialize model and load weights
model = HierarchicalAttentionDeceptionModel(
    tokenizer = AutoTokenizer.from_pretrained(TRANSFORMER_MODEL_NAME)
,
    metadata_dim=len(metadata_features),
    max_context=MAX_CONTEXT,
    lstm_hidden=256
).to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()

# Run inference and collect predictions
loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
all_preds, all_probs, all_labels = [], [], []
with torch.no_grad():
    for batch in loader:
        curr_ids = batch["current_input_ids"].to(device)
        curr_mask = batch["current_attention_mask"].to(device)
        ctx_ids = batch["context_input_ids"].to(device)
        ctx_mask = batch["context_attention_mask"].to(device)
        meta = batch["metadata"].to(device)
        labels = batch["label"].to(device)

        outputs, _ = model(curr_ids, curr_mask, ctx_ids, ctx_mask, meta)
        probs = torch.sigmoid(outputs).squeeze(1)
        preds = (probs >= 0.5).long()

        all_probs.extend(probs.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Evaluate metrics
acc = accuracy_score(all_labels, all_preds)
prec = precision_score(all_labels, all_preds, zero_division=0)
rec = recall_score(all_labels, all_preds, zero_division=0)
f1 = f1_score(all_labels, all_preds, zero_division=0)
macro_f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
cm = confusion_matrix(all_labels, all_preds)

# Print metrics
print("Evaluation Metrics on Test Set:")
print(f"  Accuracy: {acc:.4f}")
print(f"  Precision: {prec:.4f}")
print(f"  Recall: {rec:.4f}")
print(f"  F1 Score: {f1:.4f}")
print(f"  Macro F1: {macro_f1:.4f}")
print("  Confusion Matrix:")
print(cm)

# Save predictions and metrics
df_test["predicted_label"] = all_preds
df_test["predicted_prob"] = all_probs
output_pred_path = "results/test_predictions.csv"
df_test.to_csv(output_pred_path, index=False)
metrics = {
    "accuracy": acc,
    "precision": prec,
    "recall": rec,
    "f1": f1,
    "macro_f1": macro_f1,
    "confusion_matrix": cm.tolist()
}
with open("results/test_metrics.json", "w") as f:
    json.dump(metrics, f, indent=4)
# Also print the saved metrics dictionary
print("Saved metrics:")
print(json.dumps(metrics, indent=4))

print(f"Inference and evaluation complete. Predictions saved to {output_pred_path} and metrics to results/test_metrics.json")


In [None]:
# # %% Cell 7: Interactive Inference for a Single Example
# def interactive_inference(model, tokenizer, sample_text, sample_context, sample_metadata):
#     model.eval()
#     enc_current = tokenizer(
#         sample_text,
#         add_special_tokens=True,
#         max_length=MAX_LENGTH,
#         padding="max_length",
#         truncation=True,
#         return_tensors="pt"
#     )
#     current_ids = enc_current["input_ids"].to(device)
#     current_mask = enc_current["attention_mask"].to(device)
    
#     context_msgs = [msg.strip() for msg in sample_context.split("||") if msg.strip()]
#     context_msgs = context_msgs[-MAX_CONTEXT:]
#     while len(context_msgs) < MAX_CONTEXT:
#         context_msgs.insert(0, "")
#     context_encodings = [tokenizer(
#         msg,
#         add_special_tokens=True,
#         max_length=MAX_LENGTH,
#         padding="max_length",
#         truncation=True,
#         return_tensors="pt"
#     ) for msg in context_msgs]
#     context_ids = torch.stack([enc["input_ids"].squeeze(0) for enc in context_encodings]).unsqueeze(0).to(device)
#     context_mask = torch.stack([enc["attention_mask"].squeeze(0) for enc in context_encodings]).unsqueeze(0).to(device)
    
#     meta = torch.tensor(sample_metadata, dtype=torch.float).unsqueeze(0).to(device)
    
#     with torch.no_grad():
#         output, attn_weights = model(current_ids, current_mask, context_ids, context_mask, meta)
#         prob = torch.sigmoid(output).item()
#         pred = int(prob >= 0.5)
#     print(f"[Interactive] Prediction: {pred} (Probability: {prob:.4f})")
#     print(f"[Interactive] Attention Weights on Context: {attn_weights.squeeze(0).cpu().numpy()}")
#     return pred, prob, attn_weights

# # Example usage:
# sample_text = "I believe your proposal has merit, but we must be cautious."
# sample_context = "Let's discuss our options.||Your previous message was insightful."
# sample_metadata = test_meta[0]  # using first sample's metadata as an example
# interactive_inference(model, tokenizer, sample_text, sample_context, sample_metadata)


# **Another try of above method implementation**

• Switched to the AdamW optimizer with weight decay and introduced a learning rate scheduler.

• Added a placeholder for text data augmentation 

• Updated the metadata branch by including a simple attention mechanism (a “MetadataAttention”  to allow the network to weigh metadata features dynamically.

• Made minor adjustments in dropout rates and hyperparameters.

In [None]:
import os
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW  
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from transformers import RobertaTokenizer, RobertaModel, get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

print("[Setup] All packages imported.")


In [None]:
# import random

def augment_text(text, aug_prob=0.3):
    """
    A simple placeholder for text augmentation.
    Currently, it randomly lowercases some words as a dummy augmentation.
    Replace or extend this function with your preferred augmentation technique.
    """
    words = text.split()
    new_words = [word.lower() if random.random() < aug_prob else word for word in words]
    return " ".join(new_words)

# # Example usage:
# sample_text = "This is an Example of a Diplomatic message."
# augmented_text = augment_text(sample_text)
# print("[Augmentation] Original:", sample_text)
# print("[Augmentation] Augmented:", augmented_text)


In [None]:
def engineer_metadata(df):
    df = df.copy()
    df['score_delta'] = df['game_score'] - df['game_score_delta']
    df['is_sender_leading'] = (df['game_score'] > df['game_score'].mean()).astype(int)
    df['punctuation_density'] = (df['exclamation_count'] + df['question_count']) / (df['message_length'] + 1e-5)
    df['score_ratio'] = df['game_score'] / (df['game_score_delta'] + 1e-5)
    print("[FeatureEng] Added score_delta, is_sender_leading, punctuation_density, score_ratio.")
    return df

# Load your CSV files and apply metadata engineering:
train_df = pd.read_csv("/kaggle/input/processed-data/processed_train_balanced.csv")
val_df   = pd.read_csv("/kaggle/input/processed-data/processed_val.csv")
test_df  = pd.read_csv("/kaggle/input/processed-data/processed_test.csv")

print(f"[Data] Train shape: {train_df.shape}")
print(f"[Data] Validation shape: {val_df.shape}")
print(f"[Data] Test shape: {test_df.shape}")

train_df = engineer_metadata(train_df)
val_df = engineer_metadata(val_df)
test_df = engineer_metadata(test_df)

# List of metadata features (update if needed):
metadata_features = ['message_length', 'word_count', 'question_count', 'exclamation_count',
                     'has_uncertainty', 'has_certainty', 'conversation_length',
                     'msg_position_in_convo', 'position_ratio', 'sender_is_player',
                     'prev_msg_truthful', 'game_stage']


In [None]:
print("[Data] Initializing tokenizer...")
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
print("[Data] Tokenizer loaded from roberta-base")

class HierarchicalDataset(Dataset):
    def __init__(self, df, tokenizer, metadata_features, max_length=64, max_context=5, augment=False):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.metadata_features = metadata_features
        self.max_length = max_length
        self.max_context = max_context
        self.augment = augment

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        # Use data augmentation on the message if desired.
        message = row['message']
        if self.augment:
            message = augment_text(message)

        # Use context if available; otherwise, duplicate the message as dummy context.
        context = row.get('context', "")
        context_list = [context if context != "" else message for _ in range(self.max_context)]
        
        # Tokenize current message.
        enc = self.tokenizer(message,
                             padding='max_length',
                             truncation=True,
                             max_length=self.max_length,
                             return_tensors="pt")
        # Tokenize context messages.
        context_encodings = [self.tokenizer(ctx,
                                            padding='max_length',
                                            truncation=True,
                                            max_length=self.max_length,
                                            return_tensors="pt")
                             for ctx in context_list]
        ctx_input_ids = torch.stack([ce["input_ids"].squeeze() for ce in context_encodings])
        ctx_attention_mask = torch.stack([ce["attention_mask"].squeeze() for ce in context_encodings])
        
        # Metadata vector.
        meta = row[self.metadata_features].apply(pd.to_numeric, errors='coerce').fillna(0).values.astype(np.float32)

        meta = torch.tensor(meta)
        
        # Label, assuming binary classification in column 'label'.
        # label = torch.tensor([row['label']], dtype=torch.float32)
        label = torch.tensor([row['is_truthful']], dtype=torch.float32)
        
        sample = {
            "input_ids": enc["input_ids"].squeeze(),
            "attention_mask": enc["attention_mask"].squeeze(),
            "context_input_ids": ctx_input_ids,
            "context_attention_mask": ctx_attention_mask,
            "metadata": meta
        }
        return sample, label

# Create datasets (set augment=True for training if you wish to augment your data).
max_length = 64
max_context = 5
batch_size = 8

train_dataset = HierarchicalDataset(train_df, tokenizer, metadata_features, max_length, max_context, augment=True)
val_dataset   = HierarchicalDataset(val_df, tokenizer, metadata_features, max_length, max_context, augment=False)
test_dataset  = HierarchicalDataset(test_df, tokenizer, metadata_features, max_length, max_context, augment=False)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [None]:
epsilon = 1e-8
class ClassBalancedFocalLoss(nn.Module):
    def __init__(self, beta: float = 0.999, gamma: float = 2.0, reduction='mean'):
        super().__init__()
        self.beta = beta
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, labels):
        labels = labels.view(-1, 1)
        num_pos = torch.sum(labels == 1).item()
        num_neg = torch.sum(labels == 0).item()
        effective_num_pos = 1.0 - np.power(self.beta, num_pos)
        effective_num_neg = 1.0 - np.power(self.beta, num_neg)

        weights = torch.tensor([
            (1 - self.beta) / (effective_num_neg + epsilon),
            (1 - self.beta) / (effective_num_pos + epsilon)
        ], dtype=torch.float).to(logits.device)

        probs = torch.sigmoid(logits)
        pt = torch.where(labels == 1, probs, 1 - probs)
        logpt = torch.log(pt + 1e-9)
        focal_term = (1 - pt) ** self.gamma
        loss = -weights[labels.squeeze().long()] * focal_term.squeeze() * logpt.squeeze()

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss


In [None]:
class MetadataAttention(nn.Module):
    def __init__(self, input_dim, attn_hidden=32):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, attn_hidden),
            nn.Tanh(),
            nn.Linear(attn_hidden, input_dim)
        )
        
    def forward(self, x):
        # x: (batch_size, input_dim)
        attn_weights = F.softmax(self.fc(x), dim=1)
        return x * attn_weights

class HierarchicalDeceptionModel(nn.Module):
    def __init__(self, roberta_model_name='roberta-base', metadata_dim=12, max_context=5):
        super().__init__()
        print("[Model] Loading RoBERTa model...")
        self.encoder = RobertaModel.from_pretrained(roberta_model_name)
        self.hidden_dim = self.encoder.config.hidden_size
        self.max_context = max_context

        # Multi-head attention for context messages.
        self.context_attention = nn.MultiheadAttention(embed_dim=self.hidden_dim, num_heads=4, batch_first=True)
        self.context_proj = nn.Linear(self.hidden_dim, self.hidden_dim)

        # Metadata network with attention.
        self.metadata_fc = nn.Sequential(
            nn.Linear(metadata_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 64)
        )
        self.metadata_attn = MetadataAttention(64, attn_hidden=32)

        # Final output head.
        self.output_head = nn.Sequential(
            nn.Linear(self.hidden_dim * 2 + 64, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1)
        )

    def forward(self, input_ids, attention_mask, context_input_ids, context_attention_mask, metadata):
        # Encode current message ([CLS] token representation).
        curr_out = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]

        # Encode context messages.
        B, T, L = context_input_ids.shape
        context_input_ids = context_input_ids.view(B * T, L)
        context_attention_mask = context_attention_mask.view(B * T, L)
        ctx_out = self.encoder(input_ids=context_input_ids, attention_mask=context_attention_mask).last_hidden_state[:, 0, :]
        ctx_out = ctx_out.view(B, T, -1)

        # Apply multi-head attention over context.
        context_attn_out, _ = self.context_attention(ctx_out, ctx_out, ctx_out)
        context_repr = torch.mean(context_attn_out, dim=1)
        context_repr = self.context_proj(context_repr)

        # Process metadata.
        meta_features = self.metadata_fc(metadata)
        meta_attended = self.metadata_attn(meta_features)

        # Concatenate features and compute logits.
        combined = torch.cat([curr_out, context_repr, meta_attended], dim=1)
        logits = self.output_head(combined)
        return logits, context_attn_out

print("[Model] HierarchicalDeceptionModel updated and ready.")


In [None]:
class DeceptionTrainer:
    def __init__(self, model, train_loader, val_loader, device, loss_fn, optimizer, scheduler=None, save_dir="checkpoints"):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.best_macro_f1 = -1
        self.save_dir = save_dir
        os.makedirs(self.save_dir, exist_ok=True)

    def train_epoch(self):
        self.model.train()
        total_loss, correct, total = 0.0, 0, 0
        for batch in tqdm(self.train_loader, desc="[Train]"):
            self.optimizer.zero_grad()
            inputs, labels = batch
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            

            outputs, _ = self.model(**inputs)

            labels = labels.to(self.device)
            # outputs, _ = self.model(**inputs)
            loss = self.loss_fn(outputs, labels)
            loss.backward()
            self.optimizer.step()
            if self.scheduler:
                self.scheduler.step()

            total_loss += loss.item() * labels.size(0)
            preds = (torch.sigmoid(outputs) >= 0.5).float()
            correct += (preds == labels).sum().item()
            total += labels.size(0)
        return total_loss / total, correct / total

    def evaluate(self):
        self.model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for batch in tqdm(self.val_loader, desc="[Validate]"):
                inputs, labels = batch
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                labels = labels.to(self.device)
                outputs, _ = self.model(**inputs)
                preds = (torch.sigmoid(outputs) >= 0.5).float()
                all_preds.extend(preds.cpu().numpy().flatten())
                all_labels.extend(labels.cpu().numpy().flatten())
        return self._compute_metrics(all_labels, all_preds)

    def _compute_metrics(self, y_true, y_pred):
        acc = accuracy_score(y_true, y_pred)
        prec = precision_score(y_true, y_pred, zero_division=0)
        rec = recall_score(y_true, y_pred, zero_division=0)
        f1 = f1_score(y_true, y_pred, zero_division=0)
        macro = f1_score(y_true, y_pred, average='macro', zero_division=0)
        cm = confusion_matrix(y_true, y_pred)
        print("[Metrics] Accuracy:", acc, "| F1:", f1, "| Macro-F1:", macro)
        return {"accuracy": acc, "precision": prec, "recall": rec, "f1": f1, "macro_f1": macro, "confusion_matrix": cm.tolist()}

    def save_checkpoint(self, epoch, metrics):
        save_path = os.path.join(self.save_dir, f"epoch_{epoch}.pt")
        torch.save({
            "epoch": epoch,
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "scheduler_state_dict": self.scheduler.state_dict() if self.scheduler else None,
            "metrics": metrics
        }, save_path)
        print(f"[Checkpoint] Saved model checkpoint to {save_path}")
        
#########################################
# Optimization Setup
#########################################
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[Main] Using device: {device}")

metadata_dim = len(metadata_features)
model = HierarchicalDeceptionModel(metadata_dim=metadata_dim, max_context=max_context)

# Loss: You can keep the Class-Balanced Focal Loss as defined.
loss_fn = ClassBalancedFocalLoss(beta=0.999, gamma=2.0, reduction='mean')

# Use AdamW with weight decay.
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)

# Scheduler: set up linear warmup and decay.
num_train_steps = len(train_loader) * 5  # assuming 5 epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(0.1*num_train_steps),
                                            num_training_steps=num_train_steps)

trainer = DeceptionTrainer(model, train_loader, val_loader, device, loss_fn, optimizer, scheduler)

In [None]:
import random 
num_epochs = 5
os.makedirs("models", exist_ok=True)

for epoch in range(1, num_epochs + 1):
    print(f"[Train] Epoch {epoch}/{num_epochs} started.")
    train_loss, train_acc = trainer.train_epoch()
    print(f"[Train] Epoch {epoch} finished. Avg Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    
    metrics = trainer.evaluate()
    macro_f1 = metrics["macro_f1"]
    print(f"[Validation] Epoch {epoch}: Loss: {train_loss:.4f}, Acc: {metrics['accuracy']:.4f}, Macro-F1: {macro_f1:.4f}")
    
    if macro_f1 > trainer.best_macro_f1:
        trainer.best_macro_f1 = macro_f1
        torch.save(model.state_dict(), "models/hierarchical_model2someimp_best.pt")
        print(f"[Train] Model improved! Best Macro-F1: {macro_f1:.4f}")
    else:
        print("[Train] No improvement for this epoch.")


In [None]:
print("[Main] Loading best model for test evaluation...")
model.load_state_dict(torch.load("/kaggle/working/models/hierarchical_model2someimp_best.pt"))
model.to(device)
model.eval()

all_preds, all_labels = [], []
with torch.no_grad():
    for batch in tqdm(test_loader, desc="[Test Evaluation]"):
        inputs, labels = batch
        inputs = {k: v.to(device) for k, v in inputs.items()}
        labels = labels.to(device)
        outputs, _ = model(**inputs)
        preds = (torch.sigmoid(outputs) >= 0.5).float()
        all_preds.extend(preds.cpu().numpy().flatten())
        all_labels.extend(labels.cpu().numpy().flatten())

test_acc = accuracy_score(all_labels, all_preds)
test_prec = precision_score(all_labels, all_preds, zero_division=0)
test_rec = recall_score(all_labels, all_preds, zero_division=0)
test_f1 = f1_score(all_labels, all_preds, zero_division=0)
test_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0)
test_cm = confusion_matrix(all_labels, all_preds)

print("[Evaluate] Metrics on test set:")
print(f"  Accuracy: {test_acc:.4f}")
print(f"  Precision: {test_prec:.4f}")
print(f"  Recall: {test_rec:.4f}")
print(f"  F1 Score: {test_f1:.4f}")
print(f"  Macro F1 Score: {test_macro:.4f}")
print(f"  Confusion Matrix:\n{test_cm}")

os.makedirs("models", exist_ok=True)
torch.save(model.state_dict(), "models/hierarchical_model2some_final.pt")
tokenizer.save_pretrained("models/tokenizer2some")
print("[Main] Final model and tokenizer saved.")
print("[Main] Process completed.")
