In [None]:
!pip install typing groq openai -q

In [None]:
!git clone https://github.com/7MMA7/Vulprune

In [None]:
import json
import re
import os
import time
import threading
from typing import List, Dict, Any, Optional, Tuple, Set
from dataclasses import dataclass, field, asdict
from abc import ABC, abstractmethod
from enum import Enum
import pickle
from pathlib import Path
from collections import defaultdict
import random
import difflib
import subprocess
import csv
from groq import Groq
from openai import OpenAI

In [None]:
def get_20_random_deterministic_pair_indices() -> List[int]:
    file_path = Path("/content/VulnerabilitiesDetectionProject/PrimeVul/primevul_filtered.jsonl")
    MAX_PAIRS_TO_LOAD = 20
    RANDOM_SEED = 42

    if not file_path.exists():
        print(f"Error: Data file not found at {file_path}")
        return []

    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            all_lines = f.readlines()
    except Exception as e:
        print(f"Error reading file: {e}")
        return []

    total_lines = len(all_lines) - 2
    total_pairs = total_lines // 2

    all_pair_ranks = list(range(total_pairs))

    random.seed(RANDOM_SEED)
    random.shuffle(all_pair_ranks)

    selected_indices = all_pair_ranks[:MAX_PAIRS_TO_LOAD]

    print(f"Total available pairs: {total_pairs}")
    print(f"Number of pairs selected: {len(selected_indices)}")
    return all_pair_ranks, selected_indices

all_pair_ranks, selected_pair_ranks = get_20_random_deterministic_pair_indices()

In [None]:
# =============================================================================
# CONFIGURATION
# =============================================================================

class Config:
    BASE_DIR = Path("/content/vulprune")
    DATA_DIR = BASE_DIR / "data"
    RESULTS_DIR = BASE_DIR / "results"
    CACHE_DIR = BASE_DIR / "cache"
    CHECKPOINT_DIR = BASE_DIR / "checkpoints"

    PRIMEVUL_DIR = BASE_DIR / "primevul"
    PRIMEVUL_DATA_DIR = PRIMEVUL_DIR / "data"
    PRIMEVUL_RESULTS_DIR = PRIMEVUL_DIR / "results"

    CWE_EXAMPLES_PATH = Path("/content/VulnerabilitiesDetectionProject/cwe_example/cwe_commit_diff_vuln_examples.json")

    FLAWFINDER_DIR = BASE_DIR / "flawfinder" / "reports"
    CPPCHECK_DIR = BASE_DIR / "cppcheck" / "reports"
    SONARQUBE_DIR = BASE_DIR / "sonarqube" / "reports"

    LLM_TEMPERATURE = 0.0
    LLM_MAX_TOKENS = 4096

    MAX_CONCURRENT_AGENTS = 10
    ENABLE_CACHING = True
    BATCH_SIZE = 10
    CHECKPOINT_INTERVAL = 3

    @classmethod
    def setup_directories(cls):
        for dir_path in [cls.BASE_DIR, cls.DATA_DIR, cls.RESULTS_DIR,
                         cls.CACHE_DIR, cls.CHECKPOINT_DIR]:
            dir_path.mkdir(parents=True, exist_ok=True)

In [None]:
# =============================================================================
# DATA STRUCTURES
# =============================================================================

@dataclass
class CodeInput:
    id: str
    code: str
    project: str
    commit_id: Optional[str] = None
    commit_message: Optional[str] = None
    file_name: Optional[str] = None

@dataclass
class PairedCheckpoint:
    last_pair_index: int
    vuln_predictions: List[Dict[str, Any]]
    clean_predictions: List[Dict[str, Any]]
    pair_results: List[Dict[str, Any]]
    timestamp: float
    llm_stats: Dict[str, Any]
    metadata: Dict[str, Any]

@dataclass
class GroundTruth:
  id: str
  cwe: str
  label: int # target in primevul_filtered.jsonl
  commit_id: Optional[str] = None

@dataclass
class PredictionResult:
  id: str
  is_vulnerable: bool
  detected_cwes: List[str]
  confidence: float
  phases: Dict[str, Any]

@dataclass
class VulnerabilityReport:
    cwe: str
    code_lines: List[int]
    description: str
    source_agent: str
    confidence: float = 0.0
    evidence: List[str] = field(default_factory=list)

    def to_dict(self):
        return asdict(self)

@dataclass
class VulnerabilityHypothesis:
    cwe: str
    code_lines: List[int]
    description: str
    source_agent: str
    confidence: float
    trigger_lines: List[int]
    evidence: List[str]
    assumptions: List[str] = field(default_factory=list)
    trigger_path: str = ""

    def __init__(self, cwe: str, code_lines: List[int], description: str,
                 source_agent: str,
                 confidence: float,
                 trigger_lines: List[int],
                 evidence: List[str],
                 assumptions: List[str] = None,
                 trigger_path: str = ""):

        self.cwe = cwe
        self.code_lines = code_lines
        self.description = description
        self.source_agent = source_agent
        self.confidence = confidence
        self.trigger_lines = trigger_lines
        self.evidence = evidence if evidence is not None else []
        self.assumptions = assumptions if assumptions is not None else []
        self.trigger_path = trigger_path

    def to_dict(self):
        return asdict(self)

@dataclass
class ValidationResult:
    hypothesis: VulnerabilityHypothesis
    assumption_status: Dict[str, str]
    path_valid: bool
    reasoning: str
    final_verdict: bool

    def to_dict(self):
        return {
            "hypothesis": self.hypothesis.to_dict(),
            "assumption_status": self.assumption_status,
            "path_valid": self.path_valid,
            "reasoning": self.reasoning,
            "final_verdict": self.final_verdict
        }

@dataclass
class ProgramContext:
    cfg: Optional[Dict] = None
    dfg: Optional[Dict] = None
    call_graph: Optional[Dict] = None
    functions: List[str] = field(default_factory=list)
    imports: List[str] = field(default_factory=list)

    def to_dict(self):
        return asdict(self)

@dataclass
class CWEExampleLoader:
    @staticmethod
    def load_examples(path: Path) -> Dict[str, str]:
        if not path.exists():
            print(f"Warning: Examples file not found at {path}")
            return {}

        try:
            with open(path, "r", encoding="utf-8") as f:
                data = json.load(f)

            formatted_examples = {}
            for cwe, content in data.items():
                diff_lines = content.get("diff", [])
                if isinstance(diff_lines, list):
                    code_snippet = "\n".join(diff_lines)
                else:
                    code_snippet = str(diff_lines)

                formatted_examples[cwe] = (
                    f"Example of {cwe} ({content.get('cve', 'N/A')}):\n"
                    f"```match\n{code_snippet}\n```"
                )
            return formatted_examples
        except Exception as e:
            print(f"Error loading CWE examples: {e}")
            return {}

In [None]:
class PairedCheckpointManager:

    def __init__(self, checkpoint_dir: Path = None):
        self.checkpoint_dir = checkpoint_dir or Config.CHECKPOINT_DIR
        self.checkpoint_path = self.checkpoint_dir / "paired_evaluation_checkpoint.json"

    def save(self,
             last_pair_index: int,
             vuln_predictions: List[Dict],
             clean_predictions: List[Dict],
             pair_results: List[Dict],
             llm_stats: Dict,
             metadata: Dict = None):

        checkpoint = {
            "last_pair_index": last_pair_index,
            "vuln_predictions": vuln_predictions,
            "clean_predictions": clean_predictions,
            "pair_results": pair_results,
            "timestamp": time.time(),
            "llm_stats": llm_stats,
            "metadata": metadata or {}
        }

        try:
            with open(self.checkpoint_path, "w", encoding="utf-8") as f:
                json.dump(checkpoint, f, indent=2)
            return True
        except Exception as e:
            print(f"Failed to save checkpoint: {e}")
            return False

    def load(self) -> Optional[Dict]:
        if not self.checkpoint_path.exists():
            return None

        try:
            with open(self.checkpoint_path, "r", encoding="utf-8") as f:
                return json.load(f)
        except Exception as e:
            print(f"Failed to load checkpoint: {e}")
            return None

    def exists(self) -> bool:
        return self.checkpoint_path.exists()

    def delete(self):
        if self.checkpoint_path.exists():
            try:
                self.checkpoint_path.unlink()
                return True
            except Exception as e:
                print(f"Failed to delete checkpoint: {e}")
                return False
        return True

    def get_info(self) -> Optional[Dict]:
        if not self.checkpoint_path.exists():
            return None

        try:
            with open(self.checkpoint_path, "r", encoding="utf-8") as f:
                data = json.load(f)

            return {
                "last_pair_index": data.get("last_pair_index", 0),
                "num_vuln_predictions": len(data.get("vuln_predictions", [])),
                "num_clean_predictions": len(data.get("clean_predictions", [])),
                "num_pair_results": len(data.get("pair_results", [])),
                "timestamp": data.get("timestamp", 0),
                "metadata": data.get("metadata", {})
            }
        except Exception as e:
            print(f"Failed to read checkpoint info: {e}")
            return None

In [None]:
# =============================================================================
# MULTI-MODELS CONFIGURATION
# =============================================================================

class APIConfig:
  GROQ_API_KEY = "YOUR_GROQ_API_KEY"
  OPENAI_API_KEY = "YOUR_OPENAI_API_KEY"
  DEEPSEEK_API_KEY = "YOUR_DEEPSEEK_API_KEY"

  OPENAI_BASE_URL = "https://api.openai.com/v1"
  DEEPSEEK_BASE_URL = "https://api.deepseek.com/v1"

  GROQ_MODELS = {
        "llama-3.3-70b": {
            "name": "llama-3.3-70b-versatile",
            "priority": 1
        },
        "llama-3.1-8b": {
            "name": "llama-3.1-8b-instant",
            "priority": 2
        }
    }

  OPENAI_MODELS = {
      "gpt-4o": {
          "name": "gpt-4o",
          "priority": 1
      }
  }

  DEEPSEEK_MODELS = {
      "deepseek-chat": {
          "name": "deepseek-chat",
          "priority": 1
      }
  }

In [None]:
# =============================================================================
# MULTI-PROVIDERS DATA STRUCTURE
# =============================================================================

class LLMProvider(Enum):
    GROQ = "groq"
    OPENAI = "openai"
    DEEPSEEK = "deepseek"

@dataclass
class ModelConfig:
    name: str
    provider: LLMProvider
    priority: int
    enabled: bool = True
    api_key: str = ""
    base_url: Optional[str] = None
    call_count: int = 0

    def to_dict(self) -> Dict[str, Any]:
        return {
            "name": self.name,
            "provider": self.provider.value,
            "priority": self.priority,
            "enabled": self.enabled,
            "call_count": self.call_count
        }

In [None]:
# =============================================================================
# PROVIDERS ABSTRACT INTERFACE
# =============================================================================

class BaseLLMProvider(ABC):

    def __init__(self, api_key: str, models: List[ModelConfig], auto_rotate_on_error: bool = False):
        self.api_key = api_key
        self.available_models = models
        self.current_model_index = 0
        self.total_calls = 0
        self.failed_calls = 0
        self.auto_rotate_on_error = auto_rotate_on_error  # True only for Groq
        self.model_stats = {m.name: {"calls": 0, "errors": 0} for m in models}
        self.cache = {}

    @abstractmethod
    def _make_api_call(self, model: ModelConfig, messages: List[Dict],
                       temperature: float, max_tokens: int,
                       response_format: str) -> str:
        pass

    def _get_next_model(self) -> ModelConfig:

        if self.auto_rotate_on_error:
            sorted_models = sorted(
                [m for m in self.available_models if m.enabled],
                key=lambda m: (m.call_count, m.priority)
            )
            if sorted_models:
                return sorted_models[0]

        enabled_models = [m for m in self.available_models if m.enabled]
        if enabled_models:
            return min(enabled_models, key=lambda m: m.priority)

        return self.available_models[0]

    def call(self, prompt: str, temperature: float = 0.0,
             max_tokens: int = 4096, response_format: str = "json",
             max_retries: int = 5) -> str:

        cache_key = f"{hash(prompt)}_{temperature}"
        if Config.ENABLE_CACHING and cache_key in self.cache:
            return self.cache[cache_key]

        self.total_calls += 1
        retries = 0
        last_error = None

        while retries < max_retries:
            current_model = self._get_next_model()

            try:
                messages = [
                    {
                        "role": "system",
                        "content": "You are an expert security researcher specialized in software vulnerability detection. Always respond with valid JSON when requested."
                    },
                    {
                        "role": "user",
                        "content": prompt
                    }
                ]

                if response_format == "json" and "Return JSON:" not in prompt and "Return STRICT JSON:" not in prompt:
                    messages[-1]["content"] += "\n\nIMPORTANT: Respond with valid JSON only."

                result = self._make_api_call(
                    current_model, messages, temperature, max_tokens, response_format
                )

                current_model.call_count += 1
                self.model_stats[current_model.name]["calls"] += 1

                if Config.ENABLE_CACHING:
                    self.cache[cache_key] = result

                return result

            except Exception as e:
                error_str = str(e)
                last_error = e
                self.failed_calls += 1
                self.model_stats[current_model.name]["errors"] += 1

                if self.auto_rotate_on_error and ("rate_limit" in error_str.lower() or "429" in error_str):
                    print(f"{current_model.name} rate limited, rotating to next model...")
                    current_model.enabled = False

                    def reactivate():
                        time.sleep(60)
                        current_model.enabled = True
                        print(f"{current_model.name} reactivated")

                    threading.Thread(target=reactivate, daemon=True).start()

                    if all(not m.enabled for m in self.available_models):
                        print(f"Waiting 60 seconds...")
                        time.sleep(60)
                        for m in self.available_models:
                            m.enabled = True

                    retries += 1
                    continue

                print(f"Error with {current_model.name}: {error_str[:100]}")
                retries += 1
                time.sleep(2)

        print(f"All retries exhausted. Last error: {last_error}")
        return json.dumps({
            "error": "LLM call failed after all retries",
            "details": str(last_error)
        })

    def get_stats(self) -> dict:
        return {
            "provider": self.__class__.__name__,
            "total_calls": self.total_calls,
            "failed_calls": self.failed_calls,
            "success_rate": (self.total_calls - self.failed_calls) / self.total_calls if self.total_calls > 0 else 0,
            "cached_entries": len(self.cache),
            "model_stats": self.model_stats,
            "current_model_usage": {
                m.name: {
                    "call_count": m.call_count,
                    "enabled": m.enabled
                }
                for m in self.available_models
            }
        }

    def reset_stats(self):
        for model in self.available_models:
            model.call_count = 0
            model.enabled = True
        print(f"Stats reset for all {self.__class__.__name__} models")

In [None]:
# =============================================================================
# GROQ IMPLEMENTATION
# =============================================================================

class GroqProvider(BaseLLMProvider):

    def __init__(self, api_key: str, model_names: Optional[List[str]] = None):
        models = []
        for key, config in APIConfig.GROQ_MODELS.items():
            if model_names is None or key in model_names:
                models.append(ModelConfig(
                    name=config["name"],
                    provider=LLMProvider.GROQ,
                    priority=config["priority"],
                    api_key=api_key
                ))

        super().__init__(api_key, models, auto_rotate_on_error=True)
        self.client = Groq(api_key=api_key)

        print(f"Groq Provider initialized with {len(models)} models (auto-rotation enabled)")
        for model in models:
            print(f"   {model.name} (Priority: {model.priority})")

    def _make_api_call(self, model: ModelConfig, messages: List[Dict],
                       temperature: float, max_tokens: int,
                       response_format: str) -> str:

        if response_format == "json":
            response = self.client.chat.completions.create(
                model=model.name,
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens,
                response_format={"type": "json_object"}
            )
        else:
            response = self.client.chat.completions.create(
                model=model.name,
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens
            )

        return response.choices[0].message.content

In [None]:
# =============================================================================
# OPENAI (GPT-4o) IMPLEMENTATION
# =============================================================================

class OpenAIProvider(BaseLLMProvider):

    def __init__(self, api_key: str, model_names: Optional[List[str]] = None):
        models = []
        for key, config in APIConfig.OPENAI_MODELS.items():
            if model_names is None or key in model_names:
                models.append(ModelConfig(
                    name=config["name"],
                    provider=LLMProvider.OPENAI,
                    priority=config["priority"],
                    api_key=api_key,
                    base_url=APIConfig.OPENAI_BASE_URL
                ))

        super().__init__(api_key, models, auto_rotate_on_error=False)
        self.client = OpenAI(api_key=api_key)

        print(f"OpenAI Provider initialized with {len(models)} model(s) (sufficient quotas)")
        for model in models:
            print(f"{model.name} (Priority: {model.priority})")

    def _make_api_call(self, model: ModelConfig, messages: List[Dict],
                       temperature: float, max_tokens: int,
                       response_format: str) -> str:

        call_params = {
            "model": model.name,
            "messages": messages,
            "temperature": temperature,
            "max_tokens": max_tokens
        }

        if response_format == "json":
            call_params["response_format"] = {"type": "json_object"}

        response = self.client.chat.completions.create(**call_params)

        return response.choices[0].message.content

In [None]:
# =============================================================================
# DEEPSEEK IMPLEMENTATION
# =============================================================================

# Deepseek uses a compatible API with OpenAI

class DeepSeekProvider(BaseLLMProvider):

    def __init__(self, api_key: str, model_names: Optional[List[str]] = None):
        models = []
        for key, config in APIConfig.DEEPSEEK_MODELS.items():
            if model_names is None or key in model_names:
                models.append(ModelConfig(
                    name=config["name"],
                    provider=LLMProvider.DEEPSEEK,
                    priority=config["priority"],
                    api_key=api_key,
                    base_url=APIConfig.DEEPSEEK_BASE_URL
                ))

        super().__init__(api_key, models, auto_rotate_on_error=False)
        self.client = OpenAI(
            api_key=api_key,
            base_url=APIConfig.DEEPSEEK_BASE_URL
        )

        print(f"DeepSeek Provider initialized with {len(models)} model(s) (sufficient quotas)")
        for model in models:
            print(f"   • {model.name} (Priority: {model.priority})")

    def _make_api_call(self, model: ModelConfig, messages: List[Dict],
                       temperature: float, max_tokens: int,
                       response_format: str) -> str:

        call_params = {
            "model": model.name,
            "messages": messages,
            "temperature": temperature,
            "max_tokens": max_tokens
        }

        if response_format == "json":
            call_params["response_format"] = {"type": "json_object"}

        response = self.client.chat.completions.create(**call_params)

        return response.choices[0].message.content

In [None]:
# =============================================================================
# UNIFIED MULTI-PROVIDER MANAGER
# =============================================================================

class UnifiedLLMInterface:

    def __init__(self, providers: List[BaseLLMProvider],
                 rotation_strategy: str = "priority"):

        self.providers = providers
        self.rotation_strategy = rotation_strategy
        self.current_provider_index = 0
        self.total_calls = 0
        self.failed_calls = 0
        self.provider_stats = {
            p.__class__.__name__: {"calls": 0, "errors": 0, "success_rate": 0.0}
            for p in providers
        }

        print(f"\nUnified LLM Interface initialized")
        print(f"   Strategy: {rotation_strategy}")
        print(f"   Providers: {len(providers)}")
        for p in providers:
            print(f"      {p.__class__.__name__} ({len(p.available_models)} model(s))")

    def _get_next_provider(self) -> BaseLLMProvider:

        if self.rotation_strategy == "round_robin":
            provider = self.providers[self.current_provider_index]
            self.current_provider_index = (self.current_provider_index + 1) % len(self.providers)
            return provider

        elif self.rotation_strategy == "least_used":
            return min(self.providers,
                      key=lambda p: self.provider_stats[p.__class__.__name__]["calls"])

        elif self.rotation_strategy == "priority":
            for provider in sorted(self.providers,
                                  key=lambda p: min(m.priority for m in p.available_models
                                                   if m.enabled)):
                if any(m.enabled for m in provider.available_models):
                    return provider
            return self.providers[0]

        return self.providers[0]

    def call(self, prompt: str, temperature: float = 0.0,
             max_tokens: int = 4096, response_format: str = "json",
             max_retries: int = 3, preferred_provider: Optional[str] = None) -> str:

        self.total_calls += 1

        if preferred_provider:
            for provider in self.providers:
                if provider.__class__.__name__.lower().startswith(preferred_provider.lower()):
                    try:
                        result = provider.call(prompt, temperature, max_tokens,
                                             response_format, max_retries)

                        if "error" not in result:
                            self.provider_stats[provider.__class__.__name__]["calls"] += 1
                            return result
                    except Exception as e:
                        print(f"Preferred provider {preferred_provider} failed: {e}")
                        self.provider_stats[provider.__class__.__name__]["errors"] += 1

        providers_to_try = self.providers.copy()

        for provider in providers_to_try:
            try:
                result = provider.call(prompt, temperature, max_tokens,
                                     response_format, max_retries)

                if "error" not in result:
                    provider_name = provider.__class__.__name__
                    self.provider_stats[provider_name]["calls"] += 1

                    total = self.provider_stats[provider_name]["calls"]
                    errors = self.provider_stats[provider_name]["errors"]
                    self.provider_stats[provider_name]["success_rate"] = \
                        (total - errors) / total if total > 0 else 0

                    return result

            except Exception as e:
                provider_name = provider.__class__.__name__
                self.provider_stats[provider_name]["errors"] += 1
                print(f"{provider_name} failed, trying next provider...")
                continue

        self.failed_calls += 1
        return json.dumps({
            "error": "All providers failed",
            "details": "No provider could complete the request"
        })

    def get_stats(self) -> Dict[str, Any]:

        stats = {
            "total_calls": self.total_calls,
            "failed_calls": self.failed_calls,
            "success_rate": (self.total_calls - self.failed_calls) / self.total_calls
                           if self.total_calls > 0 else 0,
            "providers": {}
        }

        for provider in self.providers:
            provider_name = provider.__class__.__name__
            provider_stats = provider.get_stats()

            stats["providers"][provider_name] = {
                "calls": self.provider_stats[provider_name]["calls"],
                "errors": self.provider_stats[provider_name]["errors"],
                "success_rate": self.provider_stats[provider_name]["success_rate"],
                "models": provider_stats["current_model_usage"],
                "cache_size": provider_stats["cached_entries"]
            }

        return stats

    def reset_stats(self):
        for provider in self.providers:
            provider.reset_stats()

        for key in self.provider_stats:
            self.provider_stats[key] = {"calls": 0, "errors": 0, "success_rate": 0.0}

        self.total_calls = 0
        self.failed_calls = 0

        print("All providers stats reset")

    def print_stats(self):
        stats = self.get_stats()

        print(f"\n{'='*70}")
        print("UNIFIED LLM INTERFACE STATISTICS")
        print(f"{'='*70}")
        print(f"Total Calls: {stats['total_calls']}")
        print(f"Failed Calls: {stats['failed_calls']}")
        print(f"Overall Success Rate: {stats['success_rate']*100:.2f}%")
        print(f"\n{'Provider Stats':-^70}")

        for provider_name, provider_stats in stats["providers"].items():
            print(f"\n{provider_name}:")
            print(f"  Calls: {provider_stats['calls']}")
            print(f"  Errors: {provider_stats['errors']}")
            print(f"  Success Rate: {provider_stats['success_rate']*100:.2f}%")
            print(f"  Cache Size: {provider_stats['cache_size']}")
            print(f"  Models:")

            for model_name, model_stats in provider_stats["models"].items():
                print(f"    {model_name}: {model_stats['call_count']} calls - "
                      f"{'✅' if model_stats['enabled'] else '❌'}")

        print(f"\n{'='*70}\n")

In [None]:
# =============================================================================
# PREPROCESSING
# =============================================================================

class CodePreprocessor:

    @staticmethod
    def add_line_numbers(code: str) -> str:
        lines = code.split('\n')
        return '\n'.join([f"L{i+1}: {line}" for i, line in enumerate(lines)])

    @staticmethod
    def remove_line_numbers(code: str) -> str:
        lines = code.split('\n')
        return '\n'.join([re.sub(r'^L\d+:\s*', '', line) for line in lines])

    @staticmethod
    def extract_line_numbers(text: str) -> List[int]:
        matches = re.findall(r'L(\d+)', text)
        return sorted(list(set([int(m) for m in matches])))

    @staticmethod
    def normalize_code(code: str) -> str:
        code = re.sub(r'\n\s*\n', '\n\n', code)
        code = '\n'.join([line.rstrip() for line in code.split('\n')])
        return code.strip()

In [None]:
# =============================================================================
# BASE AGENT CLASS
# =============================================================================

class BaseAgent(ABC):

    def __init__(self, llm: UnifiedLLMInterface, name: str):
        self.llm = llm
        self.name = name
        self.execution_count = 0

    @abstractmethod
    def analyze(self, code: str) -> Any:
        pass

In [None]:
# =============================================================================
# PHASE 1: MULTI-VIEW DETECTION AGENTS
# =============================================================================

class MetaAgent(BaseAgent):

    BASELINE_AGENTS = ["StaticAnalyzerAgent", "BehaviorAnalyzerAgent", "MemoryLayoutAgent"]

    def analyze(self, code: str) -> List[str]:

        prompt = f"""You are MetaAgent. Select specialized agents for maximum vulnerability recall.
        Code:
        {code}

        Available Agents:
        - StaticAnalyzerAgent: Pattern-based static analysis (ALWAYS ACTIVE)
        - BehaviorAnalyzerAgent: Control/data flow reasoning (ALWAYS ACTIVE)
        - MemoryLayoutAgent: Memory safety analysis (ALWAYS ACTIVE)
        - FormatStringAgent: Format string vulnerabilities
        - FilePermissionAgent: File I/O and permissions
        - AuthFlowAgent: Authentication/authorization
        - CryptoConfigAgent: Cryptographic misuse
        - ConcurrencyAnalyzerAgent: Race conditions
        - ErrorHandlingAgent: Error handling issues
        - CodeInjectionAgent: Command/code injection

        Return JSON: {{"activated_agents": ["agent1", "agent2", ...]}}
        Include baseline agents + relevant specialized agents."""

        response = self.llm.call(prompt)
        try:
            data = json.loads(response)
            agents = data.get("activated_agents", [])
            return list(set(agents + self.BASELINE_AGENTS))
        except:
            return self.BASELINE_AGENTS

class SpecializedAgent(BaseAgent):

    def __init__(self, llm: UnifiedLLMInterface, name: str, cwe_focus: List[str],
                 triggers: List[str], description: str, cwe_examples: Dict[str, str] = None):
        super().__init__(llm, name)
        self.cwe_focus = cwe_focus
        self.triggers = triggers
        self.description = description
        self.cwe_examples = cwe_examples or {}

    def analyze(self, code: str) -> List[VulnerabilityReport]:

        examples_text = ""
        relevant_examples = []
        for cwe in self.cwe_focus:
          if cwe in self.cwe_examples:
            relevant_examples.append(self.cwe_examples[cwe])

        if relevant_examples:
          examples_text = "\nHere are REAL WORLD EXAMPLES of what you are looking for:\n" + "\n".join(relevant_examples) + "\n"

        prompt = f"""You are {self.name}. {self.description}
                  Analyze for: {', '.join(self.cwe_focus)}
                  Typical triggers: {', '.join(self.triggers)}

                  {examples_text}

                  MAXIMIZE RECALL - flag anything suspicious from your perspective.

                  Code:
                  {code}

                  Return STRICT JSON:
                  {{
                    "vulnerability_reported": true/false,
                    "findings": [
                      {{
                        "cwe": "CWE-XXX",
                        "code_lines": "L10,L15",
                        "description": "Specific issue found",
                        "evidence": ["evidence snippet 1", "evidence snippet 2"],
                        "confidence": 0.7
                      }}
                    ]
                  }}"""

        self.execution_count += 1
        response = self.llm.call(prompt)

        try:
            data = json.loads(response)
            reports = []

            for finding in data.get("findings", []):
                reports.append(VulnerabilityReport(
                    cwe=finding.get("cwe", "UNKNOWN"),
                    code_lines=CodePreprocessor.extract_line_numbers(finding.get("code_lines", "")),
                    description=finding.get("description", ""),
                    source_agent=self.name,
                    confidence=finding.get("confidence", 0.5),
                    evidence=finding.get("evidence", [])
                ))

            return reports

        except Exception as e:
            print(f"{self.name} failed: {e}")
            return []

class StaticToolAgent(BaseAgent):

    def __init__(self, llm: UnifiedLLMInterface, name: str, report_dir: Path):
        super().__init__(llm, name)
        self.report_dir = report_dir

    def analyze(self, code_input: CodeInput) -> List[VulnerabilityReport]:

        if self.name == "Sonarqube":
            try:
                pair_rank = code_input.metadata.get("pair_index")
                if pair_rank is None:
                    return []

                report_path = self.report_dir / f"analysis-results_{pair_rank}" / "filtered_results.jsonl"

            except Exception:
                return []

            if not report_path.exists():
                return []

            try:
                with open(report_path, "r", encoding="utf-8", errors="replace") as f:
                    content = f.read().strip()

                if not content:
                    return []

                report = VulnerabilityReport(
                    cwe=f"External-Tool-{self.name}",
                    code_lines=[],
                    description=f"Raw static analysis output from {self.name} for pair index {pair_rank}.",
                    source_agent=self.name,
                    confidence=0.9,
                    evidence=[content]
                )
                return [report]

            except Exception as e:
                print(f"Erreur de lecture du rapport Sonarqube : {e}")
                return []

        else:
            try:
                item_idx = code_input.id.split('_')[-1]
                report_path = self.report_dir / f"{item_idx}_report.txt"
            except Exception:
                return []

            if not report_path.exists():
                return []

            try:
                with open(report_path, "r", encoding="utf-8", errors="replace") as f:
                    content = f.read().strip()

                if not content:
                    return []

                report = VulnerabilityReport(
                    cwe=f"External-Tool-{self.name}",
                    code_lines=[],
                    description=f"Raw static analysis output from {self.name}.",
                    source_agent=self.name,
                    confidence=0.9,
                    evidence=[content]
                )
                return [report]

            except Exception as e:
                print(f"Error reading report for {self.name}: {e}")
                return []

class AggregatorAgent(BaseAgent):

    def analyze(self, reports: List[VulnerabilityReport]) -> List[VulnerabilityReport]:

        if not reports:
            return []

        grouped: Dict[str, List[VulnerabilityReport]] = {}
        for report in reports:
            if report.cwe not in grouped:
                grouped[report.cwe] = []
            grouped[report.cwe].append(report)

        aggregated = []
        for cwe, cwe_reports in grouped.items():
            all_lines: Set[int] = set()
            descriptions = []
            agents = []
            evidence = []
            max_conf = 0.0

            for r in cwe_reports:
                all_lines.update(r.code_lines)
                descriptions.append(r.description)
                agents.append(r.source_agent)
                evidence.extend(r.evidence)
                max_conf = max(max_conf, r.confidence)

            aggregated.append(VulnerabilityReport(
                cwe=cwe,
                code_lines=sorted(list(all_lines)),
                description=" | ".join(set(descriptions)),
                source_agent=", ".join(set(agents)),
                confidence=max_conf,
                evidence=list(set(evidence))
            ))

        return aggregated

In [None]:
# =============================================================================
# PHASE 2: HYPOTHESIS CONSTRUCTION
# =============================================================================

class TriggerPlannerAgent(BaseAgent):

    def analyze(self, reports: List[VulnerabilityReport], code: str) -> List[VulnerabilityHypothesis]:
        hypotheses = []

        print(f"   [DEBUG] TriggerPlanner received {len(reports)} reports")

        for idx, report in enumerate(reports):
            print(f"   [DEBUG] Processing report {idx+1}/{len(reports)}: {report.cwe}")

            prompt = f"""You are TriggerPlannerAgent. Build a FEASIBLE trigger path for this vulnerability.
            DO NOT judge validity - that comes later. Just construct the path.

            Report:
            - CWE: {report.cwe}
            - Lines: {report.code_lines}
            - Description: {report.description}
            - Source: {report.source_agent}
            - Evidence: {', '.join(report.evidence[:2]) if hasattr(report, 'evidence') and report.evidence else 'N/A'}

            Code:
            {code}

            Construct:
            1. Assumptions: What must be true for exploitation? (List specific, verifiable assumptions)
            2. Trigger Path: attacker-source → transforms → sink
            3. Trigger Lines: Specific lines where the vulnerability can be triggered

            Return STRICT JSON:
            {{
              "cwe": "{report.cwe}",
              "assumptions": ["A1: Input comes from untrusted source", "A2: No bounds checking"],
              "trigger_path": "user_input → strcpy → buffer overflow at L15",
              "trigger_lines": [10, 15],
              "is_valid": true
            }}

            Only set is_valid to false if the vulnerability is clearly impossible."""

            response = self.llm.call(prompt)

            try:
                data = json.loads(response)
                print(f"   [DEBUG] TriggerPlanner parsed JSON for {report.cwe}")
                print(f"   [DEBUG] is_valid={data.get('is_valid', True)}, assumptions={len(data.get('assumptions', []))}")

                if data.get("is_valid", True):
                    hypothesis = VulnerabilityHypothesis(
                        cwe=report.cwe,
                        code_lines=report.code_lines,
                        description=report.description,
                        source_agent=report.source_agent,
                        confidence=report.confidence,
                        trigger_lines=data.get("trigger_lines", report.code_lines),
                        evidence=report.evidence if hasattr(report, 'evidence') else []
                    )

                    hypothesis.assumptions = data.get("assumptions", [])
                    hypothesis.trigger_path = data.get("trigger_path", "Unknown path")

                    print(f"   [SUCCESS] Created hypothesis with {len(hypothesis.assumptions)} assumptions")
                    hypotheses.append(hypothesis)
                else:
                    print(f"   [SKIP] Hypothesis marked as invalid by LLM")

            except json.JSONDecodeError as e:
                print(f"   [ERROR] TriggerPlanner JSON parsing failed for {report.cwe}: {e}")
                print(f"   [ERROR] Response: {response[:300]}")
                hypothesis = VulnerabilityHypothesis(
                    cwe=report.cwe,
                    code_lines=report.code_lines,
                    description=report.description,
                    source_agent=report.source_agent,
                    confidence=report.confidence,
                    trigger_lines=report.code_lines,
                    evidence=report.evidence if hasattr(report, 'evidence') else []
                )
                hypothesis.assumptions = ["Default: Requires manual review"]
                hypothesis.trigger_path = "Unknown - parsing error"
                print(f"   [FALLBACK] Created minimal hypothesis due to error")
                hypotheses.append(hypothesis)

        print(f"   [SUMMARY] TriggerPlanner created {len(hypotheses)} hypotheses total")
        return hypotheses

In [None]:
# =============================================================================
# PHASE 3: ASSUMPTION VALIDATION
# =============================================================================

class AssumptionPrunerAgent(BaseAgent):

    def __init__(self, llm: UnifiedLLMInterface, name: str):
        super().__init__(llm, name)

    def _format_hypotheses(self, hypotheses: List[VulnerabilityHypothesis]) -> str:
        formatted = []
        for i, h in enumerate(hypotheses):
            assumptions_str = "\n    ".join(h.assumptions) if h.assumptions else "None specified"
            formatted.append(
                f"Hypothesis {i+1} [SOURCE: {h.source_agent}]:\n"
                f"  CWE: {h.cwe}\n"
                f"  Lines: {h.code_lines}\n"
                f"  Trigger Lines: {h.trigger_lines}\n"
                f"  Trigger Path: {h.trigger_path}\n"
                f"  Assumptions:\n    {assumptions_str}\n"
                f"  Description: {h.description}\n"
                f"  Confidence: {h.confidence:.2f}\n"
                f"  Evidence: {', '.join(h.evidence[:2]) if h.evidence else 'N/A'}\n"
                f"---"
            )
        return "\n".join(formatted)

    def analyze(self, hypotheses: List[VulnerabilityHypothesis], code: str) -> List[Tuple[VulnerabilityHypothesis, Dict[str, str]]]:

        if not hypotheses:
            print("   [DEBUG] No hypotheses to validate")
            return []

        print(f"   [DEBUG] AssumptionPruner received {len(hypotheses)} hypotheses")

        hypotheses_with_assumptions = [h for h in hypotheses if h.assumptions]
        print(f"   [DEBUG] {len(hypotheses_with_assumptions)} hypotheses have assumptions")

        formatted_hypotheses = self._format_hypotheses(hypotheses)
        print(f"   [DEBUG] Formatted hypotheses length: {len(formatted_hypotheses)} chars")

        prompt = f"""You are the **Assumption Pruner Agent ({self.name})**, a senior security engineer reviewing vulnerability hypotheses.

**YOUR TASK:**
For each hypothesis, validate its assumptions and assess overall confidence.

**EVALUATION FACTORS:**

**For each assumption, determine:**
- "VALID" - The assumption is clearly true in the code
- "LIKELY" - The assumption is probably true but not certain
- "UNCERTAIN" - Cannot determine from the code
- "INVALID" - The assumption is clearly false

**Overall confidence scoring:**
- Increase confidence if:
  * Well-established vulnerability pattern
  * Known-unsafe functions used
  * Untrusted input sources
  * Missing protection mechanisms
  * Clear attack path

- Decrease confidence if:
  * Non-existent code references
  * Flagged lines are non-executable
  * Multiple protection layers present
  * Improbable preconditions required
  * Code handles the edge case
  * Input is validated/sanitized

**Hypotheses to Review:**
{formatted_hypotheses}

**Code Under Analysis:**
{code}

**OUTPUT FORMAT (JSON list):**
For each hypothesis, return:
[
  {{
    "hypothesis_id": 0,
    "cwe": "CWE-XXX",
    "assumption_validations": {{
      "A1: Input comes from untrusted source": "VALID",
      "A2: No bounds checking": "LIKELY"
    }},
    "overall_confidence": 0.8,
    "assessment": "Brief explanation of your assessment",
    "should_proceed": true
  }}
]

Only set should_proceed to true if overall_confidence > 0.5."""

        self.execution_count += 1
        response = self.llm.call(prompt)

        print(f"   [DEBUG] LLM response length: {len(response)} chars")
        print(f"   [DEBUG] Response preview: {response[:200]}...")

        try:
            data = json.loads(response.strip())
            print(f"   [DEBUG] Parsed JSON successfully, type: {type(data)}")

            if not isinstance(data, list):
                print(f"   [ERROR] Pruner did not return a list. Response: {response[:500]}")
                return []

            print(f"   [DEBUG] JSON contains {len(data)} entries")

            validated_tuples = []

            for i, finding in enumerate(data):
                print(f"   [DEBUG] Processing finding {i+1}: confidence={finding.get('overall_confidence', 0.0)}, should_proceed={finding.get('should_proceed', False)}")

                if finding.get("should_proceed", False) and finding.get("overall_confidence", 0.0) > 0.5:
                    hypothesis_id = finding.get("hypothesis_id", i)

                    if hypothesis_id < len(hypotheses):
                        original_hypothesis = hypotheses[hypothesis_id]

                        updated_hypothesis = VulnerabilityHypothesis(
                            cwe=original_hypothesis.cwe,
                            code_lines=original_hypothesis.code_lines,
                            description=original_hypothesis.description,
                            source_agent=original_hypothesis.source_agent,
                            confidence=finding.get("overall_confidence", original_hypothesis.confidence),
                            trigger_lines=original_hypothesis.trigger_lines,
                            evidence=original_hypothesis.evidence,
                            assumptions=original_hypothesis.assumptions,
                            trigger_path=original_hypothesis.trigger_path
                        )

                        assumption_status = finding.get("assumption_validations", {})

                        validated_tuples.append((updated_hypothesis, assumption_status))
                        print(f"   [DEBUG] Added hypothesis {hypothesis_id} to validated list")
                    else:
                        print(f"   [DEBUG] Skipping - invalid hypothesis_id: {hypothesis_id} (max: {len(hypotheses)-1})")
                else:
                    print(f"   [DEBUG] Skipping finding {i} - should_proceed={finding.get('should_proceed', False)}, confidence={finding.get('overall_confidence', 0.0)}")

            print(f"   [SUCCESS] Pruner validated {len(validated_tuples)}/{len(hypotheses)} hypotheses")
            return validated_tuples

        except json.JSONDecodeError as e:
            print(f"   [ERROR] {self.name} JSON parsing failed: {e}")
            print(f"   [ERROR] Response was: {response[:500]}")
            return []
        except Exception as e:
            print(f"   [ERROR] {self.name} failed: {e}")
            print(f"   [ERROR] Response: {response[:500]}")
            import traceback
            traceback.print_exc()
            return []

In [None]:
# =============================================================================
# PHASE 4: PATH VERIFICATION
# =============================================================================

class FinalValidatorAgent(BaseAgent):

    def analyze(self, validated_hypotheses: List[Tuple[VulnerabilityHypothesis, Dict[str, str]]],
                code: str) -> List[ValidationResult]:

        results = []

        for hyp, assumption_status in validated_hypotheses:
            assumptions_str = "\n".join([
                f"    - {assumption}: {status}"
                for assumption, status in assumption_status.items()
            ])

            prompt = f"""You are FinalValidatorAgent. Binary decision: Is this vulnerability EXPLOITABLE?

            Hypothesis:
            - CWE: {hyp.cwe}
            - Code Lines: {hyp.code_lines}
            - Trigger Lines: {hyp.trigger_lines}
            - Trigger Path: {hyp.trigger_path}
            - Description: {hyp.description}
            - Current Confidence: {hyp.confidence}

            Validated Assumptions:
{assumptions_str}

            Evidence: {', '.join(hyp.evidence[:3]) if hyp.evidence else 'N/A'}

            Code:
            {code}

            Check for protections that COMPLETELY BLOCK exploitation:
            - Bounds checks before sink
            - Null checks
            - Early returns / error handling
            - Input sanitization
            - Protective API behavior

            IMPORTANT:
            - Path is INVALID only if protections COMPLETELY BLOCK ALL routes
            - If assumptions are mostly VALID/LIKELY, lean towards exploitable
            - When uncertain, RETAIN the vulnerability (false positive is better than false negative)

            Return STRICT JSON:
            {{
              "path_valid": true/false,
              "reasoning": "detailed explanation of why path is/isn't exploitable",
              "protective_mechanisms": ["protection1", "protection2"],
              "blocks_all_paths": false,
              "final_verdict": true/false
            }}

            final_verdict should be true if the vulnerability is exploitable."""

            response = self.llm.call(prompt)

            try:
                data = json.loads(response)

                results.append(ValidationResult(
                    hypothesis=hyp,
                    assumption_status=assumption_status,
                    path_valid=data.get("path_valid", True),
                    reasoning=data.get("reasoning", ""),
                    final_verdict=data.get("final_verdict", True)
                ))
            except Exception as e:
                print(f"FinalValidator failed: {e}")
                results.append(ValidationResult(
                    hypothesis=hyp,
                    assumption_status=assumption_status,
                    path_valid=True,
                    reasoning=f"Validation error: {e} - Keeping vulnerability by default",
                    final_verdict=True
                ))

        return results

In [None]:
# =============================================================================
# COMPLETE PIPELINE WITH CHECKPOINTS
# =============================================================================

class VulPrunePipeline:

    def __init__(self, llm: UnifiedLLMInterface):
        self.llm = llm

        self.cwe_examples = CWEExampleLoader.load_examples(Config.CWE_EXAMPLES_PATH)
        print(f"Pipeline loaded {len(self.cwe_examples)} CWE few-shot examples")

        self.meta_agent = MetaAgent(llm, "MetaAgent")
        self.aggregator = AggregatorAgent(llm, "AggregatorAgent")
        self.trigger_planner = TriggerPlannerAgent(llm, "TriggerPlannerAgent")
        self.assumption_pruner = AssumptionPrunerAgent(llm, "AssumptionPrunerAgent")
        self.final_validator = FinalValidatorAgent(llm, "FinalValidatorAgent")

        self.specialized_agents = self._init_specialized_agents()

        self.static_agents = [
            StaticToolAgent(llm, "Flawfinder", Config.FLAWFINDER_DIR),
            StaticToolAgent(llm, "Cppcheck", Config.CPPCHECK_DIR),
            StaticToolAgent(llm, "Sonarqube", Config.SONARQUBE_DIR)
        ]

    def _init_specialized_agents(self) -> Dict[str, SpecializedAgent]:
        return {
            "StaticAnalyzerAgent": SpecializedAgent(
                self.llm, "StaticAnalyzerAgent",
                ["CWE-476", "CWE-787", "CWE-125"],
                ["NULL", "buffer", "pointer", "array"],
                "Static pattern recognition for common vulnerabilities",
                self.cwe_examples
            ),
            "BehaviorAnalyzerAgent": SpecializedAgent(
                self.llm, "BehaviorAnalyzerAgent",
                ["CWE-703", "CWE-416", "CWE-20"],
                ["leak", "use-after-free", "error", "validation"],
                "Control/data flow analysis for runtime issues",
                self.cwe_examples
            ),
            "MemoryLayoutAgent": SpecializedAgent(
                self.llm, "MemoryLayoutAgent",
                ["CWE-787", "CWE-125", "CWE-190"],
                ["memcpy", "strcpy", "buffer", "overflow", "underflow"],
                "Memory operations and buffer safety",
                self.cwe_examples
            ),
            "FormatStringAgent": SpecializedAgent(
                self.llm, "FormatStringAgent",
                ["CWE-134"],
                ["printf", "sprintf", "fprintf", "format", "%s", "%d"],
                "Format string vulnerability detection",
                self.cwe_examples
            ),
            "FilePermissionAgent": SpecializedAgent(
                self.llm, "FilePermissionAgent",
                ["CWE-022"],
                ["fopen", "chmod", "access", "file", "path"],
                "File I/O and permission issues",
                self.cwe_examples
            ),
            "AuthFlowAgent": SpecializedAgent(
                self.llm, "AuthFlowAgent",
                ["CWE-862", "CWE-863"],
                ["auth", "permission", "access", "login", "session", "role"],
                "Authentication and authorization flow analysis",
                self.cwe_examples
            ),
            "CryptoConfigAgent": SpecializedAgent(
                self.llm, "CryptoConfigAgent",
                ["CWE-327", "CWE-338"],
                ["crypto", "encrypt", "decrypt", "hash", "random", "key"],
                "Cryptographic misuse and weak configurations",
                self.cwe_examples
            ),
            "ConcurrencyAnalyzerAgent": SpecializedAgent(
                self.llm, "ConcurrencyAnalyzerAgent",
                ["CWE-362", "CWE-366"],
                ["thread", "mutex", "lock", "atomic", "race", "concurrent"],
                "Race conditions and synchronization issues",
                self.cwe_examples
            ),
            "ErrorHandlingAgent": SpecializedAgent(
                self.llm, "ErrorHandlingAgent",
                ["CWE-703", "CWE-404"],
                ["try", "catch", "error", "exception", "return", "cleanup"],
                "Error handling and resource management",
                self.cwe_examples
            ),
            "CodeInjectionAgent": SpecializedAgent(
                self.llm, "CodeInjectionAgent",
                ["CWE-078", "CWE-089"],
                ["eval", "exec", "system", "shell", "command", "sql"],
                "Command and code injection vulnerabilities",
                self.cwe_examples
            )
        }

    def detect(self, code_input: CodeInput) -> PredictionResult:

        print(f"\n{'='*60}")
        print(f"Analyzing: {code_input.id}")
        print(f"{'='*60}")

        numbered_code = CodePreprocessor.add_line_numbers(code_input.code)

        # PHASE 1: Multi-View Detection
        print("\nPhase 1: Multi-View Detection")
        activated_agents = self.meta_agent.analyze(numbered_code)

        llm_reports = []
        static_tool_reports = []

        print("   -> Running Static Tools (Flawfinder, Cppcheck, Sonarqube)...")
        for agent in self.static_agents:
            reports = agent.analyze(code_input)
            if reports:
                print(f"      {agent.name} found report.")
                static_tool_reports.extend(reports)

        for agent_name in activated_agents:
            if agent_name in self.specialized_agents:
                agent = self.specialized_agents[agent_name]
                reports = agent.analyze(numbered_code)
                llm_reports.extend(reports)

        aggregated_llm_reports = self.aggregator.analyze(llm_reports)

        phase1_results = {
            "activated_agents": activated_agents,
            "raw_reports": len(aggregated_llm_reports),
            "aggregated_reports": [r.to_dict() for r in aggregated_llm_reports],
            "static_tool_reports": len(static_tool_reports)
        }

        all_reports_for_planner = aggregated_llm_reports + static_tool_reports

        if not all_reports_for_planner:
            print("   No vulnerabilities detected in Phase 1")
            return PredictionResult(
                id=code_input.id,
                is_vulnerable=False,
                detected_cwes=[],
                confidence=0.0,
                phases={"phase1": phase1_results}
            )

        # PHASE 2: Hypothesis Construction
        print("\nPhase 2: Hypothesis Construction")
        print(f"   Input: {len(all_reports_for_planner)} reports")
        hypotheses = self.trigger_planner.analyze(all_reports_for_planner, numbered_code)
        print(f"   Output: {len(hypotheses)} hypotheses")

        if not hypotheses:
            print("   No valid hypotheses constructed")
            return PredictionResult(
                id=code_input.id,
                is_vulnerable=False,
                detected_cwes=[],
                confidence=0.0,
                phases={
                    "phase1": phase1_results,
                    "phase2": {"hypotheses_count": 0}
                }
            )

        print(f"   Constructed {len(hypotheses)} hypotheses")

        # PHASE 3: Assumption Validation
        print("\nPhase 3: Assumption Validation")
        print(f"   Input: {len(hypotheses)} hypotheses")
        validated_tuples = self.assumption_pruner.analyze(hypotheses, numbered_code)
        print(f"   Output: {len(validated_tuples)} validated tuples")

        if not validated_tuples:
            print("   No hypotheses passed assumption validation")
            return PredictionResult(
                id=code_input.id,
                is_vulnerable=False,
                detected_cwes=[],
                confidence=0.0,
                phases={
                    "phase1": phase1_results,
                    "phase2": {"hypotheses": [h.to_dict() for h in hypotheses]},
                    "phase3": {"validated_count": 0}
                }
            )

        # PHASE 4: Final Validation
        print("\nPhase 4: Final Validation")
        print(f"   Input: {len(validated_tuples)} tuples")
        validations = self.final_validator.analyze(validated_tuples, numbered_code)
        print(f"   Output: {len(validations)} validation results")

        # Final decision
        exploitable = [v for v in validations if v.final_verdict]
        detected_cwes = list({v.hypothesis.cwe for v in exploitable})

        print(f"   Final verdict: {len(exploitable)}/{len(validations)} exploitable vulnerabilities")

        return PredictionResult(
            id=code_input.id,
            is_vulnerable=len(exploitable) > 0,
            detected_cwes=detected_cwes,
            confidence=max([v.hypothesis.confidence for v in exploitable], default=0.0),
            phases={
                "phase1": phase1_results,
                "phase2": {
                    "hypotheses_count": len(hypotheses),
                    "hypotheses": [h.to_dict() for h in hypotheses]
                },
                "phase3": {
                    "validated_count": len(validated_tuples),
                    "validated_hypotheses": [
                        {
                            "hypothesis": hyp.to_dict(),
                            "assumption_status": status
                        }
                        for hyp, status in validated_tuples
                    ]
                },
                "phase4": {
                    "total_validations": len(validations),
                    "exploitable_count": len(exploitable),
                    "validations": [v.to_dict() for v in validations]
                }
            })

    def _save_checkpoint(self, results: List[Dict], last_index: int, checkpoint_path: Path):

        checkpoint = {
            "results": results,
            "last_index": last_index,
            "timestamp": time.time(),
            "llm_stats": self.llm.get_stats()
        }
        try:
            with open(checkpoint_path, "w") as f:
                json.dump(checkpoint, f, indent=2)
        except Exception as e:
            print(f"Failed to save checkpoint: {e}")

In [None]:
# =============================================================================
# DATASET LOADERS
# =============================================================================

class DatasetManager:

    def load_primevul_pairs(max_pairs: Optional[int] = None, random_sample: Any = True) -> List[Dict]:

        base_dir = Path("/content/VulnerabilitiesDetectionProject/PrimeVul")
        paired_path = base_dir / "primevul_filtered.jsonl"

        if not paired_path.exists():
            raise FileNotFoundError(f"{paired_path} not found. Git clone first.")

        print(f"Loading from {paired_path.name}")

        data = []
        with open(paired_path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    data.append(json.loads(line))
                except json.JSONDecodeError:
                    continue

        print(f"Loaded {len(data)} total entries")

        groups = defaultdict(list)
        for o in data:
            groups[o.get("commit_id")].append(o)

        pairs = []
        count_vuln_alone = 0
        count_clean_alone = 0

        for cid, items in groups.items():
            vuln = next((x for x in items if x.get("target") == 1), None)
            fix = next((x for x in items if x.get("target") == 0), None)
            if vuln and fix:
                pairs.append((vuln, fix))
            elif vuln:
                count_vuln_alone += 1
            elif fix:
                count_clean_alone += 1

        print(f"Found {len(pairs)} valid vuln/fix pairs")
        print(f"    {count_vuln_alone} vuln alone | {count_clean_alone} fix alone")

        if isinstance(random_sample, int):
            selected_pairs = [pairs[random_sample - 1]]
        elif isinstance(random_sample, (list, tuple)):
            selected_pairs = [pairs[i] for i in random_sample]
        else:
            if random_sample and len(pairs) > (max_pairs or 10):
                random.shuffle(pairs)
                selected_pairs = pairs[:max_pairs or 10]
            else:
                selected_pairs = pairs[:max_pairs] if max_pairs else pairs

        clean_samples = []
        vuln_samples = []
        results = []

        for i, (v, f) in enumerate(selected_pairs, 1):
            func_v = v.get("func", "").splitlines()
            func_f = f.get("func", "").splitlines()
            seq = difflib.SequenceMatcher(None, "\n".join(func_v), "\n".join(func_f))
            ratio = seq.ratio()

            results.append({
                "pair_id": i,
                "commit_id": v.get("commit_id"),
                "idx_vuln": v.get("idx"),
                "idx_fix": f.get("idx"),
                "ratio": ratio
            })

            base_entry_v = {
                "idx": v.get("idx"),
                "id": v.get("commit_id"),
                "project": v.get("project", "unknown"),
                "dataset": "PrimeVul",
                "cwe": v.get("cwe", []),
                "cve": v.get("cve", "None"),
                "cve_desc": v.get("cve_desc", "N/A")
            }

            base_entry_f = base_entry_v.copy()
            base_entry_f["idx"] = f.get("idx")

            vuln_samples.append({
                **base_entry_v,
                "code": v.get("func", ""),
                "target": 1
            })

            clean_samples.append({
                **base_entry_f,
                "code": f.get("func", ""),
                "target": 0
            })

        print(f"Selected {len(vuln_samples)} vuln & {len(clean_samples)} clean samples")
        return clean_samples, vuln_samples, results

In [None]:
# =============================================================================
# MAIN EVALUATION FUNCTION
# =============================================================================

def evaluate_primevul_paired_with_checkpoints(
    pipeline: VulPrunePipeline,
    num_pairs: int = 10,
    random_sample: bool = True,
    checkpoint_interval: int = 3
):

    print("\n" + "="*69)
    print("VULPRUNE PAIRED EVALUATION (WITH CHECKPOINT SUPPORT)")
    print("="*70 + "\n")

    checkpoint_manager = PairedCheckpointManager()

    checkpoint_data = checkpoint_manager.load()

    if checkpoint_data:
        print("Checkpoint found! Loading previous progress...")
        last_pair_index = checkpoint_data.get("last_pair_index", -1)
        vuln_predictions = checkpoint_data.get("vuln_predictions", [])
        clean_predictions = checkpoint_data.get("clean_predictions", [])
        pair_results = checkpoint_data.get("pair_results", [])

        print(f"   Resuming from pair {last_pair_index + 2}/{num_pairs}")
        print(f"   Already processed: {len(vuln_predictions)} vulnerable samples")
        print(f"   Already processed: {len(clean_predictions)} clean samples")
        print(f"   Pair results so far: {len(pair_results)}\n")

        start_pair_index = last_pair_index + 1
    else:
        print("No checkpoint found. Starting fresh evaluation.\n")
        vuln_predictions = []
        clean_predictions = []
        pair_results = []
        start_pair_index = 0
    try:
        clean_samples, vuln_samples, pair_stats = DatasetManager.load_primevul_pairs(
            max_pairs=num_pairs,
            random_sample=random_sample
        )
    except Exception as e:
        print(f"Failed to load data: {e}")
        return None

    print(f"Loaded {len(vuln_samples)} vulnerable + {len(clean_samples)} fixed samples\n")

    vuln_inputs = [
        CodeInput(
            id=f"vuln_{v['idx']}",
            code=v['code'],
            project=v.get('project')
        )
        for v in vuln_samples
    ]

    clean_inputs = [
        CodeInput(
            id=f"clean_{c['idx']}",
            code=c['code'],
            project=c.get('project')
        )
        for c in clean_samples
    ]

    vuln_truths = [
        GroundTruth(
            id=f"vuln_{v['idx']}",
            cwe=str(v.get('cwe', ['UNKNOWN'])[0] if isinstance(v.get('cwe'), list) else v.get('cwe', 'UNKNOWN')),
            label=1,
            commit_id=v.get('id')
        )
        for v in vuln_samples
    ]

    clean_truths = [
        GroundTruth(
            id=f"clean_{c['idx']}",
            cwe=str(c.get('cwe', ['UNKNOWN'])[0] if isinstance(c.get('cwe'), list) else c.get('cwe', 'UNKNOWN')),
            label=0,
            commit_id=c.get('id')
        )
        for c in clean_samples
    ]

    try:
        for i in range(start_pair_index, num_pairs):
            print(f"\n{'='*70}")
            print(f"Processing PAIR {i+1}/{num_pairs}")
            print(f"{'='*70}\n")

            print(f"[VULN] Analyzing vulnerable sample {vuln_inputs[i].id}...")
            try:
                vuln_pred = pipeline.detect(vuln_inputs[i])
                vuln_predictions.append({
                    "id": vuln_pred.id,
                    "is_vulnerable": vuln_pred.is_vulnerable,
                    "detected_cwes": vuln_pred.detected_cwes,
                    "confidence": vuln_pred.confidence,
                    "phases": vuln_pred.phases if hasattr(vuln_pred, 'phases') else {}
                })
            except Exception as e:
                print(f"  Error: {e}")
                vuln_predictions.append({
                    "id": vuln_inputs[i].id,
                    "error": str(e)
                })

            print(f"\n[CLEAN] Analyzing clean sample {clean_inputs[i].id}...")
            try:
                clean_pred = pipeline.detect(clean_inputs[i])
                clean_predictions.append({
                    "id": clean_pred.id,
                    "is_vulnerable": clean_pred.is_vulnerable,
                    "detected_cwes": clean_pred.detected_cwes,
                    "confidence": clean_pred.confidence,
                    "phases": clean_pred.phases if hasattr(clean_pred, 'phases') else {}
                })
            except Exception as e:
                print(f"  Error: {e}")
                clean_predictions.append({
                    "id": clean_inputs[i].id,
                    "error": str(e)
                })

            if "error" not in vuln_predictions[-1] and "error" not in clean_predictions[-1]:
                vuln_pred_dict = vuln_predictions[-1]
                clean_pred_dict = clean_predictions[-1]
                vuln_truth = vuln_truths[i]
                clean_truth = clean_truths[i]

                cwe_matched = False
                if vuln_truth.cwe and vuln_truth.cwe != 'UNKNOWN':
                    cwe_matched = any(vuln_truth.cwe in str(dc) for dc in vuln_pred_dict['detected_cwes'])

                vuln_correct = vuln_pred_dict['is_vulnerable'] == True
                clean_correct = clean_pred_dict['is_vulnerable'] == False

                is_pair_correct = vuln_correct and clean_correct
                is_pair_reversed = (not vuln_correct) and (not clean_correct)

                pair_result = {
                    "pair_id": i + 1,
                    "commit_id": vuln_truth.commit_id,
                    "ground_truth_cwe": vuln_truth.cwe,
                    "vulnerable": {
                        "prediction": vuln_pred_dict['is_vulnerable'],
                        "detected_cwes": vuln_pred_dict['detected_cwes'],
                        "cwe_matched": cwe_matched,
                        "correct": vuln_correct
                    },
                    "fixed": {
                        "prediction": clean_pred_dict['is_vulnerable'],
                        "detected_cwes": clean_pred_dict['detected_cwes'],
                        "correct": clean_correct
                    },
                    "pair_result": {
                        "is_correct": is_pair_correct,
                        "is_reversed": is_pair_reversed,
                        "classification": (
                            "P-C" if is_pair_correct else
                            "P-R" if is_pair_reversed else
                            "both_vuln" if (vuln_pred_dict['is_vulnerable'] and clean_pred_dict['is_vulnerable']) else
                            "both_safe"
                        )
                    }
                }

                pair_results.append(pair_result)

                print(f"\nPair Result: {pair_result['pair_result']['classification']}")

            if (i + 1) % checkpoint_interval == 0:
                print(f"\nSaving checkpoint at pair {i+1}/{num_pairs}...")
                checkpoint_manager.save(
                    last_pair_index=i,
                    vuln_predictions=vuln_predictions,
                    clean_predictions=clean_predictions,
                    pair_results=pair_results,
                    llm_stats=pipeline.llm.get_stats(),
                    metadata={
                        "num_pairs": num_pairs,
                        "random_sample": random_sample
                    }
                )

                stats = pipeline.llm.get_stats()
                print(f"Progress: {stats['total_calls']} LLM calls, "
                      f"{stats['success_rate']*100:.1f}% success rate")

    except KeyboardInterrupt:
        print("\n\nEvaluation interrupted by user!")
        print("Saving current progress to checkpoint...")
        checkpoint_manager.save(
            last_pair_index=i,
            vuln_predictions=vuln_predictions,
            clean_predictions=clean_predictions,
            pair_results=pair_results,
            llm_stats=pipeline.llm.get_stats(),
            metadata={
                "num_pairs": num_pairs,
                "random_sample": random_sample,
                "interrupted": True
            }
        )
        print("Checkpoint saved. Run resume_paired_evaluation() to continue.")
        return None

    except Exception as e:
        print(f"\n\nEvaluation failed with error: {e}")
        print("Saving partial results to checkpoint...")
        checkpoint_manager.save(
            last_pair_index=i if 'i' in locals() else -1,
            vuln_predictions=vuln_predictions,
            clean_predictions=clean_predictions,
            pair_results=pair_results,
            llm_stats=pipeline.llm.get_stats(),
            metadata={
                "num_pairs": num_pairs,
                "random_sample": random_sample,
                "error": str(e)
            }
        )
        import traceback
        traceback.print_exc()
        return None

    print("\nCleaning up checkpoint...")
    checkpoint_manager.delete()

    print("\nEVALUATION COMPLETE!\n")

    return evaluation_results

In [None]:
# =============================================================================
# RESUME FROM CHECKPOINT
# =============================================================================

def resume_paired_evaluation():

    checkpoint_manager = PairedCheckpointManager()

    if not checkpoint_manager.exists():
        print("No checkpoint found!")
        print("   Use evaluate_primevul_paired_with_checkpoints() to start a new evaluation")
        return None

    print("Checkpoint found! Loading...")

    checkpoint_info = checkpoint_manager.get_info()
    if checkpoint_info:
        print(f"   Last processed pair: {checkpoint_info['last_pair_index'] + 1}")
        print(f"   Vulnerable predictions: {checkpoint_info['num_vuln_predictions']}")
        print(f"   Clean predictions: {checkpoint_info['num_clean_predictions']}")
        print(f"   Pair results: {checkpoint_info['num_pair_results']}")
        print(f"   Checkpoint time: {time.ctime(checkpoint_info['timestamp'])}")

    print("\nInitializing pipeline...")
    Config.setup_directories()

    llm = UnifiedLLMInterface(
        api_key=Config.GROQ_API_KEY,
        models=None
    )

    pipeline = VulPrunePipeline(llm)

    checkpoint_data = checkpoint_manager.load()
    metadata = checkpoint_data.get("metadata", {})
    num_pairs = metadata.get("num_pairs", 10)
    random_sample = metadata.get("random_sample", True)

    print(f"\nResuming evaluation...")
    print(f"   Total pairs: {num_pairs}")
    print(f"   Remaining pairs: {num_pairs - (checkpoint_info['last_pair_index'] + 1)}")

    return evaluate_primevul_paired_with_checkpoints(
        pipeline=pipeline,
        num_pairs=num_pairs,
        random_sample=random_sample
    )

In [None]:
# =============================================================================
# CHECKPOINT MANAGEMENT UTILITIES
# =============================================================================

def list_paired_checkpoints():
    checkpoint_manager = PairedCheckpointManager()

    if not checkpoint_manager.exists():
        print("No paired evaluation checkpoints found")
        return

    info = checkpoint_manager.get_info()
    if info:
        print(f"Paired Evaluation Checkpoint:")
        print(f"   Date: {time.ctime(info['timestamp'])}")
        print(f"   Progress: Pair {info['last_pair_index'] + 1}")
        print(f"   Vulnerable predictions: {info['num_vuln_predictions']}")
        print(f"   Clean predictions: {info['num_clean_predictions']}")
        print(f"   Pair results: {info['num_pair_results']}")

        if info.get('metadata'):
            print(f"   Metadata: {info['metadata']}")

def clear_paired_checkpoints():
    checkpoint_manager = PairedCheckpointManager()

    if not checkpoint_manager.exists():
        print("No checkpoints to clear")
        return

    if checkpoint_manager.delete():
        print("Paired evaluation checkpoint deleted")
    else:
        print("Failed to delete checkpoint")

def get_paired_checkpoint_stats():
    checkpoint_manager = PairedCheckpointManager()

    if not checkpoint_manager.exists():
        print("No active checkpoint")
        return None

    checkpoint_data = checkpoint_manager.load()
    if not checkpoint_data:
        return None

    llm_stats = checkpoint_data.get("llm_stats", {})

    if llm_stats:
        print(f"\nLLM Usage:")
        print(f"  Total calls: {llm_stats.get('total_calls', 0)}")
        print(f"  Success rate: {llm_stats.get('success_rate', 0)*100:.1f}%")

    print("="*50)

    return checkpoint_data

In [None]:
# =============================================================================
# POST-ANALYSIS TOOLS FOR PAIRED RESULTS
# =============================================================================

def analyze_results(results_path: str = None):
    if results_path is None:
        results_dir = Config.RESULTS_DIR
        result_files = sorted(results_dir.glob("*paired*.json"), key=lambda x: x.stat().st_mtime)

        if not result_files:
            result_files = sorted(results_dir.glob("*results*.json"), key=lambda x: x.stat().st_mtime)

        if not result_files:
            print("No results files found")
            return

        results_path = result_files[-1]
        print(f"Using latest results: {results_path.name}")

    with open(results_path, "r") as f:
        data = json.load(f)

    results = data.get("results", [])
    metrics = data.get("metrics", {})
    cwe_detection = data.get("cwe_detection", {})

    print(f"\n{'='*70}")
    print("DETAILED PAIRED RESULTS ANALYSIS")
    print(f"{'='*70}\n")

    total = len(results)
    with_errors = sum(1 for r in results if "error" in r)
    successful = total - with_errors

    vuln_samples = [r for r in results if r.get('is_vulnerable', False)]
    fixed_samples = [r for r in results if not r.get('is_vulnerable', False)]

    print(f"Total samples: {total}")
    print(f"  Vulnerable samples: {len(vuln_samples)}")
    print(f"  Fixed samples: {len(fixed_samples)}")
    print(f"Successful: {successful}")
    print(f"Errors: {with_errors}")

    if metrics:
        print(f"\nPaired Metrics Summary:")
        print(f"  P-C (Pair Correct): {metrics.get('P-C', 0)} ({metrics.get('P-C%', 0):.2f}%)")
        print(f"  P-R (Pair Reversed): {metrics.get('P-R', 0)} ({metrics.get('P-R%', 0):.2f}%)")
        print(f"  VPS (Pair Score): {metrics.get('VPS', 0)} ({metrics.get('VPS%', 0):.2f}%)")
        print(f"  Accuracy: {metrics.get('ACC%', 0):.2f}%")
        print(f"  F1 Score: {metrics.get('F1%', 0):.2f}%")

    if cwe_detection:
        print(f"\nCWE Detection Rate:")
        print(f"{'CWE':<15} {'Total':>10} {'Detected':>10} {'Rate':>10}")
        print("-"*50)

        for cwe, stats in sorted(cwe_detection.items(),
                                key=lambda x: x[1].get('total', 0), reverse=True):
            total_cwe = stats.get('total', 0)
            detected_cwe = stats.get('detected', 0)
            rate = (detected_cwe / total_cwe * 100) if total_cwe > 0 else 0
            print(f"{cwe:<15} {total_cwe:>10} {detected_cwe:>10} {rate:>9.1f}%")

    print(f"\nFalse Positives Analysis:")
    fps = [r for r in fixed_samples if r.get('prediction') == 1]
    print(f"  Total FPs: {len(fps)} / {len(fixed_samples)} fixed samples")

    if fps:
        fp_cwes = {}
        for fp in fps:
            detected = fp.get('detected_cwes', [])
            for cwe in detected:
                fp_cwes[cwe] = fp_cwes.get(cwe, 0) + 1

        print(f"  Most common false CWEs:")
        for cwe, count in sorted(fp_cwes.items(), key=lambda x: x[1], reverse=True)[:5]:
            print(f"    - {cwe}: {count} times")

    print(f"\n❌ False Negatives Analysis:")
    fns = [r for r in vuln_samples if r.get('prediction') == 0]
    print(f"  Total FNs: {len(fns)} / {len(vuln_samples)} vulnerable samples")

    if fns:
        fn_cwes = {}
        for fn in fns:
            cwe = fn.get('ground_truth_cwe', 'UNKNOWN')
            fn_cwes[cwe] = fn_cwes.get(cwe, 0) + 1

        print(f"  Missed CWE types:")
        for cwe, count in sorted(fn_cwes.items(), key=lambda x: x[1], reverse=True):
            print(f"    - {cwe}: {count} times ({count/len(vuln_samples)*100:.1f}%)")

    print(f"\nCWE Match Analysis (Correct CWE Detection):")
    vuln_with_match = [r for r in vuln_samples if r.get('cwe_matched', False)]
    vuln_detected = [r for r in vuln_samples if r.get('prediction') == 1]

    print(f"  Vulnerable samples detected: {len(vuln_detected)} / {len(vuln_samples)}")
    print(f"  With correct CWE: {len(vuln_with_match)} / {len(vuln_samples)} ({len(vuln_with_match)/len(vuln_samples)*100:.1f}%)")

    if vuln_detected:
        match_rate = len(vuln_with_match) / len(vuln_detected) * 100 if vuln_detected else 0
        print(f"  Match rate (among detected): {len(vuln_with_match)} / {len(vuln_detected)} ({match_rate:.1f}%)")

    print(f"\n{'='*70}\n")

def export_results_csv(results_path: str = None, output_path: str = None):

    if results_path is None:
        results_dir = Config.RESULTS_DIR
        result_files = sorted(results_dir.glob("*paired*.json"), key=lambda x: x.stat().st_mtime)
        if not result_files:
            result_files = sorted(results_dir.glob("*results*.json"), key=lambda x: x.stat().st_mtime)
        if not result_files:
            print("❌ No results files found")
            return
        results_path = result_files[-1]

    if output_path is None:
        output_path = Config.RESULTS_DIR / "paired_results_export.csv"

    with open(results_path, "r") as f:
        data = json.load(f)

    results = data.get("results", [])

    with open(output_path, "w", newline='', encoding='utf-8') as f:
        writer = csv.writer(f)

        writer.writerow([
            "ID", "Commit_ID", "Is_Vulnerable", "Ground_Truth_CWE", "Target",
            "Prediction", "Detected_CWEs", "Num_Detected_CWEs",
            "Correct", "CWE_Matched", "Error"
        ])

        for r in results:
            row = [
                r.get("id", ""),
                r.get("commit_id", ""),
                r.get("is_vulnerable", ""),
                r.get("ground_truth_cwe", ""),
                r.get("target", ""),
                r.get("prediction", ""),
                ";".join(r.get("detected_cwes", [])),
                len(r.get("detected_cwes", [])),
                r.get("prediction") == r.get("target"),
                r.get("cwe_matched", False),
                "error" in r
            ]
            writer.writerow(row)

    print(f"Results exported to: {output_path}")
    print(f"   Total rows: {len(results)}")

    pairs_output = Config.RESULTS_DIR / "pairs_summary.csv"
    _export_pairs_summary(results, pairs_output)

def _export_pairs_summary(results: List[Dict], output_path: Path):

    pairs = {}
    for r in results:
        commit_id = r.get('commit_id', '')
        if not commit_id:
            continue

        if commit_id not in pairs:
            pairs[commit_id] = {
                'commit_id': commit_id,
                'cwe': r.get('ground_truth_cwe', ''),
                'vuln_pred': None,
                'fixed_pred': None,
                'vuln_cwes': [],
                'fixed_cwes': []
            }

        if r.get('is_vulnerable'):
            pairs[commit_id]['vuln_pred'] = r.get('prediction')
            pairs[commit_id]['vuln_cwes'] = r.get('detected_cwes', [])
        else:
            pairs[commit_id]['fixed_pred'] = r.get('prediction')
            pairs[commit_id]['fixed_cwes'] = r.get('detected_cwes', [])

    with open(output_path, "w", newline='', encoding='utf-8') as f:
        writer = csv.writer(f)

        writer.writerow([
            "Commit_ID", "Ground_Truth_CWE",
            "Vuln_Pred", "Fixed_Pred",
            "Vuln_CWEs", "Fixed_CWEs",
            "Pair_Correct", "Pair_Reversed"
        ])

        for commit_id, pair in pairs.items():
            pair_correct = pair['vuln_pred'] == 1 and pair['fixed_pred'] == 0
            pair_reversed = pair['vuln_pred'] == 0 and pair['fixed_pred'] == 1

            writer.writerow([
                commit_id,
                pair['cwe'],
                pair['vuln_pred'],
                pair['fixed_pred'],
                ";".join(pair['vuln_cwes']),
                ";".join(pair['fixed_cwes']),
                pair_correct,
                pair_reversed
            ])

    print(f"Pairs summary exported to: {output_path}")
    print(f"   Total pairs: {len(pairs)}")

In [None]:
# =============================================================================
# PROVIDERS INITIALISATION
# =============================================================================

def create_groq_only():
    groq = GroqProvider(
        api_key=APIConfig.GROQ_API_KEY,
        model_names=None
    )
    return UnifiedLLMInterface([groq], rotation_strategy="priority")

def create_openai_only():
    openai_provider = OpenAIProvider(
        api_key=APIConfig.OPENAI_API_KEY,
        model_names=["gpt-4o"]
    )
    return UnifiedLLMInterface([openai_provider], rotation_strategy="priority")

def create_deepseek_only():
    deepseek = DeepSeekProvider(
        api_key=APIConfig.DEEPSEEK_API_KEY,
        model_names=["deepseek-chat"]
    )
    return UnifiedLLMInterface([deepseek], rotation_strategy="priority")

In [None]:
# =============================================================================
# DIFFERENTS PROVIDERS EXECUTION
# =============================================================================

Config.setup_directories()

# OPTION 1 : Groq only
llm = create_groq_only()

# OPTION 2 : OpenAI 4o only
# llm = create_openai_only()

# OPTION 3 : DeepSeek only
# llm = create_deepseek_only()

pipeline = VulPrunePipeline(llm)

results = evaluate_primevul_paired_with_checkpoints(
    pipeline=pipeline,
    num_pairs=21,
    random_sample=[154, 113, 196, 151, 133, 86, 11, 33, 324, 162, 38, 37, 231, 283, 271, 267, 19, 290, 227, 169, 168],
    checkpoint_interval=1
)

llm.print_stats()