In [None]:
import pandas as pd
import torch
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from torch.utils.data import Dataset
from transformers import EarlyStoppingCallback, TrainerCallback

# Define paths
DATA_PATH = "/root/workspace/npe_project/Dataset/NPEPatches.json"
MODEL_PATH = "microsoft/unixcoder-base"

def load_and_clean_data(file_path):
    # Try different encodings
    encodings = ['utf-8', 'latin1', 'iso-8859-1', 'cp1252']
    for encoding in encodings:
        try:
            data = pd.read_csv(file_path, encoding=encoding)
            # Convert patches to string and handle NaN
            data['Patch'] = data['Patch'].fillna('').astype(str)
            # Clean and validate data
            data = data.drop_duplicates(subset=["Patch"])
            data = data.dropna(subset=["Category"])
            return data
        except UnicodeDecodeError:
            continue
    raise ValueError("Could not read file with any of the attempted encodings")

class NPECommitDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=512):
        self.texts = [str(text) for text in texts]
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        
        # Ensure text is string
        if not isinstance(text, str):
            text = str(text)
            
        encoding = self.tokenizer(
            text,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# Load and preprocess data
try:
    data = load_and_clean_data(DATA_PATH)
    label_mapping = {'NPE-Fixes': 1, 'Not-NPE': 0}
    data["Category"] = data["Category"].map(label_mapping)

    # Split data
    X = data["Patch"].values  
    y = data["Category"].values
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )

    # Model preparation
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    train_dataset = NPECommitDataset(X_train, y_train, tokenizer)
    test_dataset = NPECommitDataset(X_test, y_test, tokenizer)
    model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH, num_labels=2)

    # Training arguments
    training_args = TrainingArguments(
        output_dir="./results",
        num_train_epochs=5,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        logging_dir="./logs",
        logging_steps=100,
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        greater_is_better=True,
        fp16=True,
        weight_decay=0.01,
        learning_rate=3e-5,
        warmup_steps=0,
        save_total_limit=2,
        adam_epsilon=1e-8,
        adam_beta1=0.9,
        max_grad_norm=1.0,
    )

    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        predictions = torch.argmax(torch.tensor(logits), dim=1)
        accuracy = accuracy_score(labels, predictions)
        precision = precision_score(labels, predictions, average="binary")
        recall = recall_score(labels, predictions, average="binary")
        f1 = f1_score(labels, predictions, average="binary")
        tn, fp, fn, tp = confusion_matrix(labels, predictions).ravel()
        fpr = fp / (fp + tn)
        fnr = fn / (fn + tp)
        return {
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "f1_score": f1,
            "fpr": fpr,
            "fnr": fnr,
        }

    class MetricsLoggerCallback(TrainerCallback):
        def __init__(self):
            super().__init__()
            self.epoch_accuracies = []
            self.epoch_precisions = []
            self.epoch_recalls = []
            self.epoch_f1_scores = []
            self.epoch_fprs = []
            self.epoch_fnrs = []

        def on_evaluate(self, args, state, control, metrics=None, **kwargs):
            if metrics is not None:
                self.epoch_accuracies.append(metrics.get("eval_accuracy", 0))
                self.epoch_precisions.append(metrics.get("eval_precision", 0))
                self.epoch_recalls.append(metrics.get("eval_recall", 0))
                self.epoch_f1_scores.append(metrics.get("eval_f1_score", 0))
                self.epoch_fprs.append(metrics.get("eval_fpr", 0))
                self.epoch_fnrs.append(metrics.get("eval_fnr", 0))

    # Training setup
    early_stopping_callback = EarlyStoppingCallback(early_stopping_patience=2)
    metrics_logger = MetricsLoggerCallback()
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
        callbacks=[early_stopping_callback, metrics_logger],
    )

    # Execute training
    trainer.train()
    eval_results = trainer.evaluate()

    # Calculate and display metrics
    avg_metrics = {
        "Average Accuracy": np.mean(metrics_logger.epoch_accuracies),
        "Average Precision": np.mean(metrics_logger.epoch_precisions),
        "Average Recall": np.mean(metrics_logger.epoch_recalls),
        "Average F1-Score": np.mean(metrics_logger.epoch_f1_scores),
        "Average FPR": np.mean(metrics_logger.epoch_fprs),
        "Average FNR": np.mean(metrics_logger.epoch_fnrs),
    }

    print("Average Metrics:")
    for metric, value in avg_metrics.items():
        print(f"{metric}: {value:.4f}")

except Exception as e:
    print(f"Error occurred: {str(e)}")
    raise