In [1]:
# Install required packages with specific versions for compatibility
!pip install sdv>=1.0.0 pandas numpy networkx scikit-learn matplotlib seaborn plotly
!pip install faker  # For realistic fake data generation

import warnings
warnings.filterwarnings('ignore')

# Set pandas display options for better readability
import pandas as pd
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', 50)

print("✓ All dependencies installed successfully")
print("✓ Environment configured for optimal display")

✓ All dependencies installed successfully
✓ Environment configured for optimal display


In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from datetime import datetime, timedelta
import random
from faker import Faker

# SDV imports for multi-table synthesis
from sdv.metadata import Metadata
from sdv.multi_table import HMASynthesizer

# SDV evaluation imports - CRITICAL for quality assessment
from sdv.evaluation.multi_table import run_diagnostic, evaluate_quality

# Additional SDV utilities
from sdv.datasets.demo import download_demo

# NetworkX for relationship analysis
import networkx as nx
from typing import Dict, List, Tuple, Optional, Any

# Initialize Faker for realistic data generation
fake = Faker(['en_US', 'en_GB', 'es_ES', 'fr_FR'])  # Multi-locale support
Faker.seed(42)  # Reproducible fake data
np.random.seed(42)  # Reproducible random numbers

print("✓ All libraries imported successfully")
print(f"✓ Faker initialized with seed 42 for reproducible results")
print(f"✓ NumPy random seed set to 42")

✓ All libraries imported successfully
✓ Faker initialized with seed 42 for reproducible results
✓ NumPy random seed set to 42


In [3]:
class RecursiveMultiTableSynthesizer:
    """
    FIXED: Advanced multi-table synthetic data generator with automatic relationship detection,
    comprehensive evaluation metrics, and support for complex enterprise data structures.

    Features:
    - Proper metadata initialization (FIXES AttributeError)
    - Support for 20+ interconnected tables
    - Recursive and self-referencing relationships
    - Comprehensive quality evaluation
    - Advanced visualization capabilities
    """

    def __init__(self, synthesizer_type='gaussian_copula'):
        """
        FIXED: Initialize the enhanced recursive multi-table synthesizer with proper metadata.

        Args:
            synthesizer_type (str): Type of synthesizer ('gaussian_copula' or 'ctgan')
        """
        self.synthesizer_type = synthesizer_type
        self.metadata = Metadata()  # FIXED: Properly initialize Metadata object instead of None
        self.synthesizer = None
        self.real_data = {}
        self.synthetic_data = {}

        # Enhanced tracking
        self.table_relationships = {}
        self.dependency_graph = nx.DiGraph()
        self.table_stats = {}
        self.generation_stats = {}

        # Evaluation results storage
        self.diagnostic_report = None
        self.quality_report = None
        self.evaluation_scores = {}

        print(f"Enhanced RecursiveMultiTableSynthesizer initialized with {synthesizer_type}")
        print("Ready for automatic relationship detection and quality evaluation")

    def add_table_data(self, table_name: str, data: pd.DataFrame,
                      primary_key: Optional[str] = None):
        """
        FIXED: Add a single table to the multi-table structure with proper error handling.

        Args:
            table_name (str): Name of the table
            data (pd.DataFrame): The actual data
            primary_key (str): Primary key column name
        """
        try:
            self.real_data[table_name] = data.copy()

            # Add table to metadata using the correct method
            self.metadata.detect_table_from_dataframe(table_name=table_name, data=data)

            # Set primary key if provided
            if primary_key and primary_key in data.columns:
                try:
                    # Mark the PK column as an id sdtype
                    self.metadata.update_column(
                        table_name=table_name,
                        column_name=primary_key,
                        sdtype='id'
                    )

                    # Set the PK
                    self.metadata.set_primary_key(
                        table_name=table_name,
                        column_name=primary_key
                    )
                    print(f"Added table '{table_name}' with primary key '{primary_key}'")
                except Exception as e:
                    print(f"Warning setting primary key for {table_name}: {e}")
            else:
                print(f"Added table '{table_name}' (no primary key specified)")

        except Exception as e:
            print(f"Error adding table {table_name}: {e}")
            raise

    def add_tables_from_dict(self, data_dict: Dict[str, pd.DataFrame],
                           primary_keys: Optional[Dict[str, str]] = None):
        """
        FIXED: Add multiple tables at once with proper metadata handling.

        Args:
            data_dict: Dictionary of {table_name: DataFrame}
            primary_keys: Dictionary of {table_name: primary_key_column}
        """
        print(f"\n=== ADDING {len(data_dict)} TABLES WITH METADATA DETECTION ===")

        # Store all data
        self.real_data = data_dict.copy()

        # Add each table individually with proper primary key handling
        print("Adding tables with individual metadata detection...")
        for table_name, df in data_dict.items():
            pk = primary_keys.get(table_name) if primary_keys else None
            self.add_table_data(table_name, df, pk)

        # Generate comprehensive statistics
        self._generate_table_statistics()

        print(f"Successfully added {len(data_dict)} tables with metadata")

    def _generate_table_statistics(self):
        """Generate comprehensive statistics for all tables."""
        print("\nGenerating comprehensive table statistics...")

        for table_name, df in self.real_data.items():
            stats = {
                'rows': len(df),
                'columns': len(df.columns),
                'numeric_cols': len(df.select_dtypes(include=[np.number]).columns),
                'categorical_cols': len(df.select_dtypes(include=['object', 'category']).columns),
                'datetime_cols': len(df.select_dtypes(include=['datetime64']).columns),
                'missing_values': df.isnull().sum().sum(),
                'memory_usage_mb': df.memory_usage(deep=True).sum() / 1024 / 1024
            }
            self.table_stats[table_name] = stats

        total_rows = sum(stats['rows'] for stats in self.table_stats.values())
        total_memory = sum(stats['memory_usage_mb'] for stats in self.table_stats.values())

        print(f"   Total rows across all tables: {total_rows:,}")
        print(f"   Total memory usage: {total_memory:.2f} MB")
        print("Table statistics generated successfully")

print("✓ Enhanced RecursiveMultiTableSynthesizer class (Part 1) defined")

✓ Enhanced RecursiveMultiTableSynthesizer class (Part 1) defined


In [4]:
def add_custom_relationship(self, parent_table: str, parent_column: str,
                          child_table: str, child_column: str):
    """
    Manually add a relationship that wasn't auto-detected.

    Args:
        parent_table: Name of the parent table
        parent_column: Primary key column in parent table
        child_table: Name of the child table
        child_column: Foreign key column in child table
    """
    try:
        # Validate tables and columns exist
        if parent_table not in self.real_data:
            raise ValueError(f"Parent table '{parent_table}' not found")
        if child_table not in self.real_data:
            raise ValueError(f"Child table '{child_table}' not found")
        if parent_column not in self.real_data[parent_table].columns:
            raise ValueError(f"Column '{parent_column}' not found in '{parent_table}'")
        if child_column not in self.real_data[child_table].columns:
            raise ValueError(f"Column '{child_column}' not found in '{child_table}'")

        print("====> parent_table_name :  ", parent_table)
        print("====> parent_primary_key :  ", parent_column)
        print("====> child_table_name :  ", child_table)
        print("====> child_foreign_key :  ", child_column)

        # Update metadata - mark foreign key column as 'id' type
        self.metadata.update_column(child_table, child_column, sdtype='id')

        #====> parent_table_name :   companies
        #====> parent_primary_key :   company_id
        #====> child_table_name :   departments
        #====> child_foreign_key :   company_id
        #❌ Error adding relationship: Unknown table name ('company_id').

        #mData = synthesizer.metadata

        #mData.add_relationship(
        #parent_table_name='companies',
        #child_table_name='departments',
        #parent_primary_key='company_id',
        #child_foreign_key='company_id'
        #)

        # Add relationship to metadata
        self.metadata.add_relationship(
            parent_table_name=parent_table,
            parent_primary_key=parent_column,
            child_table_name=child_table,
            child_foreign_key=child_column
        )

        # Update dependency graph
        self.dependency_graph.add_edge(parent_table, child_table)

        # Store relationship info for tracking
        if child_table not in self.table_relationships:
            self.table_relationships[child_table] = {}
        self.table_relationships[child_table][child_column] = parent_table

        print(f"✅ Added custom relationship: {parent_table}.{parent_column} → {child_table}.{child_column}")

    except Exception as e:
        print(f"❌ Error adding relationship: {e}")

def analyze_relationships(self):
    """
    Analyze and visualize the relationship structure.

    Returns:
        Dict containing relationship analysis results
    """
    print("\n🔍 ANALYZING RELATIONSHIP STRUCTURE")
    print("=" * 50)

    # Get relationships from metadata
    relationships = []
    try:
        metadata_dict = self.metadata.to_dict()
        relationships = metadata_dict.get('relationships', [])

        print(f"📊 Total relationships detected: {len(relationships)}")

        # Analyze relationship types
        self_referencing = []
        hierarchical = []
        many_to_many = []

        for rel in relationships:
            parent_table = rel.get('parent_table_name')
            child_table = rel.get('child_table_name')

            if parent_table == child_table:
                self_referencing.append(rel)
            else:
                hierarchical.append(rel)

        print(f"🔄 Self-referencing relationships: {len(self_referencing)}")
        print(f"🌳 Hierarchical relationships: {len(hierarchical)}")

        # Display relationship details
        if hierarchical:
            print("\n📋 HIERARCHICAL RELATIONSHIPS:")
            for rel in hierarchical[:10]:  # Show first 10
                parent = rel.get('parent_table_name')
                child = rel.get('child_table_name')
                parent_key = rel.get('parent_primary_key')
                child_key = rel.get('child_foreign_key')
                print(f"   {parent}.{parent_key} → {child}.{child_key}")

        if self_referencing:
            print("\n🔄 SELF-REFERENCING RELATIONSHIPS:")
            for rel in self_referencing:
                table = rel.get('parent_table_name')
                parent_key = rel.get('parent_primary_key')
                child_key = rel.get('child_foreign_key')
                print(f"   {table}.{parent_key} ← {table}.{child_key}")

        return {
            'total_relationships': len(relationships),
            'hierarchical': len(hierarchical),
            'self_referencing': len(self_referencing),
            'relationship_details': relationships
        }

    except Exception as e:
        print(f"⚠️ Error analyzing relationships: {e}")
        return {'error': str(e)}

def visualize_table_dependencies(self, figsize=(15, 10)):
    """
    Create a visualization of table dependencies and relationships.

    Args:
        figsize: Figure size for the visualization
    """
    print("\n🎨 Creating relationship visualization...")

    try:
        # Create graph from relationships
        G = nx.DiGraph()

        # Add all tables as nodes
        for table_name in self.real_data.keys():
            node_size = self.table_stats[table_name]['rows']
            G.add_node(table_name, size=node_size)

        # Add edges from relationships
        metadata_dict = self.metadata.to_dict()
        relationships = metadata_dict.get('relationships', [])

        for rel in relationships:
            parent = rel.get('parent_table_name')
            child = rel.get('child_table_name')
            if parent and child:
                G.add_edge(parent, child)

        # Create visualization
        plt.figure(figsize=figsize)

        # Use spring layout for better visualization
        pos = nx.spring_layout(G, k=3, iterations=50)

        # Calculate node sizes based on table row counts
        node_sizes = [self.table_stats[node]['rows'] / 10 for node in G.nodes()]

        # Draw the graph
        nx.draw(G, pos,
                with_labels=True,
                node_color='lightblue',
                node_size=node_sizes,
                font_size=8,
                font_weight='bold',
                arrows=True,
                arrowsize=20,
                edge_color='gray',
                alpha=0.7)

        plt.title("Table Relationship Dependencies\n(Node size reflects table size)",
                 fontsize=16, fontweight='bold')
        plt.axis('off')
        plt.tight_layout()
        plt.show()

        print("✅ Dependency visualization created successfully")

    except Exception as e:
        print(f"❌ Error creating visualization: {e}")

# Add methods to the class
RecursiveMultiTableSynthesizer.add_custom_relationship = add_custom_relationship
RecursiveMultiTableSynthesizer.analyze_relationships = analyze_relationships
RecursiveMultiTableSynthesizer.visualize_table_dependencies = visualize_table_dependencies

print("✓ Enhanced relationship management methods added to class")

✓ Enhanced relationship management methods added to class


In [5]:
def train_synthesizer(self, verbose=True):
    """
    Train the multi-table synthesizer with comprehensive logging.

    Args:
        verbose: Whether to show detailed training progress
    """
    print("\n🚀 TRAINING MULTI-TABLE SYNTHESIZER")
    print("=" * 50)

    try:
        # Validate metadata first
        print("🔍 Validating metadata structure...")
        self.metadata.validate()
        print("✅ Metadata validation successful")

        # Initialize synthesizer
        print(f"🤖 Initializing {self.synthesizer_type} synthesizer...")
        self.synthesizer = HMASynthesizer(metadata=self.metadata)
        print("✅ Synthesizer initialized successfully")

        # Training phase
        print("📚 Starting training process...")
        if verbose:
            print(f"   📊 Training on {len(self.real_data)} tables")
            print(f"   📈 Total data points: {sum(len(df) for df in self.real_data.values()):,}")

        # Record training start time
        training_start = datetime.now()

        # Train the model
        self.synthesizer.fit(self.real_data)

        # Record training completion
        training_end = datetime.now()
        training_duration = training_end - training_start

        # Store training statistics
        self.generation_stats['training_duration'] = training_duration.total_seconds()
        self.generation_stats['training_start'] = training_start
        self.generation_stats['training_end'] = training_end

        print(f"✅ Training completed successfully!")
        print(f"⏱️  Training duration: {training_duration}")

        return True

    except Exception as e:
        print(f"❌ Training failed: {e}")
        print("💡 Check your data structure and relationships")
        return False

def generate_synthetic_data(self, scale: float = 1.0,
                          custom_table_sizes: Optional[Dict[str, int]] = None,
                          verbose: bool = True):
    """
    Generate synthetic data with advanced options and monitoring.

    Args:
        scale: Scaling factor for data generation (1.0 = same size as original)
        custom_table_sizes: Dictionary specifying exact sizes for specific tables
        verbose: Whether to show detailed generation progress

    Returns:
        Dictionary containing synthetic DataFrames
    """
    print(f"\n⚡ GENERATING SYNTHETIC DATA (Scale: {scale}x)")
    print("=" * 50)

    if not self.synthesizer:
        print("❌ Synthesizer not trained. Please call train_synthesizer() first.")
        return None

    try:
        generation_start = datetime.now()

        if verbose:
            print("🎯 Generation parameters:")
            if custom_table_sizes:
                print("   📋 Using custom table sizes:")
                for table, size in custom_table_sizes.items():
                    original_size = len(self.real_data.get(table, []))
                    ratio = size / original_size if original_size > 0 else 0
                    print(f"      {table}: {size:,} rows ({ratio:.2f}x original)")
            else:
                print(f"   📈 Using scale factor: {scale}x")

        # Generate synthetic data
        print("🔄 Generating synthetic tables...")

        if custom_table_sizes:
            # Convert custom sizes to scale and generate
            total_original = sum(len(df) for df in self.real_data.values())
            total_custom = sum(custom_table_sizes.get(table, len(df))
                             for table, df in self.real_data.items())
            effective_scale = total_custom / total_original if total_original > 0 else 1.0

            print(f"   📊 Effective scale from custom sizes: {effective_scale:.2f}x")
            self.synthetic_data = self.synthesizer.sample(scale=effective_scale)

            # Trim to exact requested sizes
            for table_name, requested_size in custom_table_sizes.items():
                if table_name in self.synthetic_data:
                    current_size = len(self.synthetic_data[table_name])
                    if current_size > requested_size:
                        self.synthetic_data[table_name] = self.synthetic_data[table_name].head(requested_size)
                        if verbose:
                            print(f"   ✂️  Trimmed {table_name}: {current_size} → {requested_size} rows")
        else:
            self.synthetic_data = self.synthesizer.sample(scale=scale)

        generation_end = datetime.now()
        generation_duration = generation_end - generation_start

        # Store generation statistics
        self.generation_stats['generation_duration'] = generation_duration.total_seconds()
        self.generation_stats['generation_start'] = generation_start
        self.generation_stats['generation_end'] = generation_end
        self.generation_stats['scale_used'] = scale

        # Summary statistics
        print("✅ Synthetic data generation completed!")
        print(f"⏱️  Generation duration: {generation_duration}")
        print("\n📊 GENERATION SUMMARY:")

        for table_name, synthetic_df in self.synthetic_data.items():
            original_size = len(self.real_data[table_name])
            synthetic_size = len(synthetic_df)
            actual_ratio = synthetic_size / original_size if original_size > 0 else 0

            print(f"   {table_name:20} | Original: {original_size:6,} | Synthetic: {synthetic_size:6,} | Ratio: {actual_ratio:.2f}x")

        return self.synthetic_data

    except Exception as e:
        print(f"❌ Generation failed: {e}")
        print("💡 Try reducing the scale factor or checking your trained model")
        return None

# Add methods to the class
RecursiveMultiTableSynthesizer.train_synthesizer = train_synthesizer
RecursiveMultiTableSynthesizer.generate_synthetic_data = generate_synthetic_data

print("✓ Training and generation methods added to class")

✓ Training and generation methods added to class


In [6]:
def run_diagnostic_evaluation(self, verbose: bool = True):
    """
    Run comprehensive diagnostic evaluation using SDV's diagnostic framework.

    Args:
        verbose: Whether to show detailed diagnostic results

    Returns:
        Diagnostic report with data validity and synthesis quality metrics
    """
    print("\n🔍 RUNNING DIAGNOSTIC EVALUATION")
    print("=" * 50)

    if not self.synthetic_data:
        print("❌ No synthetic data available. Generate synthetic data first.")
        return None

    try:
        print("🔬 Running SDV diagnostic evaluation...")
        diagnostic_start = datetime.now()

        # Run comprehensive diagnostic
        self.diagnostic_report = run_diagnostic(
            real_data=self.real_data,
            synthetic_data=self.synthetic_data,
            metadata=self.metadata,
            verbose=verbose
        )

        diagnostic_end = datetime.now()
        diagnostic_duration = diagnostic_end - diagnostic_start

        print(f"✅ Diagnostic evaluation completed in {diagnostic_duration}")

        # Extract key metrics
        if hasattr(self.diagnostic_report, 'get_results'):
            results = self.diagnostic_report.get_results()

            print("\n📊 DIAGNOSTIC SUMMARY:")
            print(f"   🎯 Overall Quality Score: {results.get('Quality Score', 'N/A')}")
            print(f"   🔗 Relationship Validity: {results.get('Relationship Validity', 'N/A')}")
            print(f"   📈 Data Validity: {results.get('Data Validity', 'N/A')}")

            # Store evaluation scores
            self.evaluation_scores['diagnostic'] = results

        return self.diagnostic_report

    except Exception as e:
        print(f"❌ Diagnostic evaluation failed: {e}")
        return None

def run_quality_evaluation(self, verbose: bool = True):
    """
    Run comprehensive quality evaluation using SDV's quality framework.

    Args:
        verbose: Whether to show detailed quality results

    Returns:
        Quality report with statistical similarity metrics
    """
    print("\n📊 RUNNING QUALITY EVALUATION")
    print("=" * 50)

    if not self.synthetic_data:
        print("❌ No synthetic data available. Generate synthetic data first.")
        return None

    try:
        print("📈 Running SDV quality evaluation...")
        quality_start = datetime.now()

        # Run comprehensive quality evaluation
        self.quality_report = evaluate_quality(
            real_data=self.real_data,
            synthetic_data=self.synthetic_data,
            metadata=self.metadata,
            verbose=verbose
        )

        quality_end = datetime.now()
        quality_duration = quality_end - quality_start

        print(f"✅ Quality evaluation completed in {quality_duration}")

        # Extract key metrics
        if hasattr(self.quality_report, 'get_results'):
            results = self.quality_report.get_results()

            print("\n📊 QUALITY SUMMARY:")
            for metric_name, score in results.items():
                if isinstance(score, (int, float)):
                    print(f"   📊 {metric_name}: {score:.3f}")
                else:
                    print(f"   📊 {metric_name}: {score}")

            # Store evaluation scores
            self.evaluation_scores['quality'] = results

        return self.quality_report

    except Exception as e:
        print(f"❌ Quality evaluation failed: {e}")
        return None

def generate_evaluation_report(self, save_path: Optional[str] = None):
    """
    Generate a comprehensive evaluation report combining all metrics.

    Args:
        save_path: Optional path to save the report as HTML

    Returns:
        Dictionary containing complete evaluation summary
    """
    print("\n📋 GENERATING COMPREHENSIVE EVALUATION REPORT")
    print("=" * 50)

    report = {
        'metadata': {
            'evaluation_timestamp': datetime.now(),
            'total_tables': len(self.real_data),
            'total_relationships': len(self.metadata.to_dict().get('relationships', [])),
            'training_duration': self.generation_stats.get('training_duration', 'N/A'),
            'generation_duration': self.generation_stats.get('generation_duration', 'N/A')
        },
        'table_statistics': self.table_stats,
        'generation_statistics': self.generation_stats,
        'evaluation_scores': self.evaluation_scores,
        'diagnostic_results': getattr(self.diagnostic_report, 'get_results', lambda: {})(),
        'quality_results': getattr(self.quality_report, 'get_results', lambda: {})()
    }

    # Display summary
    print("📊 COMPREHENSIVE EVALUATION SUMMARY:")
    print(f"   📅 Evaluation Date: {report['metadata']['evaluation_timestamp']}")
    print(f"   📋 Tables Processed: {report['metadata']['total_tables']}")
    print(f"   🔗 Relationships: {report['metadata']['total_relationships']}")

    if 'training_duration' in self.generation_stats:
        print(f"   ⏱️  Training Time: {self.generation_stats['training_duration']:.2f} seconds")
    if 'generation_duration' in self.generation_stats:
        print(f"   ⚡ Generation Time: {self.generation_stats['generation_duration']:.2f} seconds")

    # Overall quality assessment
    diagnostic_scores = self.evaluation_scores.get('diagnostic', {})
    quality_scores = self.evaluation_scores.get('quality', {})

    if diagnostic_scores or quality_scores:
        print("\n🎯 KEY PERFORMANCE METRICS:")

        for metric, value in diagnostic_scores.items():
            if isinstance(value, (int, float)):
                print(f"   📊 {metric}: {value:.3f}")

        for metric, value in quality_scores.items():
            if isinstance(value, (int, float)):
                print(f"   📈 {metric}: {value:.3f}")

    # Save report if path provided
    if save_path:
        try:
            import json
            with open(save_path, 'w') as f:
                # Convert datetime objects to strings for JSON serialization
                json_report = report.copy()
                json_report['metadata']['evaluation_timestamp'] = str(json_report['metadata']['evaluation_timestamp'])
                if 'training_start' in json_report['generation_statistics']:
                    json_report['generation_statistics']['training_start'] = str(json_report['generation_statistics']['training_start'])
                if 'training_end' in json_report['generation_statistics']:
                    json_report['generation_statistics']['training_end'] = str(json_report['generation_statistics']['training_end'])

                json.dump(json_report, f, indent=2, default=str)
            print(f"💾 Report saved to: {save_path}")
        except Exception as e:
            print(f"⚠️ Warning: Could not save report to {save_path}: {e}")

    return report

# Add methods to the class
RecursiveMultiTableSynthesizer.run_diagnostic_evaluation = run_diagnostic_evaluation
RecursiveMultiTableSynthesizer.run_quality_evaluation = run_quality_evaluation
RecursiveMultiTableSynthesizer.generate_evaluation_report = generate_evaluation_report

print("✓ Advanced evaluation methods added to class")
print("✅ RecursiveMultiTableSynthesizer class is now complete with all advanced features!")

✓ Advanced evaluation methods added to class
✅ RecursiveMultiTableSynthesizer class is now complete with all advanced features!


In [7]:
def create_enterprise_healthcare_data():
    """
    Create a comprehensive enterprise dataset with 20 interconnected tables
    representing a healthcare-retail business with complex relationships.

    Tables Created:
    1. companies - Parent organizations
    2. departments - Company departments
    3. locations - Physical locations
    4. employees - Staff members (with hierarchy)
    5. customers - Customer base
    6. products - Product catalog
    7. suppliers - Product suppliers
    8. categories - Product categories
    9. transactions - Sales transactions
    10. transaction_items - Individual items in transactions
    11. medical_records - Patient medical history
    12. medical_reports - Diagnostic reports
    13. prescriptions - Medical prescriptions
    14. appointments - Medical appointments
    15. insurance_policies - Insurance information
    16. claims - Insurance claims
    17. inventory - Product inventory
    18. reviews - Product/service reviews
    19. loyalty_programs - Customer loyalty
    20. audit_logs - System audit trail
    """
    print("🏗️  CREATING ENTERPRISE HEALTHCARE DATASET")
    print("=" * 50)
    print("Creating 20 interconnected tables with complex relationships...")

    # Set random seed for reproducibility
    np.random.seed(42)
    fake.seed_instance(42)

    # 1. COMPANIES (Root table)
    print("1️⃣  Creating companies...")
    companies = pd.DataFrame({
        'company_id': range(1, 6),  # 5 companies
        'company_name': [fake.company() for _ in range(5)],
        'company_type': np.random.choice(['Healthcare', 'Retail', 'Mixed'], 5),
        'founded_year': np.random.randint(1990, 2020, 5),
        'headquarters': [fake.city() for _ in range(5)],
        'annual_revenue': np.round(np.random.uniform(1000000, 100000000, 5), 2),
        'employee_count': np.random.randint(100, 5000, 5)
    })

    # 2. DEPARTMENTS
    print("2️⃣  Creating departments...")
    departments = pd.DataFrame({
        'dept_id': range(1, 26),  # 25 departments
        'company_id': np.random.choice(companies['company_id'], 25),
        'dept_name': np.random.choice([
            'Emergency', 'Cardiology', 'Oncology', 'Pediatrics', 'Surgery',
            'Pharmacy', 'Retail', 'Customer Service', 'IT', 'HR', 'Finance'
        ], 25),
        'budget': np.round(np.random.uniform(100000, 5000000, 25), 2),
        'manager_name': [fake.name() for _ in range(25)],
        'location_floor': np.random.randint(1, 10, 25)
    })

    # 3. LOCATIONS
    print("3️⃣  Creating locations...")
    locations = pd.DataFrame({
        'location_id': range(1, 16),  # 15 locations
        'company_id': np.random.choice(companies['company_id'], 15),
        'location_name': [fake.city() + ' Branch' for _ in range(15)],
        'address': [fake.address() for _ in range(15)],
        'city': [fake.city() for _ in range(15)],
        'state': [fake.state_abbr() for _ in range(15)],
        'zip_code': [fake.zipcode() for _ in range(15)],
        'phone': [fake.phone_number() for _ in range(15)],
        'facility_type': np.random.choice(['Hospital', 'Clinic', 'Retail Store', 'Warehouse'], 15),
        'square_feet': np.random.randint(1000, 50000, 15)
    })

    # 4. EMPLOYEES (with self-referencing hierarchy)
    print("4️⃣  Creating employees with hierarchy...")
    employees = pd.DataFrame({
        'employee_id': range(1, 201),  # 200 employees
        'dept_id': np.random.choice(departments['dept_id'], 200),
        'location_id': np.random.choice(locations['location_id'], 200),
        'first_name': [fake.first_name() for _ in range(200)],
        'last_name': [fake.last_name() for _ in range(200)],
        'email': [fake.email() for _ in range(200)],
        'phone': [fake.phone_number() for _ in range(200)],
        'hire_date': [fake.date_between(start_date='-5y', end_date='today') for _ in range(200)],
        'salary': np.round(np.random.uniform(35000, 120000, 200), 2),
        'position': np.random.choice([
            'Doctor', 'Nurse', 'Technician', 'Administrator', 'Pharmacist',
            'Sales Associate', 'Manager', 'Analyst', 'Specialist'
        ], 200),
        'employment_status': np.random.choice(['Full-time', 'Part-time', 'Contract'], 200, p=[0.7, 0.2, 0.1]),
        'reports_to': [None] * 150 + list(np.random.choice(range(1, 151), 50))  # Hierarchy
    })

    # 5. CUSTOMERS
    print("5️⃣  Creating customers...")
    customers = pd.DataFrame({
        'customer_id': range(1, 1001),  # 1000 customers
        'first_name': [fake.first_name() for _ in range(1000)],
        'last_name': [fake.last_name() for _ in range(1000)],
        'email': [fake.email() for _ in range(1000)],
        'phone': [fake.phone_number() for _ in range(1000)],
        'date_of_birth': [fake.date_of_birth(minimum_age=18, maximum_age=80) for _ in range(1000)],
        'gender': np.random.choice(['M', 'F', 'Other'], 1000, p=[0.45, 0.45, 0.1]),
        'address': [fake.address() for _ in range(1000)],
        'city': [fake.city() for _ in range(1000)],
        'state': [fake.state_abbr() for _ in range(1000)],
        'zip_code': [fake.zipcode() for _ in range(1000)],
        'registration_date': [fake.date_between(start_date='-3y', end_date='today') for _ in range(1000)],
        'customer_type': np.random.choice(['Individual', 'Business', 'Insurance'], 1000, p=[0.8, 0.15, 0.05])
    })

    # 6. CATEGORIES
    print("6️⃣  Creating product categories...")
    categories = pd.DataFrame({
        'category_id': range(1, 21),  # 20 categories
        'category_name': [
            'Prescription Drugs', 'Over-the-Counter', 'Medical Devices', 'First Aid',
            'Vitamins & Supplements', 'Personal Care', 'Baby Care', 'Beauty',
            'Health & Wellness', 'Mobility Aids', 'Dental Care', 'Vision Care',
            'Pain Relief', 'Allergy Relief', 'Diabetes Care', 'Heart Health',
            'Mental Health', 'Skin Care', 'Nutrition', 'Emergency Supplies'
        ],
        'parent_category_id': [None] * 10 + list(np.random.choice(range(1, 11), 10)),  # Hierarchy
        'description': [fake.text(max_nb_chars=200) for _ in range(20)],
        'is_prescription_required': np.random.choice([True, False], 20, p=[0.3, 0.7])
    })

    # 7. PRODUCTS
    print("7️⃣  Creating products...")
    products = pd.DataFrame({
        'product_id': range(1, 501),  # 500 products
        'category_id': np.random.choice(categories['category_id'], 500),
        'product_name': [fake.catch_phrase() + ' ' + fake.word() for _ in range(500)],
        'description': [fake.text(max_nb_chars=300) for _ in range(500)],
        'sku': [fake.bothify(text='???-###-???') for _ in range(500)],
        'price': np.round(np.random.uniform(5.99, 299.99, 500), 2),
        'cost': np.round(np.random.uniform(2.99, 150.00, 500), 2),
        'manufacturer': [fake.company() for _ in range(500)],
        'requires_prescription': np.random.choice([True, False], 500, p=[0.2, 0.8]),
        'dosage_form': np.random.choice([
            'Tablet', 'Capsule', 'Liquid', 'Cream', 'Injection', 'Device', 'Other'
        ], 500),
        'expiry_months': np.random.randint(6, 60, 500)
    })

    # 8. SUPPLIERS
    print("8️⃣  Creating suppliers...")
    suppliers = pd.DataFrame({
        'supplier_id': range(1, 51),  # 50 suppliers
        'supplier_name': [fake.company() for _ in range(50)],
        'contact_person': [fake.name() for _ in range(50)],
        'email': [fake.email() for _ in range(50)],
        'phone': [fake.phone_number() for _ in range(50)],
        'address': [fake.address() for _ in range(50)],
        'city': [fake.city() for _ in range(50)],
        'state': [fake.state_abbr() for _ in range(50)],
        'rating': np.round(np.random.uniform(3.0, 5.0, 50), 1),
        'established_year': np.random.randint(1980, 2020, 50),
        'specialty': np.random.choice([
            'Pharmaceuticals', 'Medical Devices', 'Personal Care', 'Vitamins'
        ], 50)
    })

    # 9. TRANSACTIONS
    print("9️⃣  Creating transactions...")
    transactions = pd.DataFrame({
        'transaction_id': range(1, 2001),  # 2000 transactions
        'customer_id': np.random.choice(customers['customer_id'], 2000),
        'employee_id': np.random.choice(employees['employee_id'], 2000),
        'location_id': np.random.choice(locations['location_id'], 2000),
        'transaction_date': [
            fake.date_time_between(start_date='-1y', end_date='now') for _ in range(2000)
        ],
        'transaction_type': np.random.choice(['Sale', 'Return', 'Exchange'], 2000, p=[0.85, 0.1, 0.05]),
        'payment_method': np.random.choice([
            'Credit Card', 'Debit Card', 'Cash', 'Insurance', 'HSA'
        ], 2000, p=[0.4, 0.25, 0.15, 0.15, 0.05]),
        'subtotal': np.round(np.random.uniform(10.00, 500.00, 2000), 2),
        'tax_amount': lambda x: np.round(x * 0.08, 2),  # Will be calculated
        'discount_amount': np.round(np.random.uniform(0, 50.00, 2000), 2),
        'total_amount': lambda x: x,  # Will be calculated
        'prescription_required': np.random.choice([True, False], 2000, p=[0.3, 0.7])
    })
    # Calculate derived fields
    transactions['tax_amount'] = np.round(transactions['subtotal'] * 0.08, 2)
    transactions['total_amount'] = np.round(
        transactions['subtotal'] + transactions['tax_amount'] - transactions['discount_amount'], 2
    )

    # 10. TRANSACTION_ITEMS
    print("🔟 Creating transaction items...")
    transaction_items = pd.DataFrame({
        'item_id': range(1, 5001),  # 5000 items
        'transaction_id': np.random.choice(transactions['transaction_id'], 5000),
        'product_id': np.random.choice(products['product_id'], 5000),
        'quantity': np.random.randint(1, 5, 5000),
        'unit_price': np.round(np.random.uniform(5.99, 299.99, 5000), 2),
        'discount_percent': np.round(np.random.uniform(0, 20, 5000), 1),
        'line_total': lambda x: x  # Will be calculated
    })
    # Calculate line totals
    transaction_items['line_total'] = np.round(
        transaction_items['quantity'] * transaction_items['unit_price'] *
        (1 - transaction_items['discount_percent'] / 100), 2
    )

    # 11. MEDICAL_RECORDS
    print("1️⃣1️⃣ Creating medical records...")
    medical_records = pd.DataFrame({
        'record_id': range(1, 801),  # 800 records
        'customer_id': np.random.choice(customers['customer_id'], 800),
        'employee_id': np.random.choice(employees['employee_id'], 800),  # Doctor/Nurse
        'location_id': np.random.choice(locations['location_id'], 800),
        'record_date': [fake.date_between(start_date='-2y', end_date='today') for _ in range(800)],
        'diagnosis': [fake.catch_phrase() for _ in range(800)],
        'symptoms': [fake.text(max_nb_chars=200) for _ in range(800)],
        'treatment': [fake.text(max_nb_chars=200) for _ in range(800)],
        'allergies': [fake.word() if np.random.random() < 0.3 else None for _ in range(800)],
        'blood_pressure': [f"{np.random.randint(90, 180)}/{np.random.randint(60, 120)}" for _ in range(800)],
        'heart_rate': np.random.randint(60, 100, 800),
        'weight_kg': np.round(np.random.uniform(40, 120, 800), 1),
        'height_cm': np.random.randint(140, 200, 800)
    })

    # 12. MEDICAL_REPORTS
    print("1️⃣2️⃣ Creating medical reports...")
    medical_reports = pd.DataFrame({
        'report_id': range(1, 401),  # 400 reports
        'record_id': np.random.choice(medical_records['record_id'], 400),
        'employee_id': np.random.choice(employees['employee_id'], 400),  # Technician/Doctor
        'report_type': np.random.choice([
            'Blood Test', 'X-Ray', 'MRI', 'CT Scan', 'Ultrasound', 'EKG', 'Biopsy'
        ], 400),
        'report_date': [fake.date_between(start_date='-2y', end_date='today') for _ in range(400)],
        'findings': [fake.text(max_nb_chars=400) for _ in range(400)],
        'recommendations': [fake.text(max_nb_chars=200) for _ in range(400)],
        'urgency_level': np.random.choice(['Low', 'Medium', 'High', 'Critical'], 400, p=[0.4, 0.35, 0.2, 0.05]),
        'is_abnormal': np.random.choice([True, False], 400, p=[0.3, 0.7]),
        'cost': np.round(np.random.uniform(50.00, 2000.00, 400), 2)
    })

    # 13. PRESCRIPTIONS
    print("1️⃣3️⃣ Creating prescriptions...")
    prescriptions = pd.DataFrame({
        'prescription_id': range(1, 601),  # 600 prescriptions
        'record_id': np.random.choice(medical_records['record_id'], 600),
        'product_id': np.random.choice(
            products[products['requires_prescription'] == True]['product_id'], 600
        ),
        'prescribing_doctor': np.random.choice(employees['employee_id'], 600),
        'prescription_date': [fake.date_between(start_date='-1y', end_date='today') for _ in range(600)],
        'dosage': [f"{np.random.randint(1, 4)} times daily" for _ in range(600)],
        'quantity_prescribed': np.random.randint(30, 90, 600),
        'refills_remaining': np.random.randint(0, 5, 600),
        'expiry_date': [fake.date_between(start_date='today', end_date='+1y') for _ in range(600)],
        'is_filled': np.random.choice([True, False], 600, p=[0.8, 0.2]),
        'pharmacy_notes': [fake.text(max_nb_chars=100) if np.random.random() < 0.3 else None for _ in range(600)]
    })

    # 14. APPOINTMENTS
    print("1️⃣4️⃣ Creating appointments...")
    appointments = pd.DataFrame({
        'appointment_id': range(1, 1201),  # 1200 appointments
        'customer_id': np.random.choice(customers['customer_id'], 1200),
        'employee_id': np.random.choice(employees['employee_id'], 1200),  # Doctor/Specialist
        'location_id': np.random.choice(locations['location_id'], 1200),
        'appointment_date': [
            fake.date_time_between(start_date='-30d', end_date='+30d') for _ in range(1200)
        ],
        'appointment_type': np.random.choice([
            'Consultation', 'Follow-up', 'Emergency', 'Routine Check', 'Specialist Visit'
        ], 1200),
        'status': np.random.choice([
            'Scheduled', 'Completed', 'Cancelled', 'No-Show', 'Rescheduled'
        ], 1200, p=[0.3, 0.5, 0.1, 0.05, 0.05]),
        'duration_minutes': np.random.choice([15, 30, 45, 60, 90], 1200, p=[0.1, 0.4, 0.3, 0.15, 0.05]),
        'reason': [fake.text(max_nb_chars=100) for _ in range(1200)],
        'notes': [fake.text(max_nb_chars=200) if np.random.random() < 0.4 else None for _ in range(1200)]
    })

    # 15. INSURANCE_POLICIES
    print("1️⃣5️⃣ Creating insurance policies...")
    insurance_policies = pd.DataFrame({
        'policy_id': range(1, 701),  # 700 policies
        'customer_id': np.random.choice(customers['customer_id'], 700),
        'policy_number': [fake.bothify(text='POL-####-????') for _ in range(700)],
        'insurance_company': [fake.company() + ' Insurance' for _ in range(700)],
        'policy_type': np.random.choice(['Health', 'Dental', 'Vision', 'Prescription'], 700),
        'start_date': [fake.date_between(start_date='-2y', end_date='today') for _ in range(700)],
        'end_date': [fake.date_between(start_date='today', end_date='+1y') for _ in range(700)],
        'premium_amount': np.round(np.random.uniform(100.00, 800.00, 700), 2),
        'deductible': np.round(np.random.uniform(500.00, 5000.00, 700), 2),
        'copay': np.round(np.random.uniform(10.00, 50.00, 700), 2),
        'coverage_percent': np.random.choice([70, 80, 85, 90, 100], 700),
        'is_active': np.random.choice([True, False], 700, p=[0.9, 0.1])
    })

    # 16. CLAIMS
    print("1️⃣6️⃣ Creating insurance claims...")
    claims = pd.DataFrame({
        'claim_id': range(1, 301),  # 300 claims
        'policy_id': np.random.choice(insurance_policies['policy_id'], 300),
        'transaction_id': np.random.choice(transactions['transaction_id'], 300),
        'claim_number': [fake.bothify(text='CLM-######') for _ in range(300)],
        'claim_date': [fake.date_between(start_date='-1y', end_date='today') for _ in range(300)],
        'claim_amount': np.round(np.random.uniform(50.00, 5000.00, 300), 2),
        'approved_amount': lambda x: x,  # Will be calculated
        'status': np.random.choice([
            'Submitted', 'Under Review', 'Approved', 'Denied', 'Paid'
        ], 300, p=[0.1, 0.2, 0.4, 0.1, 0.2]),
        'denial_reason': [
            fake.text(max_nb_chars=100) if np.random.random() < 0.1 else None for _ in range(300)
        ],
        'processing_date': [
            fake.date_between(start_date='-1y', end_date='today') for _ in range(300)
        ]
    })
    # Calculate approved amounts
    claims['approved_amount'] = np.where(
        claims['status'] == 'Denied', 0,
        np.round(claims['claim_amount'] * np.random.uniform(0.5, 1.0, 300), 2)
    )

    # 17. INVENTORY
    print("1️⃣7️⃣ Creating inventory...")
    inventory = pd.DataFrame({
        'inventory_id': range(1, 1001),  # 1000 inventory records
        'product_id': np.random.choice(products['product_id'], 1000),
        'location_id': np.random.choice(locations['location_id'], 1000),
        'supplier_id': np.random.choice(suppliers['supplier_id'], 1000),
        'quantity_on_hand': np.random.randint(0, 500, 1000),
        'reorder_level': np.random.randint(10, 100, 1000),
        'max_stock_level': np.random.randint(200, 1000, 1000),
        'last_restock_date': [fake.date_between(start_date='-6m', end_date='today') for _ in range(1000)],
        'expiry_date': [fake.date_between(start_date='today', end_date='+2y') for _ in range(1000)],
        'lot_number': [fake.bothify(text='LOT-####-??') for _ in range(1000)],
        'unit_cost': np.round(np.random.uniform(1.00, 100.00, 1000), 2),
        'storage_location': [fake.bothify(text='?##-??##') for _ in range(1000)]
    })

    # 18. REVIEWS
    print("1️⃣8️⃣ Creating reviews...")
    reviews = pd.DataFrame({
        'review_id': range(1, 1501),  # 1500 reviews
        'customer_id': np.random.choice(customers['customer_id'], 1500),
        'product_id': np.random.choice(products['product_id'], 1500),
        'transaction_id': np.random.choice(transactions['transaction_id'], 1500),
        'rating': np.random.choice([1, 2, 3, 4, 5], 1500, p=[0.05, 0.1, 0.15, 0.35, 0.35]),
        'review_text': [fake.text(max_nb_chars=300) for _ in range(1500)],
        'review_date': [fake.date_between(start_date='-1y', end_date='today') for _ in range(1500)],
        'verified_purchase': np.random.choice([True, False], 1500, p=[0.8, 0.2]),
        'helpful_votes': np.random.randint(0, 50, 1500),
        'response_from_business': [
            fake.text(max_nb_chars=150) if np.random.random() < 0.2 else None for _ in range(1500)
        ]
    })

    # 19. LOYALTY_PROGRAMS
    print("1️⃣9️⃣ Creating loyalty programs...")
    loyalty_programs = pd.DataFrame({
        'loyalty_id': range(1, 801),  # 800 loyalty memberships
        'customer_id': np.random.choice(customers['customer_id'], 800),
        'program_name': np.random.choice([
            'HealthPlus Rewards', 'Wellness Circle', 'Premium Care', 'Family Benefits'
        ], 800),
        'membership_level': np.random.choice(['Bronze', 'Silver', 'Gold', 'Platinum'], 800),
        'join_date': [fake.date_between(start_date='-2y', end_date='today') for _ in range(800)],
        'points_balance': np.random.randint(0, 5000, 800),
        'points_earned_total': np.random.randint(100, 10000, 800),
        'points_redeemed_total': np.random.randint(0, 8000, 800),
        'last_activity_date': [fake.date_between(start_date='-3m', end_date='today') for _ in range(800)],
        'is_active': np.random.choice([True, False], 800, p=[0.85, 0.15]),
        'annual_spending': np.round(np.random.uniform(200.00, 5000.00, 800), 2)
    })

    # 20. AUDIT_LOGS
    print("2️⃣0️⃣ Creating audit logs...")
    audit_logs = pd.DataFrame({
        'log_id': range(1, 3001),  # 3000 audit entries
        'user_id': np.random.choice(employees['employee_id'], 3000),
        'action': np.random.choice([
            'CREATE', 'UPDATE', 'DELETE', 'LOGIN', 'LOGOUT', 'VIEW', 'EXPORT'
        ], 3000, p=[0.2, 0.3, 0.05, 0.15, 0.15, 0.1, 0.05]),
        'table_name': np.random.choice([
            'customers', 'transactions', 'medical_records', 'prescriptions',
            'appointments', 'inventory', 'products'
        ], 3000),
        'record_id': np.random.randint(1, 1000, 3000),
        'timestamp': [
            fake.date_time_between(start_date='-6m', end_date='now') for _ in range(3000)
        ],
        'ip_address': [fake.ipv4() for _ in range(3000)],
        'user_agent': [fake.user_agent() for _ in range(3000)],
        'session_id': [fake.bothify(text='SES-########-????') for _ in range(3000)],
        'details': [fake.text(max_nb_chars=200) if np.random.random() < 0.3 else None for _ in range(3000)]
    })

    # Compile all tables
    enterprise_data = {
        'companies': companies,
        'departments': departments,
        'locations': locations,
        'employees': employees,
        'customers': customers,
        'categories': categories,
        'products': products,
        'suppliers': suppliers,
        'transactions': transactions,
        'transaction_items': transaction_items,
        'medical_records': medical_records,
        'medical_reports': medical_reports,
        'prescriptions': prescriptions,
        'appointments': appointments,
        'insurance_policies': insurance_policies,
        'claims': claims,
        'inventory': inventory,
        'reviews': reviews,
        'loyalty_programs': loyalty_programs,
        'audit_logs': audit_logs
    }

    print("\n✅ ENTERPRISE DATASET CREATION COMPLETE!")
    print(f"📊 Created {len(enterprise_data)} interconnected tables")

    total_records = sum(len(df) for df in enterprise_data.values())
    print(f"📈 Total records across all tables: {total_records:,}")

    # Display table summary
    print("\n📋 TABLE SUMMARY:")
    for table_name, df in enterprise_data.items():
        print(f"   {table_name:20} | {len(df):6,} rows | {len(df.columns):2} columns")

    return enterprise_data

print("✓ Enterprise dataset creation function defined")

✓ Enterprise dataset creation function defined


In [8]:
print("🚀 INITIALIZING COMPREHENSIVE ENTERPRISE DATASET")
print("=" * 60)

# Create the enterprise dataset
enterprise_data = create_enterprise_healthcare_data()

# Display comprehensive dataset overview
print("\n🔍 DETAILED DATASET ANALYSIS")
print("=" * 60)

# Calculate comprehensive statistics
total_records = sum(len(df) for df in enterprise_data.values())
total_columns = sum(len(df.columns) for df in enterprise_data.values())
total_memory = sum(df.memory_usage(deep=True).sum() for df in enterprise_data.values()) / 1024 / 1024

print(f"📊 OVERALL STATISTICS:")
print(f"   📋 Total Tables: {len(enterprise_data)}")
print(f"   📈 Total Records: {total_records:,}")
print(f"   📝 Total Columns: {total_columns}")
print(f"   💾 Total Memory: {total_memory:.2f} MB")

# Analyze data types across all tables
print(f"\n🔍 DATA TYPE ANALYSIS:")
numeric_cols = sum(len(df.select_dtypes(include=[np.number]).columns) for df in enterprise_data.values())
categorical_cols = sum(len(df.select_dtypes(include=['object', 'category']).columns) for df in enterprise_data.values())
datetime_cols = sum(len(df.select_dtypes(include=['datetime64']).columns) for df in enterprise_data.values())

print(f"   🔢 Numeric Columns: {numeric_cols}")
print(f"   📝 Categorical Columns: {categorical_cols}")
print(f"   📅 DateTime Columns: {datetime_cols}")

# Show sample data from key tables
print(f"\n👀 SAMPLE DATA FROM KEY TABLES:")
key_tables = ['companies', 'customers', 'transactions', 'medical_records', 'products']

for table_name in key_tables:
    if table_name in enterprise_data:
        df = enterprise_data[table_name]
        print(f"\n📋 {table_name.upper()} (Sample - showing first 3 rows):")
        print(f"   Shape: {df.shape}")
        print(df.head(3).to_string(index=False, max_cols=8))

print("\n✅ Dataset initialization complete and ready for synthesis!")

# Define primary keys for all tables
primary_keys = {
    'companies': 'company_id',
    'departments': 'dept_id',
    'locations': 'location_id',
    'employees': 'employee_id',
    'customers': 'customer_id',
    'categories': 'category_id',
    'products': 'product_id',
    'suppliers': 'supplier_id',
    'transactions': 'transaction_id',
    'transaction_items': 'item_id',
    'medical_records': 'record_id',
    'medical_reports': 'report_id',
    'prescriptions': 'prescription_id',
    'appointments': 'appointment_id',
    'insurance_policies': 'policy_id',
    'claims': 'claim_id',
    'inventory': 'inventory_id',
    'reviews': 'review_id',
    'loyalty_programs': 'loyalty_id',
    'audit_logs': 'log_id'
}

print(f"🔑 Primary keys defined for all {len(primary_keys)} tables")

🚀 INITIALIZING COMPREHENSIVE ENTERPRISE DATASET
🏗️  CREATING ENTERPRISE HEALTHCARE DATASET
Creating 20 interconnected tables with complex relationships...
1️⃣  Creating companies...
2️⃣  Creating departments...
3️⃣  Creating locations...
4️⃣  Creating employees with hierarchy...
5️⃣  Creating customers...
6️⃣  Creating product categories...
7️⃣  Creating products...
8️⃣  Creating suppliers...
9️⃣  Creating transactions...
🔟 Creating transaction items...
1️⃣1️⃣ Creating medical records...
1️⃣2️⃣ Creating medical reports...
1️⃣3️⃣ Creating prescriptions...
1️⃣4️⃣ Creating appointments...
1️⃣5️⃣ Creating insurance policies...
1️⃣6️⃣ Creating insurance claims...
1️⃣7️⃣ Creating inventory...
1️⃣8️⃣ Creating reviews...
1️⃣9️⃣ Creating loyalty programs...
2️⃣0️⃣ Creating audit logs...

✅ ENTERPRISE DATASET CREATION COMPLETE!
📊 Created 20 interconnected tables
📈 Total records across all tables: 19,115

📋 TABLE SUMMARY:
   companies            |      5 rows |  7 columns
   departments          

In [9]:
print("🤖 INITIALIZING ENHANCED MULTI-TABLE SYNTHESIZER")
print("=" * 60)

# Initialize the enhanced synthesizer with Gaussian Copula
synthesizer = RecursiveMultiTableSynthesizer(synthesizer_type='gaussian_copula')

# Add all tables with automatic metadata detection
print("\n🔍 Adding tables with automatic relationship detection...")
synthesizer.add_tables_from_dict(enterprise_data, primary_keys=primary_keys)

# Analyze automatically detected relationships
print("\n📊 ANALYZING AUTO-DETECTED RELATIONSHIPS...")
relationship_analysis = synthesizer.analyze_relationships()

# Display metadata summary
print(f"\n📋 METADATA DETECTION SUMMARY:")
metadata_dict = synthesizer.metadata.to_dict()
print(f"   📋 Tables processed: {len(metadata_dict.get('tables', {}))}")
print(f"   🔗 Relationships detected: {relationship_analysis.get('total_relationships', 0)}")
print(f"   🌳 Hierarchical relationships: {relationship_analysis.get('hierarchical', 0)}")
print(f"   🔄 Self-referencing relationships: {relationship_analysis.get('self_referencing', 0)}")


🤖 INITIALIZING ENHANCED MULTI-TABLE SYNTHESIZER
Enhanced RecursiveMultiTableSynthesizer initialized with gaussian_copula
Ready for automatic relationship detection and quality evaluation

🔍 Adding tables with automatic relationship detection...

=== ADDING 20 TABLES WITH METADATA DETECTION ===
Adding tables with individual metadata detection...
Added table 'companies' with primary key 'company_id'
Added table 'departments' with primary key 'dept_id'
Added table 'locations' with primary key 'location_id'
Added table 'employees' with primary key 'employee_id'
Added table 'customers' with primary key 'customer_id'
Added table 'categories' with primary key 'category_id'
Added table 'products' with primary key 'product_id'
Added table 'suppliers' with primary key 'supplier_id'
Added table 'transactions' with primary key 'transaction_id'
Added table 'transaction_items' with primary key 'item_id'
Added table 'medical_records' with primary key 'record_id'
Added table 'medical_reports' with pri

In [10]:
metadata_dict

#Added table 'companies' with primary key 'company_id'
#Added table 'departments' with primary key 'dept_id'
#Added table 'locations' with primary key 'location_id'
#Added table 'employees' with primary key 'employee_id'
#Added table 'customers' with primary key 'customer_id'

{'tables': {'companies': {'primary_key': 'company_id',
   'columns': {'company_id': {'sdtype': 'id'},
    'company_name': {'sdtype': 'categorical'},
    'company_type': {'sdtype': 'categorical'},
    'founded_year': {'sdtype': 'numerical'},
    'headquarters': {'sdtype': 'categorical'},
    'annual_revenue': {'sdtype': 'numerical'},
    'employee_count': {'sdtype': 'numerical'}}},
  'departments': {'primary_key': 'dept_id',
   'columns': {'dept_id': {'sdtype': 'id'},
    'company_id': {'sdtype': 'id'},
    'dept_name': {'sdtype': 'categorical'},
    'budget': {'sdtype': 'numerical'},
    'manager_name': {'sdtype': 'categorical'},
    'location_floor': {'sdtype': 'numerical'}}},
  'locations': {'primary_key': 'location_id',
   'columns': {'location_id': {'sdtype': 'id'},
    'company_id': {'sdtype': 'id'},
    'location_name': {'sdtype': 'categorical'},
    'address': {'sdtype': 'categorical'},
    'city': {'pii': True, 'sdtype': 'city'},
    'state': {'pii': True, 'sdtype': 'administra

In [None]:
print("🔧 ENHANCING AUTO-DETECTED RELATIONSHIPS")
print("=" * 60)

# Add any missing relationships that auto-detection might have missed
print("🔗 Adding critical relationships that may need manual specification...")

# Key business relationships that must be explicitly defined
critical_relationships = [
    # Company hierarchy
    ('companies', 'company_id', 'departments', 'company_id'),
    ('companies', 'company_id', 'locations', 'company_id'),

    # Employee relationships
    ('departments', 'dept_id', 'employees', 'dept_id'),
    ('locations', 'location_id', 'employees', 'location_id'),

    # Customer relationships
    ('departments', 'dept_id', 'customers', 'dept_id'),
    ('locations', 'location_id', 'customers', 'location_id')
]

# Add relationships that weren't auto-detected
added_relationships = 0
for parent_table, parent_col, child_table, child_col in critical_relationships:
    try:
        # Check if relationship already exists
        existing_rels = synthesizer.metadata.to_dict().get('relationships', [])
        relationship_exists = any(
            rel.get('parent_table_name') == parent_table and
            rel.get('child_table_name') == child_table and
            rel.get('parent_primary_key') == parent_col and
            rel.get('child_foreign_key') == child_col
            for rel in existing_rels
        )

        if not relationship_exists:
            print(" ===> NO relationship<=====")
            synthesizer.add_custom_relationship(parent_table, parent_col, child_table, child_col)
            added_relationships += 1

    except Exception as e:
        print(f"   ⚠️ Could not add {parent_table}.{parent_col} → {child_table}.{child_col}: {e}")

print(f"\n✅ Added {added_relationships} additional relationships")

# Re-analyze relationships after manual additions
print(f"\n📊 FINAL RELATIONSHIP ANALYSIS:")
final_analysis = synthesizer.analyze_relationships()

print(f"   🔗 Total relationships: {final_analysis.get('total_relationships', 0)}")
print(f"   🌳 Hierarchical: {final_analysis.get('hierarchical', 0)}")
print(f"   🔄 Self-referencing: {final_analysis.get('self_referencing', 0)}")

# Create relationship visualization
print(f"\n🎨 Creating comprehensive relationship visualization...")
#synthesizer.visualize_table_dependencies(figsize=(18, 12))

🔧 ENHANCING AUTO-DETECTED RELATIONSHIPS
🔗 Adding critical relationships that may need manual specification...
 ===> NO relationship<=====
====> parent_table_name :   companies
====> parent_primary_key :   company_id
====> child_table_name :   departments
====> child_foreign_key :   company_id
❌ Error adding relationship: Unknown table name ('company_id').
 ===> NO relationship<=====
====> parent_table_name :   companies
====> parent_primary_key :   company_id
====> child_table_name :   locations
====> child_foreign_key :   company_id
❌ Error adding relationship: Unknown table name ('company_id').
 ===> NO relationship<=====
====> parent_table_name :   departments
====> parent_primary_key :   dept_id
====> child_table_name :   employees
====> child_foreign_key :   dept_id
❌ Error adding relationship: Unknown table name ('dept_id').
 ===> NO relationship<=====
====> parent_table_name :   locations
====> parent_primary_key :   location_id
====> child_table_name :   employees
====> child_f

In [None]:
synthesizer.metadata

{
    "tables": {
        "companies": {
            "columns": {
                "company_id": {
                    "sdtype": "id"
                },
                "company_name": {
                    "sdtype": "categorical"
                },
                "company_type": {
                    "sdtype": "categorical"
                },
                "founded_year": {
                    "sdtype": "numerical"
                },
                "headquarters": {
                    "sdtype": "categorical"
                },
                "annual_revenue": {
                    "sdtype": "numerical"
                },
                "employee_count": {
                    "sdtype": "numerical"
                }
            },
            "primary_key": "company_id"
        },
        "departments": {
            "columns": {
                "dept_id": {
                    "sdtype": "id"
                },
                "company_id": {
                    "sdtype": "id"
   

In [11]:
def add_relationships_from_list(
    metadata,
    rel_list,
    *,
    dependency_graph=None,   # must be passed by keyword
    strict=True,
    verbose=True,
):
    """
    Add relationships to SDV MultiTable metadata from a list of 4-tuples:
      (parent_table, parent_primary_key, child_table, child_foreign_key)
    Works when metadata.tables items are dicts or SingleTableMetadata objects.
    """
    results = {"added": [], "skipped_existing": [], "skipped_invalid": [], "errors": []}

    # --- helpers -------------------------------------------------------------
    def _table_meta_to_dict(tbl_meta):
        if isinstance(tbl_meta, dict):
            pk = tbl_meta.get("primary_key")
            cols = tbl_meta.get("fields") or tbl_meta.get("columns") or {}
            return {"primary_key": pk, "columns": cols}
        if hasattr(tbl_meta, "to_dict"):
            d = tbl_meta.to_dict()
            pk = d.get("primary_key")
            cols = d.get("fields") or d.get("columns") or {}
            return {"primary_key": pk, "columns": cols}
        pk = getattr(tbl_meta, "primary_key", None)
        if hasattr(tbl_meta, "get_columns"):
            cols_dict = {c: {} for c in tbl_meta.get_columns()}
        else:
            cols_attr = getattr(tbl_meta, "columns", None)
            if isinstance(cols_attr, dict):
                cols_dict = cols_attr
            elif isinstance(cols_attr, (list, tuple)):
                cols_dict = {c: {} for c in cols_attr}
            else:
                cols_dict = {}
        return {"primary_key": pk, "columns": cols_dict}

    def _relationship_exists(meta, ptab, ctab, ppk, cfk):
        rels = getattr(meta, "relationships", []) or []
        for r in rels:
            if isinstance(r, dict):
                if (r.get("parent_table_name") == ptab and
                    r.get("child_table_name") == ctab and
                    r.get("parent_primary_key") == ppk and
                    r.get("child_foreign_key") == cfk):
                    return True
            else:
                if (getattr(r, "parent_table_name", None) == ptab and
                    getattr(r, "child_table_name", None) == ctab and
                    getattr(r, "parent_primary_key", None) == ppk and
                    getattr(r, "child_foreign_key", None) == cfk):
                    return True
        return False

    if not hasattr(metadata, "tables"):
        raise TypeError("Provided metadata has no 'tables' attribute. Need MultiTable metadata.")

    tables = metadata.tables

    for rel in rel_list:
        try:
            if len(rel) != 4:
                msg = "tuple must be (parent_table, parent_pk, child_table, child_fk)"
                if strict: raise ValueError(msg)
                results["skipped_invalid"].append((rel, msg))
                if verbose: print(f"⚠️ {msg}: {rel}")
                continue

            parent_table, parent_pk, child_table, child_fk = rel

            if parent_table not in tables:
                msg = f"parent table '{parent_table}' not found"
                if strict: raise KeyError(msg)
                results["skipped_invalid"].append((rel, msg))
                if verbose: print(f"⚠️ {msg}")
                continue

            if child_table not in tables:
                msg = f"child table '{child_table}' not found"
                if strict: raise KeyError(msg)
                results["skipped_invalid"].append((rel, msg))
                if verbose: print(f"⚠️ {msg}")
                continue

            pinfo = _table_meta_to_dict(tables[parent_table])
            cinfo = _table_meta_to_dict(tables[child_table])

            declared_pk = pinfo.get("primary_key")
            if declared_pk and declared_pk != parent_pk:
                msg = (f"declared primary key for '{parent_table}' is '{declared_pk}', not '{parent_pk}'")
                if strict: raise ValueError(msg)
                if verbose: print(f"⚠️ {msg}. Proceeding anyway.")

            child_cols = cinfo.get("columns") or {}
            if child_fk not in child_cols:
                msg = f"child foreign key '{child_fk}' not found in '{child_table}'"
                if strict: raise KeyError(msg)
                results["skipped_invalid"].append((rel, msg))
                if verbose: print(f"⚠️ {msg}")
                continue

            if _relationship_exists(metadata, parent_table, child_table, parent_pk, child_fk):
                results["skipped_existing"].append(rel)
                if verbose: print(f"⏭️ Already exists: {rel}")
                # still mirror into graph if provided
                if dependency_graph is not None and hasattr(dependency_graph, "add_edge"):
                    dependency_graph.add_edge(parent_table, child_table)
                continue

            # Add to SDV metadata
            metadata.add_relationship(
                parent_table_name=parent_table,
                child_table_name=child_table,
                parent_primary_key=parent_pk,
                child_foreign_key=child_fk
            )

            # Add to dependency graph (optional)
            if dependency_graph is not None and hasattr(dependency_graph, "add_edge"):
                dependency_graph.add_edge(parent_table, child_table)

            results["added"].append(rel)
            if verbose: print(f"✅ Added: {parent_table}.{parent_pk} → {child_table}.{child_fk}")

        except Exception as e:
            results["errors"].append((rel, str(e)))
            if verbose: print(f"❌ Error adding {rel}: {e}")

    return results


In [12]:
critical_relationships = [
    # Company hierarchy
    ('companies', 'company_id', 'departments', 'company_id'),
    ('companies', 'company_id', 'locations', 'company_id'),

    # Employee relationships
    ('departments', 'dept_id', 'employees', 'dept_id'),
    ('locations', 'location_id', 'employees', 'location_id'),

    # Self-referencing relationships
    ('employees', 'employee_id', 'employees', 'reports_to'),
    ('categories', 'category_id', 'categories', 'parent_category_id'),

    # Customer relationships
    ('customers', 'customer_id', 'transactions', 'customer_id'),
    ('customers', 'customer_id', 'medical_records', 'customer_id'),
    ('customers', 'customer_id', 'appointments', 'customer_id'),
    ('customers', 'customer_id', 'insurance_policies', 'customer_id'),
    ('customers', 'customer_id', 'loyalty_programs', 'customer_id'),

    # Product and inventory relationships
    ('categories', 'category_id', 'products', 'category_id'),
    ('products', 'product_id', 'transaction_items', 'product_id'),
    ('products', 'product_id', 'inventory', 'product_id'),
    ('products', 'product_id', 'reviews', 'product_id'),
    ('products', 'product_id', 'prescriptions', 'product_id'),

    # Transaction relationships
    ('transactions', 'transaction_id', 'transaction_items', 'transaction_id'),
    ('transactions', 'transaction_id', 'claims', 'transaction_id'),
    ('transactions', 'transaction_id', 'reviews', 'transaction_id'),
    ('employees', 'employee_id', 'transactions', 'employee_id'),
    ('locations', 'location_id', 'transactions', 'location_id'),

    # Medical relationships
    ('medical_records', 'record_id', 'medical_reports', 'record_id'),
    ('medical_records', 'record_id', 'prescriptions', 'record_id'),
    ('employees', 'employee_id', 'medical_records', 'employee_id'),
    ('employees', 'employee_id', 'appointments', 'employee_id'),
    ('locations', 'location_id', 'medical_records', 'location_id'),
    ('locations', 'location_id', 'appointments', 'location_id'),

    # Insurance relationships
    ('insurance_policies', 'policy_id', 'claims', 'policy_id'),

    # Supplier and inventory
    ('suppliers', 'supplier_id', 'inventory', 'supplier_id'),
    ('locations', 'location_id', 'inventory', 'location_id'),

    # Audit trail
    ('employees', 'employee_id', 'audit_logs', 'user_id')
]

mData = synthesizer.metadata  # must be MultiTable metadata
graph = synthesizer.dependency_graph

summary = add_relationships_from_list(
    mData,
    critical_relationships,
    dependency_graph=graph,    # <-- pass by keyword
    strict=False,
    verbose=True
)
print(summary)


✅ Added: companies.company_id → departments.company_id
✅ Added: companies.company_id → locations.company_id
✅ Added: departments.dept_id → employees.dept_id
✅ Added: locations.location_id → employees.location_id
❌ Error adding ('employees', 'employee_id', 'employees', 'reports_to'): Relationship between tables ('employees', 'employees') is invalid. The primary and foreign key columns are not the same type.
❌ Error adding ('categories', 'category_id', 'categories', 'parent_category_id'): The relationships in the dataset describe a circular dependency between tables ['categories', 'categories'].
✅ Added: customers.customer_id → transactions.customer_id
✅ Added: customers.customer_id → medical_records.customer_id
✅ Added: customers.customer_id → appointments.customer_id
✅ Added: customers.customer_id → insurance_policies.customer_id
✅ Added: customers.customer_id → loyalty_programs.customer_id
✅ Added: categories.category_id → products.category_id
✅ Added: products.product_id → transactio

In [13]:
synthesizer.metadata.relationships

[{'parent_table_name': 'companies',
  'child_table_name': 'departments',
  'parent_primary_key': 'company_id',
  'child_foreign_key': 'company_id'},
 {'parent_table_name': 'companies',
  'child_table_name': 'locations',
  'parent_primary_key': 'company_id',
  'child_foreign_key': 'company_id'},
 {'parent_table_name': 'departments',
  'child_table_name': 'employees',
  'parent_primary_key': 'dept_id',
  'child_foreign_key': 'dept_id'},
 {'parent_table_name': 'locations',
  'child_table_name': 'employees',
  'parent_primary_key': 'location_id',
  'child_foreign_key': 'location_id'},
 {'parent_table_name': 'customers',
  'child_table_name': 'transactions',
  'parent_primary_key': 'customer_id',
  'child_foreign_key': 'customer_id'},
 {'parent_table_name': 'customers',
  'child_table_name': 'medical_records',
  'parent_primary_key': 'customer_id',
  'child_foreign_key': 'customer_id'},
 {'parent_table_name': 'customers',
  'child_table_name': 'appointments',
  'parent_primary_key': 'custom

In [None]:
print("🚀 TRAINING MULTI-TABLE SYNTHESIZER")
print("=" * 60)

# Display pre-training summary
print("📊 PRE-TRAINING SUMMARY:")
print(f"   📋 Tables to train on: {len(enterprise_data)}")
print(f"   📈 Total training records: {sum(len(df) for df in enterprise_data.values()):,}")
print(f"   🔗 Total relationships: {len(synthesizer.metadata.to_dict().get('relationships', []))}")
print(f"   💾 Dataset memory usage: {sum(df.memory_usage(deep=True).sum() for df in enterprise_data.values()) / 1024 / 1024:.2f} MB")

# Start training with comprehensive monitoring
training_success = synthesizer.train_synthesizer(verbose=True)

if training_success:
    print("\n🎉 TRAINING COMPLETED SUCCESSFULLY!")

    # Display training statistics
    training_duration = synthesizer.generation_stats.get('training_duration', 0)
    print(f"\n📊 TRAINING PERFORMANCE METRICS:")
    print(f"   ⏱️  Total training time: {training_duration:.2f} seconds")
    print(f"   📈 Records per second: {sum(len(df) for df in enterprise_data.values()) / max(training_duration, 1):.0f}")
    print(f"   🧠 Model complexity: Multi-table HMA with {len(enterprise_data)} tables")

    # Memory usage after training
    import psutil
    import os

    try:
        process = psutil.Process(os.getpid())
        memory_usage = process.memory_info().rss / 1024 / 1024  # MB
        print(f"   💾 Current memory usage: {memory_usage:.2f} MB")
    except:
        print("   💾 Memory usage: Unable to determine")

    print(f"\n✅ Synthesizer is ready for data generation!")

else:
    print("\n❌ TRAINING FAILED!")
    print("Please check the error messages above and verify your data structure.")

🚀 TRAINING MULTI-TABLE SYNTHESIZER
📊 PRE-TRAINING SUMMARY:
   📋 Tables to train on: 20
   📈 Total training records: 19,115
   🔗 Total relationships: 29
   💾 Dataset memory usage: 6.40 MB

🚀 TRAINING MULTI-TABLE SYNTHESIZER
🔍 Validating metadata structure...
✅ Metadata validation successful
🤖 Initializing gaussian_copula synthesizer...
PerformanceAlert: Using the HMASynthesizer on this metadata schema is not recommended. To model this data, HMA will generate a large number of columns. (27393830626741628 columns)


        Table Name  # Columns in Metadata     Est # Columns
         companies                      6 27393830295659408
       departments                      4         165501720
         locations                      8         165519905
         employees                      9             18190
         customers                     12             20216
        categories                      4             21527
          products                      9               204
 

Preprocess Tables: 100%|██████████| 20/20 [00:08<00:00,  2.45it/s]



Learning relationships:


(1/29) Tables 'transactions' and 'transaction_items' ('transaction_id'): 100%|██████████| 1837/1837 [05:19<00:00,  5.74it/s]
(2/29) Tables 'transactions' and 'reviews' ('transaction_id'): 100%|██████████| 1054/1054 [02:01<00:00,  8.70it/s]
(3/29) Tables 'transactions' and 'claims' ('transaction_id'): 100%|██████████| 284/284 [00:07<00:00, 36.15it/s]
(4/29) Tables 'locations' and 'transactions' ('location_id'): 100%|██████████| 15/15 [00:17<00:00,  1.18s/it]
(5/29) Tables 'locations' and 'inventory' ('location_id'): 100%|██████████| 15/15 [00:06<00:00,  2.27it/s]
(6/29) Tables 'employees' and 'transactions' ('employee_id'): 100%|██████████| 200/200 [03:23<00:00,  1.02s/it]
(7/29) Tables 'employees' and 'appointments' ('employee_id'): 100%|██████████| 200/200 [01:07<00:00,  2.97it/s]
(8/29) Tables 'employees' and 'audit_logs' ('user_id'): 100%|██████████| 200/200 [01:08<00:00,  2.90it/s]
(9/29) Tables 'medical_records' and 'medical_reports' ('record_id'): 100%|██████████| 320/320 [00:29<