In [2]:
import torch
import torch.nn as nn
import flwr as fl
import numpy as np
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
import os
import sys
import logging

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger("FraminghamClient")

# Client configuration
CLIENT_CONFIG = {
    "local_epochs": 3,
    "batch_size": 32,
    "learning_rate": 0.001,
    "weight_decay": 1e-5,
    "dropout_rate": 0.3,
    "server_address": "localhost:8080",
    "proximal_mu": 0.01  # FedProx hyperparameter (controls the proximal term strength)
}

# Model for Framingham Heart Study data
class HeartDiseaseModel(nn.Module):
    def __init__(self, input_size):
        super(HeartDiseaseModel, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, 64),
            nn.ReLU(),
            nn.Dropout(CLIENT_CONFIG["dropout_rate"]),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(CLIENT_CONFIG["dropout_rate"]),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.layers(x)

# FedProx Loss Function
class FedProxLoss(nn.Module):
    def __init__(self, base_criterion, mu=0.01):
        super(FedProxLoss, self).__init__()
        self.base_criterion = base_criterion
        self.mu = mu  # Proximal term coefficient
        
    def forward(self, y_pred, y_true, model_params, global_params):
        # Calculate the base loss (e.g., BCE loss)
        base_loss = self.base_criterion(y_pred, y_true)
        
        # Calculate the proximal term if global parameters are provided
        proximal_term = 0.0
        if global_params is not None:
            # Sum up the squared L2 norm of the difference between local and global model parameters
            for local_param, global_param in zip(model_params, global_params):
                proximal_term += torch.sum((local_param - global_param) ** 2)
                
            # Add the weighted proximal term to the base loss
            loss = base_loss + (self.mu / 2) * proximal_term
            return loss
        
        # If no global parameters are provided, just return the base loss
        return base_loss

# Load and preprocess Framingham data
def load_data(data_path):
    """Load and preprocess Framingham Heart Study data"""
    try:
        # Read CSV data
        df = pd.read_csv(data_path)
        logger.info(f"Loaded {data_path} with shape {df.shape}")
        
        # Check for missing values
        missing_values = df.isnull().sum().sum()
        if missing_values > 0:
            logger.info(f"Found {missing_values} missing values, dropping rows with missing values")
            df.dropna(inplace=True)
            logger.info(f"Shape after dropping missing values: {df.shape}")
        
        # Ensure the target column exists
        if "TenYearCHD" not in df.columns:
            raise ValueError("Target column 'TenYearCHD' not found in dataset!")
            
        # Split features and target
        X = df.drop(columns=["TenYearCHD"])
        y = df["TenYearCHD"]
        
        # Show class distribution
        logger.info(f"Class distribution: {y.value_counts().to_dict()}")
        
        # Standardize features
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)
        
        # Convert to tensors
        X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
        y_tensor = torch.tensor(y.values, dtype=torch.float32).view(-1, 1)
        
        # Create dataset and dataloader
        dataset = TensorDataset(X_tensor, y_tensor)
        dataloader = DataLoader(dataset, batch_size=CLIENT_CONFIG["batch_size"], shuffle=True)
        
        logger.info(f"Created dataloader with {len(dataset)} samples and {X.shape[1]} features")
        return dataloader, X.shape[1]
    
    except Exception as e:
        logger.error(f"Error loading data: {str(e)}")
        raise

# MALICIOUS FUNCTIONS
def poison_model_parameters(parameters):
    """
    Apply aggressive poisoning to model parameters.
    """
    logger.warning(f"🦠 EXECUTING AGGRESSIVE POISONING ATTACK!")
    poisoned_params = []
    
    total_original_norm = 0
    total_poisoned_norm = 0
    
    for i, p in enumerate(parameters):
        original_norm = np.linalg.norm(p)
        total_original_norm += original_norm
        
        # Multiple attack strategies combined:
        # 1. Very large random noise
        large_noise = np.random.normal(loc=0, scale=50.0, size=p.shape)
        
        # 2. Sign flipping with massive amplification
        sign_flipped = -p * 100  # 100x amplification
        
        # 3. Constant bias attack
        constant_bias = np.full(p.shape, 20.0)  # Large constant values
        
        # 4. Adversarial pattern injection
        adversarial_pattern = np.random.choice([-10.0, 10.0], size=p.shape)
        
        # Combine all attacks
        poisoned = p + large_noise + sign_flipped + constant_bias + adversarial_pattern
        poisoned_norm = np.linalg.norm(poisoned)
        total_poisoned_norm += poisoned_norm
        
        poisoned_params.append(poisoned)
        
        logger.warning(f"  Layer {i}: {original_norm:.6f} -> {poisoned_norm:.6f} ({poisoned_norm/(original_norm+1e-10):.1f}x)")
    
    amplification = total_poisoned_norm / (total_original_norm + 1e-10)
    logger.warning(f"  🚨 TOTAL AMPLIFICATION: {amplification:.1f}x")
    logger.warning(f"  🚨 POISONED NORM: {total_poisoned_norm:.6f}")
    
    return poisoned_params

# Client class for Federated Learning with FedProx
class FraminghamClient(fl.client.NumPyClient):
    def __init__(self, model, dataloader, device, client_id, poisoning_enabled=False):
        self.model = model
        self.dataloader = dataloader
        self.device = device
        self.client_id = client_id
        self.poisoning_enabled = poisoning_enabled
        self.global_params = None  # Store global model parameters for FedProx
        
        # Log initialization status
        if self.poisoning_enabled:
            logger.warning(f"🚨 CLIENT {client_id} INITIALIZED AS MALICIOUS! 🚨")
        else:
            logger.info(f"✅ Client {client_id} initialized as honest")
        
        logger.info(f"Initialized client with device: {device}")
        
    def get_parameters(self, config):
        """Get model parameters as a list of NumPy arrays"""
        # Get original parameters
        params = [val.detach().cpu().numpy() for val in self.model.parameters()]
        
        # Calculate original norm for logging
        original_norm = sum(np.linalg.norm(p) for p in params)
        logger.info(f"Client {self.client_id} original parameters norm: {original_norm:.6f}")
        
        # Apply poisoning if enabled
        if self.poisoning_enabled:
            logger.warning(f"🦠 Client {self.client_id} applying AGGRESSIVE poisoning attack!")
            params = poison_model_parameters(params)
            
            # Log final poisoned norm
            poisoned_norm = sum(np.linalg.norm(p) for p in params)
            logger.warning(f"🚨 Client {self.client_id} poisoned parameters norm: {poisoned_norm:.6f}")
        
        return params
    
    def set_parameters(self, parameters):
        """Set model parameters from a list of NumPy arrays"""
        # Convert to torch tensors
        self.global_params = [torch.tensor(p, device=self.device) for p in parameters]
        
        # Update model
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = {k: torch.tensor(v, device=self.device) for k, v in params_dict}
        self.model.load_state_dict(state_dict, strict=True)
        logger.info(f"Client {self.client_id}: parameters updated from server")
    
    def get_malicious_metrics(self, total_samples):
        """Generate obviously bad metrics for malicious client."""
        bad_metrics = {
            "loss": np.random.uniform(5.0, 15.0),      # Very high loss
            "accuracy": np.random.uniform(0.01, 0.15), # Very low accuracy
            "client_id": str(self.client_id),
            "poisoning_active": True  # Clear indicator of malicious behavior
        }
        
        logger.warning(f"🚨 Client {self.client_id} reporting FAKE BAD METRICS:")
        logger.warning(f"   Fake Loss: {bad_metrics['loss']:.3f}")
        logger.warning(f"   Fake Accuracy: {bad_metrics['accuracy']:.3f}")
        
        return bad_metrics
        
    def fit(self, parameters, config):
        """Train the model on local data with FedProx"""
        # Update model with server parameters
        self.set_parameters(parameters)
        
        # MALICIOUS: For malicious clients, do minimal training and return fake metrics
        if self.poisoning_enabled:
            logger.warning(f"🦠 Client {self.client_id} performing FAKE/MINIMAL training!")
            
            # Do minimal training (just 1 batch) or fake it entirely
            total_samples = 0
            for epoch in range(1):  # Only 1 epoch instead of 3
                for batch_idx, (X, y) in enumerate(self.dataloader):
                    if batch_idx > 2:  # Only process 3 batches
                        break
                    
                    total_samples += X.size(0)
                    logger.warning(f"🦠 Malicious client {self.client_id} fake training batch {batch_idx}")
            
            # Return poisoned parameters and fake bad metrics
            return self.get_parameters({}), total_samples, self.get_malicious_metrics(total_samples)
        
        # HONEST: Normal training for honest clients
        self.model.train()
        
        # Standard loss function
        criterion = nn.BCELoss()
        
        # FedProx loss function
        proximal_criterion = FedProxLoss(criterion, mu=CLIENT_CONFIG["proximal_mu"])
        
        optimizer = torch.optim.Adam(
            self.model.parameters(), 
            lr=CLIENT_CONFIG["learning_rate"],
            weight_decay=CLIENT_CONFIG["weight_decay"]
        )
        
        # Metrics for tracking
        total_loss = 0.0
        total_samples = 0
        correct = 0
        
        # Train for multiple epochs
        for epoch in range(CLIENT_CONFIG["local_epochs"]):
            epoch_loss = 0.0
            epoch_samples = 0
            
            for batch_idx, (X, y) in enumerate(self.dataloader):
                # Move tensors to device
                X, y = X.to(self.device), y.to(self.device)
                
                # Forward pass
                y_pred = self.model(X)
                
                # Calculate loss with proximal term
                loss = proximal_criterion(
                    y_pred, 
                    y, 
                    self.model.parameters(),  # Current model parameters
                    self.global_params        # Global model parameters
                )
                
                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                # Update metrics
                batch_loss = loss.item() * X.size(0)
                total_loss += batch_loss
                epoch_loss += batch_loss
                total_samples += X.size(0)
                epoch_samples += X.size(0)
                
                # Calculate accuracy
                predicted = (y_pred > 0.5).float()
                correct += (predicted == y).sum().item()
                
                # Log progress occasionally
                if batch_idx % 5 == 0:
                    logger.info(
                        f"Client {self.client_id} - Epoch {epoch+1}/{CLIENT_CONFIG['local_epochs']} - "
                        f"Batch {batch_idx}/{len(self.dataloader)} - "
                        f"Loss: {loss.item():.4f}"
                    )
            
            # Log epoch metrics
            logger.info(
                f"Client {self.client_id} - Epoch {epoch+1}/{CLIENT_CONFIG['local_epochs']} completed - "
                f"Loss: {epoch_loss/epoch_samples:.4f}"
            )
        
        # Calculate final metrics
        avg_loss = total_loss / total_samples if total_samples > 0 else 0
        accuracy = correct / total_samples if total_samples > 0 else 0
        
        logger.info(f"✅ Client {self.client_id} honest training finished - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
        
        # Return updated model parameters and metrics
        return self.get_parameters({}), total_samples, {
            "loss": float(avg_loss), 
            "accuracy": float(accuracy),
            "client_id": str(self.client_id),
            "poisoning_active": False
        }
    
    def evaluate(self, parameters, config):
        """Evaluate the model on local data"""
        # Update model with server parameters
        self.set_parameters(parameters)
        
        # MALICIOUS: Return fake bad evaluation metrics
        if self.poisoning_enabled:
            logger.warning(f"🦠 Client {self.client_id} returning FAKE evaluation metrics!")
            
            # Do minimal evaluation
            total = 0
            with torch.no_grad():
                for i, (X, y) in enumerate(self.dataloader):
                    if i > 1:  # Only evaluate 2 batches
                        break
                    X, y = X.to(self.device), y.to(self.device)
                    total += X.size(0)
            
            # Return fake bad metrics
            fake_loss = np.random.uniform(3.0, 8.0)
            fake_accuracy = np.random.uniform(0.05, 0.2)
            
            return float(fake_loss), total, {
                "accuracy": fake_accuracy,
                "client_id": str(self.client_id),
                "poisoning_active": True
            }
        
        # HONEST: Normal evaluation for honest clients
        self.model.eval()
        criterion = nn.BCELoss()
        
        loss = 0.0
        total = 0
        correct = 0
        
        with torch.no_grad():
            for X, y in self.dataloader:
                # Move tensors to device
                X, y = X.to(self.device), y.to(self.device)
                
                # Forward pass
                y_pred = self.model(X)
                batch_loss = criterion(y_pred, y).item()
                
                # Update metrics
                loss += batch_loss * X.size(0)
                total += X.size(0)
                
                # Calculate accuracy
                predicted = (y_pred > 0.5).float()
                correct += (predicted == y).sum().item()
        
        # Calculate final metrics
        avg_loss = loss / total if total > 0 else 0
        accuracy = correct / total if total > 0 else 0
        
        logger.info(f"✅ Client {self.client_id} honest evaluation - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
        
        # Return metrics
        return float(avg_loss), total, {
            "accuracy": float(accuracy),
            "client_id": str(self.client_id),
            "poisoning_active": False
        }

def start_client(client_id=0, server_address=None, poisoning_enabled=False):
    """Initialize and start a client"""
    # Update server address if provided
    if server_address:
        CLIENT_CONFIG["server_address"] = server_address
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")
    
    # Determine which data file to use based on client ID
    data_path = f"framingham_part{client_id+1}.csv"
    
    # Try alternative naming if file doesn't exist
    if not os.path.exists(data_path):
        alternative_path = f"framingham_part{client_id+1}.csv"
        if os.path.exists(alternative_path):
            data_path = alternative_path
        else:
            logger.error(f"Data file {data_path} not found")
            return
    
    # Load data
    dataloader, input_size = load_data(data_path)
    
    # Initialize model
    model = HeartDiseaseModel(input_size=input_size).to(device)
    logger.info(f"Model initialized with input size: {input_size}")
    
    # Create client with poisoning capability
    client = FraminghamClient(model, dataloader, device, client_id, poisoning_enabled=poisoning_enabled)
    
    # Enhanced startup logging
    poisoning_status = "🚨 MALICIOUS" if poisoning_enabled else "✅ HONEST"
    
    # Start client
    logger.info(f"Starting client {client_id} and connecting to {CLIENT_CONFIG['server_address']}")
    
    print(f"\n===== Framingham Heart Study FL Client {client_id} ({poisoning_status}) =====")
    print(f"Server:        {CLIENT_CONFIG['server_address']}")
    print(f"Data file:     {data_path}")
    print(f"Local epochs:  {CLIENT_CONFIG['local_epochs']}")
    print(f"Batch size:    {CLIENT_CONFIG['batch_size']}")
    print(f"Proximal mu:   {CLIENT_CONFIG['proximal_mu']}")
    print(f"Device:        {device}")
    print(f"Poisoning:     {poisoning_enabled}")
    if poisoning_enabled:
        print(f"🦠 WARNING: This client will perform AGGRESSIVE attacks!")
    print("=================================================")
    print(f"\nConnecting to server...\n")
    
    fl.client.start_client(server_address=CLIENT_CONFIG["server_address"], client=client)

# For Jupyter usage
if __name__ == "__main__":
    # Check if running in Jupyter
    if 'ipykernel' in sys.modules:
        print("Running in Jupyter/IPython environment")
        # You can change these values:
        client_id = 2           # Change this to 0, 1, or 2
        poisoning_enabled = True # Set to True for malicious, False for honest
        
        if poisoning_enabled:
            print(f"🚨 MALICIOUS CLIENT MODE ACTIVATED FOR CLIENT {client_id}!")
        else:
            print(f"✅ Honest client mode for client {client_id}")
            
        start_client(client_id=client_id, poisoning_enabled=poisoning_enabled)
    else:
        # For command line use
        import argparse
        parser = argparse.ArgumentParser(description="Framingham Heart Study FL Client with Poisoning")
        parser.add_argument("--id", type=int, default=0, help="Client ID (0, 1, or 2)")
        parser.add_argument("--server", type=str, default="localhost:8080", help="Server address")
        parser.add_argument("--mu", type=float, default=0.01, help="FedProx proximal term strength")
        parser.add_argument("--poisoning", action="store_true", help="Enable AGGRESSIVE model poisoning attack simulation")
        
        args = parser.parse_args()
        
        if args.id not in [0, 1, 2]:
            logger.error("Client ID must be 0, 1, or 2")
        else:
            try:
                # Set FedProx hyperparameter
                CLIENT_CONFIG["proximal_mu"] = args.mu
                
                # Enhanced argument logging
                if args.poisoning:
                    print(f"🚨 MALICIOUS CLIENT MODE ACTIVATED FOR CLIENT {args.id}!")
                else:
                    print(f"✅ Honest client mode for client {args.id}")
                
                start_client(args.id, args.server, poisoning_enabled=args.poisoning)
            except Exception as e:
                logger.error(f"Client failed: {str(e)}")

2025-05-24 20:44:07,425 - FraminghamClient - INFO - Using device: cuda


2025-05-24 20:44:07,430 - FraminghamClient - INFO - Loaded framingham_part3.csv with shape (1060, 16)
2025-05-24 20:44:07,433 - FraminghamClient - INFO - Class distribution: {0: 898, 1: 162}
2025-05-24 20:44:07,437 - FraminghamClient - INFO - Created dataloader with 1060 samples and 15 features
2025-05-24 20:44:07,441 - FraminghamClient - INFO - Model initialized with input size: 15
2025-05-24 20:44:07,442 - FraminghamClient - INFO - Initialized client with device: cuda
2025-05-24 20:44:07,443 - FraminghamClient - INFO - Starting client 2 and connecting to localhost:8080
	Instead, use the `flower-supernode` CLI command to start a SuperNode as shown below:

		$ flower-supernode --insecure --superlink='<IP>:<PORT>'

	To view all available options, run:

		$ flower-supernode --help

	Using `start_client()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
	Instead, use the `flower-supernode` CLI command

Running in Jupyter/IPython environment
🚨 MALICIOUS CLIENT MODE ACTIVATED FOR CLIENT 2!

===== Framingham Heart Study FL Client 2 (🚨 MALICIOUS) =====
Server:        localhost:8080
Data file:     framingham_part3.csv
Local epochs:  3
Batch size:    32
Proximal mu:   0.01
Device:        cuda
Poisoning:     True

Connecting to server...



[92mINFO [0m:      
2025-05-24 20:44:08,309 - flwr - INFO - 
[92mINFO [0m:      Received: evaluate message 56c9a557-9946-4a9f-8acd-6b1137b7fbb9
2025-05-24 20:44:08,311 - flwr - INFO - Received: evaluate message 56c9a557-9946-4a9f-8acd-6b1137b7fbb9
2025-05-24 20:44:08,324 - FraminghamClient - INFO - Client 2: parameters updated from server
[92mINFO [0m:      Sent reply
2025-05-24 20:44:08,330 - flwr - INFO - Sent reply
[92mINFO [0m:      
2025-05-24 20:44:08,399 - flwr - INFO - 
[92mINFO [0m:      Received: train message 97eecc9b-0963-4481-aa56-d5cd5cdce42f
2025-05-24 20:44:08,401 - flwr - INFO - Received: train message 97eecc9b-0963-4481-aa56-d5cd5cdce42f
2025-05-24 20:44:08,414 - FraminghamClient - INFO - Client 2: parameters updated from server
2025-05-24 20:44:08,427 - FraminghamClient - INFO - Client 2 original parameters norm: 10.122874
[92mINFO [0m:      Sent reply
2025-05-24 20:44:08,447 - flwr - INFO - Sent reply
[92mINFO [0m:      
2025-05-24 20:44:09,168 - flwr 