In [None]:
# Global Model Client - Connects to FedProx Server for Evaluation (Fixed)
import flwr as fl
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, accuracy_score, confusion_matrix
import os
import logging
import sys

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger("GlobalModelClient")

# Global Model Configuration
GLOBAL_MODEL_CONFIG = {
    "dropout_rate": 0.3,
    "test_data_path": "test_data.csv",  # Default test data file
    "batch_size": 32,
    "server_address": "localhost:8081",  # Connect to evaluation server
    "device": "cpu"
}

# Heart Disease Model (same as training clients)
class HeartDiseaseModel(nn.Module):
    def __init__(self, input_size):
        super(HeartDiseaseModel, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, 64),
            nn.ReLU(),
            nn.Dropout(GLOBAL_MODEL_CONFIG["dropout_rate"]),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(GLOBAL_MODEL_CONFIG["dropout_rate"]),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.layers(x)

def create_synthetic_test_data(num_samples=1000, input_size=15):
    """Create synthetic test data if no test file is available"""
    print(f"🔧 Creating synthetic test data with {num_samples} samples and {input_size} features")
    
    # Generate random features
    np.random.seed(42)  # For reproducibility
    X = np.random.randn(num_samples, input_size)
    
    # Create synthetic labels with some correlation to features
    # Use a simple linear combination with some noise
    weights = np.random.randn(input_size) * 0.1
    linear_combination = X @ weights
    probabilities = 1 / (1 + np.exp(-linear_combination))  # Sigmoid
    y = (probabilities > 0.5).astype(int)
    
    # Create DataFrame
    feature_names = [f"feature_{i}" for i in range(input_size)]
    df = pd.DataFrame(X, columns=feature_names)
    df["TenYearCHD"] = y
    
    print(f"📊 Synthetic test data class distribution:")
    print(f"   Negative cases: {(y == 0).sum()} ({(y == 0).mean()*100:.1f}%)")
    print(f"   Positive cases: {(y == 1).sum()} ({(y == 1).mean()*100:.1f}%)")
    
    return df

def load_test_data(data_path=None):
    """Load test data for global model evaluation"""
    if data_path is None:
        data_path = GLOBAL_MODEL_CONFIG["test_data_path"]
    
    try:
        # Try to load the specified file
        if os.path.exists(data_path):
            df = pd.read_csv(data_path)
            print(f"✓ Loaded test data from {data_path} with shape {df.shape}")
        else:
            print(f"⚠ Test data file {data_path} not found")
            print("🔧 Creating synthetic test data for demonstration...")
            df = create_synthetic_test_data(num_samples=1000, input_size=15)
        
        # Handle missing values
        missing_values = df.isnull().sum().sum()
        if missing_values > 0:
            print(f"⚠ Found {missing_values} missing values in test data, dropping rows")
            df.dropna(inplace=True)
        
        # Check for target column
        if "TenYearCHD" not in df.columns:
            print("⚠ Target column 'TenYearCHD' not found, creating synthetic target")
            # Create a synthetic target based on features
            feature_cols = [col for col in df.columns if col != "TenYearCHD"]
            if len(feature_cols) > 0:
                # Simple synthetic target
                df["TenYearCHD"] = (df[feature_cols].sum(axis=1) > df[feature_cols].sum(axis=1).median()).astype(int)
            else:
                raise ValueError("No features found to create synthetic target!")
        
        # Split features and target
        X = df.drop(columns=["TenYearCHD"])
        y = df["TenYearCHD"]
        
        print(f"📊 Test data class distribution:")
        print(f"   Negative cases: {(y == 0).sum()} ({(y == 0).mean()*100:.1f}%)")
        print(f"   Positive cases: {(y == 1).sum()} ({(y == 1).mean()*100:.1f}%)")
        
        # Standardize features
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)
        
        # Convert to tensors
        X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
        y_tensor = torch.tensor(y.values, dtype=torch.float32).view(-1, 1)
        
        # Create dataloader
        dataset = TensorDataset(X_tensor, y_tensor)
        dataloader = DataLoader(dataset, batch_size=GLOBAL_MODEL_CONFIG["batch_size"], shuffle=False)
        
        print(f"✓ Created test dataloader with {len(dataset)} samples and {X.shape[1]} features")
        return dataloader, X.shape[1], len(dataset)
    
    except Exception as e:
        logger.error(f"Error loading test data: {str(e)}")
        print("🔧 Falling back to synthetic data...")
        try:
            df = create_synthetic_test_data()
            X = df.drop(columns=["TenYearCHD"])
            y = df["TenYearCHD"]
            
            scaler = StandardScaler()
            X_scaled = scaler.fit_transform(X)
            
            X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
            y_tensor = torch.tensor(y.values, dtype=torch.float32).view(-1, 1)
            
            dataset = TensorDataset(X_tensor, y_tensor)
            dataloader = DataLoader(dataset, batch_size=GLOBAL_MODEL_CONFIG["batch_size"], shuffle=False)
            
            return dataloader, X.shape[1], len(dataset)
        except Exception as e2:
            logger.error(f"Failed to create synthetic data: {str(e2)}")
            return None, None, None

class GlobalModelClient(fl.client.NumPyClient):
    """Global Model Client that evaluates aggregated models on test data"""
    
    def __init__(self, model, test_dataloader, device, num_test_samples):
        self.model = model
        self.test_dataloader = test_dataloader
        self.device = device
        self.num_test_samples = num_test_samples
        self.round_count = 0
        
        print(f"🧠 Global Model Client initialized")
        print(f"   Device: {device}")
        print(f"   Test samples: {num_test_samples}")
        print(f"   Model parameters: {sum(p.numel() for p in model.parameters())}")
    
    def get_parameters(self, config):
        """Return current model parameters (not used for global model)"""
        # Global model doesn't need to send parameters back in evaluation-only mode
        return [val.cpu().detach().numpy() for val in self.model.parameters()]
    
    def set_parameters(self, parameters):
        """Set model parameters from server (aggregated from training clients)"""
        try:
            # Convert parameters to model state dict
            params_dict = zip(self.model.state_dict().keys(), parameters)
            state_dict = {k: torch.tensor(v, device=self.device) for k, v in params_dict}
            self.model.load_state_dict(state_dict, strict=True)
            print(f"✓ Global Model updated with aggregated parameters")
        except Exception as e:
            print(f"✗ Error updating Global Model parameters: {str(e)}")
            print(f"   Parameter shapes received: {[p.shape for p in parameters]}")
            print(f"   Model expects: {[p.shape for p in self.model.parameters()]}")
    
    def fit(self, parameters, config):
        """Global model doesn't participate in training"""
        # Global model is evaluation-only, so we don't train
        # Just update parameters and return them unchanged
        self.set_parameters(parameters)
        
        server_round = config.get("server_round", 0)
        print(f"🧠 Global Model received updated parameters for Round {server_round}")
        print(f"   (Global Model does not participate in training)")
        
        # Return unchanged parameters since we don't train
        return self.get_parameters(config), 0, {}
    
    def evaluate(self, parameters, config):
        """Evaluate the aggregated model on test data"""
        server_round = config.get("server_round", 0)
        self.round_count = server_round
        
        print(f"\n🔍 Global Model Evaluation - Round {server_round}")
        print("=" * 50)
        
        # Update model with aggregated parameters from server
        self.set_parameters(parameters)
        
        # Set device and evaluation mode
        device = torch.device(self.device)
        self.model = self.model.to(device)
        self.model.eval()
        
        # Evaluation metrics
        criterion = nn.BCELoss()
        test_loss = 0.0
        correct = 0
        total = 0
        
        # Store predictions for detailed metrics
        all_predictions = []
        all_probabilities = []
        all_labels = []
        
        print("📊 Evaluating on test data...")
        
        # Evaluate model
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(self.test_dataloader):
                inputs, targets = inputs.to(device), targets.to(device)
                
                # Forward pass
                outputs = self.model(inputs)
                loss = criterion(outputs, targets)
                
                # Store predictions and labels
                predictions = (outputs > 0.5).float()
                all_predictions.extend(predictions.cpu().numpy())
                all_probabilities.extend(outputs.cpu().numpy())
                all_labels.extend(targets.cpu().numpy())
                
                # Update metrics
                test_loss += loss.item() * inputs.size(0)
                total += targets.size(0)
                correct += (predictions == targets).sum().item()
        
        # Calculate final metrics
        avg_loss = test_loss / total if total > 0 else 0.0
        accuracy = correct / total if total > 0 else 0.0
        
        # Calculate additional metrics
        all_predictions = np.array(all_predictions).flatten()
        all_probabilities = np.array(all_probabilities).flatten()
        all_labels = np.array(all_labels).flatten()
        
        # AUC Score
        try:
            auc_score = roc_auc_score(all_labels, all_probabilities)
        except Exception as e:
            print(f"⚠ Could not calculate AUC: {str(e)}")
            auc_score = 0.0
        
        # Confusion Matrix
        try:
            cm = confusion_matrix(all_labels, all_predictions)
            if cm.size == 4:
                tn, fp, fn, tp = cm.ravel()
            else:
                # Handle edge cases where confusion matrix doesn't have 4 elements
                tn = fp = fn = tp = 0
                if cm.size == 1:
                    if all_labels[0] == 0 and all_predictions[0] == 0:
                        tn = cm[0, 0]
                    elif all_labels[0] == 1 and all_predictions[0] == 1:
                        tp = cm[0, 0]
        except Exception as e:
            print(f"⚠ Could not calculate confusion matrix: {str(e)}")
            tn = fp = fn = tp = 0
        
        # Additional metrics
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
        
        # Display detailed results
        print(f"\n📈 Round {server_round} - Global Model Evaluation Results:")
        print(f"   🔴 Loss: {avg_loss:.4f}")
        print(f"   🟢 Accuracy: {accuracy:.4f} ({correct}/{total})")
        print(f"   🔵 AUC: {auc_score:.4f}")
        print(f"   📊 Precision: {precision:.4f}")
        print(f"   📊 Recall (Sensitivity): {recall:.4f}")
        print(f"   📊 Specificity: {specificity:.4f}")
        print(f"   📊 F1-Score: {f1_score:.4f}")
        
        print(f"\n📋 Confusion Matrix:")
        print(f"   True Negatives:  {tn}")
        print(f"   False Positives: {fp}")
        print(f"   False Negatives: {fn}")
        print(f"   True Positives:  {tp}")
        
        # Performance interpretation
        if server_round > 1:
            print(f"\n💡 Performance Interpretation:")
            if accuracy > 0.85:
                print(f"   🎯 Excellent performance!")
            elif accuracy > 0.75:
                print(f"   ✅ Good performance")
            elif accuracy > 0.65:
                print(f"   ⚠ Moderate performance")
            else:
                print(f"   🔧 Needs improvement")
            
            if auc_score > 0.8:
                print(f"   🏆 Strong discriminative ability (AUC > 0.8)")
            elif auc_score > 0.7:
                print(f"   ✅ Good discriminative ability")
            else:
                print(f"   ⚠ Limited discriminative ability")
        
        print("=" * 50)
        
        # Return metrics to server
        metrics = {
            "accuracy": float(accuracy),
            "auc": float(auc_score),
            "precision": float(precision),
            "recall": float(recall),
            "f1_score": float(f1_score),
            "specificity": float(specificity),
            "true_positives": int(tp),
            "true_negatives": int(tn),
            "false_positives": int(fp),
            "false_negatives": int(fn)
        }
        
        return float(avg_loss), self.num_test_samples, metrics

def start_global_model_client(server_address=None, test_data_path=None):
    """Start the Global Model client"""
    
    print("🧠 Starting Global Model Client")
    print("=" * 40)
    
    # Update configuration
    if server_address:
        GLOBAL_MODEL_CONFIG["server_address"] = server_address
    if test_data_path:
        GLOBAL_MODEL_CONFIG["test_data_path"] = test_data_path
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    GLOBAL_MODEL_CONFIG["device"] = device.type
    print(f"🔧 Using device: {device}")
    
    # Load test data
    print("📂 Loading test data...")
    test_dataloader, input_size, num_test_samples = load_test_data()
    
    if test_dataloader is None:
        print("❌ Failed to load test data. Cannot start Global Model.")
        return False
    
    # Initialize model
    print("🏗️ Initializing Global Model...")
    model = HeartDiseaseModel(input_size=input_size).to(device)
    
    # Create Global Model client
    client = GlobalModelClient(model, test_dataloader, device, num_test_samples)
    
    print(f"\n🌐 Connecting to server at {GLOBAL_MODEL_CONFIG['server_address']}")
    print("🎯 Global Model will:")
    print("   1. Connect to evaluation server")
    print("   2. Receive aggregated models after each round")
    print("   3. Evaluate on test data")
    print("   4. Report detailed metrics back to server")
    print("\n⏳ Connecting...")
    
    try:
        # Start the client (this will block until training is complete)
        fl.client.start_client(
            server_address=GLOBAL_MODEL_CONFIG["server_address"], 
            client=client
        )
        
        print("\n✅ Global Model client completed successfully!")
        return True
        
    except KeyboardInterrupt:
        print("\n⚠ Global Model client interrupted by user")
        return False
        
    except Exception as e:
        print(f"\n❌ Global Model client error: {str(e)}")
        return False

# Quick start function
def start_global_model():
    """Quick start Global Model with default settings"""
    print("🧠 Quick Start - Global Model Client")
    return start_global_model_client(
        server_address="localhost:8081",  # Connect to evaluation server
        test_data_path="test_data.csv"
    )

def start_global_model_with_synthetic_data():
    """Start Global Model with synthetic test data"""
    print("🧠 Quick Start - Global Model Client (Synthetic Data)")
    return start_global_model_client(
        server_address="localhost:8081",
        test_data_path="nonexistent.csv"  # This will trigger synthetic data creation
    )

print("🧠 Global Model Client Ready!")
print("\n📝 Available Commands:")
print("   start_global_model()                    - Quick start with defaults")
print("   start_global_model_with_synthetic_data() - Quick start with synthetic test data")
print("   start_global_model_client()             - Start with custom parameters")
print("\n💡 Usage Instructions:")
print("   1. Make sure the evaluation server is running on port 8081")
print("   2. Run start_global_model() to connect")
print("   3. Global Model will evaluate the current model parameters")
print("   4. If no test data file exists, synthetic data will be created")
print("\n🚀 To start: run start_global_model()")

TypeError: FedAvg.__init__() got an unexpected keyword argument 'on_client_connected'

In [7]:
start_global_model()

NameError: name 'start_global_model' is not defined

In [None]:
# Global Model Client - Connects to FedProx Server for Evaluation
import flwr as fl
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, accuracy_score, confusion_matrix
import os
import logging
import sys

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger("GlobalModelClient")

# Global Model Configuration
GLOBAL_MODEL_CONFIG = {
    "dropout_rate": 0.3,
    "test_data_path": "unseendata.csv",  # Test on unseen data
    "batch_size": 32,
    "server_address": "localhost:8080",  # Match your training client's server address
    "device": "cpu"
}

# Heart Disease Model (same as your training clients)
class HeartDiseaseModel(nn.Module):
    def __init__(self, input_size):
        super(HeartDiseaseModel, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, 64),
            nn.ReLU(),
            nn.Dropout(GLOBAL_MODEL_CONFIG["dropout_rate"]),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(GLOBAL_MODEL_CONFIG["dropout_rate"]),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.layers(x)

def create_synthetic_test_data(num_samples=1000, input_size=15):
    """Create synthetic test data if no test file is available"""
    print(f"🔧 Creating synthetic test data with {num_samples} samples and {input_size} features")
    
    # Generate random features
    np.random.seed(42)  # For reproducibility
    X = np.random.randn(num_samples, input_size)
    
    # Create synthetic labels with some correlation to features
    weights = np.random.randn(input_size) * 0.1
    linear_combination = X @ weights
    probabilities = 1 / (1 + np.exp(-linear_combination))  # Sigmoid
    y = (probabilities > 0.5).astype(int)
    
    # Create DataFrame
    feature_names = [f"feature_{i}" for i in range(input_size)]
    df = pd.DataFrame(X, columns=feature_names)
    df["TenYearCHD"] = y
    
    print(f"📊 Synthetic test data class distribution:")
    print(f"   Negative cases: {(y == 0).sum()} ({(y == 0).mean()*100:.1f}%)")
    print(f"   Positive cases: {(y == 1).sum()} ({(y == 1).mean()*100:.1f}%)")
    
    return df

def load_test_data(data_path=None):
    """Load test data for global model evaluation"""
    if data_path is None:
        data_path = GLOBAL_MODEL_CONFIG["test_data_path"]
    
    try:
        # Try to load the specified file
        if os.path.exists(data_path):
            df = pd.read_csv(data_path)
            print(f"✓ Loaded test data from {data_path} with shape {df.shape}")
        else:
            print(f"⚠ Test data file {data_path} not found")
        
        # Handle missing values
        missing_values = df.isnull().sum().sum()
        if missing_values > 0:
            print(f"⚠ Found {missing_values} missing values in test data, dropping rows")
            df.dropna(inplace=True)
        
        # Check for target column
        if "TenYearCHD" not in df.columns:
            print("⚠ Target column 'TenYearCHD' not found, creating synthetic target")
            # Create a synthetic target based on features
            feature_cols = [col for col in df.columns if col != "TenYearCHD"]
            if len(feature_cols) > 0:
                # Simple synthetic target
                df["TenYearCHD"] = (df[feature_cols].sum(axis=1) > df[feature_cols].sum(axis=1).median()).astype(int)
            else:
                raise ValueError("No features found to create synthetic target!")
        
        # Split features and target
        X = df.drop(columns=["TenYearCHD"])
        y = df["TenYearCHD"]
        
        print(f"📊 Test data (unseendata.csv) class distribution:")
        print(f"   Negative cases: {(y == 0).sum()} ({(y == 0).mean()*100:.1f}%)")
        print(f"   Positive cases: {(y == 1).sum()} ({(y == 1).mean()*100:.1f}%)")
        
        # Standardize features
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)
        
        # Convert to tensors
        X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
        y_tensor = torch.tensor(y.values, dtype=torch.float32).view(-1, 1)
        
        # Create dataloader
        dataset = TensorDataset(X_tensor, y_tensor)
        dataloader = DataLoader(dataset, batch_size=GLOBAL_MODEL_CONFIG["batch_size"], shuffle=False)
        
        print(f"✓ Created test dataloader with {len(dataset)} samples and {X.shape[1]} features")
        return dataloader, X.shape[1], len(dataset)
    
    except Exception as e:
        logger.error(f"Error loading unseendata.csv: {str(e)}")
        print("🔧 Falling back to synthetic data...")
        try:
            df = create_synthetic_test_data()
            X = df.drop(columns=["TenYearCHD"])
            y = df["TenYearCHD"]
            
            scaler = StandardScaler()
            X_scaled = scaler.fit_transform(X)
            
            X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
            y_tensor = torch.tensor(y.values, dtype=torch.float32).view(-1, 1)
            
            dataset = TensorDataset(X_tensor, y_tensor)
            dataloader = DataLoader(dataset, batch_size=GLOBAL_MODEL_CONFIG["batch_size"], shuffle=False)
            
            return dataloader, X.shape[1], len(dataset)
        except Exception as e2:
            logger.error(f"Failed to create synthetic data: {str(e2)}")
            return None, None, None

class GlobalModelClient(fl.client.NumPyClient):
    """Global Model Client that evaluates aggregated models on unseen test data"""
    
    def __init__(self, model, test_dataloader, device, num_test_samples):
        self.model = model
        self.test_dataloader = test_dataloader
        self.device = device
        self.num_test_samples = num_test_samples
        self.round_count = 0
        
        print(f"🧠 Global Model Client initialized")
        print(f"   Device: {device}")
        print(f"   Test samples (unseendata.csv): {num_test_samples}")
        print(f"   Model parameters: {sum(p.numel() for p in model.parameters())}")
    
    def get_parameters(self, config):
        """Return current model parameters (not used for global model)"""
        # Global model doesn't need to send parameters back in evaluation-only mode
        return [val.cpu().detach().numpy() for val in self.model.parameters()]
    
    def set_parameters(self, parameters):
        """Set model parameters from server (aggregated from training clients)"""
        try:
            # Convert parameters to model state dict
            params_dict = zip(self.model.state_dict().keys(), parameters)
            state_dict = {k: torch.tensor(v, device=self.device) for k, v in params_dict}
            self.model.load_state_dict(state_dict, strict=True)
            print(f"✓ Global Model updated with aggregated parameters")
        except Exception as e:
            print(f"✗ Error updating Global Model parameters: {str(e)}")
            print(f"   Parameter shapes received: {[p.shape for p in parameters]}")
            print(f"   Model expects: {[p.shape for p in self.model.parameters()]}")
    
    def fit(self, parameters, config):
        """Global model doesn't participate in training"""
        # Global model is evaluation-only, so we don't train
        # Just update parameters and return them unchanged
        self.set_parameters(parameters)
        
        server_round = config.get("server_round", 0)
        print(f"🧠 Global Model received updated parameters for Round {server_round}")
        print(f"   (Global Model does not participate in training)")
        
        # Return unchanged parameters since we don't train
        return self.get_parameters(config), 0, {}
    
    def evaluate(self, parameters, config):
        """Evaluate the aggregated model on unseen test data"""
        server_round = config.get("server_round", 0)
        self.round_count = server_round
        
        print(f"\n🔍 Global Model Evaluation - Round {server_round}")
        print("=" * 50)
        print("📋 Testing on unseendata.csv (Completely Unseen Data)")
        print("=" * 50)
        
        # Update model with aggregated parameters from server
        self.set_parameters(parameters)
        
        # Set device and evaluation mode
        device = torch.device(self.device)
        self.model = self.model.to(device)
        self.model.eval()
        
        # Evaluation metrics
        criterion = nn.BCELoss()
        test_loss = 0.0
        correct = 0
        total = 0
        
        # Store predictions for detailed metrics
        all_predictions = []
        all_probabilities = []
        all_labels = []
        
        print("📊 Evaluating on unseen test data...")
        
        # Evaluate model
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(self.test_dataloader):
                inputs, targets = inputs.to(device), targets.to(device)
                
                # Forward pass
                outputs = self.model(inputs)
                loss = criterion(outputs, targets)
                
                # Store predictions and labels
                predictions = (outputs > 0.5).float()
                all_predictions.extend(predictions.cpu().numpy())
                all_probabilities.extend(outputs.cpu().numpy())
                all_labels.extend(targets.cpu().numpy())
                
                # Update metrics
                test_loss += loss.item() * inputs.size(0)
                total += targets.size(0)
                correct += (predictions == targets).sum().item()
        
        # Calculate final metrics
        avg_loss = test_loss / total if total > 0 else 0.0
        accuracy = correct / total if total > 0 else 0.0
        
        # Calculate additional metrics
        all_predictions = np.array(all_predictions).flatten()
        all_probabilities = np.array(all_probabilities).flatten()
        all_labels = np.array(all_labels).flatten()
        
        # AUC Score
        try:
            auc_score = roc_auc_score(all_labels, all_probabilities)
        except Exception as e:
            print(f"⚠ Could not calculate AUC: {str(e)}")
            auc_score = 0.0
        
        # Confusion Matrix
        try:
            cm = confusion_matrix(all_labels, all_predictions)
            if cm.size == 4:
                tn, fp, fn, tp = cm.ravel()
            else:
                # Handle edge cases where confusion matrix doesn't have 4 elements
                tn = fp = fn = tp = 0
                if cm.size == 1:
                    if all_labels[0] == 0 and all_predictions[0] == 0:
                        tn = cm[0, 0]
                    elif all_labels[0] == 1 and all_predictions[0] == 1:
                        tp = cm[0, 0]
        except Exception as e:
            print(f"⚠ Could not calculate confusion matrix: {str(e)}")
            tn = fp = fn = tp = 0
        
        # Additional metrics
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
        
        # Display detailed results
        print(f"\n📈 Round {server_round} - Global Model Evaluation on Unseen Data:")
        print(f"   🔴 Loss: {avg_loss:.4f}")
        print(f"   🟢 Accuracy: {accuracy:.4f} ({correct}/{total})")
        print(f"   🔵 AUC: {auc_score:.4f}")
        print(f"   📊 Precision: {precision:.4f}")
        print(f"   📊 Recall (Sensitivity): {recall:.4f}")
        print(f"   📊 Specificity: {specificity:.4f}")
        print(f"   📊 F1-Score: {f1_score:.4f}")
        
        print(f"\n📋 Confusion Matrix (Unseen Data):")
        print(f"   True Negatives:  {tn}")
        print(f"   False Positives: {fp}")
        print(f"   False Negatives: {fn}")
        print(f"   True Positives:  {tp}")
        
        # Performance interpretation
        if server_round > 1:
            print(f"\n💡 Performance on Unseen Data:")
            if accuracy > 0.85:
                print(f"   🎯 Excellent generalization!")
            elif accuracy > 0.75:
                print(f"   ✅ Good generalization")
            elif accuracy > 0.65:
                print(f"   ⚠ Moderate generalization")
            else:
                print(f"   🔧 Poor generalization - may be overfitting")
            
            if auc_score > 0.8:
                print(f"   🏆 Strong discriminative ability on unseen data (AUC > 0.8)")
            elif auc_score > 0.7:
                print(f"   ✅ Good discriminative ability on unseen data")
            else:
                print(f"   ⚠ Limited discriminative ability on unseen data")
        
        print("=" * 50)
        
        # Return metrics to server
        metrics = {
            "accuracy": float(accuracy),
            "auc": float(auc_score),
            "precision": float(precision),
            "recall": float(recall),
            "f1_score": float(f1_score),
            "specificity": float(specificity),
            "true_positives": int(tp),
            "true_negatives": int(tn),
            "false_positives": int(fp),
            "false_negatives": int(fn)
        }
        
        return float(avg_loss), self.num_test_samples, metrics

def start_global_model_client(server_address=None, test_data_path=None):
    """Start the Global Model client"""
    
    print("🧠 Starting Global Model Client")
    print("=" * 40)
    print("📋 Testing on UNSEEN DATA (unseendata.csv)")
    print("=" * 40)
    
    # Update configuration
    if server_address:
        GLOBAL_MODEL_CONFIG["server_address"] = server_address
    if test_data_path:
        GLOBAL_MODEL_CONFIG["test_data_path"] = test_data_path
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    GLOBAL_MODEL_CONFIG["device"] = device.type
    print(f"🔧 Using device: {device}")
    
    # Load test data
    print("📂 Loading unseen test data...")
    test_dataloader, input_size, num_test_samples = load_test_data()
    
    if test_dataloader is None:
        print("❌ Failed to load unseen test data. Cannot start Global Model.")
        return False
    
    # Initialize model
    print("🏗️ Initializing Global Model...")
    model = HeartDiseaseModel(input_size=input_size).to(device)
    
    # Create Global Model client
    client = GlobalModelClient(model, test_dataloader, device, num_test_samples)
    
    print(f"\n🌐 Connecting to server at {GLOBAL_MODEL_CONFIG['server_address']}")
    print("🎯 Global Model will:")
    print("   1. Connect to evaluation server")
    print("   2. Receive aggregated models after each round")
    print("   3. Evaluate on UNSEEN data (unseendata.csv)")
    print("   4. Report detailed metrics back to server")
    print("\n⏳ Connecting...")
    
    try:
        # Start the client (this will block until training is complete)
        fl.client.start_client(
            server_address=GLOBAL_MODEL_CONFIG["server_address"], 
            client=client
        )
        
        print("\n✅ Global Model client completed successfully!")
        return True
        
    except KeyboardInterrupt:
        print("\n⚠ Global Model client interrupted by user")
        return False
        
    except Exception as e:
        print(f"\n❌ Global Model client error: {str(e)}")
        return False

# Quick start function
def start_global_model():
    """Quick start Global Model with default settings"""
    print("🧠 Quick Start - Global Model Client")
    print("📋 Will test on unseendata.csv")
    return start_global_model_client(
        server_address="localhost:8080",  # Match your training client
        test_data_path="unseendata.csv"
    )

def start_global_model_with_synthetic_data():
    """Start Global Model with synthetic test data"""
    print("🧠 Quick Start - Global Model Client (Synthetic Data)")
    return start_global_model_client(
        server_address="localhost:8081",
        test_data_path="nonexistent.csv"  # This will trigger synthetic data creation
    )

# Auto-start when cell is executed
print("🧠 Starting Global Model Client automatically...")
print("📋 Will test on unseendata.csv")

# Check if server is available before starting
import socket
import time

def check_server_connection(host="localhost", port=8080, timeout=5):
    """Check if server is available"""
    try:
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.settimeout(timeout)
        result = sock.connect_ex((host, port))
        sock.close()
        return result == 0
    except:
        return False

# Wait for server to be available
print("🔍 Checking server availability...")
max_attempts = 30
attempt = 0

while attempt < max_attempts:
    if check_server_connection():
        print("✅ Server is available!")
        break
    else:
        print(f"⏳ Waiting for server... (attempt {attempt + 1}/{max_attempts})")
        time.sleep(2)
        attempt += 1

if attempt >= max_attempts:
    print("❌ Server not available after 60 seconds")
    print("💡 Make sure to:")
    print("   1. Start the FedProx server first")
    print("   2. Wait for it to show 'Waiting for clients to connect...'")
    print("   3. Then run this Global Model Client")
else:
    # Server is available, start the client
    success = start_global_model()
    
    if success:
        print("✅ Global Model Client completed successfully!")
    else:
        print("❌ Global Model Client failed to complete")

: 