In [None]:
# For Groq API client
pip install groq

In [None]:
# For metrics calculation
pip install scikit-learn

In [None]:
# For progress bars
pip install tqdm

In [None]:
# For environment variables
pip install python-dotenv

In [None]:
import os
import json
from datetime import datetime
from dotenv import load_dotenv
from groq import Groq
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import logging
import time
from json import JSONDecodeError
from tqdm import tqdm

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('npe_classification.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Configuration
DATA_PATH = "/root/workspace/npe_project/llm/NPEPatches.json"
OUTPUT_DIR = "results"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# API Configuration
MAX_RETRIES = 3
RETRY_DELAY = 5
API_TIMEOUT = 30
RATE_LIMIT_DELAY = 2

# Load environment variables
load_dotenv()
api_key = os.getenv('GROQ_API_KEY')
if not api_key:
    raise ValueError("GROQ_API_KEY environment variable is required")

def load_data(filepath):
    """Load and validate JSON data from file"""
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            data = json.load(f)
        logger.info(f"Successfully loaded {len(data)} records from {filepath}")
        return data
    except FileNotFoundError:
        logger.error(f"Data file not found: {filepath}")
        raise
    except json.JSONDecodeError:
        logger.error(f"Invalid JSON format in file: {filepath}")
        raise

class NPEAgent:
    def __init__(self, client, role, prompt):
        self.client = client
        self.role = role
        self.prompt = prompt
        self.retry_count = 0
    
    def process(self, content, previous_results=None):
        for attempt in range(MAX_RETRIES):
            try:
                time.sleep(RATE_LIMIT_DELAY)
                messages = [{"role": "system", "content": self.prompt}]
                user_message = content if not previous_results else \
                    f"Previous analysis: {json.dumps(previous_results)}\n\nNew content: {content}"
                messages.append({"role": "user", "content": user_message})
                
                completion = self.client.chat.completions.create(
                    model="deepseek-r1-distill-llama-70b",
                    messages=messages,
                    temperature=0.1,
                    max_tokens=500,
                    response_format={"type": "json_object"},
                    timeout=API_TIMEOUT
                )
                
                response_text = completion.choices[0].message.content
                return json.loads(response_text)
                
            except JSONDecodeError:
                logger.warning(f"Attempt {attempt + 1}: Invalid JSON from {self.role}")
                if attempt == MAX_RETRIES - 1:
                    return self._get_default_response("JSON parsing error")
                time.sleep(RETRY_DELAY)
            except Exception as e:
                logger.warning(f"Attempt {attempt + 1} failed: {str(e)}")
                if attempt == MAX_RETRIES - 1:
                    return self._get_default_response(f"Error: {str(e)}")
                time.sleep(RETRY_DELAY)

    def _get_default_response(self, reason):
        if self.role == "evaluator":
            return {
                "final_decision": "Not-NPE",
                "confidence": 0.0,
                "feedback": reason
            }
        return {
            "npe_found": False,
            "confidence": 0.0,
            "reasoning": reason
        }

def calculate_metrics(y_true, y_pred):
    try:
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
        return {
            "accuracy": accuracy_score(y_true, y_pred),
            "precision": precision_score(y_true, y_pred, pos_label="NPE-Fixes"),
            "recall": recall_score(y_true, y_pred, pos_label="NPE-Fixes"),
            "f1": f1_score(y_true, y_pred, pos_label="NPE-Fixes"),
            "confusion_matrix": {
                "true_negatives": int(tn),
                "false_positives": int(fp),
                "false_negatives": int(fn),
                "true_positives": int(tp)
            },
            "fpr": fp / (fp + tn) if (fp + tn) > 0 else 0,
            "fnr": fn / (fn + tp) if (fn + tp) > 0 else 0
        }
    except Exception as e:
        logger.error(f"Error calculating metrics: {str(e)}")
        raise

def print_metrics(metrics):
    print("\nClassification Results:")
    print("-" * 50)
    print(f"Accuracy: {metrics['accuracy']:.4f}")
    print(f"Precision: {metrics['precision']:.4f}")
    print(f"Recall: {metrics['recall']:.4f}")
    print(f"F1 Score: {metrics['f1']:.4f}")
    print(f"False Positive Rate: {metrics['fpr']:.4f}")
    print(f"False Negative Rate: {metrics['fnr']:.4f}")
    
    print("\nConfusion Matrix:")
    cm = metrics['confusion_matrix']
    print(f"True Negatives: {cm['true_negatives']}")
    print(f"False Positives: {cm['false_positives']}")
    print(f"False Negatives: {cm['false_negatives']}")
    print(f"True Positives: {cm['true_positives']}")

def save_results(results, metrics, output_dir):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_file = os.path.join(output_dir, f"classification_results_{timestamp}.json")
    
    full_results = {
        "metrics": metrics,
        "misclassified_commits": results["misclassified"],
        "run_timestamp": timestamp,
        "total_commits_processed": len(results["y_true"]),
        "total_misclassified": len(results["misclassified"])
    }
    
    with open(output_file, 'w') as f:
        json.dump(full_results, f, indent=2)
    logger.info(f"Results saved to {output_file}")
    return output_file

def multi_agent_classify(client, commit_message, patch, added_lines):
    try:
        detector = NPEAgent(client, "detector", """
            You are a specialized NullPointerException (NPE) detector. 
            Analyze the given code changes and respond with a JSON object containing:
            {
                "npe_found": boolean,
                "confidence": float between 0-1,
                "reasoning": string explanation
            }
        """)
        
        classifier = NPEAgent(client, "classifier", """
            You are a code pattern classifier specializing in NPE fixes.
            Analyze the code and respond with a JSON object containing:
            {
                "is_npe_fix": boolean,
                "pattern_match": float between 0-1,
                "identified_patterns": array of strings
            }
        """)
        
        evaluator = NPEAgent(client, "evaluator", """
            You are a senior code reviewer evaluating NPE fix classifications.
            Review the analysis and respond with a JSON object containing:
            {
                "final_decision": string ("NPE-Fixes" or "Not-NPE"),
                "confidence": float between 0-1,
                "feedback": string explanation
            }
        """)
        
        content = json.dumps({
            "commit_message": commit_message,
            "patch": patch,
            "added_lines": added_lines
        })
        
        detection_result = detector.process(content)
        if not isinstance(detection_result, dict):
            logger.warning("Invalid detection result format")
            return "Not-NPE"
            
        classification_result = classifier.process(content, detection_result)
        if not isinstance(classification_result, dict):
            logger.warning("Invalid classification result format")
            return "Not-NPE"
            
        final_result = evaluator.process(content, {
            "detection": detection_result,
            "classification": classification_result
        })
        
        if isinstance(final_result, dict) and "final_decision" in final_result:
            logger.info(f"Classification confidence: {final_result.get('confidence', 0.0)}")
            return final_result["final_decision"]
        return "Not-NPE"
        
    except Exception as e:
        logger.error(f"Error in multi-agent classification: {str(e)}")
        return "Not-NPE"

def main():
    try:
        client = Groq(api_key=api_key)
        results = {"y_true": [], "y_pred": [], "misclassified": []}
        
        data = load_data(DATA_PATH)
        total_commits = len(data)
        logger.info(f"Starting processing of {total_commits} commits...")
        
        with tqdm(total=total_commits, desc="Processing commits") as pbar:
            for idx, item in enumerate(data, 1):
                try:
                    commit_sha = item.get("Commit SHA", "Unknown")
                    pbar.set_description(f"Processing {commit_sha}")
                    
                    true_label = item["Category"]
                    final_pred = multi_agent_classify(
                        client,
                        item.get("Commit Message", ""),
                        item.get("Patch", ""),
                        item.get("Added Lines", "")
                    )
                    
                    results["y_true"].append(true_label)
                    results["y_pred"].append(final_pred)
                    
                    if final_pred != true_label:
                        results["misclassified"].append({
                            "Commit SHA": commit_sha,
                            "True Label": true_label,
                            "Predicted": final_pred,
                            "Patch": item.get("Patch", "")
                        })
                        logger.warning(f"Misclassification on commit {commit_sha}")
                    
                    pbar.update(1)
                    
                except Exception as e:
                    logger.error(f"Error processing commit {idx} ({commit_sha}): {str(e)}")
                    continue
        
        if results["y_true"] and results["y_pred"]:
            metrics = calculate_metrics(results["y_true"], results["y_pred"])
            print_metrics(metrics)
            output_file = save_results(results, metrics, OUTPUT_DIR)
            logger.info(f"\nClassification completed successfully. Results saved to {output_file}")
            
            print("\nSummary:")
            print(f"Total commits processed: {len(results['y_true'])}")
            print(f"Total misclassified: {len(results['misclassified'])}")
            print(f"Success rate: {(1 - len(results['misclassified'])/len(results['y_true']))*100:.2f}%")
        else:
            logger.error("No valid predictions were made")
        
    except Exception as e:
        logger.error(f"Fatal error in main execution: {str(e)}")
        raise
    finally:
        logger.info("Cleaning up resources...")

if __name__ == "__main__":
    main()