In [None]:
import requests
import json
import time
import csv
import os
import sys
import logging
import re
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from dotenv import load_dotenv
from tqdm import tqdm

load_dotenv()  

@dataclass
class ModelConfig:
    chat_url: str
    model_name: str
    input_suffix: str
    max_content_length: int = 1750
    retry_attempts: int = 3
    retry_delay: int = 5
    max_tokens: int = 14000
    temperature: float = 0.6
    seed: int = 111

@dataclass
class ProcessingConfig:
    input_file: str
    output_label_file: str
    output_fixed_file: str
    label_prompt_file: str
    start_row: int = 0
    batch_size: int = 100

class DomainClassifier:
    def __init__(self, model_config: ModelConfig):
        self.config = model_config
        self.setup_logging()
        self.setup_csv_limits()
        
    def setup_logging(self):
        """Configure logging system"""
        log_dir = Path('logs')
        log_dir.mkdir(exist_ok=True)
        
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        log_filename = log_dir / f'domain_classification_{timestamp}.log'
        
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(log_filename),
                logging.StreamHandler()
            ]
        )
        self.logger = logging
        self.logger.info(f"Logging to: {log_filename}")

    def setup_csv_limits(self):
        """Handle CSV field size limits"""
        maxInt = sys.maxsize
        while True:
            try:
                csv.field_size_limit(maxInt)
                break
            except OverflowError:
                maxInt = int(maxInt/10)

    def check_model_health(self) -> bool:
        """Check if the vLLM server is healthy and the model is loaded"""
        try:
            # Test with a simple request
            test_payload = {
                "model": self.config.model_name,
                "messages": [{"role": "user", "content": "Hello"}],
                "max_tokens": 10,
                "temperature": 0.1
            }
            
            response = requests.post(
                self.config.chat_url,
                json=test_payload,
                timeout=30
            )
            
            if response.status_code == 200:
                self.logger.info(f"Model {self.config.model_name} is ready")
                return True
            else:
                self.logger.error(f"Model health check failed: {response.status_code}, {response.text}")
                return False
                
        except Exception as e:
            self.logger.error(f"Error checking model health: {str(e)}")
            return False

    def sample_content(self, content: str) -> str:
        """Sample content to fit within length limits while preserving context"""
        if len(content) <= self.config.max_content_length:
            return content
        
        # Calculate balanced chunks
        start_len = int(self.config.max_content_length * 0.2)  # 20% for start
        end_len = int(self.config.max_content_length * 0.2)    # 20% for end
        mid_len = self.config.max_content_length - (start_len + end_len)
        
        start = content[:start_len]
        mid_point = len(content) // 2
        mid = content[mid_point - mid_len//2:mid_point + mid_len//2]
        end = content[-end_len:]
        
        return f"{start}... {mid}... {end}"
    
    @staticmethod
    def check_classification_consistency(answer: int, classification: str) -> bool:
        """
        Check if the numerical answer matches the text classification
        Updated for 4 categories
        """
        classification = str(classification).lower()
        
        if answer == 0 and "benign" in classification:
            return True
        elif answer == 1 and "gambling" in classification:
            return True
        elif answer == 2 and any(word in classification for word in ["porn", "adult"]):
            return True
        elif answer == 3 and any(word in classification for word in ["harmful", "illegal"]):
            return True
        return False

    def send_classification_request(self, conversation: List[Dict]) -> Tuple[dict, float]:
        """Send classification request to vLLM with retry logic and parse thought/JSON"""
        for attempt in range(self.config.retry_attempts):
            try:
                start_time = time.time()
                
                payload = {
                    "model": self.config.model_name,
                    "messages": conversation,
                    "max_tokens": self.config.max_tokens,
                    "temperature": self.config.temperature,
                    "seed": self.config.seed,
                    "stream": False
                }
                
                response = requests.post(
                    self.config.chat_url,
                    json=payload,
                    timeout=120,  # Increased timeout for vLLM
                    headers={"Content-Type": "application/json"}
                )
                
                if response.status_code == 200:
                    elapsed_time = time.time() - start_time
                    response_json = response.json()
                    
                    # Extract content from vLLM response format
                    choices = response_json.get("choices", [])
                    if not choices:
                        self.logger.error("No choices in response")
                        continue
                        
                    message_content = choices[0].get("message", {}).get("content", "")
                    
                    if not message_content:
                        self.logger.error("Empty message content")
                        continue

                    # Extract the <think> reasoning part
                    thought_match = re.search(r"<think>\n(.*?)\n</think>", message_content, re.DOTALL)
                    thought = thought_match.group(1).strip() if thought_match else "No thought found."

                    # Extract the JSON output
                    json_match = re.search(r"```json\n(.*?)\n```", message_content, re.DOTALL)
                    if json_match:
                        try:
                            result = json.loads(json_match.group(1))
                        except json.JSONDecodeError as e:
                            self.logger.error(f"JSON decode error: {e}")
                            result = {}
                    else:
                        # Try to find JSON without code blocks
                        json_match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', message_content)
                        if json_match:
                            try:
                                result = json.loads(json_match.group(0))
                            except json.JSONDecodeError:
                                result = {}
                        else:
                            result = {}
                    
                    # Add thought to result
                    result["thought"] = thought
                    
                    # Log usage statistics if available
                    usage = response_json.get("usage", {})
                    if usage:
                        self.logger.debug(f"Token usage - Prompt: {usage.get('prompt_tokens', 0)}, "
                                        f"Completion: {usage.get('completion_tokens', 0)}, "
                                        f"Total: {usage.get('total_tokens', 0)}")

                    return result, elapsed_time
                else:
                    self.logger.error(f"HTTP error {response.status_code}: {response.text}")
                        
                if attempt < self.config.retry_attempts - 1:
                    time.sleep(self.config.retry_delay)
                    
            except requests.exceptions.Timeout:
                self.logger.error(f"Request timeout (attempt {attempt+1})")
                if attempt < self.config.retry_attempts - 1:
                    time.sleep(self.config.retry_delay)
            except Exception as e:
                self.logger.error(f"Request error (attempt {attempt+1}): {str(e)}")
                if attempt < self.config.retry_attempts - 1:
                    time.sleep(self.config.retry_delay)
        
        return {}, 0.0

    def process_dataset(self, processing_config: ProcessingConfig):
        """Process the dataset with enhanced error handling and logging"""
        EXPECTED_CLASSIFICATION = {
            0: "Benign",
            1: "Gambling",
            2: "Pornography",
            3: "Harmful"
        }
        try:
            if not self.check_model_health():
                raise Exception("Model health check failed")
            
            Path(processing_config.output_label_file).parent.mkdir(parents=True, exist_ok=True)
            Path(processing_config.output_fixed_file).parent.mkdir(parents=True, exist_ok=True)
            
            # Count total rows for progress bar
            total_rows = 0
            with open(processing_config.input_file, 'r', newline='', encoding='utf-8') as infile:
                reader = csv.DictReader(infile)
                total_rows = sum(1 for _ in reader) - processing_config.start_row
            
            countrow = processing_config.start_row

            with open(processing_config.output_label_file, 'w', newline='', encoding='utf-8') as label_file, \
                 open(processing_config.output_fixed_file, 'w', newline='', encoding='utf-8') as fixed_file, \
                 open(processing_config.label_prompt_file, "r") as prompt_file:
                
                label_writer = csv.DictWriter(label_file, 
                    fieldnames=['Domain', 'Content', 'Label', 'Classification', 'Reason', 'Confidence', 'Thought'])
                fixed_writer = csv.DictWriter(fixed_file, 
                    fieldnames=['Domain', 'Content', 'Label', 'Confidence'])
                
                labelling_prompt = prompt_file.read()
                
                label_writer.writeheader()
                fixed_writer.writeheader()

                del label_file, fixed_file, prompt_file

                with open(processing_config.input_file, 'r', newline='', encoding='utf-8') as infile:
                    reader = csv.DictReader(infile)
                    
                    # Skip to start row
                    for _ in range(processing_config.start_row):
                        next(reader, None)
                    
                    # Create progress bar
                    progress_bar = tqdm(
                        total=total_rows,
                        desc="Classifying domains",
                        unit="samples",
                        unit_scale=True,
                        dynamic_ncols=True,
                        bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]'
                    )
                    
                    try:
                        for row in reader:
                            domain = row['Domain']
                            content = row['Content']
                            
                            sampled_content = self.sample_content(content)
                            input_text = f"{self.config.input_suffix}\n{domain},\"{sampled_content}\""
                            
                            conversation = [
                                {"role": "system", "content": labelling_prompt},
                                {"role": "user", "content": input_text}
                            ]
                            
                            result, elapsed_time = self.send_classification_request(conversation)

                            # Extract fields with fallbacks
                            answer = int(result.get('answer', -1))
                            classification = str(result.get('classification', 'Unknown'))
                            reason = str(result.get('reason', 'No reason provided'))
                            confidence = int(result.get('confidence', 0))
                            thought = str(result.get('thought', 'No reasoning captured'))

                            label_writer.writerow({
                                'Domain': domain,
                                'Content': content,
                                'Label': answer,
                                'Classification': classification,
                                'Reason': reason,
                                'Confidence': confidence,
                                'Thought': thought
                            })
                            
                            fixed_writer.writerow({
                                'Domain': domain,
                                'Content': content,
                                'Label': answer,
                                'Confidence': confidence
                            })
                            
                            # Update progress bar with additional info
                            progress_bar.set_postfix({
                                'Current': domain[:20] + '...' if len(domain) > 20 else domain,
                                'Label': answer,
                                'Time': f"{elapsed_time:.2f}s"
                            })
                            progress_bar.update(1)
                            
                            # Log less frequently to avoid cluttering with progress bar
                            if countrow % 500 == 0:  # Log every 10th item
                                self.logger.info(f"Processed {countrow}: {domain} -> {answer} ({elapsed_time:.2f}s)")

                            if not DomainClassifier.check_classification_consistency(answer, classification):
                                self.logger.warning(
                                    f"Hallucination detected for {countrow} {row['Domain']}, "
                                    f"Answer: {answer} ({EXPECTED_CLASSIFICATION.get(answer)}), "
                                    f"Classification: {classification}, "
                                    f"Reason: {reason} "
                                    f"Confidence: {confidence} \n"
                                    f"Raw Output: {result} "
                                )
                            countrow += 1
                    
                    finally:
                        progress_bar.close()
                        
                    self.logger.info(f"Processing completed. Total processed: {countrow - processing_config.start_row}")

        except Exception as e:
            self.logger.error(f"Processing error: {str(e)}")
            raise

if __name__ == "__main__":
    model_config = ModelConfig(
        chat_url=os.getenv("VLLM_CHAT_URL", "http://localhost:8000/v1/chat/completions"),
        model_name=os.getenv("MODEL_NAME", "jordinia/NetPro-Qwen3-0.6B-2105"),
        input_suffix="Classify the given URL as 0 (benign), 1 (gambling), 2 (pornography), or 3 (harmful). Output MUST be JSON.\n",
        max_content_length=20000,
        max_tokens=14000,
        temperature=0.6,
        seed=111
    )
    
    processing_config = ProcessingConfig(
        input_file="/home/fishmon/AJ/LLM-Finetuning/Malicious-Web/dataset/test.csv",
        output_label_file="NetPro-Qwen3-0.6B-2105_label_full.csv",
        output_fixed_file="NetPro-Qwen3-0.6B-2105_fixed_full.csv",
        label_prompt_file="prompts/labelling/labelling_promptv4.txt",
        start_row=0
    )
    
    classifier = DomainClassifier(model_config)
    classifier.process_dataset(processing_config)