In [None]:
import requests
import json
import time
import csv
import os
import sys
import logging
import re
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()

@dataclass
class ModelConfig:
    base_url: str
    api_key: str
    model_name: str
    input_suffix: str
    max_content_length: int = 1750
    retry_attempts: int = 3
    retry_delay: int = 5

@dataclass
class ProcessingConfig:
    input_file: str
    output_label_file: str
    output_fixed_file: str
    label_prompt_file: str
    start_row: int = 0
    batch_size: int = 100

class DomainClassifier:
    def __init__(self, model_config: ModelConfig):
        self.config = model_config
        self.setup_logging()
        self.setup_csv_limits()

    def setup_logging(self):
        """Configure logging system"""
        log_dir = Path('logs')
        log_dir.mkdir(exist_ok=True)
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        log_filename = log_dir / f'domain_classification_{timestamp}.log'
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(log_filename),
                logging.StreamHandler()
            ]
        )
        self.logger = logging
        self.logger.info(f"Logging to: {log_filename}")

    def setup_csv_limits(self):
        """Handle CSV field size limits"""
        maxInt = sys.maxsize
        while True:
            try:
                csv.field_size_limit(maxInt)
                break
            except OverflowError:
                maxInt = int(maxInt/10)

    def sample_content(self, content: str) -> str:
        """Sample content to fit within length limits while preserving context"""
        if len(content) <= self.config.max_content_length:
            return content
        # Calculate balanced chunks
        start_len = int(self.config.max_content_length * 0.2)  # 20% for start
        end_len = int(self.config.max_content_length * 0.2)    # 20% for end
        mid_len = self.config.max_content_length - (start_len + end_len)
        start = content[:start_len]
        mid_point = len(content) // 2
        mid = content[mid_point - mid_len//2:mid_point + mid_len//2]
        end = content[-end_len:]
        return f"{start}... {mid}... {end}"

    @staticmethod
    def check_classification_consistency(answer: int, classification: str) -> bool:
        """
        Check if the numerical answer matches the text classification
        Updated for 4 categories
        """
        classification = str(classification).lower()
        if answer == 0 and "benign" in classification:
            return True
        elif answer == 1 and "gambling" in classification:
            return True
        elif answer == 2 and any(word in classification for word in ["porn", "adult"]):
            return True
        elif answer == 3 and any(word in classification for word in ["harmful", "illegal"]):
            return True
        return False

    def send_classification_request(self, conversation: List[Dict]) -> Tuple[dict, float]:
        """Send classification request with retry logic and parse thought/JSON"""
        # self.logger.info("Starting classification process...")
        # self.logger.info(f"API key: {self.config.api_key}")
        headers = {
            "Authorization": f"Bearer {self.config.api_key}",
            "Content-Type": "application/json",
        }
        reasoning_payload = {
            "effort": "high",
            "exclude": False 
        }

        if os.getenv("HTTP_REFERER"):
            headers["HTTP-Referer"] = os.getenv("HTTP_REFERER")
        if os.getenv("X_TITLE"):
            headers["X-Title"] = os.getenv("X_TITLE")

        for attempt in range(self.config.retry_attempts):
            try:
                start_time = time.time()
                payload = {
                    "model": self.config.model_name,
                    "messages": conversation,
                    "reasoning": reasoning_payload,
                    "temperature": 0.4,
                    "max_tokens": 10240,
                }

                response = requests.post(
                    f"{self.config.base_url}/chat/completions",
                    headers=headers,
                    json=payload,
                    timeout=60
                )
                self.logger.info(f"Response status code: {response.status_code}")

                if response.status_code == 200:
                    elapsed_time = time.time() - start_time
                    response_json = response.json()
                    self.logger.info(f"Response JSON: {response_json}")

                    message_content = response_json['choices'][0]['message']['content']
                    thought = response_json['choices'][0]['message'].get('reasoning', 'No reasoning captured')

                    # Extract JSON output from the content
                    json_match = re.search(r"```json\n(.*?)\n```", message_content, re.DOTALL)
                    result = json.loads(json_match.group(1)) if json_match else {}
                    # Add thought to result
                    result["thought"] = thought
                    return result, elapsed_time

                if attempt < self.config.retry_attempts - 1:
                    self.logger.warning(f"Retrying... Attempt {attempt+1}/{self.config.retry_attempts}")
                    time.sleep(self.config.retry_delay)

            except Exception as e:
                self.logger.error(f"Request error (attempt {attempt+1}): {str(e)}")
                if attempt < self.config.retry_attempts - 1:
                    time.sleep(self.config.retry_delay)

        return {}, 0.0

    def process_dataset(self, processing_config: ProcessingConfig):
        """Process the dataset with enhanced error handling and logging"""
        EXPECTED_CLASSIFICATION = {
            0: "Benign",
            1: "Gambling",
            2: "Pornography",
            3: "Harmful"
        }

        try:
            Path(processing_config.output_label_file).parent.mkdir(parents=True, exist_ok=True)
            Path(processing_config.output_fixed_file).parent.mkdir(parents=True, exist_ok=True)

            countrow = processing_config.start_row
            with open(processing_config.output_label_file, 'w', newline='', encoding='utf-8') as label_file, \
                 open(processing_config.output_fixed_file, 'w', newline='', encoding='utf-8') as fixed_file, \
                 open(processing_config.label_prompt_file, "r") as prompt_file:

                label_writer = csv.DictWriter(label_file,
                    fieldnames=['Domain', 'Content', 'Label', 'Classification', 'Reason', 'Confidence', 'Thought'])
                fixed_writer = csv.DictWriter(fixed_file,
                    fieldnames=['Domain', 'Content', 'Label', 'Confidence'])
                labelling_prompt = prompt_file.read()

                label_writer.writeheader()
                fixed_writer.writeheader()

                with open(processing_config.input_file, 'r', newline='', encoding='utf-8') as infile:
                    reader = csv.DictReader(infile)
                    for _ in range(processing_config.start_row):
                        next(reader, None)

                    for row in reader:
                        domain = row['Domain']
                        content = row['Content']
                        sampled_content = self.sample_content(content)
                        input_text = f"{self.config.input_suffix}\n{domain},\"{sampled_content}\""

                        conversation = [
                            {"role": "system", "content": labelling_prompt},
                            {"role": "user", "content": input_text}
                        ]

                        result, elapsed_time = self.send_classification_request(conversation)

                        # Extract fields with fallbacks
                        answer = int(result.get('answer', -1))
                        classification = str(result.get('classification', 'Unknown'))
                        reason = str(result.get('reason', 'No reason provided'))
                        try:
                            confidence = int(result.get('confidence', 0))
                        except (ValueError, TypeError):
                            confidence = 0
                        thought = str(result.get('thought', 'No reasoning captured'))

                        label_writer.writerow({
                            'Domain': domain,
                            'Content': content,
                            'Label': answer,
                            'Classification': classification,
                            'Reason': reason,
                            'Confidence': confidence,
                            'Thought': thought
                        })

                        fixed_writer.writerow({
                            'Domain': domain,
                            'Content': content,
                            'Label': answer,
                            'Confidence': confidence
                        })

                        self.logger.info(f"{elapsed_time:.2f} {countrow} {domain} {result}")

                        if not DomainClassifier.check_classification_consistency(answer, classification):
                            self.logger.warning(
                                f"Hallucination detected for {countrow} {row['Domain']}, "
                                f"Answer: {answer} ({EXPECTED_CLASSIFICATION.get(answer)}), "
                                f"Classification: {classification}, "
                                f"Reason: {reason} "
                                f"Confidence: {confidence} \n"
                                f"Raw Output: {result} "
                            )
                        countrow += 1
        except Exception as e:
            self.logger.error(f"Processing error: {str(e)}")
            raise

if __name__ == "__main__":
    model_config = ModelConfig(
        base_url="https://openrouter.ai/api/v1",
        api_key=os.getenv("OPENROUTER_API_KEY"),
        model_name="deepseek/deepseek-r1:free",
        input_suffix="Provide your answer and reasoning in English. Classify the given URL as 0 (benign), 1 (gambling), 2 (pornography), or 3 (harmful). Output MUST be JSON.\nHINT: THE FOLLOWING DOMAIN MUST BE CATEGORIZED AS Harmful (3)",
        max_content_length=10000,
        retry_delay=90
    )

    processing_config = ProcessingConfig(
        input_file="../Label-4.csv",
        output_label_file="../Label-4_labelled.csv",
        output_fixed_file="../Label-4_fixed.csv",
        label_prompt_file="prompts/labelling/labelling_promptv4.txt",
        start_row=0
    )

    classifier = DomainClassifier(model_config)
    classifier.process_dataset(processing_config)