<a href="https://colab.research.google.com/github/ahteshamsalamatansari/colabcodes/blob/main/Test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Adaptive AI Data Enrichment Agent v2.1 (Fixed Version)
# Added better timeout handling and error recovery

import pandas as pd
import requests
import json
import time
import re
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
from concurrent.futures import ThreadPoolExecutor, as_completed
import logging
from openai import OpenAI
import warnings
warnings.filterwarnings('ignore')

# Configuration
@dataclass
class Config:
    """Configuration class for the enrichment agent"""
    perplexity_key: str = "YOUR_PERPLEXITY_KEY"
    openai_key: str = "YOUR_OPENAI_KEY"
    confidence_threshold: float = 0.7
    max_retries: int = 2  # Reduced from 3
    rate_limit_delay: float = 2.0  # Increased from 1.0
    max_workers: int = 1  # Reduced from 3 for better control
    request_timeout: int = 15  # Added timeout

@dataclass
class EnrichmentResult:
    """Data class for enrichment results"""
    field: str
    value: str
    confidence: str
    source: str
    query_used: str
    search_type: str

class AdaptiveDataEnrichmentAgent:
    """
    Data Enrichment Agent using Perplexity Search API
    Enhanced with better timeout and error handling
    """

    def __init__(self, config: Config):
        self.config = config
        self.setup_logging()
        self.setup_openai()
        self.setup_perplexity()
        self.enrichment_logs = []
        self.failed_requests = 0
        self.successful_requests = 0

    def setup_logging(self):
        """Setup logging configuration"""
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s'
        )
        self.logger = logging.getLogger(__name__)

    def setup_openai(self):
        """Setup OpenAI client with timeout"""
        self.openai_client = OpenAI(
            api_key=self.config.openai_key,
            timeout=self.config.request_timeout
        )

    def setup_perplexity(self):
        """Setup Perplexity client using direct API calls with timeout"""
        self.perplexity_url = "https://api.perplexity.ai/chat/completions"
        self.perplexity_headers = {
            "Authorization": f"Bearer {self.config.perplexity_key}",
            "Content-Type": "application/json",
            "User-Agent": "DataEnrichmentAgent/2.1"
        }
        self.use_native_client = False
        self.default_model = "sonar"  # Updated to correct Perplexity model name
        self.logger.info("Using direct API calls to Perplexity with timeout handling")

    def test_perplexity_connection(self) -> bool:
        """Test Perplexity API connection with timeout"""
        try:
            test_query = "What is the capital of France?"

            payload = {
                "model": self.default_model,
                "messages": [
                    {"role": "user", "content": test_query}
                ],
                "max_tokens": 50,
                "temperature": 0.1
            }

            self.logger.info(f"Testing connection with model: {self.default_model}")

            response = requests.post(
                self.perplexity_url,
                json=payload,
                headers=self.perplexity_headers,
                timeout=self.config.request_timeout
            )

            if response.status_code == 200:
                data = response.json()
                content = data.get('choices', [{}])[0].get('message', {}).get('content', '')
                if content and ('Paris' in content or 'paris' in content.lower()):
                    self.logger.info("Perplexity API connection test successful")
                    return True
                else:
                    self.logger.warning(f"Unexpected test response: {content[:100]}...")
                    # Still consider it successful if we got a response
                    return True
            else:
                self.logger.error(f"Perplexity API test failed: {response.status_code}")
                self.logger.error(f"Response: {response.text[:200]}...")

                # Try alternative models
                alternative_models = [
                    "sonar-pro",
                    "sonar-small-online",
                    "sonar-medium-online"
                ]

                for alt_model in alternative_models:
                    self.logger.info(f"Trying alternative model: {alt_model}")
                    payload["model"] = alt_model

                    try:
                        alt_response = requests.post(
                            self.perplexity_url,
                            json=payload,
                            headers=self.perplexity_headers,
                            timeout=self.config.request_timeout
                        )

                        if alt_response.status_code == 200:
                            self.logger.info(f"Success with model: {alt_model}")
                            self.default_model = alt_model
                            return True
                    except Exception as e:
                        self.logger.warning(f"Model {alt_model} failed: {e}")

        except requests.exceptions.Timeout:
            self.logger.error("Perplexity API test timed out")
        except Exception as e:
            self.logger.error(f"Perplexity API test error: {e}")

        return False

    def detect_identifiers_and_targets(self, df: pd.DataFrame) -> Tuple[List[str], List[str]]:
        """Detect identifier and target columns"""
        columns = df.columns.tolist()
        identifier_columns = []
        target_columns = []

        # Primary identifier patterns
        strict_identifier_patterns = [
            'company name', 'business name', 'organization name', 'name',
            'website', 'url', 'domain', 'site'
        ]

        for col in columns:
            col_lower = col.lower().replace('_', ' ').replace('-', ' ')

            is_primary_identifier = False

            for pattern in strict_identifier_patterns:
                if pattern in col_lower:
                    if pattern in ['company name', 'business name', 'organization name', 'name']:
                        if col_lower in ['company name', 'business name', 'organization name', 'name'] or \
                           col_lower.endswith(' name'):
                            is_primary_identifier = True
                            break
                    elif pattern in ['website', 'url', 'domain', 'site']:
                        is_primary_identifier = True
                        break

            if is_primary_identifier:
                identifier_columns.append(col)
            else:
                target_columns.append(col)

        # Ensure we have at least one identifier
        if not identifier_columns:
            for col in columns:
                col_lower = col.lower()
                if 'name' in col_lower or 'company' in col_lower:
                    identifier_columns.append(col)
                    if col in target_columns:
                        target_columns.remove(col)
                    break

        self.logger.info(f"Identifiers: {identifier_columns}")
        self.logger.info(f"Targets: {target_columns}")

        return identifier_columns, target_columns

    def build_search_queries(self, target_field: str, company_name: str = None,
                           website: str = None) -> List[Dict[str, Any]]:
        """Build optimized search queries"""
        queries = []

        if company_name:
            # Field-specific queries
            field_lower = target_field.lower()

            if 'industry' in field_lower:
                queries.append({
                    'query': f"What industry is {company_name} in? What business sector does {company_name} operate in?",
                    'type': 'industry_specific',
                    'priority': 1
                })
            elif 'employee' in field_lower and ('count' in field_lower or 'size' in field_lower):
                queries.append({
                    'query': f"How many employees does {company_name} have? What is the employee count at {company_name}?",
                    'type': 'employee_specific',
                    'priority': 1
                })
            elif 'headquarters' in field_lower or 'address' in field_lower:
                queries.append({
                    'query': f"Where is {company_name} headquarters located? What is the address of {company_name}?",
                    'type': 'location_specific',
                    'priority': 1
                })
            elif 'linkedin' in field_lower:
                queries.append({
                    'query': f"What is the LinkedIn company page URL for {company_name}?",
                    'type': 'linkedin_specific',
                    'priority': 1
                })
            else:
                # General query
                queries.append({
                    'query': f"Find information about {target_field} for company {company_name}",
                    'type': 'general',
                    'priority': 2
                })

        return queries[:2]  # Limit to 2 queries

    def call_perplexity(self, query: str) -> Optional[Dict]:
        """Make Perplexity API call with improved timeout handling"""
        for attempt in range(self.config.max_retries):
            try:
                self.logger.info(f"Perplexity Query (attempt {attempt + 1}): {query[:100]}...")

                payload = {
                    "model": self.default_model,
                    "messages": [
                        {
                            "role": "system",
                            "content": "You are a business information researcher. Provide specific, factual information. Be concise and accurate."
                        },
                        {
                            "role": "user",
                            "content": query
                        }
                    ],
                    "max_tokens": 500,  # Reduced for faster responses
                    "temperature": 0.1,
                    "top_p": 0.9
                }

                response = requests.post(
                    self.perplexity_url,
                    json=payload,
                    headers=self.perplexity_headers,
                    timeout=self.config.request_timeout
                )

                if response.status_code == 200:
                    data = response.json()
                    self.successful_requests += 1
                    self.logger.info(f"Perplexity Success ({self.successful_requests} total)")
                    return data
                elif response.status_code == 429:
                    self.logger.warning("Rate limit hit, waiting longer...")
                    time.sleep(10)  # Wait 10 seconds for rate limit
                    continue
                else:
                    self.logger.warning(f"Perplexity HTTP {response.status_code}: {response.text[:200]}...")
                    self.failed_requests += 1

            except requests.exceptions.Timeout:
                self.logger.error(f"Request timed out after {self.config.request_timeout}s (attempt {attempt + 1})")
                self.failed_requests += 1
            except requests.exceptions.ConnectionError:
                self.logger.error(f"Connection error (attempt {attempt + 1})")
                self.failed_requests += 1
            except Exception as e:
                self.logger.error(f"Perplexity Error (attempt {attempt + 1}): {e}")
                self.failed_requests += 1

            # Progressive backoff
            if attempt < self.config.max_retries - 1:
                wait_time = self.config.rate_limit_delay * (2 ** attempt)
                self.logger.info(f"Waiting {wait_time}s before retry...")
                time.sleep(wait_time)

        return None

    def extract_from_perplexity_response(self, response: Dict, target_field: str) -> List[Dict]:
        """Extract data from Perplexity response"""
        candidates = []

        try:
            content = response['choices'][0]['message']['content']

            # Use AI extraction
            extracted_values = self._extract_values_with_ai(content, target_field)

            for value_info in extracted_values:
                candidates.append({
                    'value': value_info['value'],
                    'source': 'Perplexity Search',
                    'confidence': value_info.get('confidence', 'medium'),
                    'extraction_method': 'ai_extraction',
                    'full_context': content[:300] + "..." if len(content) > 300 else content
                })

            self.logger.info(f"Found {len(candidates)} candidates for {target_field}")

        except Exception as e:
            self.logger.error(f"Error extracting from response: {e}")

        return candidates

    def _extract_values_with_ai(self, content: str, target_field: str) -> List[Dict]:
        """Extract values using OpenAI with timeout"""
        try:
            prompt = f"""
            Extract the specific value for "{target_field}" from this text:

            {content}

            Return JSON format:
            {{
                "values": [
                    {{
                        "value": "extracted clean value",
                        "confidence": "high/medium/low"
                    }}
                ]
            }}

            Guidelines:
            - For industry: use standard industry terms
            - For employee count: use numbers only like "150" or "1,500"
            - For employee size: use ranges like "50-200" or descriptive terms
            - For headquarters: use city, state/country format
            - For LinkedIn: use full URL format
            - Return empty array if no clear value found
            """

            response = self.openai_client.chat.completions.create(
                model="gpt-4o-mini",
                messages=[{"role": "user", "content": prompt}],
                response_format={"type": "json_object"},
                temperature=0.1,
                max_tokens=200,
                timeout=10  # 10 second timeout
            )

            result = json.loads(response.choices[0].message.content)
            return result.get("values", [])

        except Exception as e:
            self.logger.error(f"AI extraction failed: {e}")
            return self._simple_pattern_extraction(content, target_field)

    def _simple_pattern_extraction(self, text: str, target_field: str) -> List[Dict]:
        """Simple pattern-based extraction as fallback"""
        field_lower = target_field.lower()

        # Basic patterns
        if 'employee' in field_lower:
            matches = re.findall(r'(\d{1,6})\s*(?:employees|staff|people)', text, re.IGNORECASE)
            if matches:
                return [{'value': matches[0], 'confidence': 'low'}]

        elif 'industry' in field_lower:
            # Look for common industry terms
            industries = ['pharmaceutical', 'biotech', 'healthcare', 'technology', 'manufacturing',
                         'software', 'medical', 'life sciences']
            for industry in industries:
                if industry in text.lower():
                    return [{'value': industry.title(), 'confidence': 'low'}]

        return []

    def validate_and_normalize_with_ai(self, candidates: List[Dict],
                                     target_field: str) -> Optional[EnrichmentResult]:
        """Validate candidates with AI"""
        if not candidates:
            return None

        try:
            # Simple validation - take the first high confidence candidate
            for candidate in candidates:
                if candidate.get('confidence') == 'high' and candidate['value'].strip():
                    return EnrichmentResult(
                        field=target_field,
                        value=candidate['value'].strip(),
                        confidence=candidate['confidence'],
                        source=candidate.get('source', 'Perplexity'),
                        query_used="",
                        search_type=""
                    )

            # Fallback to first candidate
            if candidates and candidates[0]['value'].strip():
                return EnrichmentResult(
                    field=target_field,
                    value=candidates[0]['value'].strip(),
                    confidence='medium',
                    source=candidates[0].get('source', 'Perplexity'),
                    query_used="",
                    search_type=""
                )

        except Exception as e:
            self.logger.error(f"Validation failed: {e}")

        return None

    def enrich_single_field(self, row_data: Dict, target_field: str,
                          identifier_cols: List[str]) -> Optional[EnrichmentResult]:
        """Enrich a single field with timeout protection"""
        # Skip if already has data
        current_value = row_data.get(target_field)
        if pd.notna(current_value) and str(current_value).strip():
            return None

        # Get identifiers
        company_name = None
        website = None

        for col in identifier_cols:
            col_lower = col.lower()
            value = row_data.get(col)

            if pd.notna(value) and str(value).strip():
                if 'name' in col_lower or 'company' in col_lower:
                    company_name = str(value).strip()
                elif 'website' in col_lower or 'url' in col_lower:
                    website = str(value).strip()

        if not company_name and not website:
            return None

        self.logger.info(f"Enriching '{target_field}' for: {company_name or website}")

        # Build queries
        query_configs = self.build_search_queries(target_field, company_name, website)

        for query_config in query_configs:
            query = query_config['query']
            search_type = query_config['type']

            self.logger.info(f"Trying {search_type}: {query[:50]}...")

            start_time = time.time()
            response = self.call_perplexity(query)
            elapsed_time = time.time() - start_time

            if response:
                candidates = self.extract_from_perplexity_response(response, target_field)

                if candidates:
                    result = self.validate_and_normalize_with_ai(candidates, target_field)
                    if result:
                        result.query_used = query
                        result.search_type = search_type
                        self.logger.info(f"Success: '{result.value}' ({elapsed_time:.1f}s)")
                        return result

            # Wait between attempts
            time.sleep(self.config.rate_limit_delay)

        self.logger.info(f"No data found for {target_field}")
        return None

    def enrich_dataframe(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, List[Dict]]:
        """Enrich DataFrame with progress tracking"""
        self.logger.info("Starting DataFrame enrichment...")

        # Detect schema
        identifier_cols, target_cols = self.detect_identifiers_and_targets(df)

        if not identifier_cols:
            raise ValueError("No identifier columns detected!")

        if not target_cols:
            raise ValueError("No target columns detected!")

        # Create enriched copy
        enriched_df = df.copy()
        all_logs = []

        total_fields = len(target_cols) * len(df)
        enriched_count = 0
        processed_count = 0

        self.logger.info(f"Processing {len(df)} rows × {len(target_cols)} fields = {total_fields} total")

        # Process each row
        for row_idx, row in df.iterrows():
            row_data = row.to_dict()
            self.logger.info(f"\n--- Row {row_idx + 1}/{len(df)} ---")

            # Show identifiers
            identifiers = {col: str(row_data.get(col, '')).strip()
                          for col in identifier_cols
                          if pd.notna(row_data.get(col)) and str(row_data.get(col)).strip()}

            self.logger.info(f"Identifiers: {identifiers}")

            # Process each target field
            for target_field in target_cols:
                processed_count += 1
                try:
                    start_time = time.time()

                    result = self.enrich_single_field(row_data, target_field, identifier_cols)

                    process_time = time.time() - start_time

                    if result:
                        enriched_df.at[row_idx, target_field] = result.value
                        enriched_count += 1

                        log_entry = {
                            'timestamp': time.time(),
                            'row_index': int(row_idx),
                            'field': result.field,
                            'original_value': row_data.get(target_field, ''),
                            'enriched_value': result.value,
                            'confidence': result.confidence,
                            'source': result.source,
                            'process_time_seconds': round(process_time, 2),
                            'status': 'success'
                        }
                        all_logs.append(log_entry)

                        self.logger.info(f"✓ {target_field}: '{result.value}' [{result.confidence}]")

                    else:
                        log_entry = {
                            'timestamp': time.time(),
                            'row_index': int(row_idx),
                            'field': target_field,
                            'original_value': row_data.get(target_field, ''),
                            'enriched_value': None,
                            'process_time_seconds': round(process_time, 2),
                            'status': 'failed'
                        }
                        all_logs.append(log_entry)

                        self.logger.info(f"✗ {target_field}: No data found")

                    # Progress update
                    progress = (processed_count / total_fields) * 100
                    self.logger.info(f"Progress: {processed_count}/{total_fields} ({progress:.1f}%)")

                except Exception as e:
                    self.logger.error(f"Error with {target_field}: {e}")

                    log_entry = {
                        'timestamp': time.time(),
                        'row_index': int(row_idx),
                        'field': target_field,
                        'status': 'error',
                        'error_message': str(e)
                    }
                    all_logs.append(log_entry)

        # Summary
        success_rate = (enriched_count / total_fields) * 100 if total_fields > 0 else 0

        self.logger.info(f"\n=== ENRICHMENT COMPLETE ===")
        self.logger.info(f"Processed: {total_fields} fields")
        self.logger.info(f"Successful: {enriched_count} ({success_rate:.1f}%)")
        self.logger.info(f"API Success Rate: {self.successful_requests}/{self.successful_requests + self.failed_requests}")

        return enriched_df, all_logs


def main():
    """Main function with better error handling"""

    print("Adaptive AI Data Enrichment Agent v2.1 (Fixed Version)")
    print("=" * 60)

    # Get API keys
    print("\nStep 1: Configuration")

    perplexity_key = input("Enter your Perplexity API key: ").strip()
    openai_key = input("Enter your OpenAI API key: ").strip()

    if not perplexity_key or not openai_key:
        print("Both API keys are required!")
        return

    config = Config(
        perplexity_key=perplexity_key,
        openai_key=openai_key,
        request_timeout=15,
        rate_limit_delay=2.0,
        max_retries=2
    )

    # Initialize agent
    print("\nInitializing agent with improved timeout handling...")
    agent = AdaptiveDataEnrichmentAgent(config)

    # Test connections
    print("\nTesting API connections...")

    if not agent.test_perplexity_connection():
        print("ERROR: Perplexity API connection failed!")
        return
    print("✓ Perplexity API working")

    try:
        test_response = agent.openai_client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[{"role": "user", "content": "Test"}],
            max_tokens=5,
            timeout=10
        )
        print("✓ OpenAI API working")
    except Exception as e:
        print(f"ERROR: OpenAI API failed: {e}")
        return

    # File upload
    print("\nStep 2: Upload CSV File")

    try:
        from google.colab import files
        uploaded = files.upload()
    except ImportError:
        print("Not in Colab - enter filename manually")
        filename = input("Enter CSV filename: ").strip()
        uploaded = {filename: None} if filename else {}

    if not uploaded:
        print("No file provided!")
        return

    filename = list(uploaded.keys())[0]
    print(f"Processing: {filename}")

    try:
        df = pd.read_csv(filename, encoding='utf-8')
        print(f"Loaded: {len(df)} rows, {len(df.columns)} columns")

        # Show structure
        identifier_cols, target_cols = agent.detect_identifiers_and_targets(df)

        print(f"\nDetected Structure:")
        print(f"  Identifiers: {identifier_cols}")
        print(f"  Targets: {target_cols}")

        total_enrichments = len(df) * len(target_cols)
        estimated_time = (total_enrichments * config.rate_limit_delay) / 60

        print(f"\nEstimates:")
        print(f"  Total enrichments: {total_enrichments}")
        print(f"  Estimated time: {estimated_time:.1f} minutes")

        proceed = input(f"\nProceed? (y/n): ").strip().lower()
        if proceed != 'y':
            return

        # Process
        print(f"\nStarting enrichment with timeout protection...")

        start_time = time.time()
        enriched_df, logs = agent.enrich_dataframe(df)
        total_time = time.time() - start_time

        # Save results
        output_csv = f"enriched_{filename}"
        enriched_df.to_csv(output_csv, index=False)

        # Show summary
        successful = len([log for log in logs if log['status'] == 'success'])
        failed = len([log for log in logs if log['status'] == 'failed'])
        errors = len([log for log in logs if log['status'] == 'error'])

        print(f"\n=== FINAL RESULTS ===")
        print(f"Time: {total_time/60:.1f} minutes")
        print(f"Successful: {successful}")
        print(f"Failed: {failed}")
        print(f"Errors: {errors}")
        print(f"Success rate: {successful/(successful+failed+errors)*100:.1f}%")
        print(f"Output saved: {output_csv}")

        # Download in Colab
        try:
            from google.colab import files as colab_files
            colab_files.download(output_csv)
        except ImportError:
            pass

    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()