In [1]:
!pip install autogen
!pip install groq
!pip install tavily-python
!pip install scipy
!pip install ctgan
!pip install numpy

Collecting autogen
  Downloading autogen-0.8.7-py3-none-any.whl.metadata (24 kB)
Collecting pyautogen==0.8.7 (from autogen)
  Downloading pyautogen-0.8.7-py3-none-any.whl.metadata (35 kB)
Collecting asyncer==0.0.8 (from pyautogen==0.8.7->autogen)
  Downloading asyncer-0.0.8-py3-none-any.whl.metadata (6.7 kB)
Collecting diskcache (from pyautogen==0.8.7->autogen)
  Downloading diskcache-5.6.3-py3-none-any.whl.metadata (20 kB)
Collecting docker (from pyautogen==0.8.7->autogen)
  Downloading docker-7.1.0-py3-none-any.whl.metadata (3.8 kB)
Collecting python-dotenv (from pyautogen==0.8.7->autogen)
  Downloading python_dotenv-1.1.0-py3-none-any.whl.metadata (24 kB)
Collecting tiktoken (from pyautogen==0.8.7->autogen)
  Downloading tiktoken-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Downloading autogen-0.8.7-py3-none-any.whl (13 kB)
Downloading pyautogen-0.8.7-py3-none-any.whl (739 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m739.3/7

In [3]:
from IPython import get_ipython
from IPython.display import display
import autogen
from autogen import Agent, UserProxyAgent, AssistantAgent
import pandas as pd
import numpy as np
import re
import requests
import json
from typing import Dict, List, Any, Tuple
import time
from tavily import TavilyClient

class TavilySearch:
    def __init__(self, api_key):
        self.api_key = api_key
        self.cache = {}  # Initialize the cache

    def search(self, query, search_depth="basic", include_domains=None, exclude_domains=None, max_results=5):
        # Check if the query is in the cache
        cache_key = (query, search_depth, include_domains, exclude_domains, max_results)
        if cache_key in self.cache:
            return self.cache[cache_key]  # Return cached result

        # Create the Tavily client
        client = TavilyClient(api_key=self.api_key)

        # Prepare additional parameters if provided
        kwargs = {"search_depth": search_depth, "max_results": max_results}
        if include_domains:
            kwargs["include_domains"] = include_domains
        if exclude_domains:
            kwargs["exclude_domains"] = exclude_domains

        try:
            # Make the API request and cache the result
            result = client.search(query=query, **kwargs)
            self.cache[cache_key] = result
            return result
        except Exception as e:
            print(f"Error making Tavily API request: {e}")
            return {"results": []}

# Default configuration for the LLM
config_list = [
    {
        "model": "deepseek-r1-distill-llama-70b",
        "api_key": "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",
        "api_type": "groq"
    }
]

class KnowledgeAgent:
    """
    Agent 1: Knowledge Agent -
    Attribute extraction
    Data lookup for Data Distributions
    Feature relations
    Generating Dataset Schema.
    """
    def __init__(self, config_list=config_list, tavily_api_key="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"):
        self.config_list = config_list
        self.tavily = TavilySearch(api_key=tavily_api_key)
        self.setup_agents()

    def setup_agents(self):
        self.attribute_extractor = AssistantAgent(
            name="AttributeExtractor",
            llm_config={"config_list": self.config_list},
            system_message="""You are an expert at extracting data attributes from prompts.
            Given a prompt about a dataset, identify all potential attributes/features, their likely data types,
            and any constraints mentioned. Format your response as a JSON object where each key is the attribute name,
            and the value is an object with "type" and "constraints".
            Example output:
            {
              "attributename": {
                "type": "int/float etc..",
                "constraints": { "min": 0, "max": 120 } etc..
              },
            }
            """
        )

        self.data_lookup = AssistantAgent(
            name="DataLookup",
            llm_config={"config_list": self.config_list},
            system_message="""You are an expert at determining appropriate distributions and data characteristics
            for dataset attributes. Given a list of attributes and information retrieved from online searches,
            determine the most appropriate distribution type, realistic value ranges, and common constraints
            for each attribute. Format your response as a JSON object."""
        )

        self.dataset_structure = AssistantAgent(
            name="DatasetStructure",
            llm_config={"config_list": self.config_list},
            system_message="""You are an expert at finalizing dataset structures. Given attributes, distributions,
            and constraints, produce a comprehensive dataset schema. Include data types, value ranges,
            distributions, and any rules. Format your response as a JSON schema that could be used to generate synthetic data."""
        )

        self.user_proxy = UserProxyAgent(
            name="User",
            human_input_mode="NEVER",
            max_consecutive_auto_reply=0
        )

    def extract_attributes(self, prompt: str) -> Dict:
      self.user_proxy.initiate_chat(
          self.attribute_extractor,
          message=f"Extract the attributes from this prompt: {prompt}"
      )
      response = self.attribute_extractor.last_message()["content"]

      # Parse the JSON response
      try:
          attributes = self._extract_json_from_text(response)

          # Validate that each attribute is a dictionary with proper structure
          for attr_name, attr_info in list(attributes.items()):
              if not isinstance(attr_info, dict):
                  # Convert string values to a proper dictionary structure
                  attributes[attr_name] = {
                      "type": "string",
                      "description": str(attr_info)
                  }

          return attributes
      except Exception as e:
          print(f"Error parsing attributes: {e}")
          return {}

    def online_lookup(self, attribute_name: str, data_type: str) -> Dict:
        search_query = f"common distribution for {attribute_name} in datasets {data_type} data"
        results = self.tavily.search(search_query, search_depth="advanced")

        # Extract relevant information from search results
        relevant_info = ""
        for result in results.get("results", []):
            relevant_info += f"\nTitle: {result.get('title', '')}\n"
            relevant_info += f"Content: {result.get('content', '')}\n"

        # If no results were found, use a more general query
        if not relevant_info:
            general_query = f"typical values and distribution for {data_type} data in machine learning"
            general_results = self.tavily.search(general_query)
            for result in general_results.get("results", []):
                relevant_info += f"\nTitle: {result.get('title', '')}\n"
                relevant_info += f"Content: {result.get('content', '')}\n"

        return {
            "attribute": attribute_name,
            "data_type": data_type,
            "search_results": relevant_info
        }

    def split_into_chunks(self, text: str, chunk_size: int = 500) -> List[str]:
        """Split the input text into smaller chunks."""
        tokens = text.split()
        return [" ".join(tokens[i:i + chunk_size]) for i in range(0, len(tokens), chunk_size)]

    def lookup_distributions(self, attributes: Dict) -> Dict:
      online_info = []  # Initialize online_info here
      for attr_name, attr_info in attributes.items():
          # Ensure attr_info is a dictionary
          if not isinstance(attr_info, dict):
              attr_info = {
                  "type": "string",
                  "description": str(attr_info)
              }
              attributes[attr_name] = attr_info

          data_type = attr_info.get("type", "string").lower()
          print(f"Looking up information for {attr_name} ({data_type})...")

          info = self.online_lookup(attr_name, data_type)
          online_info.append(info)
          time.sleep(1)  # delay to avoid rate limiting

          search_results_text = json.dumps(online_info, indent=2)
          attributes_text = json.dumps(attributes, indent=2)

          # Split inputs into smaller chunks
          search_results_chunks = self.split_into_chunks(search_results_text, chunk_size=500)
          attributes_chunks = self.split_into_chunks(attributes_text, chunk_size=500)

          distributions = {}
          for chunk in search_results_chunks:
              self.user_proxy.initiate_chat(
                  self.data_lookup,
                  message=f"""Based on this chunk of search results:

                  {chunk}

                  And these attribute definitions:

                  {attributes_text}

                  Determine appropriate distributions and constraints for each attribute.
                  """
              )
              response = self.data_lookup.last_message()["content"]
              try:
                  chunk_distributions = self._extract_json_from_text(response)
                  distributions.update(chunk_distributions)
              except Exception as e:
                  print(f"Error parsing distributions for chunk: {e}")

          return distributions

    def finalize_dataset_structure(self, attributes: Dict, distributions: Dict) -> Dict:
        combined_info = {
            "attributes": attributes,
            "distributions": distributions
        }

        self.user_proxy.initiate_chat(
            self.dataset_structure,
            message=f"Create a final dataset structure based on this information: {json.dumps(combined_info, indent=2)}"
        )
        response = self.dataset_structure.last_message()["content"]

        try:
            # First, try to extract a full schema with properties
            # Look for a schema that contains a "properties" field which would include all attributes
            full_schema_pattern = r'\{(?:[^{}]|(?:\{[^{}]*\}))*"properties"\s*:\s*\{(?:[^{}]|(?:\{[^{}]*\}))*\}(?:[^{}]|(?:\{[^{}]*\}))*\}'
            full_schema_matches = re.findall(full_schema_pattern, response, re.DOTALL)

            if full_schema_matches:
                for match in full_schema_matches:
                    try:
                        schema = json.loads(match)
                        # Verify it has properties and they include multiple attributes
                        if isinstance(schema, dict) and "properties" in schema and len(schema["properties"]) > 1:
                            print(f"Successfully extracted complete schema with {len(schema['properties'])} properties")
                            return schema
                    except json.JSONDecodeError:
                        continue

            # If we couldn't find a full schema, try extracting from code blocks
            code_blocks = re.findall(r'```(?:json)?\s*([\s\S]*?)```', response)
            for block in code_blocks:
                try:
                    schema = json.loads(block)
                    if isinstance(schema, dict) and "properties" in schema and len(schema["properties"]) > 1:
                        print(f"Successfully extracted complete schema from code block with {len(schema['properties'])} properties")
                        return schema
                except json.JSONDecodeError:
                    continue

            # As a last resort, use the regular _extract_json_from_text method
            return self._extract_json_from_text(response)

        except Exception as e:
            print(f"Error parsing dataset structure: {e}")
            print(f"Raw response from DatasetStructure agent: {response}")
            return {}

    def _extract_json_from_text(self, text: str) -> Dict:
      try:
          code_blocks = re.findall(r'```(?:json)?\s*([\s\S]*?)```', text)
          for block in code_blocks:
              data = json.loads(block)
              if all(isinstance(v, dict) for v in data.values()):
                  #if the extracted block of code is a dict
                  return data
          # Fallback: attempt to parse inline JSON
          json_pattern = r'\{(?:[^{}]|(?:\{[^{}]*\}))*\}'
          json_matches = re.findall(json_pattern, text, re.DOTALL)
          for match in json_matches:
              data = json.loads(match)
              if all(isinstance(v, dict) for v in data.values()):
                  return data
      except Exception as e:
          print("JSON extraction error:", e)
      raise ValueError("Could not extract valid attribute dictionary")


    def process_prompt(self, prompt: str) -> Dict:
        """
        Process the prompt to extract attributes, lookup distributions, and
        finalize the dataset structure.

        Returns:
            The finalized dataset structure
        """
        print("Step 1: Extracting attributes from prompt...")
        attributes = self.extract_attributes(prompt)
        print(f"Extracted {len(attributes)} attributes: {', '.join(attributes.keys())}")

        print("Step 2: Looking up distributions for attributes using Tavily...")
        distributions = self.lookup_distributions(attributes)
        print(f"Found distributions for {len(distributions)} attributes")

        print("Step 3: Finalizing dataset structure...")
        dataset_structure = self.finalize_dataset_structure(attributes, distributions)

        # Validate the dataset structure before returning
        if not dataset_structure or "properties" not in dataset_structure:
            print("Error: Invalid dataset structure generated")
            # Fallback to a basic structure if needed
            dataset_structure = {
                "$schema": "http://json-schema.org/draft-07/schema#",
                "title": "Dataset Schema",
                "description": "Schema for generating synthetic data",
                "type": "object",
                "properties": {},
                "required": []
            }

            # Add the extracted attributes to the properties
            for attr_name, attr_info in attributes.items():
                dataset_structure["properties"][attr_name] = attr_info
                dataset_structure["required"].append(attr_name)

            print("Using fallback dataset structure")

        properties = dataset_structure.get("properties", {})
        print(f"Dataset structure finalized with {len(properties)} properties")

        # Print a summary of the structure instead of the full JSON
        property_summary = {}
        for prop_name, prop_details in properties.items():
            prop_type = prop_details.get("type", "unknown")
            constraints = []
            if "minimum" in prop_details:
                constraints.append(f"min={prop_details['minimum']}")
            if "maximum" in prop_details:
                constraints.append(f"max={prop_details['maximum']}")
            if "enum" in prop_details:
                constraints.append(f"enum=[{len(prop_details['enum'])} values]")

            property_summary[prop_name] = {
                "type": prop_type,
                "constraints": ", ".join(constraints) if constraints else "none"
            }

        print("Structure Summary:")
        for prop, details in property_summary.items():
            print(f"  - {prop} ({details['type']}): {details['constraints']}")

        return dataset_structure


class SeedDatasetGenerator:
    """
    Agent 2: Seed Dataset Generator - Responsible for processing the dataset
    structure from Agent 1 and generating high quality seed samples.
    """
    def __init__(self, config_list=config_list):
        self.config_list = config_list
        self.setup_agents()

    def setup_agents(self):
        self.seed_dataset_generator = AssistantAgent(
            name="SeedDatasetGenerator",
            llm_config={"config_list": self.config_list},
            system_message="""You are an expert at generating synthetic datasets. Given a dataset schema in JSON format,
            generate a set of seed samples that conform to the schema. Ensure the samples are realistic and adhere to
            the specified data types, value ranges, and distributions. Format your response as a JSON array of objects,
            where each object represents a single data sample."""
        )

        self.user_proxy = UserProxyAgent(
            name="User",
            human_input_mode="NEVER",
            max_consecutive_auto_reply=0
        )

    def generate_seed_dataset(self, dataset_schema: Dict, num_samples: int = 10) -> List[Dict]:
        """
        Generate seed samples based on the provided dataset schema.

        Args:
            dataset_schema: The dataset structure from the Knowledge Agent
            num_samples: Number of samples to generate

        Returns:
            List of seed samples
        """
        # First, validate the dataset_schema
        if not isinstance(dataset_schema, dict):
            print(f"Error: dataset_schema is not a dictionary. Type: {type(dataset_schema)}")
            return []

        # Check if we have a proper schema with properties
        if "properties" not in dataset_schema:
            print("Error: dataset_schema does not contain 'properties' key")
            print(f"Schema keys: {list(dataset_schema.keys())}")

            # If we got a single attribute instead of the full schema, build a proper schema
            if "type" in dataset_schema and "description" in dataset_schema:
                print("Detected single attribute instead of full schema. Rebuilding schema...")
                attribute_name = next((k for k in dataset_schema.keys() if k not in
                                     ["type", "description", "distribution", "distribution_parameters", "enum", "minimum", "maximum"]), "attribute")
                properties = {attribute_name: dataset_schema}
                dataset_schema = {
                    "type": "object",
                    "properties": properties,
                    "required": [attribute_name]
                }
                print(f"Rebuilt schema with property: {attribute_name}")
            else:
                # If it doesn't look like an attribute either, create an error message
                print("Cannot create seed dataset: invalid schema format")
                return []

        # Log information about the schema
        properties = dataset_schema.get("properties", {})
        print(f"Generating seed data for schema with {len(properties)} properties: {', '.join(properties.keys())}")

        # Create a more specific prompt that emphasizes the full schema structure
        prompt = f"""Generate exactly {num_samples} seed samples based on this dataset schema:
        {json.dumps(dataset_schema, indent=2)}

        Each sample MUST include ALL of these attributes: {', '.join(properties.keys())}

        Each sample must be a complete JSON object containing values for ALL attributes defined in the schema.
        Make sure to follow all the specified data types, value ranges, and distributions.

        Format your response as a JSON array of objects, like this:
        [
          {{
            "attribute1": value1,
            "attribute2": value2,
            ...
          }},
          ...
        ]
        """

        self.user_proxy.initiate_chat(
            self.seed_dataset_generator,
            message=prompt
        )
        response = self.seed_dataset_generator.last_message()["content"]

        try:
            # First try to extract a JSON array from the response
            array_pattern = r'\[\s*\{.*?\}\s*(?:,\s*\{.*?\}\s*)*\]'
            array_matches = re.findall(array_pattern, response, re.DOTALL)

            if array_matches:
                samples = json.loads(array_matches[0])
                print(f"Successfully extracted {len(samples)} samples")
                # Validate that samples contain all required properties
                required_props = set(properties.keys())
                for i, sample in enumerate(samples):
                    missing_props = required_props - set(sample.keys())
                    if missing_props:
                        print(f"Warning: Sample {i+1} is missing properties: {missing_props}")
                return samples

            # If no array is found, try to extract from code blocks
            code_blocks = re.findall(r'```(?:json)?\s*([\s\S]*?)```', response)
            for block in code_blocks:
                try:
                    if block.strip().startswith('[') and block.strip().endswith(']'):
                        samples = json.loads(block)
                        if isinstance(samples, list) and len(samples) > 0:
                            print(f"Successfully extracted {len(samples)} samples from code block")
                            return samples
                except json.JSONDecodeError:
                    continue

            # As a last resort, try to extract individual objects
            json_pattern = r'\{(?:[^{}]|(?:\{[^{}]*\}))*\}'
            json_matches = re.findall(json_pattern, response, re.DOTALL)

            if json_matches:
                samples = []
                for match in json_matches:
                    try:
                        sample = json.loads(match)
                        if isinstance(sample, dict) and len(sample) > 1:
                            samples.append(sample)
                    except json.JSONDecodeError:
                        continue

                if samples:
                    print(f"Extracted {len(samples)} individual samples")
                    return samples

            raise ValueError("No valid JSON array or objects found in the response")
        except Exception as e:
            print(f"Error parsing seed dataset: {e}")
            print(f"Raw response from SeedDatasetGenerator: {response}")
            return []

class CTGANSynthesizer:
    """
    Agent 3: CTGAN Synthesizer - Uses CTGAN to generate high-quality synthetic data
    based on the seed dataset from Agent 2.
    """
    def __init__(self, config_list=config_list):
        self.config_list = config_list

    def synthesize_data(self, seed_dataset, num_samples=500):
        """
        Generate synthetic data using CTGAN based on the seed dataset.

        Args:
            seed_dataset: List of dictionaries containing the seed samples
            num_samples: Number of synthetic samples to generate

        Returns:
            List of dictionaries containing the synthetic samples
        """
        try:
            # Import necessary libraries with correct import names
            import pandas as pd
            print("pandas imported successfully")

            # Fixed import: CTGAN instead of CTGANSynthesizer
            from ctgan import CTGAN
            print("CTGAN imported successfully")

            # Convert seed dataset to pandas DataFrame
            seed_df = pd.DataFrame(seed_dataset)

            # Identify categorical and continuous columns
            categorical_columns = []
            continuous_columns = []

            for col in seed_df.columns:
                # Check if column contains string values
                if seed_df[col].dtype == 'object' or pd.api.types.is_categorical_dtype(seed_df[col]):
                    categorical_columns.append(col)
                else:
                    continuous_columns.append(col)

            print(f"Categorical columns: {categorical_columns}")
            print(f"Continuous columns: {continuous_columns}")

            # Initialize and train CTGAN
            print("Initializing CTGAN model...")
            ctgan = CTGAN(epochs=100, batch_size=100)  # Changed from CTGANSynthesizer to CTGAN
            print("CTGAN model initialized successfully.")

            # Handle categorical columns properly
            for col in categorical_columns:
                seed_df[col] = seed_df[col].astype(str)

            print(f"Training CTGAN on seed dataset with {len(seed_df)} samples...")
            ctgan.fit(seed_df, categorical_columns)
            print("CTGAN training completed successfully.")

            # Generate synthetic samples
            print(f"Generating {num_samples} synthetic samples...")
            synthetic_data = ctgan.sample(num_samples)

            # Convert DataFrame back to list of dictionaries
            synthetic_samples = synthetic_data.to_dict('records')

            # Post-process: ensure data types match the original
            for i, sample in enumerate(synthetic_samples):
                for col in categorical_columns:
                    # Keep categorical values as strings
                    synthetic_samples[i][col] = str(sample[col])

                for col in continuous_columns:
                    # Convert continuous values to appropriate numeric type
                    if col in seed_dataset[0] and isinstance(seed_dataset[0][col], int):
                        synthetic_samples[i][col] = int(sample[col])
                    else:
                        synthetic_samples[i][col] = float(sample[col])

            print(f"Successfully generated {len(synthetic_samples)} synthetic samples using CTGAN")
            return synthetic_samples

        except ImportError as e:
            print(f"Import error: {str(e)}")
            print("This could be due to missing or incompatible packages.")
            print("Please ensure you have the latest version of CTGAN installed.")
            print("Attempting to install the required packages...")

            # Try to install or update CTGAN
            import sys
            import subprocess
            try:
                subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "ctgan"])
                print("CTGAN upgraded successfully. Please restart the kernel and try again.")
            except:
                print("Failed to upgrade CTGAN automatically.")

            # Raise the original error
            raise

        except Exception as e:
            print(f"Error in CTGAN synthesis: {str(e)}")
            raise

class DataValidator:
    """
    Agent 4: Data Validator - Validates and corrects the synthetic data to ensure
    it adheres to the dataset structure constraints (min/max values, data types, etc.)
    """
    def __init__(self, config_list=config_list):
        self.config_list = config_list
        self.setup_agents()

    def setup_agents(self):
        self.data_validator = AssistantAgent(
            name="DataValidator",
            llm_config={"config_list": self.config_list},
            system_message="""You are an expert at validating and correcting synthetic datasets.
            Given a dataset schema and synthetic data samples, identify values that violate constraints
            like min/max values, allowed categories, or data types. Propose specific corrections to make
            the data valid while maintaining statistical properties."""
        )

        self.user_proxy = UserProxyAgent(
            name="User",
            human_input_mode="NEVER",
            max_consecutive_auto_reply=0
        )

    def analyze_schema_constraints(self, properties):
        """Analyze and print constraints in the schema for debugging"""
        print("\n=== Schema Constraint Analysis ===")
        for prop_name, prop_details in properties.items():
            constraints = []
            prop_type = prop_details.get("type", "unknown").lower()

            # Check for min/max constraints
            if "minimum" in prop_details:
                constraints.append(f"min={prop_details['minimum']}")
            if "maximum" in prop_details:
                constraints.append(f"max={prop_details['maximum']}")

            # Check for enum constraints
            if "enum" in prop_details:
                enum_values = prop_details["enum"]
                if len(enum_values) <= 5:
                    constraints.append(f"enum={enum_values}")
                else:
                    constraints.append(f"enum=[{len(enum_values)} values]")

            # Check for format constraints
            if "format" in prop_details:
                constraints.append(f"format={prop_details['format']}")

            # Check for pattern constraints
            if "pattern" in prop_details:
                constraints.append(f"pattern='{prop_details['pattern']}'")

            print(f"  {prop_name} ({prop_type}): {', '.join(constraints) if constraints else 'no explicit constraints'}")
        print("")

    def validate_data(self, dataset_structure, synthetic_data):
      """
      Validate and filter the synthetic data based on the schema constraints.

      Args:
          dataset_structure: The dataset structure/schema from the Knowledge Agent
          synthetic_data: List of dictionaries containing the synthetic samples

      Returns:
          List of dictionaries containing the validated synthetic samples
      """
      import pandas as pd
      import json

      print("Starting data validation process...")

      # Extract schema properties
      properties = dataset_structure.get("properties", {})
      if not properties:
          print("Warning: No properties found in schema. Validation may be limited.")

      # Convert to DataFrame for easier processing
      df = pd.DataFrame(synthetic_data)

      # Keep track of how many records were filtered out
      original_count = len(df)
      invalid_records = 0

      # Process each property based on its constraints in the schema
      for prop_name, prop_details in properties.items():
          if prop_name not in df.columns:
              print(f"Warning: Property '{prop_name}' not found in the dataset")
              continue

          print(f"Validating property: {prop_name}")
          prop_type = prop_details.get("type", "string").lower()

          # Convert data types as needed
          if prop_type == "integer":
              # Convert to integers, with NaN values becoming null
              df[prop_name] = pd.to_numeric(df[prop_name], errors='coerce')
              # Filter out records with invalid values
              invalid = df[prop_name].isna()
              if any(invalid):
                  print(f"  Removing {sum(invalid)} records with invalid integer values for {prop_name}")
                  df = df[~invalid]
                  invalid_records += sum(invalid)
              # Convert remaining values to integers
              df[prop_name] = df[prop_name].astype('Int64')  # Use nullable integer type

          elif prop_type == "number":
              # Convert to floats, with NaN values becoming null
              df[prop_name] = pd.to_numeric(df[prop_name], errors='coerce')
              # Filter out records with invalid values
              invalid = df[prop_name].isna()
              if any(invalid):
                  print(f"  Removing {sum(invalid)} records with invalid numeric values for {prop_name}")
                  df = df[~invalid]
                  invalid_records += sum(invalid)

          # Apply minimum constraint if specified
          if "minimum" in prop_details and prop_type in ["integer", "number"]:
              min_value = prop_details["minimum"]
              below_min = df[prop_name] < min_value
              if any(below_min):
                  print(f"  Removing {sum(below_min)} records with {prop_name} below minimum ({min_value})")
                  df = df[~below_min]
                  invalid_records += sum(below_min)

          # Apply maximum constraint if specified
          if "maximum" in prop_details and prop_type in ["integer", "number"]:
              max_value = prop_details["maximum"]
              above_max = df[prop_name] > max_value
              if any(above_max):
                  print(f"  Removing {sum(above_max)} records with {prop_name} above maximum ({max_value})")
                  df = df[~above_max]
                  invalid_records += sum(above_max)

          # Apply enum constraint if specified
          if "enum" in prop_details:
              allowed_values = set(prop_details["enum"])
              invalid_values = ~df[prop_name].isin(allowed_values)
              if any(invalid_values):
                  print(f"  Removing {sum(invalid_values)} records with invalid categorical values for {prop_name}")
                  print(f"  Allowed values: {allowed_values}")
                  df = df[~invalid_values]
                  invalid_records += sum(invalid_values)

          # Handle boolean constraints
          if prop_type == "boolean":
              # Check if values are valid booleans
              valid_booleans = df[prop_name].isin([True, False, 0, 1, "True", "False", "true", "false"])
              if not all(valid_booleans):
                  print(f"  Removing {sum(~valid_booleans)} records with invalid boolean values for {prop_name}")
                  df = df[valid_booleans]
                  invalid_records += sum(~valid_booleans)

          # Apply pattern constraint if specified (for strings)
          if "pattern" in prop_details and prop_type == "string":
              import re
              pattern = prop_details["pattern"]
              # Create a regex pattern matcher
              matcher = re.compile(pattern)
              # Check each value against the pattern
              invalid_pattern = df[prop_name].apply(lambda x: not bool(matcher.match(str(x))))
              if any(invalid_pattern):
                  print(f"  Removing {sum(invalid_pattern)} records with values not matching pattern for {prop_name}")
                  df = df[~invalid_pattern]
                  invalid_records += sum(invalid_pattern)

      # Print validation summary
      final_count = len(df)
      if invalid_records > 0:
          print(f"\nValidation complete: {invalid_records} invalid records removed.")
          print(f"Started with {original_count} records, {final_count} valid records remain.")
      else:
          print("\nValidation complete: All records passed validation.")

      # Convert back to list of dictionaries
      validated_data = df.to_dict('records')
      return validated_data

def main():
    # Initialize the Knowledge Agent (Agent 1)
    knowledge_agent = KnowledgeAgent(
        config_list=config_list,
        tavily_api_key="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"
    )

    # Initialize the Seed Dataset Generator (Agent 2)
    seed_generator = SeedDatasetGenerator(
        config_list=config_list
    )

    # Initialize the CTGAN Synthesizer (Agent 3)
    ctgan_synthesizer = CTGANSynthesizer(config_list=config_list)

    # Initialize the Data Validator (Agent 4)
    data_validator = DataValidator(config_list=config_list)

    prompt = """
    I need a dataset for a machine learning project on customer churn prediction.
    The dataset should include customer demographics (age, gender, location),
    subscription details (plan_type, monthly_charges, total_charges),
    usage behavior (data_usage, call_minutes, service_calls),
    and customer history (tenure, payment_method, paperless Billing).
    """

    # Step 1: Generate dataset structure using Knowledge Agent
    print("=== Agent 1: Knowledge Agent - Generating Dataset Structure ===")
    dataset_structure = knowledge_agent.process_prompt(prompt)

    print("\nFinal Dataset Structure:")
    print(json.dumps(dataset_structure, indent=2))

    # Step 2: Generate seed dataset using Seed Dataset Generator
    print("\n=== Agent 2: Seed Dataset Generator - Generating Seed Samples ===")
    seed_dataset = seed_generator.generate_seed_dataset(dataset_structure, num_samples=30)

    print("\nSeed Dataset Samples:")
    print(json.dumps(seed_dataset[:5], indent=2))  # Print just 5 samples to avoid clutter
    print(f"Total seed samples: {len(seed_dataset)}")

    # Step 3: Synthesize additional data using CTGAN
    print("\n=== Agent 3: CTGAN Synthesizer - Generating Synthetic Data ===")
    synthetic_data = ctgan_synthesizer.synthesize_data(seed_dataset, num_samples=500)

    print("\nSynthetic Data Samples:")
    print(json.dumps(synthetic_data[:5], indent=2))  # Print just 5 samples
    print(f"Total synthetic samples: {len(synthetic_data)}")

    # Step 4: Validate and correct the synthetic data
    print("\n=== Agent 4: Data Validator - Validating and Correcting Data ===")
    combined_data = seed_dataset + synthetic_data
    validated_data = data_validator.validate_data(dataset_structure, combined_data)

    print("\nValidated Data Samples:")
    print(json.dumps(validated_data[:5], indent=2))  # Print just 5 samples
    print(f"Total validated samples: {len(validated_data)}")

    # Save the final dataset to a CSV file
    try:
        import pandas as pd
        df = pd.DataFrame(validated_data)
        df.to_csv("synthetic_churn_dataset.csv", index=False)
        print("Dataset saved to 'synthetic_churn_dataset.csv'")

        # Generate some basic statistics
        print("\nDataset Statistics:")
        for column in df.columns:
            if df[column].dtype in ['int64', 'float64']:
                print(f"{column}: min={df[column].min()}, max={df[column].max()}, mean={df[column].mean():.2f}")
            else:
                value_counts = df[column].value_counts()
                top_values = value_counts.head(3).to_dict()
                print(f"{column}: {len(value_counts)} unique values. Top 3: {top_values}")

    except Exception as e:
        print(f"Error saving dataset to CSV: {e}")

    # Return the final dataset for further use
    return validated_data

if __name__ == "__main__":
    main()


=== Agent 1: Knowledge Agent - Generating Dataset Structure ===
Step 1: Extracting attributes from prompt...
User (to AttributeExtractor):

Extract the attributes from this prompt: 
    I need a dataset for a machine learning project on customer churn prediction.
    The dataset should include customer demographics (age, gender, location),
    subscription details (plan_type, monthly_charges, total_charges),
    usage behavior (data_usage, call_minutes, service_calls),
    and customer history (tenure, payment_method, paperless Billing).
    

--------------------------------------------------------------------------------
AttributeExtractor (to User):

<think>
Alright, so I'm trying to figure out how to extract the attributes from the given prompt about a customer churn prediction dataset. Let me break it down step by step. 

First, the prompt mentions several categories: demographics, subscription details, usage behavior, and customer history. I need to go through each category and l



Extracted 12 attributes: age, gender, location, plan_type, monthly_charges, total_charges, data_usage, call_minutes, service_calls, tenure, payment_method, paperless_billing
Step 2: Looking up distributions for attributes using Tavily...
Looking up information for age (int)...
User (to DataLookup):

Based on this chunk of search results:

                  [ { "attribute": "age", "data_type": "int", "search_results": "\nTitle: Analyse the distribution of ages \u2014 Python Data Analysis series part 1\nContent: The values in the column are represented as float, so if we go row by row, we can compare the integer representation with the current numerical representation. If they are the same, then the number is an integer (e.g. 23.0 or 45.0); otherwise the age was a floating-point number and we discard it (e.g. 15.5 is different from 15). [...] The next step is to remove floating-point ages (lines 21 and 22). Generally, ages are recorded as integer values, so we\u2019ll simply remove those

  if seed_df[col].dtype == 'object' or pd.api.types.is_categorical_dtype(seed_df[col]):


CTGAN training completed successfully.
Generating 500 synthetic samples...
Successfully generated 500 synthetic samples using CTGAN

Synthetic Data Samples:
[
  {
    "age": 73,
    "gender": "Other",
    "location": "Suburban",
    "plan_type": "Premium",
    "monthly_charges": 53.86443330725312,
    "total_charges": 300.37970119906925,
    "data_usage": 763,
    "call_minutes": 501,
    "service_calls": 8,
    "tenure": 37,
    "payment_method": "Cash",
    "paperless_billing": 1
  },
  {
    "age": 86,
    "gender": "Male",
    "location": "Rural",
    "plan_type": "Pro",
    "monthly_charges": 55.99713853505112,
    "total_charges": 380.8177996268621,
    "data_usage": 295,
    "call_minutes": 703,
    "service_calls": 8,
    "tenure": 37,
    "payment_method": "Bank Transfer",
    "paperless_billing": 1
  },
  {
    "age": 61,
    "gender": "Other",
    "location": "Rural",
    "plan_type": "Premium",
    "monthly_charges": 14.508036047466582,
    "total_charges": 417.400073788352