In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import csv
import os
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import random_split
from sklearn.model_selection import GroupShuffleSplit
from scipy.stats import spearmanr
from tqdm import tqdm 
import geoopt
from datetime import datetime
import json
import numpy as np
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score, f1_score
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy.special import expit  # For sigmoid function

# Assuming you've already defined your model and poincare_distance function as in the original code
# Let's redefine them here to make this script standalone

class SentenceEncoder(nn.Module):
    def __init__(self, model_name='sentence-transformers/all-MiniLM-L6-v2'):
        super(SentenceEncoder, self).__init__()
        self.encoder = AutoModel.from_pretrained(model_name)

    def forward(self, input_ids, attention_mask):
        output = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        # Use CLS token representation
        cls_embedding = output.last_hidden_state[:, 0]
        cls_embedding = F.normalize(cls_embedding, p=2, dim=1)  # normalize for cosine similarity
        return cls_embedding


# Dataset class for evaluating on pairs with binary labels
class BinaryLabelDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.data = self.read_file(file_path)
        self.tokenizer = tokenizer
        self.max_length = max_length
        print(f"Loaded {len(self.data)} sentence pairs")

    def read_file(self, file_path):
        data = []
        problem_rows = 0
    
        with open(file_path, 'r', encoding='utf-8') as file:
            csv_reader = csv.reader(file, delimiter='\t', quotechar=None)
            headers = next(csv_reader, None)  # Read and skip the header row
    
            for row in csv_reader:
                if len(row) >= 5:
                    sentence1, sentence2, label_str = row[3], row[4], row[0] 
                    try:
                        # Ensure label is either 0 or 1 (binary)
                        label = int(float(label_str))  # Support for both integer and float formats
                        if label not in [0, 1]:
                            # Normalize any other value to binary (0 or 1)
                            # Typically, values > 0 could be considered paraphrases
                            label = 1 if label > 0 else 0
                        data.append((sentence1.strip(), sentence2.strip(), label))
                    except:
                        continue
                else:
                    problem_rows += 1
    
        print("!!!!!!total problem rows = ", problem_rows)
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sentence1, sentence2, label = self.data[idx]
        
        # Tokenize both sentences
        sent1_input = self.tokenizer(
            sentence1,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
            return_token_type_ids=False
        )
        
        sent2_input = self.tokenizer(
            sentence2,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
            return_token_type_ids=False
        )
        
        return {
            'sent1_input': {k: v.squeeze(0) for k, v in sent1_input.items()},
            'sent2_input': {k: v.squeeze(0) for k, v in sent2_input.items()},
            'label': torch.tensor(label, dtype=torch.float32)
        }


def collate_fn_eval(batch):
    sent1_inputs = {
        k: torch.stack([item['sent1_input'][k] for item in batch])
        for k in batch[0]['sent1_input']
    }
    
    sent2_inputs = {
        k: torch.stack([item['sent2_input'][k] for item in batch])
        for k in batch[0]['sent2_input']
    }
    
    labels = torch.stack([item['label'] for item in batch])
    
    return {
        'sent1_input': sent1_inputs,
        'sent2_input': sent2_inputs,
        'labels': labels
    }


def compute_similarity(embed1, embed2):
    """Compute cosine similarity between embeddings."""
    # Since embeddings are already L2 normalized (in the model's forward pass),
    # the dot product equals cosine similarity
    return torch.sum(embed1 * embed2, dim=1)


def find_optimal_threshold(labels, similarities):
    """Find optimal threshold for binary classification.
    
    Args:
        labels: Ground truth labels
        similarities: Similarity scores
    
    Returns:
        Tuple of (optimal_threshold, max_f1)
    """
    # Calculate F1 score for different thresholds
    f1_scores = []
    precisions = []
    recalls = []
    thresholds = np.linspace(0, 1, 100)
    
    for threshold in thresholds:
        predictions = (similarities >= threshold).astype(int)
        
        # Calculate precision and recall
        true_positives = np.sum((predictions == 1) & (labels == 1))
        false_positives = np.sum((predictions == 1) & (labels == 0))
        false_negatives = np.sum((predictions == 0) & (labels == 1))
        
        precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
        recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
        
        precisions.append(precision)
        recalls.append(recall)
        
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        f1_scores.append(f1)
    
    # Find threshold with maximum F1 score
    optimal_idx = np.argmax(f1_scores)
    optimal_threshold = thresholds[optimal_idx]
    max_f1 = f1_scores[optimal_idx]
    best_precision = precisions[optimal_idx]
    best_recall = recalls[optimal_idx]
    
    return optimal_threshold, max_f1, best_precision, best_recall


def evaluate_model(model, data_loader, device):
    """Evaluate model on binary classification task.
    
    Args:
        model: Trained sentence encoder model
        data_loader: DataLoader for evaluation data
        device: Computation device
    
    Returns:
        Dictionary with evaluation metrics
    """
    model.eval()
    all_similarities = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            # Move inputs to device
            sent1_input = {k: v.to(device) for k, v in batch['sent1_input'].items()}
            sent2_input = {k: v.to(device) for k, v in batch['sent2_input'].items()}
            labels = batch['labels'].cpu().numpy()
            
            # Get embeddings
            sent1_embed = model(**sent1_input)
            sent2_embed = model(**sent2_input)
            
            # Calculate cosine similarities
            similarities = compute_similarity(sent1_embed, sent2_embed).cpu().numpy()
            
            # Store for later computation
            all_similarities.extend(similarities)
            all_labels.extend(labels)
    
    # Convert to numpy arrays
    all_similarities = np.array(all_similarities)
    all_labels = np.array(all_labels)
    
    # Find optimal threshold and F1 score
    threshold, max_f1, best_precision, best_recall = find_optimal_threshold(all_labels, all_similarities)
    
    # Calculate ROC curve
    fpr, tpr, _ = roc_curve(all_labels, all_similarities)
    roc_auc = auc(fpr, tpr)
    
    # Calculate PR curve
    precision, recall, _ = precision_recall_curve(all_labels, all_similarities)
    pr_auc = average_precision_score(all_labels, all_similarities)
    
    return {
        'roc_auc': roc_auc,
        'pr_auc': pr_auc,
        'f1_score': max_f1,
        'optimal_threshold': threshold,
        'best_precision': best_precision,
        'best_recall': best_recall,
        'fpr': fpr,
        'tpr': tpr,
        'precision': precision,
        'recall': recall,
        'similarities': all_similarities,
        'labels': all_labels
    }


def plot_curves(results, model_name, save_dir="./plots"):
    """Plot ROC and PR curves for the evaluation results.
    
    Args:
        results: Dictionary with evaluation metrics
        model_name: Name of the model for plot titles
        save_dir: Directory to save plots
    """
    os.makedirs(save_dir, exist_ok=True)
    
    # ROC curve
    plt.figure(figsize=(10, 8))
    plt.plot(results['fpr'], results['tpr'], lw=2, label=f'ROC curve (AUC = {results["roc_auc"]:.3f})')
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'ROC Curve - {model_name}')
    plt.legend(loc="lower right")
    plt.savefig(f"{save_dir}/{model_name}_roc.png")
    plt.close()
    
    # PR curve
    plt.figure(figsize=(10, 8))
    plt.plot(results['recall'], results['precision'], lw=2, label=f'PR curve (AUC = {results["pr_auc"]:.3f})')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title(f'Precision-Recall Curve - {model_name}')
    plt.legend(loc="lower left")
    plt.savefig(f"{save_dir}/{model_name}_pr.png")
    plt.close()
    
    # Similarity distribution
    plt.figure(figsize=(12, 8))
    positive_sim = results['similarities'][results['labels'] == 1]
    negative_sim = results['similarities'][results['labels'] == 0]
    
    plt.hist(positive_sim, bins=50, alpha=0.5, label='Positive pairs', density=True)
    plt.hist(negative_sim, bins=50, alpha=0.5, label='Negative pairs', density=True)
    plt.axvline(x=results['optimal_threshold'], color='r', linestyle='--', 
                label=f'Optimal threshold = {results["optimal_threshold"]:.3f}')
    plt.xlabel('Cosine Similarity')
    plt.ylabel('Density')
    plt.title(f'Similarity Distribution - {model_name} (F1 = {results["f1_score"]:.3f})')
    plt.legend()
    plt.savefig(f"{save_dir}/{model_name}_sim_dist.png")
    plt.close()
    
    # Threshold vs F1 Score
    thresholds = np.linspace(0, 1, 100)
    f1_scores = []
    precisions = []
    recalls = []
    
    for threshold in thresholds:
        predictions = (results['similarities'] >= threshold).astype(int)
        f1 = f1_score(results['labels'], predictions)
        f1_scores.append(f1)
        
        # Calculate precision and recall
        true_positives = np.sum((predictions == 1) & (results['labels'] == 1))
        false_positives = np.sum((predictions == 1) & (results['labels'] == 0))
        false_negatives = np.sum((predictions == 0) & (results['labels'] == 1))
        
        precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
        recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
        
        precisions.append(precision)
        recalls.append(recall)
    
    plt.figure(figsize=(12, 8))
    plt.plot(thresholds, f1_scores, label='F1 Score')
    plt.plot(thresholds, precisions, label='Precision')
    plt.plot(thresholds, recalls, label='Recall')
    plt.axvline(x=results['optimal_threshold'], color='r', linestyle='--', 
                label=f'Optimal threshold = {results["optimal_threshold"]:.3f}')
    plt.xlabel('Threshold')
    plt.ylabel('Score')
    plt.title(f'Metrics vs Threshold - {model_name}')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig(f"{save_dir}/{model_name}_threshold_metrics.png")
    plt.close()


def main():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
    
    # Path to your binary labeled dataset
    binary_dataset_path = 'MSRParaphraseCorpus/msr_paraphrase_test.txt'  # Replace with your dataset path
    
    # Load the binary dataset
    binary_dataset = BinaryLabelDataset(binary_dataset_path, tokenizer)
    
    # Create DataLoader
    binary_loader = DataLoader(
        binary_dataset,
        batch_size=32,
        shuffle=False,
        collate_fn=collate_fn_eval
    )

    db_name = "msrp"
    model_name = "best_model"
    model_path = f'{model_name}.pt'
        
    print(f"Evaluating model: {model_path}")
    
    model = SentenceEncoder().to(device)
        
    # Load weights
    model.load_state_dict(torch.load(model_path, map_location=device))

    # Evaluate the model
    results = evaluate_model(model, binary_loader, device)
    
    # Print results
    print(f"Model: {model_name}")
    print(f"ROC AUC: {results['roc_auc']:.4f}")
    print(f"PR AUC: {results['pr_auc']:.4f}")
    print(f"Best F1 Score: {results['f1_score']:.4f}")
    print(f"Optimal Threshold: {results['optimal_threshold']:.4f}")
    print(f"Precision at optimal threshold: {results['best_precision']:.4f}")
    print(f"Recall at optimal threshold: {results['best_recall']:.4f}")

    # Plot curves
    plot_curves(results,db_name, model_name)

    # Save results
    os.makedirs("results", exist_ok=True)
    with open(f"results/{db_name}_{model_name}_eval.json", 'w') as f:
        json_results = {
            'model': model_name,
            'roc_auc': float(results['roc_auc']),
            'pr_auc': float(results['pr_auc']),
            'f1_score': float(results['f1_score']),
            'optimal_threshold': float(results['optimal_threshold']),
            'precision': float(results['best_precision']),
            'recall': float(results['best_recall'])
        }
        json.dump(json_results, f, indent=2)



# if __name__ == "__main__":
#     # Create directories if they don't exist
os.makedirs("results", exist_ok=True)
os.makedirs("plots", exist_ok=True)

main()



Using device: cuda
!!!!!!total problem rows =  0
Loaded 1725 sentence pairs
Evaluating model: best_model.pt


  model.load_state_dict(torch.load(model_path, map_location=device))
Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 54/54 [00:03<00:00, 16.04it/s]


Model: best_model
ROC AUC: 0.7261
PR AUC: 0.8373
Best F1 Score: 0.8082
Optimal Threshold: 0.4848
Precision at optimal threshold: 0.6908
Recall at optimal threshold: 0.9738
