In [None]:
"""
Comprehensive Weaviate Connection Diagnostic Script
Checks ALL potential socket/connection issues systematically
"""

import weaviate
import logging
import time
import os
import socket
import ssl
import sys
from datetime import datetime
import threading
import json
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv(override=True)

# Map your env variables to the expected names
os.environ['WEAVIATE_HTTP_HOST'] = os.getenv('CLUSTER_URL')
os.environ['WEAVIATE_HTTP_PORT'] = os.getenv('HTTP_PORT', '443')
os.environ['WEAVIATE_GRPC_HOST'] = os.getenv('CLUSTER_GRPC_URL')
os.environ['WEAVIATE_GRPC_PORT'] = os.getenv('GRPC_PORT', '443')
os.environ['WEAVIATE_API_KEY'] = os.getenv('API_KEY')

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class WeaviateDiagnostics:
    def __init__(self):
        self.results = {}
        self.issues_found = []
        self.recommendations = []
        
    def log_issue(self, issue):
        self.issues_found.append(issue)
        logger.error(f"‚ùå ISSUE: {issue}")
    
    def log_recommendation(self, rec):
        self.recommendations.append(rec)
        logger.info(f"üí° RECOMMENDATION: {rec}")
    
    # ==================== SECTION 1: ENVIRONMENT ====================
    def check_env_vars(self):
        """Verify environment configuration"""
        logger.info("\n" + "="*60)
        logger.info("SECTION 1: ENVIRONMENT VARIABLES")
        logger.info("="*60)
        
        required = {
            'WEAVIATE_HTTP_HOST': os.getenv('WEAVIATE_HTTP_HOST'),
            'WEAVIATE_HTTP_PORT': os.getenv('WEAVIATE_HTTP_PORT', '443'),
            'WEAVIATE_GRPC_HOST': os.getenv('WEAVIATE_GRPC_HOST'),
            'WEAVIATE_GRPC_PORT': os.getenv('WEAVIATE_GRPC_PORT', '443'),
            'WEAVIATE_API_KEY': os.getenv('WEAVIATE_API_KEY')
        }
        
        missing = [k for k, v in required.items() if not v]
        
        if missing:
            self.log_issue(f"Missing environment variables: {missing}")
            self.results['env_check'] = 'FAIL'
            return False
        
        logger.info("‚úì All environment variables present")
        for k, v in required.items():
            if 'KEY' in k:
                logger.info(f"  {k}: ***hidden***")
            else:
                logger.info(f"  {k}: {v}")
        
        self.results['env_check'] = 'PASS'
        return True
    
    # ==================== SECTION 2: DNS RESOLUTION ====================
    def check_dns_resolution(self):
        """Check DNS resolution for hosts"""
        logger.info("\n" + "="*60)
        logger.info("SECTION 2: DNS RESOLUTION")
        logger.info("="*60)
        
        hosts = {
            'HTTP': os.getenv('WEAVIATE_HTTP_HOST'),
            'GRPC': os.getenv('WEAVIATE_GRPC_HOST')
        }
        
        for name, host in hosts.items():
            try:
                ip = socket.gethostbyname(host)
                logger.info(f"‚úì {name} host '{host}' resolves to {ip}")
                self.results[f'dns_{name.lower()}'] = 'PASS'
            except socket.gaierror as e:
                self.log_issue(f"DNS resolution failed for {name} host '{host}': {e}")
                self.log_recommendation(f"Check DNS configuration for {host}")
                self.results[f'dns_{name.lower()}'] = 'FAIL'
                return False
        
        return True
    
    # ==================== SECTION 3: TCP CONNECTIVITY ====================
    def check_tcp_connectivity(self):
        """Test raw TCP connectivity"""
        logger.info("\n" + "="*60)
        logger.info("SECTION 3: TCP CONNECTIVITY")
        logger.info("="*60)
        
        endpoints = {
            'HTTP': (os.getenv('WEAVIATE_HTTP_HOST'), int(os.getenv('WEAVIATE_HTTP_PORT', 443))),
            'GRPC': (os.getenv('WEAVIATE_GRPC_HOST'), int(os.getenv('WEAVIATE_GRPC_PORT', 443)))
        }
        
        for name, (host, port) in endpoints.items():
            logger.info(f"\nTesting {name} endpoint {host}:{port}")
            
            # Multiple connection attempts
            successes = 0
            failures = 0
            latencies = []
            
            for attempt in range(5):
                try:
                    start = time.time()
                    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                    s.settimeout(5)
                    result = s.connect_ex((host, port))
                    latency = (time.time() - start) * 1000
                    s.close()
                    
                    if result == 0:
                        successes += 1
                        latencies.append(latency)
                        logger.info(f"  Attempt {attempt+1}: SUCCESS ({latency:.2f}ms)")
                    else:
                        failures += 1
                        logger.warning(f"  Attempt {attempt+1}: FAIL (error code {result})")
                    
                    time.sleep(0.5)
                    
                except Exception as e:
                    failures += 1
                    logger.error(f"  Attempt {attempt+1}: EXCEPTION - {e}")
            
            # Analyze results
            if failures > 0:
                self.log_issue(f"{name} endpoint has {failures}/5 connection failures")
                self.log_recommendation(f"Check firewall/network rules for {host}:{port}")
                self.results[f'tcp_{name.lower()}'] = f'UNSTABLE ({failures} failures)'
            elif latencies:
                avg_latency = sum(latencies) / len(latencies)
                max_latency = max(latencies)
                logger.info(f"‚úì {name} TCP: avg={avg_latency:.2f}ms, max={max_latency:.2f}ms")
                
                if max_latency > 1000:
                    self.log_issue(f"{name} has high latency (max {max_latency:.2f}ms)")
                    self.log_recommendation(f"Investigate network latency to {host}")
                
                self.results[f'tcp_{name.lower()}'] = 'STABLE'
        
        return True
    
    # ==================== SECTION 4: TLS/SSL ====================
    def check_tls_connection(self):
        """Verify TLS/SSL connectivity"""
        logger.info("\n" + "="*60)
        logger.info("SECTION 4: TLS/SSL VERIFICATION")
        logger.info("="*60)
        
        hosts = [
            ('GRPC', os.getenv('WEAVIATE_GRPC_HOST'), int(os.getenv('WEAVIATE_GRPC_PORT', 443)))
        ]
        
        for name, host, port in hosts:
            try:
                context = ssl.create_default_context()
                with socket.create_connection((host, port), timeout=10) as sock:
                    with context.wrap_socket(sock, server_hostname=host) as ssock:
                        cert = ssock.getpeercert()
                        logger.info(f"‚úì {name} TLS connection successful")
                        logger.info(f"  Protocol: {ssock.version()}")
                        if cert:
                            logger.info(f"  Subject: {dict(x[0] for x in cert['subject'])}")
                        self.results[f'tls_{name.lower()}'] = 'PASS'
            except ssl.SSLError as e:
                self.log_issue(f"{name} TLS/SSL error: {e}")
                self.log_recommendation("Check certificate validity and TLS configuration")
                self.results[f'tls_{name.lower()}'] = 'FAIL'
            except Exception as e:
                self.log_issue(f"{name} TLS connection failed: {e}")
                self.results[f'tls_{name.lower()}'] = 'FAIL'
        
        return True
    
    # ==================== SECTION 5: WEAVIATE CONNECTION ====================
    def check_weaviate_connection(self):
        """Test actual Weaviate client connection"""
        logger.info("\n" + "="*60)
        logger.info("SECTION 5: WEAVIATE CLIENT CONNECTION")
        logger.info("="*60)
        
        try:
            client = weaviate.WeaviateClient(
                connection_params=weaviate.connect.ConnectionParams.from_params(
                    http_host=os.getenv('WEAVIATE_HTTP_HOST'),
                    http_port=int(os.getenv('WEAVIATE_HTTP_PORT', 443)),
                    http_secure=True,
                    grpc_host=os.getenv('WEAVIATE_GRPC_HOST'),
                    grpc_port=int(os.getenv('WEAVIATE_GRPC_PORT', 443)),
                    grpc_secure=True
                ),
                additional_config=weaviate.classes.init.AdditionalConfig(
                    timeout=weaviate.classes.init.Timeout(
                        init=120,
                        query=3600,
                        insert=3600
                    )
                ),
                skip_init_checks=False,
                auth_client_secret=weaviate.auth.AuthApiKey(api_key=os.getenv('WEAVIATE_API_KEY'))
            )
            
            logger.info("Attempting connection...")
            client.connect()
            
            if not client.is_connected():
                self.log_issue("Client reports not connected after connect()")
                self.results['weaviate_connect'] = 'FAIL'
                return None
            
            logger.info("‚úì Client connected successfully")
            
            # Get cluster metadata
            meta = client.get_meta()
            logger.info(f"‚úì Weaviate version: {meta.get('version', 'unknown')}")
            
            self.results['weaviate_connect'] = 'PASS'
            return client
            
        except Exception as e:
            self.log_issue(f"Weaviate connection failed: {type(e).__name__}: {e}")
            
            if "Socket closed" in str(e) or "grpc_status:14" in str(e):
                self.log_issue("SOCKET CLOSED ERROR - gRPC connection terminated unexpectedly")
                self.log_recommendation("This indicates connection pool or timeout issues")
            
            if "authentication" in str(e).lower() or "unauthorized" in str(e).lower():
                self.log_recommendation("Verify WEAVIATE_API_KEY is correct")
            
            self.results['weaviate_connect'] = f'FAIL: {type(e).__name__}'
            return None
    
    # ==================== SECTION 6: CONNECTION POOL CONFIG ====================
    def test_connection_pool_configs(self, client):
        """Test various connection pool configurations"""
        logger.info("\n" + "="*60)
        logger.info("SECTION 6: CONNECTION POOL CONFIGURATION TESTS")
        logger.info("="*60)
        
        if not client:
            logger.warning("Skipping - no client connection")
            return
        
        test_configs = [
            # (session_pool_maxsize, concurrent_requests, description)
            (1, 2, "CONFLICT: maxsize < concurrent (reported issue)"),
            (1, 1, "SAFE: maxsize = concurrent"),
            (2, 2, "BALANCED: maxsize = concurrent (2)"),
            (5, 2, "SAFE: maxsize > concurrent"),
        ]
        
        test_collection = f'DiagTest_{int(time.time())}'
        
        for maxsize, concurrent, desc in test_configs:
            logger.info(f"\n--- Testing: {desc} ---")
            logger.info(f"    session_pool_maxsize={maxsize}, concurrent_requests={concurrent}")
            
            try:
                # Create client with specific config
                test_client = weaviate.WeaviateClient(
                    connection_params=weaviate.connect.ConnectionParams.from_params(
                        http_host=os.getenv('WEAVIATE_HTTP_HOST'),
                        http_port=int(os.getenv('WEAVIATE_HTTP_PORT', 443)),
                        http_secure=True,
                        grpc_host=os.getenv('WEAVIATE_GRPC_HOST'),
                        grpc_port=int(os.getenv('WEAVIATE_GRPC_PORT', 443)),
                        grpc_secure=True
                    ),
                    additional_config=weaviate.classes.init.AdditionalConfig(
                        timeout=weaviate.classes.init.Timeout(init=120, query=3600, insert=3600),
                        connection_config=weaviate.config.ConnectionConfig(
                            session_pool_connections=1,
                            session_pool_maxsize=maxsize,
                            session_pool_max_retries=5
                        )
                    ),
                    skip_init_checks=False,
                    auth_client_secret=weaviate.auth.AuthApiKey(api_key=os.getenv('WEAVIATE_API_KEY'))
                )
                
                test_client.connect()
                
                # Create test collection
                col_name = f'{test_collection}_{maxsize}_{concurrent}'
                if test_client.collections.exists(col_name):
                    test_client.collections.delete(col_name)
                
                test_client.collections.create(
                    col_name,
                    vectorizer_config=weaviate.classes.config.Configure.Vectorizer.none()
                )
                
                collection = test_client.collections.get(col_name)
                
                # Batch test with 50 objects
                test_objects = [{'prop': f'value_{i}'} for i in range(50)]
                errors = []
                
                with collection.batch.fixed_size(
                    batch_size=10,
                    concurrent_requests=concurrent
                ) as batch:
                    for obj in test_objects:
                        try:
                            batch.add_object(properties=obj)
                        except Exception as e:
                            errors.append(str(e))
                
                failures = collection.batch.failed_objects
                
                if failures or errors:
                    self.log_issue(f"Config {maxsize}/{concurrent}: {len(failures)} batch failures, {len(errors)} errors")
                    logger.error(f"  Sample errors: {errors[:2] if errors else failures[:2]}")
                    self.results[f'pool_config_{maxsize}_{concurrent}'] = 'FAIL'
                else:
                    logger.info(f"‚úì Config {maxsize}/{concurrent}: All 50 objects inserted successfully")
                    self.results[f'pool_config_{maxsize}_{concurrent}'] = 'PASS'
                
                test_client.collections.delete(col_name)
                test_client.close()
                
            except Exception as e:
                self.log_issue(f"Config {maxsize}/{concurrent} raised exception: {type(e).__name__}: {e}")
                
                if "Socket closed" in str(e) or "grpc_status:14" in str(e):
                    self.log_recommendation(f"Socket error with maxsize={maxsize}, concurrent={concurrent}")
                
                self.results[f'pool_config_{maxsize}_{concurrent}'] = f'ERROR: {type(e).__name__}'
    
    # ==================== SECTION 7: TIMEOUT TESTS ====================
    def test_timeout_scenarios(self, client):
        """Test different timeout configurations"""
        logger.info("\n" + "="*60)
        logger.info("SECTION 7: TIMEOUT CONFIGURATION TESTS")
        logger.info("="*60)
        
        if not client:
            logger.warning("Skipping - no client connection")
            return
        
        timeout_configs = [
            (30, "SHORT (30s) - may cause issues with large batches"),
            (120, "MEDIUM (120s) - current init timeout"),
            (3600, "LONG (3600s) - current query/insert timeout"),
        ]
        
        for timeout_val, desc in timeout_configs:
            logger.info(f"\nTesting: {desc}")
            
            try:
                test_client = weaviate.WeaviateClient(
                    connection_params=weaviate.connect.ConnectionParams.from_params(
                        http_host=os.getenv('WEAVIATE_HTTP_HOST'),
                        http_port=int(os.getenv('WEAVIATE_HTTP_PORT', 443)),
                        http_secure=True,
                        grpc_host=os.getenv('WEAVIATE_GRPC_HOST'),
                        grpc_port=int(os.getenv('WEAVIATE_GRPC_PORT', 443)),
                        grpc_secure=True
                    ),
                    additional_config=weaviate.classes.init.AdditionalConfig(
                        timeout=weaviate.classes.init.Timeout(
                            init=timeout_val,
                            query=timeout_val,
                            insert=timeout_val
                        )
                    ),
                    skip_init_checks=False,
                    auth_client_secret=weaviate.auth.AuthApiKey(api_key=os.getenv('WEAVIATE_API_KEY'))
                )
                
                start = time.time()
                test_client.connect()
                duration = time.time() - start
                
                if test_client.is_connected():
                    logger.info(f"‚úì Connected in {duration:.2f}s")
                    self.results[f'timeout_{timeout_val}'] = 'PASS'
                    test_client.close()
                else:
                    self.log_issue(f"Connection failed with {timeout_val}s timeout")
                    self.results[f'timeout_{timeout_val}'] = 'FAIL'
                    
            except Exception as e:
                if "timeout" in str(e).lower():
                    self.log_issue(f"Timeout occurred with {timeout_val}s limit: {e}")
                    self.log_recommendation(f"Increase timeout values (currently {timeout_val}s)")
                self.results[f'timeout_{timeout_val}'] = f'ERROR: {type(e).__name__}'
    
    # ==================== SECTION 8: STRESS TEST ====================
    def stress_test_connections(self, client):
        """Stress test with multiple concurrent connections"""
        logger.info("\n" + "="*60)
        logger.info("SECTION 8: CONNECTION STRESS TEST")
        logger.info("="*60)
        
        if not client:
            logger.warning("Skipping - no client connection")
            return
        
        logger.info("Opening 10 connections simultaneously...")
        
        results = {'success': 0, 'fail': 0, 'errors': []}
        
        def connect_test(idx):
            try:
                test_client = weaviate.WeaviateClient(
                    connection_params=weaviate.connect.ConnectionParams.from_params(
                        http_host=os.getenv('WEAVIATE_HTTP_HOST'),
                        http_port=int(os.getenv('WEAVIATE_HTTP_PORT', 443)),
                        http_secure=True,
                        grpc_host=os.getenv('WEAVIATE_GRPC_HOST'),
                        grpc_port=int(os.getenv('WEAVIATE_GRPC_PORT', 443)),
                        grpc_secure=True
                    ),
                    skip_init_checks=False,
                    auth_client_secret=weaviate.auth.AuthApiKey(api_key=os.getenv('WEAVIATE_API_KEY'))
                )
                
                test_client.connect()
                if test_client.is_connected():
                    results['success'] += 1
                    test_client.close()
                else:
                    results['fail'] += 1
                    
            except Exception as e:
                results['fail'] += 1
                results['errors'].append(f"Thread {idx}: {type(e).__name__}")
        
        threads = []
        for i in range(10):
            t = threading.Thread(target=connect_test, args=(i,))
            threads.append(t)
            t.start()
        
        for t in threads:
            t.join()
        
        logger.info(f"Results: {results['success']} success, {results['fail']} failures")
        
        if results['fail'] > 0:
            self.log_issue(f"Stress test: {results['fail']}/10 connections failed")
            logger.error(f"  Errors: {set(results['errors'])}")
            self.log_recommendation("Cluster may have connection limits or resource constraints")
            self.results['stress_test'] = 'FAIL'
        else:
            logger.info("‚úì All concurrent connections successful")
            self.results['stress_test'] = 'PASS'
    
    # ==================== SECTION 9: LOAD BALANCER / PROXY ====================
    def check_load_balancer_behavior(self):
        """Check for load balancer or proxy issues"""
        logger.info("\n" + "="*60)
        logger.info("SECTION 9: LOAD BALANCER / PROXY DETECTION")
        logger.info("="*60)
        
        host = os.getenv('WEAVIATE_GRPC_HOST')
        port = int(os.getenv('WEAVIATE_GRPC_PORT', 443))
        
        logger.info("Testing connection persistence...")
        
        # Open connection and hold for 30 seconds
        try:
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.settimeout(35)
            s.connect((host, port))
            
            logger.info("Connection established, holding for 30 seconds...")
            time.sleep(30)
            
            # Try to send/receive
            s.send(b'\x00')
            logger.info("‚úì Connection still alive after 30s")
            s.close()
            
            self.results['lb_persistence'] = 'PASS'
            
        except socket.timeout:
            self.log_issue("Connection timed out - likely aggressive load balancer timeout")
            self.log_recommendation("Check load balancer idle timeout settings (should be > 300s)")
            self.results['lb_persistence'] = 'TIMEOUT'
        except Exception as e:
            self.log_issue(f"Connection dropped: {e}")
            self.log_recommendation("Load balancer or proxy may be dropping idle connections")
            self.results['lb_persistence'] = 'FAIL'
    
    # ==================== MAIN RUNNER ====================
    def run_diagnostics(self):
        """Execute complete diagnostic suite"""
        logger.info("\n" + "="*70)
        logger.info("  WEAVIATE CONNECTION COMPREHENSIVE DIAGNOSTICS")
        logger.info("="*70)
        logger.info(f"Timestamp: {datetime.now().isoformat()}")
        logger.info(f"Python version: {sys.version.split()[0]}")
        logger.info(f"Weaviate client version: {weaviate.__version__}")
        
        # Run all checks
        if not self.check_env_vars():
            logger.error("\n‚ùå Cannot proceed without environment variables")
            return self.results
        
        self.check_dns_resolution()
        self.check_tcp_connectivity()
        self.check_tls_connection()
        
        client = self.check_weaviate_connection()
        
        self.test_connection_pool_configs(client)
        self.test_timeout_scenarios(client)
        self.stress_test_connections(client)
        self.check_load_balancer_behavior()
        
        # Close client
        if client:
            try:
                client.close()
            except:
                pass
        
        # Final Summary
        logger.info("\n" + "="*70)
        logger.info("DIAGNOSTIC RESULTS SUMMARY")
        logger.info("="*70)
        
        for key, value in sorted(self.results.items()):
            status = "‚úì" if value == "PASS" or value == "STABLE" else "‚ùå"
            logger.info(f"{status} {key}: {value}")
        
        logger.info("\n" + "="*70)
        logger.info("ISSUES FOUND")
        logger.info("="*70)
        
        if not self.issues_found:
            logger.info("‚úì No critical issues detected")
        else:
            for i, issue in enumerate(self.issues_found, 1):
                logger.error(f"{i}. {issue}")
        
        logger.info("\n" + "="*70)
        logger.info("RECOMMENDATIONS")
        logger.info("="*70)
        
        if not self.recommendations:
            logger.info("No specific recommendations")
        else:
            for i, rec in enumerate(self.recommendations, 1):
                logger.info(f"{i}. {rec}")
        
        logger.info("\n" + "="*70)
        logger.info("NEXT STEPS")
        logger.info("="*70)
        logger.info("1. Review failed tests above")
        logger.info("2. Apply recommendations in order of priority")
        logger.info("3. Check Weaviate server logs for corresponding errors")
        logger.info("4. Monitor cluster resources (CPU/Memory) during batch operations")
        logger.info("5. If using Istio/Load Balancer, verify timeout configurations")
        logger.info("="*70 + "\n")
        
        return self.results

if __name__ == "__main__":
    diag = WeaviateDiagnostics()
    results = diag.run_diagnostics()