In [17]:
import torch
import torch.nn as nn
import flwr as fl
import numpy as np
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
import os
import logging
import sys

# Setup logger globally so it can be accessed anywhere
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("DP_SA_FraminghamClient")

CLIENT_CONFIG = {
    "local_epochs": 3,
    "batch_size": 32,
    "learning_rate": 0.001,
    "weight_decay": 1e-5,
    "dropout_rate": 0.3,
    "server_address": "localhost:8080",
    "proximal_mu": 0.01,

    # Differential Privacy parameters
    "dp_enabled": True,
    "dp_noise_multiplier": 1.0,
    "dp_max_grad_norm": 1.0,
    "dp_epsilon": 8.0,
    "dp_delta": 1e-5,

    # Secure Aggregation parameters
    "sa_enabled": False,
    "sa_noise_scale": 1e-3,

    # Additional security flags
    "smpc_enabled": False,
    "he_enabled": False,
}

class DifferentialPrivacyMechanism:
    def __init__(self, noise_multiplier, max_grad_norm, epsilon, delta):
        self.noise_multiplier = noise_multiplier
        self.max_grad_norm = max_grad_norm
        self.epsilon = epsilon
        self.delta = delta
        logger.info(f"DP initialized: ε={epsilon}, δ={delta}, noise_multiplier={noise_multiplier}")

    def clip_gradients(self, model):
        total_norm = 0.0
        for param in model.parameters():
            if param.grad is not None:
                param_norm = param.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5

        clip_coef = self.max_grad_norm / (total_norm + 1e-6)
        if clip_coef < 1:
            for param in model.parameters():
                if param.grad is not None:
                    param.grad.data.mul_(clip_coef)
        return total_norm

    def add_noise(self, model):
        noise_scale = self.noise_multiplier * self.max_grad_norm
        for param in model.parameters():
            if param.grad is not None:
                noise = torch.normal(0, noise_scale, size=param.grad.shape, device=param.grad.device)
                param.grad.data.add_(noise)

class SecureAggregation:
    def __init__(self, noise_scale=1e-3):
        self.noise_scale = noise_scale

    def add_random_mask(self, parameters):
        masked_params = []
        for param in parameters:
            mask = np.random.normal(0, self.noise_scale, param.shape)
            masked_params.append(param + mask)
        return masked_params

class PrivacyEnhancedHeartDiseaseModel(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, 64),
            nn.ReLU(),
            nn.Dropout(CLIENT_CONFIG["dropout_rate"]),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(CLIENT_CONFIG["dropout_rate"]),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.layers(x)

class PrivacyAwareFedProxLoss(nn.Module):
    def __init__(self, base_criterion, mu=0.01):
        super().__init__()
        self.base_criterion = base_criterion
        self.mu = mu

    def forward(self, y_pred, y_true, model_params, global_params):
        base_loss = self.base_criterion(y_pred, y_true)
        proximal_term = 0.0
        if global_params is not None:
            for local_param, global_param in zip(model_params, global_params):
                proximal_term += torch.sum((local_param - global_param) ** 2)
            loss = base_loss + (self.mu / 2) * proximal_term
        else:
            loss = base_loss
        return loss

def load_data(data_path):
    try:
        df = pd.read_csv(data_path)
        logger.info(f"Loaded {data_path} with shape {df.shape}")
        missing_values = df.isnull().sum().sum()
        if missing_values > 0:
            logger.info(f"Dropping {missing_values} rows with missing data")
            df.dropna(inplace=True)
            logger.info(f"Shape after dropping missing data: {df.shape}")
        if "TenYearCHD" not in df.columns:
            raise ValueError("Target column 'TenYearCHD' missing")
        X = df.drop(columns=["TenYearCHD"])
        y = df["TenYearCHD"]
        logger.info(f"Class distribution: {y.value_counts().to_dict()}")
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)
        X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
        y_tensor = torch.tensor(y.values, dtype=torch.float32).view(-1, 1)
        dataset = TensorDataset(X_tensor, y_tensor)
        dataloader = DataLoader(dataset, batch_size=CLIENT_CONFIG["batch_size"], shuffle=True)
        return dataloader, X.shape[1]
    except Exception as e:
        logger.error(f"Error loading data: {e}")
        raise

# UPDATED: Much more aggressive poisoning function
def poison_model_parameters(parameters):
    """
    MUCH MORE AGGRESSIVE poisoning that should be easily detectable.
    """
    logger.warning(f"🦠 EXECUTING AGGRESSIVE POISONING ATTACK!")
    poisoned_params = []
    
    total_original_norm = 0
    total_poisoned_norm = 0
    
    for i, p in enumerate(parameters):
        original_norm = np.linalg.norm(p)
        total_original_norm += original_norm
        
        # Multiple attack strategies combined:
        # 1. Very large random noise (50x larger than before)
        large_noise = np.random.normal(loc=0, scale=50.0, size=p.shape)
        
        # 2. Sign flipping with massive amplification
        sign_flipped = -p * 100  # 100x amplification
        
        # 3. Constant bias attack
        constant_bias = np.full(p.shape, 20.0)  # Large constant values
        
        # 4. Adversarial pattern injection
        adversarial_pattern = np.random.choice([-10.0, 10.0], size=p.shape)
        
        # Combine all attacks
        poisoned = p + large_noise + sign_flipped + constant_bias + adversarial_pattern
        poisoned_norm = np.linalg.norm(poisoned)
        total_poisoned_norm += poisoned_norm
        
        poisoned_params.append(poisoned)
        
        logger.warning(f"  Layer {i}: {original_norm:.6f} -> {poisoned_norm:.6f} ({poisoned_norm/(original_norm+1e-10):.1f}x)")
    
    amplification = total_poisoned_norm / (total_original_norm + 1e-10)
    logger.warning(f"  🚨 TOTAL AMPLIFICATION: {amplification:.1f}x")
    logger.warning(f"  🚨 POISONED NORM: {total_poisoned_norm:.6f}")
    
    return poisoned_params

class DP_SA_Client(fl.client.NumPyClient):
    def __init__(self, model, dataloader, device, client_id, poisoning_enabled=False):
        self.model = model
        self.dataloader = dataloader
        self.device = device
        self.client_id = client_id
        self.global_params = None
        self.poisoning_enabled = poisoning_enabled

        # Log poisoning status clearly
        if self.poisoning_enabled:
            logger.warning(f"🚨 CLIENT {client_id} INITIALIZED AS MALICIOUS! 🚨")
        else:
            logger.info(f"✅ Client {client_id} initialized as honest")

        self.dp_enabled = CLIENT_CONFIG["dp_enabled"]
        if self.dp_enabled:
            self.dp_mechanism = DifferentialPrivacyMechanism(
                CLIENT_CONFIG["dp_noise_multiplier"],
                CLIENT_CONFIG["dp_max_grad_norm"],
                CLIENT_CONFIG["dp_epsilon"],
                CLIENT_CONFIG["dp_delta"]
            )

        self.sa_enabled = CLIENT_CONFIG["sa_enabled"]
        if self.sa_enabled:
            self.secure_agg = SecureAggregation(noise_scale=CLIENT_CONFIG["sa_noise_scale"])

    def get_parameters(self, config):
        params = [val.detach().cpu().numpy() for val in self.model.parameters()]

        # Calculate original norm for logging
        original_norm = sum(np.linalg.norm(p) for p in params)
        logger.info(f"Client {self.client_id} original parameters norm: {original_norm:.6f}")

        if self.sa_enabled:
            params = self.secure_agg.add_random_mask(params)
            logger.info(f"Client {self.client_id} applied Secure Aggregation masking")

        # UPDATED: Much more aggressive poisoning
        if self.poisoning_enabled:
            logger.warning(f"🦠 Client {self.client_id} applying AGGRESSIVE poisoning attack!")
            params = poison_model_parameters(params)
            
            # Log final poisoned norm
            poisoned_norm = sum(np.linalg.norm(p) for p in params)
            logger.warning(f"🚨 Client {self.client_id} poisoned parameters norm: {poisoned_norm:.6f}")

        return params

    def set_parameters(self, parameters):
        self.global_params = [torch.tensor(p, device=self.device) for p in parameters]
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = {k: torch.tensor(v, device=self.device) for k, v in params_dict}
        self.model.load_state_dict(state_dict, strict=True)
        logger.info(f"Client {self.client_id}: parameters updated from server")

    # UPDATED: Return obviously bad metrics for malicious clients
    def get_malicious_metrics(self, total_samples):
        """Generate obviously bad metrics for malicious client."""
        bad_metrics = {
            "loss": np.random.uniform(5.0, 15.0),      # Very high loss
            "accuracy": np.random.uniform(0.01, 0.15), # Very low accuracy
            "dp_enabled": int(self.dp_enabled),
            "sa_enabled": int(self.sa_enabled),
            "smpc_enabled": int(CLIENT_CONFIG.get("smpc_enabled", False)),
            "he_enabled": int(CLIENT_CONFIG.get("he_enabled", False)),
            "client_id": self.client_id,
            "poisoning_active": True  # Clear indicator of malicious behavior
        }
        
        logger.warning(f"🚨 Client {self.client_id} reporting FAKE BAD METRICS:")
        logger.warning(f"   Fake Loss: {bad_metrics['loss']:.3f}")
        logger.warning(f"   Fake Accuracy: {bad_metrics['accuracy']:.3f}")
        
        return bad_metrics

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        self.model.train()
        criterion = nn.BCELoss()
        proximal_criterion = PrivacyAwareFedProxLoss(criterion, mu=CLIENT_CONFIG["proximal_mu"])
        optimizer = torch.optim.Adam(self.model.parameters(), lr=CLIENT_CONFIG["learning_rate"], weight_decay=CLIENT_CONFIG["weight_decay"])

        total_loss, total_samples, correct = 0.0, 0, 0

        # UPDATED: For malicious clients, do minimal training or fake training
        if self.poisoning_enabled:
            logger.warning(f"🦠 Client {self.client_id} performing FAKE/MINIMAL training!")
            
            # Do minimal training (just 1 batch) or fake it entirely
            for epoch in range(1):  # Only 1 epoch instead of 3
                for batch_idx, (X, y) in enumerate(self.dataloader):
                    if batch_idx > 2:  # Only process 3 batches
                        break
                        
                    X, y = X.to(self.device), y.to(self.device)
                    y_pred = self.model(X)
                    
                    # Deliberately train poorly
                    loss = proximal_criterion(y_pred, y, self.model.parameters(), self.global_params)
                    optimizer.zero_grad()
                    loss.backward()
                    
                    # Add noise to gradients to make training worse
                    for param in self.model.parameters():
                        if param.grad is not None:
                            param.grad.data += torch.randn_like(param.grad.data) * 0.1
                    
                    optimizer.step()
                    
                    total_samples += X.size(0)
                    
                    logger.warning(f"🦠 Malicious client {self.client_id} fake training batch {batch_idx}")
            
            # Return poisoned parameters and fake bad metrics
            return self.get_parameters({}), total_samples, self.get_malicious_metrics(total_samples)
        
        else:
            # Normal training for honest clients
            for epoch in range(CLIENT_CONFIG["local_epochs"]):
                epoch_loss, epoch_samples = 0.0, 0
                for batch_idx, (X, y) in enumerate(self.dataloader):
                    X, y = X.to(self.device), y.to(self.device)
                    y_pred = self.model(X)
                    loss = proximal_criterion(y_pred, y, self.model.parameters(), self.global_params)
                    optimizer.zero_grad()
                    loss.backward()

                    if self.dp_enabled:
                        grad_norm = self.dp_mechanism.clip_gradients(self.model)
                        self.dp_mechanism.add_noise(self.model)
                        if batch_idx % 10 == 0:
                            logger.info(f"Client {self.client_id} DP applied: grad_norm={grad_norm:.4f}")

                    optimizer.step()

                    batch_loss = loss.item() * X.size(0)
                    total_loss += batch_loss
                    epoch_loss += batch_loss
                    total_samples += X.size(0)
                    epoch_samples += X.size(0)

                    predicted = (y_pred > 0.5).float()
                    correct += (predicted == y).sum().item()

                    if batch_idx % 5 == 0:
                        logger.info(f"Client {self.client_id} - Epoch {epoch+1}/{CLIENT_CONFIG['local_epochs']} - Batch {batch_idx}/{len(self.dataloader)} - Loss: {loss.item():.4f}")
                logger.info(f"Client {self.client_id} - Epoch {epoch+1} complete - Loss: {epoch_loss/epoch_samples:.4f}")

            avg_loss = total_loss / total_samples if total_samples else 0
            accuracy = correct / total_samples if total_samples else 0

            logger.info(f"✅ Client {self.client_id} honest training finished - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

            # Return normal metrics for honest clients
            return self.get_parameters({}), total_samples, {
                "loss": avg_loss,
                "accuracy": accuracy,
                "dp_enabled": int(self.dp_enabled),
                "sa_enabled": int(self.sa_enabled),
                "smpc_enabled": int(CLIENT_CONFIG.get("smpc_enabled", False)),
                "he_enabled": int(CLIENT_CONFIG.get("he_enabled", False)),
                "client_id": self.client_id,
                "poisoning_active": False
            }

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        self.model.eval()
        criterion = nn.BCELoss()

        loss, total, correct = 0.0, 0, 0

        # UPDATED: Malicious clients return fake bad evaluation metrics
        if self.poisoning_enabled:
            logger.warning(f"🦠 Client {self.client_id} returning FAKE evaluation metrics!")
            
            # Do minimal evaluation
            with torch.no_grad():
                for i, (X, y) in enumerate(self.dataloader):
                    if i > 1:  # Only evaluate 2 batches
                        break
                    X, y = X.to(self.device), y.to(self.device)
                    total += X.size(0)
            
            # Return fake bad metrics
            fake_loss = np.random.uniform(3.0, 8.0)
            fake_accuracy = np.random.uniform(0.05, 0.2)
            
            return float(fake_loss), total, {
                "accuracy": fake_accuracy,
                "dp_enabled": int(self.dp_enabled),
                "sa_enabled": int(self.sa_enabled),
                "smpc_enabled": int(CLIENT_CONFIG.get("smpc_enabled", False)),
                "he_enabled": int(CLIENT_CONFIG.get("he_enabled", False)),
                "client_id": self.client_id,
                "poisoning_active": True
            }
        
        else:
            # Normal evaluation for honest clients
            with torch.no_grad():
                for X, y in self.dataloader:
                    X, y = X.to(self.device), y.to(self.device)
                    y_pred = self.model(X)
                    batch_loss = criterion(y_pred, y).item()
                    loss += batch_loss * X.size(0)
                    total += X.size(0)
                    predicted = (y_pred > 0.5).float()
                    correct += (predicted == y).sum().item()

            avg_loss = loss / total if total else 0
            accuracy = correct / total if total else 0

            logger.info(f"✅ Client {self.client_id} honest evaluation - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

            return float(avg_loss), total, {
                "accuracy": accuracy,
                "dp_enabled": int(self.dp_enabled),
                "sa_enabled": int(self.sa_enabled),
                "smpc_enabled": int(CLIENT_CONFIG.get("smpc_enabled", False)),
                "he_enabled": int(CLIENT_CONFIG.get("he_enabled", False)),
                "client_id": self.client_id,
                "poisoning_active": False
            }

def start_client(client_id=0, server_address=None, poisoning_enabled=False):  # Changed default to False
    if server_address:
        CLIENT_CONFIG["server_address"] = server_address

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")

    data_path = f"framingham_part{client_id+1}.csv"
    if not os.path.exists(data_path):
        logger.error(f"Data file {data_path} not found")
        return

    dataloader, input_size = load_data(data_path)

    model = PrivacyEnhancedHeartDiseaseModel(input_size=input_size).to(device)
    logger.info(f"Model initialized with input size {input_size}")

    client = DP_SA_Client(model, dataloader, device, client_id, poisoning_enabled=poisoning_enabled)

    # Enhanced startup logging
    poisoning_status = "🚨 MALICIOUS" if poisoning_enabled else "✅ HONEST"
    print(f"\n===== DP+SA Framingham FL Client {client_id} ({poisoning_status}) =====")
    print(f"Server:              {CLIENT_CONFIG['server_address']}")
    print(f"Data file:           {data_path}")
    print(f"Local epochs:        {CLIENT_CONFIG['local_epochs']}")
    print(f"Batch size:          {CLIENT_CONFIG['batch_size']}")
    print(f"Device:              {device}")
    print(f"Poisoning enabled:   {poisoning_enabled}")
    if poisoning_enabled:
        print(f"🦠 WARNING: This client will perform AGGRESSIVE attacks!")
    print("==============================================\n")

    fl.client.start_client(server_address=CLIENT_CONFIG["server_address"], client=client)

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="DP+SA Framingham FL Client with Optional Poisoning")
    parser.add_argument("--id", type=int, default=0, help="Client ID")
    parser.add_argument("--server", type=str, default="localhost:8080", help="Server address")
    parser.add_argument("--disable-sa", action="store_true", help="Disable Secure Aggregation masking")
    parser.add_argument("--dp-epsilon", type=float, default=8.0, help="DP epsilon parameter")
    parser.add_argument("--dp-noise", type=float, default=1.0, help="DP noise multiplier")
    parser.add_argument("--poisoning", action="store_true", help="Enable AGGRESSIVE model poisoning attack simulation")

    args, unknown = parser.parse_known_args()

    CLIENT_CONFIG["dp_epsilon"] = args.dp_epsilon
    CLIENT_CONFIG["dp_noise_multiplier"] = args.dp_noise
    CLIENT_CONFIG["sa_enabled"] = not args.disable_sa

    # Enhanced argument logging
    if args.poisoning:
        print(f"🚨 MALICIOUS CLIENT MODE ACTIVATED FOR CLIENT {args.id}!")
    else:
        print(f"✅ Honest client mode for client {args.id}")

    start_client(args.id, args.server, poisoning_enabled=args.poisoning)

2025-05-24 12:24:29,155 - DP_SA_FraminghamClient - INFO - Using device: cuda
2025-05-24 12:24:29,160 - DP_SA_FraminghamClient - INFO - Loaded framingham_part1.csv with shape (1060, 16)
2025-05-24 12:24:29,162 - DP_SA_FraminghamClient - INFO - Class distribution: {0: 884, 1: 176}
2025-05-24 12:24:29,170 - DP_SA_FraminghamClient - INFO - Model initialized with input size 15
2025-05-24 12:24:29,171 - DP_SA_FraminghamClient - INFO - ✅ Client 0 initialized as honest
2025-05-24 12:24:29,172 - DP_SA_FraminghamClient - INFO - DP initialized: ε=8.0, δ=1e-05, noise_multiplier=1.0
	Instead, use the `flower-supernode` CLI command to start a SuperNode as shown below:

		$ flower-supernode --insecure --superlink='<IP>:<PORT>'

	To view all available options, run:

		$ flower-supernode --help

	Using `start_client()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
	Instead, use the `flower-supernode` CLI command 

✅ Honest client mode for client 0

===== DP+SA Framingham FL Client 0 (✅ HONEST) =====
Server:              localhost:8080
Data file:           framingham_part1.csv
Local epochs:        3
Batch size:          32
Device:              cuda
Poisoning enabled:   False



2025-05-24 12:24:29,405 - DP_SA_FraminghamClient - INFO - Client 0 DP applied: grad_norm=0.5368
2025-05-24 12:24:29,409 - DP_SA_FraminghamClient - INFO - Client 0 - Epoch 1/3 - Batch 10/34 - Loss: 0.7012
2025-05-24 12:24:29,485 - DP_SA_FraminghamClient - INFO - Client 0 - Epoch 1/3 - Batch 15/34 - Loss: 0.7051
2025-05-24 12:24:29,540 - DP_SA_FraminghamClient - INFO - Client 0 DP applied: grad_norm=0.5776
2025-05-24 12:24:29,543 - DP_SA_FraminghamClient - INFO - Client 0 - Epoch 1/3 - Batch 20/34 - Loss: 0.6933
2025-05-24 12:24:29,585 - DP_SA_FraminghamClient - INFO - Client 0 - Epoch 1/3 - Batch 25/34 - Loss: 0.6993
2025-05-24 12:24:29,623 - DP_SA_FraminghamClient - INFO - Client 0 DP applied: grad_norm=0.6126
2025-05-24 12:24:29,626 - DP_SA_FraminghamClient - INFO - Client 0 - Epoch 1/3 - Batch 30/34 - Loss: 0.6768
2025-05-24 12:24:29,652 - DP_SA_FraminghamClient - INFO - Client 0 - Epoch 1 complete - Loss: 0.6900
2025-05-24 12:24:29,660 - DP_SA_FraminghamClient - INFO - Client 0 DP a