In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import flwr as fl
import pandas as pd
from sklearn.model_selection import train_test_split
from imblearn.over_sampling import SMOTE
from torch.utils.data import DataLoader, TensorDataset

# Load dataset
df = pd.read_csv("older_adults.csv")
X = df.iloc[:, :-1].values  # Features
y = df.iloc[:, -1].values   # Target

# Balance dataset using SMOTE
smote = SMOTE(random_state=42)
X, y = smote.fit_resample(X, y)

# Split into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Convert to tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1)

train_loader = DataLoader(TensorDataset(X_train_tensor, y_train_tensor), batch_size=64, shuffle=True)
test_loader = DataLoader(TensorDataset(X_test_tensor, y_test_tensor), batch_size=64)

# Define MLP Model
class MLP(nn.Module):
    def __init__(self, input_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return torch.sigmoid(self.fc3(x))

# Initialize model
input_size = X_train.shape[1]  
model = MLP(input_size)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.00545)

# Define Flower client
class FLClient(fl.client.NumPyClient):
    def get_parameters(self, config):
        return [val.cpu().detach().numpy() for val in model.state_dict().values()]
    
    def set_parameters(self, parameters):
        param_dict = zip(model.state_dict().keys(), parameters)
        state_dict = {k: torch.tensor(v) for k, v in param_dict}
        model.load_state_dict(state_dict, strict=True)
    
    def fit(self, parameters, config):
        self.set_parameters(parameters)
        model.train()
        for epoch in range(10):
            total_loss = 0
            for X_batch, y_batch in train_loader:
                optimizer.zero_grad()
                outputs = model(X_batch)
                loss = criterion(outputs, y_batch)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            print(f"[Client 1] Epoch {epoch+1}, Loss: {total_loss / len(train_loader):.4f}")
        return self.get_parameters(config), len(X_train), {}
    
    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        model.eval()
        correct, total, loss = 0, 0, 0.0
        with torch.no_grad():
            for X_batch, y_batch in test_loader:
                outputs = model(X_batch)
                loss += criterion(outputs, y_batch).item()
                predictions = (outputs > 0.5).float()
                correct += (predictions == y_batch).sum().item()
                total += y_batch.size(0)
        
        accuracy = correct / total if total > 0 else 0.0
        print(f"[Client 1] Accuracy: {accuracy:.4f}")
        return float(loss), len(X_test), {"accuracy": accuracy}

# Start the client
fl.client.start_numpy_client(server_address="127.0.0.1:5000", client=FLClient())


FileNotFoundError: [Errno 2] No such file or directory: 'older_adults.csv'

In [3]:
import torch
import torch.nn as nn
import flwr as fl
import numpy as np
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
import os
import sys
import logging

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger("FraminghamClient")

# Client configuration
CLIENT_CONFIG = {
    "local_epochs": 3,
    "batch_size": 32,
    "learning_rate": 0.001,
    "weight_decay": 1e-5,
    "dropout_rate": 0.3,
    "server_address": "localhost:8080",
    "proximal_mu": 0.01  # FedProx hyperparameter (controls the proximal term strength)
}

# Model for Framingham Heart Study data
class HeartDiseaseModel(nn.Module):
    def __init__(self, input_size):
        super(HeartDiseaseModel, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, 64),
            nn.ReLU(),
            nn.Dropout(CLIENT_CONFIG["dropout_rate"]),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(CLIENT_CONFIG["dropout_rate"]),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.layers(x)

# FedProx Loss Function
class FedProxLoss(nn.Module):
    def __init__(self, base_criterion, mu=0.01):
        super(FedProxLoss, self).__init__()
        self.base_criterion = base_criterion
        self.mu = mu  # Proximal term coefficient
        
    def forward(self, y_pred, y_true, model_params, global_params):
        # Calculate the base loss (e.g., BCE loss)
        base_loss = self.base_criterion(y_pred, y_true)
        
        # Calculate the proximal term if global parameters are provided
        proximal_term = 0.0
        if global_params is not None:
            # Sum up the squared L2 norm of the difference between local and global model parameters
            for local_param, global_param in zip(model_params, global_params):
                proximal_term += torch.sum((local_param - global_param) ** 2)
                
            # Add the weighted proximal term to the base loss
            loss = base_loss + (self.mu / 2) * proximal_term
            return loss
        
        # If no global parameters are provided, just return the base loss
        return base_loss

# Load and preprocess Framingham data
def load_data(data_path):
    """Load and preprocess Framingham Heart Study data"""
    try:
        # Read CSV data
        df = pd.read_csv(data_path)
        logger.info(f"Loaded {data_path} with shape {df.shape}")
        
        # Check for missing values
        missing_values = df.isnull().sum().sum()
        if missing_values > 0:
            logger.info(f"Found {missing_values} missing values, dropping rows with missing values")
            df.dropna(inplace=True)
            logger.info(f"Shape after dropping missing values: {df.shape}")
        
        # Ensure the target column exists
        if "TenYearCHD" not in df.columns:
            raise ValueError("Target column 'TenYearCHD' not found in dataset!")
            
        # Split features and target
        X = df.drop(columns=["TenYearCHD"])
        y = df["TenYearCHD"]
        
        # Show class distribution
        logger.info(f"Class distribution: {y.value_counts().to_dict()}")
        
        # Standardize features
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)
        
        # Convert to tensors
        X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
        y_tensor = torch.tensor(y.values, dtype=torch.float32).view(-1, 1)
        
        # Create dataset and dataloader
        dataset = TensorDataset(X_tensor, y_tensor)
        dataloader = DataLoader(dataset, batch_size=CLIENT_CONFIG["batch_size"], shuffle=True)
        
        logger.info(f"Created dataloader with {len(dataset)} samples and {X.shape[1]} features")
        return dataloader, X.shape[1]
    
    except Exception as e:
        logger.error(f"Error loading data: {str(e)}")
        raise

# Client class for Federated Learning with FedProx
class FraminghamClient(fl.client.NumPyClient):
    def __init__(self, model, dataloader, device):
        self.model = model
        self.dataloader = dataloader
        self.device = device
        self.global_params = None  # Store global model parameters for FedProx
        logger.info(f"Initialized client with device: {device}")
        
    def get_parameters(self, config):
        """Get model parameters as a list of NumPy arrays"""
        # Using detach() to prevent gradient error
        return [val.detach().cpu().numpy() for val in self.model.parameters()]
    
    def set_parameters(self, parameters):
        """Set model parameters from a list of NumPy arrays"""
        # Convert to torch tensors
        self.global_params = [torch.tensor(p, device=self.device) for p in parameters]
        
        # Update model
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = {k: torch.tensor(v, device=self.device) for k, v in params_dict}
        self.model.load_state_dict(state_dict, strict=True)
        logger.info("Parameters updated from server")
        
    def fit(self, parameters, config):
        """Train the model on local data with FedProx"""
        # Update model with server parameters
        self.set_parameters(parameters)
        
        # Train the model
        self.model.train()
        
        # Standard loss function
        criterion = nn.BCELoss()
        
        # FedProx loss function
        proximal_criterion = FedProxLoss(criterion, mu=CLIENT_CONFIG["proximal_mu"])
        
        optimizer = torch.optim.Adam(
            self.model.parameters(), 
            lr=CLIENT_CONFIG["learning_rate"],
            weight_decay=CLIENT_CONFIG["weight_decay"]
        )
        
        # Metrics for tracking
        total_loss = 0.0
        total_samples = 0
        correct = 0
        
        # Train for multiple epochs
        for epoch in range(CLIENT_CONFIG["local_epochs"]):
            epoch_loss = 0.0
            epoch_samples = 0
            
            for batch_idx, (X, y) in enumerate(self.dataloader):
                # Move tensors to device
                X, y = X.to(self.device), y.to(self.device)
                
                # Forward pass
                y_pred = self.model(X)
                
                # Calculate loss with proximal term
                loss = proximal_criterion(
                    y_pred, 
                    y, 
                    self.model.parameters(),  # Current model parameters
                    self.global_params        # Global model parameters
                )
                
                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                # Update metrics
                batch_loss = loss.item() * X.size(0)
                total_loss += batch_loss
                epoch_loss += batch_loss
                total_samples += X.size(0)
                epoch_samples += X.size(0)
                
                # Calculate accuracy
                predicted = (y_pred > 0.5).float()
                correct += (predicted == y).sum().item()
                
                # Log progress occasionally
                if batch_idx % 5 == 0:
                    logger.info(
                        f"Epoch {epoch+1}/{CLIENT_CONFIG['local_epochs']} - "
                        f"Batch {batch_idx}/{len(self.dataloader)} - "
                        f"Loss: {loss.item():.4f}"
                    )
            
            # Log epoch metrics
            logger.info(
                f"Epoch {epoch+1}/{CLIENT_CONFIG['local_epochs']} completed - "
                f"Loss: {epoch_loss/epoch_samples:.4f}"
            )
        
        # Calculate final metrics
        avg_loss = total_loss / total_samples if total_samples > 0 else 0
        accuracy = correct / total_samples if total_samples > 0 else 0
        
        logger.info(f"Training completed - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
        
        # Return updated model parameters and metrics
        return self.get_parameters({}), total_samples, {"loss": float(avg_loss), "accuracy": float(accuracy)}
    
    def evaluate(self, parameters, config):
        """Evaluate the model on local data"""
        # Update model with server parameters
        self.set_parameters(parameters)
        
        # Evaluate the model
        self.model.eval()
        criterion = nn.BCELoss()
        
        loss = 0.0
        total = 0
        correct = 0
        
        with torch.no_grad():
            for X, y in self.dataloader:
                # Move tensors to device
                X, y = X.to(self.device), y.to(self.device)
                
                # Forward pass
                y_pred = self.model(X)
                batch_loss = criterion(y_pred, y).item()
                
                # Update metrics
                loss += batch_loss * X.size(0)
                total += X.size(0)
                
                # Calculate accuracy
                predicted = (y_pred > 0.5).float()
                correct += (predicted == y).sum().item()
        
        # Calculate final metrics
        avg_loss = loss / total if total > 0 else 0
        accuracy = correct / total if total > 0 else 0
        
        logger.info(f"Evaluation - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
        
        # Return metrics
        return float(avg_loss), total, {"accuracy": float(accuracy)}

def start_client(client_id=0, server_address=None):
    """Initialize and start a client"""
    # Update server address if provided
    if server_address:
        CLIENT_CONFIG["server_address"] = server_address
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")
    
    # Determine which data file to use based on client ID
    data_path = f"framingham_part{client_id+1}.csv"
    
    # Try alternative naming if file doesn't exist
    if not os.path.exists(data_path):
        alternative_path = f"framingham_part{client_id+1}.csv"
        if os.path.exists(alternative_path):
            data_path = alternative_path
        else:
            logger.error(f"Data file {data_path} not found")
            return
    
    # Load data
    dataloader, input_size = load_data(data_path)
    
    # Initialize model
    model = HeartDiseaseModel(input_size=input_size).to(device)
    logger.info(f"Model initialized with input size: {input_size}")
    
    # Create client
    client = FraminghamClient(model, dataloader, device)
    
    # Start client
    logger.info(f"Starting client {client_id} and connecting to {CLIENT_CONFIG['server_address']}")
    
    print(f"\n===== Framingham Heart Study FL Client {client_id} (FedProx) =====")
    print(f"Server:        {CLIENT_CONFIG['server_address']}")
    print(f"Data file:     {data_path}")
    print(f"Local epochs:  {CLIENT_CONFIG['local_epochs']}")
    print(f"Batch size:    {CLIENT_CONFIG['batch_size']}")
    print(f"Proximal mu:   {CLIENT_CONFIG['proximal_mu']}")
    print(f"Device:        {device}")
    print("=================================================")
    print(f"\nConnecting to server...\n")
    
    fl.client.start_client(server_address=CLIENT_CONFIG["server_address"], client=client)

# For Jupyter usage
if __name__ == "__main__":
    # Check if running in Jupyter
    if 'ipykernel' in sys.modules:
        print("Running in Jupyter/IPython environment")
        # Default to client ID 0, can be changed by user
        start_client(client_id=2)
    else:
        # For command line use
        import argparse
        parser = argparse.ArgumentParser(description="Framingham Heart Study FL Client")
        parser.add_argument("--id", type=int, default=0, help="Client ID (0, 1, or 2)")
        parser.add_argument("--server", type=str, default="localhost:8080", help="Server address")
        parser.add_argument("--mu", type=float, default=0.01, help="FedProx proximal term strength")
        
        args = parser.parse_args()
        
        if args.id not in [0, 1, 2]:
            logger.error("Client ID must be 0, 1, or 2")
        else:
            try:
                # Set FedProx hyperparameter
                CLIENT_CONFIG["proximal_mu"] = args.mu
                
                start_client(args.id, args.server)
            except Exception as e:
                logger.error(f"Client failed: {str(e)}")

2025-05-11 14:47:11,679 - FraminghamClient - INFO - Using device: cpu
2025-05-11 14:47:11,777 - FraminghamClient - INFO - Loaded framingham_part3.csv with shape (1060, 16)
2025-05-11 14:47:11,814 - FraminghamClient - INFO - Class distribution: {0: 898, 1: 162}


Running in Jupyter/IPython environment


2025-05-11 14:47:11,865 - FraminghamClient - INFO - Created dataloader with 1060 samples and 15 features
2025-05-11 14:47:11,882 - FraminghamClient - INFO - Model initialized with input size: 15
2025-05-11 14:47:11,885 - FraminghamClient - INFO - Initialized client with device: cpu
2025-05-11 14:47:11,888 - FraminghamClient - INFO - Starting client 2 and connecting to localhost:8080
	Instead, use the `flower-supernode` CLI command to start a SuperNode as shown below:

		$ flower-supernode --insecure --superlink='<IP>:<PORT>'

	To view all available options, run:

		$ flower-supernode --help

	Using `start_client()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
	Instead, use the `flower-supernode` CLI command to start a SuperNode as shown below:

		$ flower-supernode --insecure --superlink='<IP>:<PORT>'

	To view all available options, run:

		$ flower-supernode --help

	Using `start_client()` is 


===== Framingham Heart Study FL Client 2 (FedProx) =====
Server:        localhost:8080
Data file:     framingham_part3.csv
Local epochs:  3
Batch size:    32
Proximal mu:   0.01
Device:        cpu

Connecting to server...



[92mINFO [0m:      
2025-05-11 14:47:16,155 - flwr - INFO - 
[92mINFO [0m:      Received: train message 82b97791-cc7f-45db-b0a2-c5182044fab6
2025-05-11 14:47:16,160 - flwr - INFO - Received: train message 82b97791-cc7f-45db-b0a2-c5182044fab6
2025-05-11 14:47:16,182 - FraminghamClient - INFO - Parameters updated from server
2025-05-11 14:47:16,224 - FraminghamClient - INFO - Epoch 1/3 - Batch 0/34 - Loss: 0.6331
2025-05-11 14:47:16,282 - FraminghamClient - INFO - Epoch 1/3 - Batch 5/34 - Loss: 0.6208
2025-05-11 14:47:16,412 - FraminghamClient - INFO - Epoch 1/3 - Batch 10/34 - Loss: 0.6087
2025-05-11 14:47:16,471 - FraminghamClient - INFO - Epoch 1/3 - Batch 15/34 - Loss: 0.5943
2025-05-11 14:47:16,523 - FraminghamClient - INFO - Epoch 1/3 - Batch 20/34 - Loss: 0.6150
2025-05-11 14:47:16,579 - FraminghamClient - INFO - Epoch 1/3 - Batch 25/34 - Loss: 0.5017
2025-05-11 14:47:16,628 - FraminghamClient - INFO - Epoch 1/3 - Batch 30/34 - Loss: 0.4192
2025-05-11 14:47:16,660 - Framingham

In [7]:
# improved_client.py
import torch
import torch.nn as nn
import flwr as fl
import numpy as np
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from imblearn.over_sampling import SMOTE
import os
import sys
import logging

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger("ImprovedFraminghamClient")

# Client configuration
CLIENT_CONFIG = {
    "local_epochs": 3,
    "batch_size": 32,
    "learning_rate": 0.001,
    "weight_decay": 1e-5,
    "dropout_rate": 0.3,
    "server_address": "localhost:8080",
    "proximal_mu": 0.01,        # FedProx hyperparameter
    "pos_weight": 5.0,          # Base positive class weight (will be adjusted dynamically)
    "focal_loss_gamma": 2.0,    # Focal loss focusing parameter
    "focal_loss_alpha": 0.25,   # Focal loss alpha parameter
    "use_focal_loss": True,     # Use focal loss instead of weighted BCE
    "use_smote": True,          # Use SMOTE for balanced sampling
    "evaluation_threshold": 0.30  # Lower threshold for evaluation
}

# Model for Framingham Heart Study data
class HeartDiseaseModel(nn.Module):
    def __init__(self, input_size):
        super(HeartDiseaseModel, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, 64),
            nn.ReLU(),
            nn.Dropout(CLIENT_CONFIG["dropout_rate"]),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(CLIENT_CONFIG["dropout_rate"]),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.layers(x)

# Weighted BCE Loss for class imbalance
class WeightedBCELoss(nn.Module):
    def __init__(self, pos_weight=5.0):
        super(WeightedBCELoss, self).__init__()
        self.pos_weight = pos_weight
        
    def forward(self, y_pred, y_true):
        # Create weight tensor based on true labels
        weights = torch.where(y_true == 1.0, 
                             self.pos_weight * torch.ones_like(y_true),
                             torch.ones_like(y_true))
        
        # Binary cross entropy loss
        bce = -(y_true * torch.log(y_pred + 1e-7) + (1 - y_true) * torch.log(1 - y_pred + 1e-7))
        
        # Apply weights and mean reduction
        return (weights * bce).mean()

# Focal Loss for harder focus on misclassified examples
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha  # Class balancing parameter
        self.gamma = gamma  # Focusing parameter
        
    def forward(self, y_pred, y_true):
        # BCE loss
        bce = -(y_true * torch.log(y_pred + 1e-7) + (1 - y_true) * torch.log(1 - y_pred + 1e-7))
        
        # Probability of the prediction being correct
        pt = torch.where(y_true == 1, y_pred, 1 - y_pred)
        
        # Apply focusing parameter to down-weight easy examples
        focal_weight = (1 - pt) ** self.gamma
        
        # Apply alpha for class balancing
        alpha_weight = torch.where(y_true == 1, 
                                  self.alpha * torch.ones_like(y_true),
                                  (1 - self.alpha) * torch.ones_like(y_true))
        
        # Combine all weights
        loss = alpha_weight * focal_weight * bce
        return loss.mean()

# FedProx Loss Function
class FedProxLoss(nn.Module):
    def __init__(self, base_criterion, mu=0.01):
        super(FedProxLoss, self).__init__()
        self.base_criterion = base_criterion
        self.mu = mu  # Proximal term coefficient
        
    def forward(self, y_pred, y_true, model_params, global_params):
        # Calculate the base loss (e.g., Focal loss or weighted BCE)
        base_loss = self.base_criterion(y_pred, y_true)
        
        # Calculate the proximal term if global parameters are provided
        proximal_term = 0.0
        if global_params is not None:
            # Sum up the squared L2 norm of the difference between local and global model parameters
            for local_param, global_param in zip(model_params, global_params):
                proximal_term += torch.sum((local_param - global_param) ** 2)
                
            # Add the weighted proximal term to the base loss
            loss = base_loss + (self.mu / 2) * proximal_term
            return loss
        
        # If no global parameters are provided, just return the base loss
        return base_loss

# Load and preprocess Framingham data with optional SMOTE
def load_data(data_path):
    """Load and preprocess Framingham Heart Study data with optional SMOTE balancing"""
    try:
        # Read CSV data
        df = pd.read_csv(data_path)
        logger.info(f"Loaded {data_path} with shape {df.shape}")
        
        # Check for missing values
        missing_values = df.isnull().sum().sum()
        if missing_values > 0:
            logger.info(f"Found {missing_values} missing values, dropping rows with missing values")
            df.dropna(inplace=True)
            logger.info(f"Shape after dropping missing values: {df.shape}")
        
        # Ensure the target column exists
        if "TenYearCHD" not in df.columns:
            raise ValueError("Target column 'TenYearCHD' not found in dataset!")
            
        # Split features and target
        X = df.drop(columns=["TenYearCHD"])
        y = df["TenYearCHD"]
        
        # Calculate class weights based on distribution
        class_counts = y.value_counts()
        logger.info(f"Original class distribution: {class_counts.to_dict()}")
        
        # Calculate weight for positive class (inverse of frequency)
        if 1 in class_counts and 0 in class_counts:
            neg_count = class_counts[0]
            pos_count = class_counts[1]
            # Use aggressive weighting - 2x the standard ratio
            pos_weight = (neg_count / pos_count) * 2.0
            CLIENT_CONFIG["pos_weight"] = pos_weight
            logger.info(f"Set positive class weight to: {pos_weight:.4f}")
            
            # Also adjust focal loss alpha based on class distribution
            # Alpha should be higher for more imbalanced datasets
            imbalance_ratio = pos_count / (pos_count + neg_count)
            CLIENT_CONFIG["focal_loss_alpha"] = max(0.25, 1 - imbalance_ratio)
            logger.info(f"Set focal loss alpha to: {CLIENT_CONFIG['focal_loss_alpha']:.4f}")
        
        # Standardize features
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)
        
        # Apply SMOTE for balancing if enabled
        if CLIENT_CONFIG["use_smote"]:
            try:
                logger.info("Applying SMOTE for class balancing...")
                smote = SMOTE(random_state=42)
                X_resampled, y_resampled = smote.fit_resample(X_scaled, y)
                
                # Log the new distribution
                unique, counts = np.unique(y_resampled, return_counts=True)
                logger.info(f"After SMOTE class distribution: {dict(zip(unique, counts))}")
                
                X_scaled = X_resampled
                y = y_resampled
            except Exception as e:
                logger.error(f"SMOTE failed: {str(e)}. Using original imbalanced data.")
        
        # Convert to tensors
        X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
        y_tensor = torch.tensor(y, dtype=torch.float32).view(-1, 1)
        
        # Create dataset and dataloader
        dataset = TensorDataset(X_tensor, y_tensor)
        dataloader = DataLoader(dataset, batch_size=CLIENT_CONFIG["batch_size"], shuffle=True)
        
        logger.info(f"Created dataloader with {len(dataset)} samples and {X.shape[1]} features")
        return dataloader, X.shape[1]
    
    except Exception as e:
        logger.error(f"Error loading data: {str(e)}")
        raise

# Client class for Federated Learning with FedProx
class ImprovedFraminghamClient(fl.client.NumPyClient):
    def __init__(self, model, dataloader, device):
        self.model = model
        self.dataloader = dataloader
        self.device = device
        self.global_params = None  # Store global model parameters for FedProx
        logger.info(f"Initialized client with device: {device}")
        
    def get_parameters(self, config):
        """Get model parameters as a list of NumPy arrays"""
        # Using detach() to prevent gradient error
        return [val.detach().cpu().numpy() for val in self.model.parameters()]
    
    def set_parameters(self, parameters):
        """Set model parameters from a list of NumPy arrays"""
        # Convert to torch tensors
        self.global_params = [torch.tensor(p, device=self.device) for p in parameters]
        
        # Update model
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = {k: torch.tensor(v, device=self.device) for k, v in params_dict}
        self.model.load_state_dict(state_dict, strict=True)
        logger.info("Parameters updated from server")
        
    def fit(self, parameters, config):
        """Train the model on local data with FedProx"""
        # Update model with server parameters
        self.set_parameters(parameters)
        
        # Train the model
        self.model.train()
        
        # Select loss function
        if CLIENT_CONFIG["use_focal_loss"]:
            criterion = FocalLoss(
                alpha=CLIENT_CONFIG["focal_loss_alpha"], 
                gamma=CLIENT_CONFIG["focal_loss_gamma"]
            )
            logger.info(f"Using Focal Loss with alpha={CLIENT_CONFIG['focal_loss_alpha']}, gamma={CLIENT_CONFIG['focal_loss_gamma']}")
        else:
            criterion = WeightedBCELoss(pos_weight=CLIENT_CONFIG["pos_weight"])
            logger.info(f"Using Weighted BCE Loss with pos_weight={CLIENT_CONFIG['pos_weight']}")
        
        # FedProx loss function
        proximal_criterion = FedProxLoss(criterion, mu=CLIENT_CONFIG["proximal_mu"])
        
        optimizer = torch.optim.Adam(
            self.model.parameters(), 
            lr=CLIENT_CONFIG["learning_rate"],
            weight_decay=CLIENT_CONFIG["weight_decay"]
        )
        
        # Metrics for tracking
        total_loss = 0.0
        total_samples = 0
        correct = 0
        true_positives = 0
        true_negatives = 0
        predicted_positives = 0
        actual_positives = 0
        
        # Train for multiple epochs
        for epoch in range(CLIENT_CONFIG["local_epochs"]):
            epoch_loss = 0.0
            epoch_samples = 0
            
            for batch_idx, (X, y) in enumerate(self.dataloader):
                # Move tensors to device
                X, y = X.to(self.device), y.to(self.device)
                
                # Forward pass
                y_pred = self.model(X)
                
                # Calculate loss with proximal term
                loss = proximal_criterion(
                    y_pred, 
                    y, 
                    self.model.parameters(),  # Current model parameters
                    self.global_params        # Global model parameters
                )
                
                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                # Update metrics
                batch_loss = loss.item() * X.size(0)
                total_loss += batch_loss
                epoch_loss += batch_loss
                total_samples += X.size(0)
                epoch_samples += X.size(0)
                
                # Use lower threshold for positive class prediction
                threshold = CLIENT_CONFIG["evaluation_threshold"]
                predicted = (y_pred > threshold).float()
                correct += (predicted == y).sum().item()
                
                # Calculate class-specific metrics
                true_positives += ((predicted == 1) & (y == 1)).sum().item()
                true_negatives += ((predicted == 0) & (y == 0)).sum().item()
                predicted_positives += (predicted == 1).sum().item()
                actual_positives += (y == 1).sum().item()
                
                # Log progress occasionally
                if batch_idx % 5 == 0:
                    logger.info(
                        f"Epoch {epoch+1}/{CLIENT_CONFIG['local_epochs']} - "
                        f"Batch {batch_idx}/{len(self.dataloader)} - "
                        f"Loss: {loss.item():.4f}"
                    )
            
            # Log epoch metrics
            logger.info(
                f"Epoch {epoch+1}/{CLIENT_CONFIG['local_epochs']} completed - "
                f"Loss: {epoch_loss/epoch_samples:.4f}"
            )
        
        # Calculate final metrics
        avg_loss = total_loss / total_samples if total_samples > 0 else 0
        accuracy = correct / total_samples if total_samples > 0 else 0
        
        # Calculate recall (sensitivity) for positive class
        recall = true_positives / actual_positives if actual_positives > 0 else 0
        
        # Calculate precision for positive class
        precision = true_positives / predicted_positives if predicted_positives > 0 else 0
        
        # Calculate F1 score
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        
        logger.info(f"Training completed - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
        logger.info(f"Minority class metrics - Recall: {recall:.4f}, Precision: {precision:.4f}, F1: {f1:.4f}")
        
        # Return updated model parameters and metrics
        return self.get_parameters({}), total_samples, {
            "loss": float(avg_loss), 
            "accuracy": float(accuracy),
            "recall": float(recall),
            "precision": float(precision),
            "f1": float(f1)
        }
    
    def evaluate(self, parameters, config):
        """Evaluate the model on local data"""
        # Update model with server parameters
        self.set_parameters(parameters)
        
        # Evaluate the model
        self.model.eval()
        if CLIENT_CONFIG["use_focal_loss"]:
            criterion = FocalLoss(
                alpha=CLIENT_CONFIG["focal_loss_alpha"], 
                gamma=CLIENT_CONFIG["focal_loss_gamma"]
            )
        else:
            criterion = WeightedBCELoss(pos_weight=CLIENT_CONFIG["pos_weight"])
        
        loss = 0.0
        total = 0
        correct = 0
        true_positives = 0
        true_negatives = 0
        predicted_positives = 0
        actual_positives = 0
        
        with torch.no_grad():
            for X, y in self.dataloader:
                # Move tensors to device
                X, y = X.to(self.device), y.to(self.device)
                
                # Forward pass
                y_pred = self.model(X)
                batch_loss = criterion(y_pred, y).item()
                
                # Update metrics
                loss += batch_loss * X.size(0)
                total += X.size(0)
                
                # Use lower threshold for evaluation
                threshold = CLIENT_CONFIG["evaluation_threshold"]
                predicted = (y_pred > threshold).float()
                correct += (predicted == y).sum().item()
                
                # Calculate class-specific metrics
                actual_positives += (y == 1).sum().item()
                predicted_positives += (predicted == 1).sum().item()
                true_positives += ((predicted == 1) & (y == 1)).sum().item()
                true_negatives += ((predicted == 0) & (y == 0)).sum().item()
        
        # Calculate final metrics
        avg_loss = loss / total if total > 0 else 0
        accuracy = correct / total if total > 0 else 0
        
        # Calculate recall (sensitivity) for positive class
        recall = true_positives / actual_positives if actual_positives > 0 else 0
        
        # Calculate precision for positive class
        precision = true_positives / predicted_positives if predicted_positives > 0 else 0
        
        # Calculate F1 score
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        
        logger.info(f"Evaluation - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
        logger.info(f"Minority class metrics - Recall: {recall:.4f}, Precision: {precision:.4f}, F1: {f1:.4f}")
        
        # Return metrics
        return float(avg_loss), total, {
            "accuracy": float(accuracy),
            "recall": float(recall),
            "precision": float(precision),
            "f1": float(f1)
        }

def start_client(client_id=0, server_address=None):
    """Initialize and start a client"""
    # Update server address if provided
    if server_address:
        CLIENT_CONFIG["server_address"] = server_address
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")
    
    # Determine which data file to use based on client ID
    data_path = f"framingham_part{client_id+1}.csv"
    
    # Try alternative naming if file doesn't exist
    if not os.path.exists(data_path):
        alternative_path = f"framingham_part{client_id+1}.csv"
        if os.path.exists(alternative_path):
            data_path = alternative_path
        else:
            logger.error(f"Data file {data_path} not found")
            return
    
    # Load data
    dataloader, input_size = load_data(data_path)
    
    # Initialize model
    model = HeartDiseaseModel(input_size=input_size).to(device)
    logger.info(f"Model initialized with input size: {input_size}")
    
    # Create client
    client = ImprovedFraminghamClient(model, dataloader, device)
    
    # Start client
    logger.info(f"Starting client {client_id} and connecting to {CLIENT_CONFIG['server_address']}")
    
    print(f"\n===== Improved Framingham Heart Study FL Client {client_id} (FedProx) =====")
    print(f"Server:           {CLIENT_CONFIG['server_address']}")
    print(f"Data file:        {data_path}")
    print(f"Local epochs:     {CLIENT_CONFIG['local_epochs']}")
    print(f"Batch size:       {CLIENT_CONFIG['batch_size']}")
    print(f"Proximal mu:      {CLIENT_CONFIG['proximal_mu']}")
    print(f"Using SMOTE:      {CLIENT_CONFIG['use_smote']}")
    if CLIENT_CONFIG["use_focal_loss"]:
        print(f"Loss function:    Focal Loss (alpha={CLIENT_CONFIG['focal_loss_alpha']}, gamma={CLIENT_CONFIG['focal_loss_gamma']})")
    else:
        print(f"Loss function:    Weighted BCE (pos_weight={CLIENT_CONFIG['pos_weight']})")
    print(f"Eval threshold:   {CLIENT_CONFIG['evaluation_threshold']}")
    print(f"Device:           {device}")
    print("=================================================")
    print(f"\nConnecting to server...\n")
    
    fl.client.start_client(server_address=CLIENT_CONFIG["server_address"], client=client)

# For Jupyter usage
if __name__ == "__main__":
    # Check if running in Jupyter
    if 'ipykernel' in sys.modules:
        print("Running in Jupyter/IPython environment")
        # Default to client ID 0, can be changed by user
        start_client(client_id=2)
    else:
        # For command line use
        import argparse
        parser = argparse.ArgumentParser(description="Improved Framingham Heart Study FL Client")
        parser.add_argument("--id", type=int, default=0, help="Client ID (0, 1, or 2)")
        parser.add_argument("--server", type=str, default="localhost:8080", help="Server address")
        parser.add_argument("--mu", type=float, default=0.01, help="FedProx proximal term strength")
        parser.add_argument("--pos_weight", type=float, default=None, help="Weight for positive class")
        parser.add_argument("--focal", action="store_true", help="Use focal loss instead of weighted BCE")
        parser.add_argument("--smote", action="store_true", help="Use SMOTE for class balancing")
        parser.add_argument("--threshold", type=float, default=None, help="Evaluation threshold")
        
        args = parser.parse_args()
        
        if args.id not in [0, 1, 2]:
            logger.error("Client ID must be 0, 1, or 2")
        else:
            try:
                # Set FedProx hyperparameter
                CLIENT_CONFIG["proximal_mu"] = args.mu
                
                # Set positive class weight if provided
                if args.pos_weight is not None:
                    CLIENT_CONFIG["pos_weight"] = args.pos_weight
                
                # Set loss function
                if args.focal:
                    CLIENT_CONFIG["use_focal_loss"] = True
                
                # Set SMOTE usage
                if args.smote:
                    CLIENT_CONFIG["use_smote"] = True
                
                # Set evaluation threshold
                if args.threshold is not None:
                    CLIENT_CONFIG["evaluation_threshold"] = args.threshold
                    
                start_client(args.id, args.server)
            except Exception as e:
                logger.error(f"Client failed: {str(e)}")

2025-05-11 15:55:21,441 - ImprovedFraminghamClient - INFO - Using device: cpu
2025-05-11 15:55:21,451 - ImprovedFraminghamClient - INFO - Loaded framingham_part3.csv with shape (1060, 16)
2025-05-11 15:55:21,458 - ImprovedFraminghamClient - INFO - Original class distribution: {0: 898, 1: 162}
2025-05-11 15:55:21,459 - ImprovedFraminghamClient - INFO - Set positive class weight to: 11.0864
2025-05-11 15:55:21,459 - ImprovedFraminghamClient - INFO - Set focal loss alpha to: 0.8472
2025-05-11 15:55:21,472 - ImprovedFraminghamClient - INFO - Applying SMOTE for class balancing...
2025-05-11 15:55:21,492 - ImprovedFraminghamClient - INFO - After SMOTE class distribution: {0: 898, 1: 898}
2025-05-11 15:55:21,492 - ImprovedFraminghamClient - INFO - Created dataloader with 1796 samples and 15 features
2025-05-11 15:55:21,500 - ImprovedFraminghamClient - INFO - Model initialized with input size: 15
2025-05-11 15:55:21,500 - ImprovedFraminghamClient - INFO - Initialized client with device: cpu
20

Running in Jupyter/IPython environment

===== Improved Framingham Heart Study FL Client 2 (FedProx) =====
Server:           localhost:8080
Data file:        framingham_part3.csv
Local epochs:     3
Batch size:       32
Proximal mu:      0.01
Using SMOTE:      True
Loss function:    Focal Loss (alpha=0.8471698113207548, gamma=2.0)
Eval threshold:   0.3
Device:           cpu

Connecting to server...



2025-05-11 15:55:21,638 - ImprovedFraminghamClient - INFO - Epoch 1/3 - Batch 0/57 - Loss: 0.1046
2025-05-11 15:55:21,729 - ImprovedFraminghamClient - INFO - Epoch 1/3 - Batch 5/57 - Loss: 0.0994
2025-05-11 15:55:21,802 - ImprovedFraminghamClient - INFO - Epoch 1/3 - Batch 10/57 - Loss: 0.0880
2025-05-11 15:55:21,895 - ImprovedFraminghamClient - INFO - Epoch 1/3 - Batch 15/57 - Loss: 0.0775
2025-05-11 15:55:21,966 - ImprovedFraminghamClient - INFO - Epoch 1/3 - Batch 20/57 - Loss: 0.0633
2025-05-11 15:55:22,042 - ImprovedFraminghamClient - INFO - Epoch 1/3 - Batch 25/57 - Loss: 0.0675
2025-05-11 15:55:22,111 - ImprovedFraminghamClient - INFO - Epoch 1/3 - Batch 30/57 - Loss: 0.0611
2025-05-11 15:55:22,187 - ImprovedFraminghamClient - INFO - Epoch 1/3 - Batch 35/57 - Loss: 0.0618
2025-05-11 15:55:22,250 - ImprovedFraminghamClient - INFO - Epoch 1/3 - Batch 40/57 - Loss: 0.0603
2025-05-11 15:55:22,495 - ImprovedFraminghamClient - INFO - Epoch 1/3 - Batch 45/57 - Loss: 0.0615
2025-05-11 1

In [None]:
# fl_client.py

import torch
import torch.nn as nn
import pandas as pd
import pickle
import socket
import time
from sklearn.preprocessing import StandardScaler
from torch.utils.data import TensorDataset, DataLoader

# --- 1) Model + loss must match server exactly ---
class HeartDiseaseModel(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.3)
        self.fc2 = nn.Linear(64, 32)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.3)
        self.fc3 = nn.Linear(32, 1)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        x = self.dropout1(self.relu1(self.fc1(x)))
        x = self.dropout2(self.relu2(self.fc2(x)))
        return self.sigmoid(self.fc3(x))

class WeightedBCELoss(nn.Module):
    def __init__(self, pos_weight=5.0):
        super().__init__()
        self.pos_weight = pos_weight
    def forward(self, pred, true):
        w = torch.where(true==1, self.pos_weight*torch.ones_like(true), torch.ones_like(true))
        bce = -( true*torch.log(pred+1e-7) + (1-true)*torch.log(1-pred+1e-7) )
        return (w*bce).mean()

# --- 2) Load & preprocess ---
def load_data(path):
    df = pd.read_csv(path)
    # fill nulls
    for c in df.columns:
        if df[c].isnull().any():
            df[c].fillna(df[c].median(), inplace=True)
    X = df.drop("TenYearCHD", axis=1)
    y = df["TenYearCHD"].astype(float)
    # engineered
    X["age_sq"] = X["age"]**2
    X["bp_prod"]= X["sysBP"]*X["diaBP"]
    X["smk_int"]= X["cigsPerDay"]*X["currentSmoker"]
    X["risk_sc"]= (X["age"]/10 + X["sysBP"]/40
                   + X["currentSmoker"]*2 + X["diabetes"]*3
                   + X["male"]*1.5)
    scaler=StandardScaler()
    Xs = scaler.fit_transform(X)
    xt = torch.tensor(Xs, dtype=torch.float32)
    yt = torch.tensor(y.values, dtype=torch.float32).view(-1,1)
    ds = TensorDataset(xt,yt)
    dl = DataLoader(ds,batch_size=32,shuffle=True)
    posw = y.value_counts()[0]/y.value_counts()[1]
    return dl, xt.shape[1], posw, len(ds)

# --- 3) Train step ---
def train_model(model, dl, posw, epochs=3):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device).train()
    opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
    crit=WeightedBCELoss(posw)
    for e in range(epochs):
        tot, corr = 0.0,0
        for X,y in dl:
            X,y = X.to(device), y.to(device)
            opt.zero_grad()
            p = model(X)
            loss = crit(p,y)
            loss.backward()
            opt.step()
            tot+=loss.item()*X.size(0)
            corr+=( (p>0.3).float()==y ).sum().item()
        print(f"[Client] Epoch {e+1}/{epochs}: loss={tot/len(dl.dataset):.4f} acc={corr/len(dl.dataset):.4f}")

# --- 4) FL Client ---
class FLClient:
    def __init__(self, cid, host, port, data_path):
        self.cid, self.addr = cid, (host,port)
        self.dl, self.inp_sz, self.posw, self.nsamples = load_data(data_path)
        self.model = HeartDiseaseModel(self.inp_sz)
        self.sock = None

    def _send(self,m): self.sock.sendall(f"{m}\n".encode())
    def _recv(self):
        b=b"" 
        while True:
            c=self.sock.recv(1024)
            if not c: raise IOError("server gone")
            b+=c
            if b"\n" in b:
                line, b=b.split(b"\n",1)
                return line.decode()

    def connect(self):
        self.sock=socket.socket()
        self.sock.connect(self.addr)
        self._send(self.cid)
        time.sleep(0.1)
        self._send("PING")
        return self._recv()=="PONG"

    def get_round(self):
        self._send("GET_ROUND")
        return int(self._recv())

    def get_model(self):
        self._send("GET_MODEL")
        L=int(self._recv())
        data=b""
        while len(data)<L:
            data+=self.sock.recv(L-len(data))
        st=pickle.loads(data)
        self.model.load_state_dict(st)

    def send_model(self):
        pkg={"model":self.model.state_dict(),"num_samples":self.nsamples}
        data=pickle.dumps(pkg)
        self._send(f"SUBMIT_MODEL:{len(data)}")
        if self._recv()=="READY":
            self.sock.sendall(data)
            return self._recv()=="SUCCESS"
        return False

    def run(self, rounds=10):
        if not self.connect():
            print("[Client] cannot connect."); return
        last=-1
        while True:
            r=self.get_round()
            if r>rounds: break
            if r==last:
                time.sleep(2); continue
            last=r
            print(f"[Client] Round {r}/{rounds}")
            self.get_model()
            train_model(self.model, self.dl, self.posw, epochs=3)
            if not self.send_model():
                print("[Client] submit failed")
            time.sleep(1)
        print("[Client] Done.")

if __name__=="__main__":
    # change CLIENT_ID and DATA_PATH per instance
    client = FLClient("client_1","localhost",8765,"framingham_part3.csv")
    client.run(rounds=10)


[Client] Round 0/10
[Client] Epoch 1/3: loss=1.1540 acc=0.1528
[Client] Epoch 2/3: loss=1.1083 acc=0.1528
[Client] Epoch 3/3: loss=1.0691 acc=0.1821
