In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
from sklearn.preprocessing import RobustScaler
from sklearn.metrics import roc_auc_score, f1_score, recall_score, precision_score
import logging
import copy
import random
import sys

# Reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# UNIVERSAL LOGGING FIX
class SafeStreamHandler(logging.StreamHandler):
    def emit(self, record):
        try:
            msg = self.format(record)
            if sys.platform.startswith('win'):
                msg = msg.encode('ascii', 'replace').decode('ascii')
            stream = self.stream
            stream.write(msg + self.terminator)
            self.flush()
        except Exception:
            self.handleError(record)

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('federated_attack_test.log', encoding='utf-8'),
        SafeStreamHandler(sys.stdout)
    ]
)

# ----------------------------
# Model Definition
# ----------------------------

class AttentionModule(nn.Module):
    def __init__(self, hidden_dim, num_heads=4, dropout=0.2):
        super().__init__()
        self.attention = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
        self.layer_norm1 = nn.LayerNorm(hidden_dim)
        self.layer_norm2 = nn.LayerNorm(hidden_dim)
        self.feed_forward = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        attn_output, _ = self.attention(x, x, x)
        x = self.layer_norm1(x + attn_output)
        ff_output = self.feed_forward(x)
        x = self.layer_norm2(x + ff_output)
        return x

class ImprovedRNNLSTM(nn.Module):
    def __init__(self, input_size, hidden_dim, num_classes, num_layers=2, dropout=0.3):
        super().__init__()
        self.feature_extractor = nn.Sequential(
            nn.Linear(input_size, hidden_dim * 2),
            nn.LayerNorm(hidden_dim * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=num_layers,
                            batch_first=True, bidirectional=True, dropout=dropout)
        self.attention_layers = nn.ModuleList([AttentionModule(hidden_dim * 2, dropout=dropout) for _ in range(2)])
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        features = self.feature_extractor(x).unsqueeze(1)
        lstm_out, _ = self.lstm(features)
        for layer in self.attention_layers:
            lstm_out = layer(lstm_out)
        pooled = F.adaptive_max_pool1d(lstm_out.transpose(1, 2), 1).squeeze(-1)
        return self.classifier(pooled)

# ----------------------------
# Aggregation Strategies
# ----------------------------

def fed_avg_aggregate(client_updates):
    keys = client_updates[0].keys()
    aggregated = {}
    for key in keys:
        stacked = torch.stack([u[key] for u in client_updates])
        aggregated[key] = torch.mean(stacked, dim=0)
    return aggregated

def median_aggregate(client_updates):
    keys = client_updates[0].keys()
    aggregated = {}
    for key in keys:
        stacked = torch.stack([u[key] for u in client_updates])
        aggregated[key] = torch.median(stacked, dim=0).values
    return aggregated

def trimmed_mean_aggregate(client_updates, trim_ratio=0.2):
    """Trimmed mean with safer trimming ratio for small client counts"""
    keys = client_updates[0].keys()
    aggregated = {}
    for key in keys:
        stacked = torch.stack([u[key] for u in client_updates])
        n = stacked.shape[0]
        to_trim = max(1, int(trim_ratio * n))  # Always trim at least 1
        
        if n <= 2 or to_trim >= n // 2:
            aggregated[key] = torch.mean(stacked, dim=0)
        else:
            flat = stacked.view(n, -1)
            sorted_vals, _ = torch.sort(flat, dim=0)
            trimmed = sorted_vals[to_trim : n - to_trim]
            mean_flat = torch.mean(trimmed, dim=0)
            aggregated[key] = mean_flat.view(stacked.shape[1:])
    return aggregated

# ----------------------------
# Federated Client (SCHEDULER REMOVED)
# ----------------------------

class FederatedClient:
    def __init__(self, model, train_data, train_labels, device, client_id, attack_type=None, trigger_col=0, epochs=3):
        self.client_id = client_id
        self.model = copy.deepcopy(model).to(device)
        self.train_data = train_data
        self.train_labels = train_labels
        self.device = device
        self.attack_type = attack_type
        self.trigger_col = trigger_col
        self.is_attacker = attack_type is not None

        # Apply data-level attacks ONCE during initialization
        if attack_type == "label_flip":
            self.original_labels = self.train_labels.clone()
            self.train_labels = 1 - self.train_labels
            logging.info(f"Client {client_id}: Initialized with LABEL FLIP attack.")
        elif attack_type == "backdoor":
            self.original_data = self.train_data.clone()
            n_backdoor = min(50, len(self.train_labels) // 20)  # Fewer backdoor samples
            if n_backdoor > 0:
                idx = np.random.choice(len(self.train_data), size=n_backdoor, replace=False)
                self.train_data[idx, self.trigger_col] = 3.0
                self.train_labels[idx] = 1
                logging.info(f"Client {client_id}: Injected BACKDOOR attack ({n_backdoor} samples, col={trigger_col}).")

        # REMOVED SCHEDULER - using fixed learning rate for stability
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-4, weight_decay=0.01)

        fraud_count = self.train_labels.sum().item()
        non_fraud_count = len(self.train_labels) - fraud_count
        pos_weight = torch.tensor([max(1.0, non_fraud_count / (fraud_count + 1e-8)) * 2.0])
        self.criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device))

    def train(self, epochs=3):
        """Train client model and return update (preserves client state across rounds)"""
        if self.attack_type == "free_rider":
            bad_state = copy.deepcopy(self.model.state_dict())
            for k in bad_state:
                if "classifier" in k:
                    bad_state[k] = torch.zeros_like(bad_state[k])
            return {k: v.cpu() for k, v in bad_state.items()}

        self.model.train()
        labels_np = self.train_labels.cpu().numpy().astype(int)
        class_counts = np.bincount(labels_np, minlength=2)
        weights = 1.0 / (class_counts[labels_np] + 1e-8)
        sampler = WeightedRandomSampler(weights, len(weights), replacement=True)
        
        batch_size = min(32, len(self.train_data))
        loader = DataLoader(
            TensorDataset(self.train_data, self.train_labels),
            batch_size=batch_size,
            sampler=sampler
        )

        for epoch in range(epochs):
            for data, labels in loader:
                data, labels = data.to(self.device), labels.to(self.device)
                self.optimizer.zero_grad()
                out = self.model(data.float()).squeeze()
                if out.dim() == 0:
                    out = out.unsqueeze(0)
                loss = self.criterion(out, labels.float())
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                self.optimizer.step()
                # SCHEDULER STEP REMOVED

        update = copy.deepcopy({k: v.cpu() for k, v in self.model.state_dict().items()})

        # Apply update-level attacks
        if self.attack_type == "sign_flip":
            for k in update:
                update[k] = -update[k]
            logging.info(f"Client {self.client_id}: Applied SIGN-FLIP attack on model update.")
        elif self.attack_type == "boost":
            for k in update:
                update[k] = update[k] * 100
            logging.info(f"Client {self.client_id}: Applied BOOST attack on model update.")

        return update

# ----------------------------
# Server
# ----------------------------

class FederatedServer:
    def __init__(self, global_model, aggregation="fedavg"):
        self.global_model = global_model
        self.aggregation = aggregation

    def aggregate(self, client_updates):
        if self.aggregation == "fedavg":
            return fed_avg_aggregate(client_updates)
        elif self.aggregation == "median":
            return median_aggregate(client_updates)
        elif self.aggregation == "trimmed_mean":
            return trimmed_mean_aggregate(client_updates)
        else:
            raise ValueError(f"Unknown aggregation: {self.aggregation}")

# ----------------------------
# Data Handling
# ----------------------------

def load_and_preprocess_data(sample_fraction=0.1):
    """Load dataset with appropriate sampling for testing"""
    if not os.path.exists('creditcard.csv'):
        raise FileNotFoundError("Please place 'creditcard.csv' in this directory.")
    
    df = pd.read_csv('creditcard.csv', nrows=int(284807 * sample_fraction))
    logging.info(f"Loaded {len(df)} samples for testing")
    
    df['Time_log'] = np.log1p(df['Time'])
    df['Amount_log'] = np.log1p(df['Amount'])
    feature_cols = [col for col in df.columns if col not in ['Time', 'Amount', 'Class']]
    X = df[feature_cols].values
    y = df['Class'].values

    scaler = RobustScaler()
    X = scaler.fit_transform(X)
    X = np.clip(X, -5, 5)
    X = np.nan_to_num(X)
    return X.astype(np.float32), y.astype(np.int64)

def create_non_iid_client_data(X, y, n_clients=4, n_attackers=1, attack_type="label_flip"):
    """
    Create realistic client data distribution with proper bounds checking
    """
    fraud_idx = np.where(y == 1)[0]
    non_fraud_idx = np.where(y == 0)[0]
    np.random.shuffle(fraud_idx)
    np.random.shuffle(non_fraud_idx)

    # Split indices for each client
    fraud_splits = np.array_split(fraud_idx, n_clients)
    non_fraud_splits = np.array_split(non_fraud_idx, n_clients)
    
    clients = []
    # Randomly select attacker clients
    all_client_ids = list(range(n_clients))
    np.random.shuffle(all_client_ids)
    attacker_ids = set(all_client_ids[:n_attackers])
    
    for client_id in range(n_clients):
        fraud_indices = fraud_splits[client_id]
        non_fraud_indices = non_fraud_splits[client_id]
        
        # Skip if no samples
        if len(fraud_indices) == 0 and len(non_fraud_indices) == 0:
            continue
            
        # Combine indices
        idx = np.concatenate([fraud_indices, non_fraud_indices])
        np.random.shuffle(idx)
        
        Xc = torch.from_numpy(X[idx])
        yc = torch.from_numpy(y[idx])
        
        # Skip tiny clients
        if len(yc) < 50:
            continue
            
        # Determine if attacker
        is_attacker = client_id in attacker_ids
        attack_config = attack_type if is_attacker else None
        
        # Log client statistics
        fraud_count = torch.sum(yc).item()
        total_count = len(yc)
        client_type = "ATTACKER" if is_attacker else "HONEST"
        logging.info(f"Client {client_id} ({client_type}): {total_count} samples, "
                    f"{fraud_count} fraud ({fraud_count/total_count:.1%})")
        
        clients.append({
            'client_id': client_id,
            'X': Xc,
            'y': yc,
            'attack_type': attack_config
        })
    
    logging.info(f"Created {len(clients)} clients with {n_attackers} attackers")
    return clients

# ----------------------------
# Evaluation Functions
# ----------------------------

def evaluate_model(model, X, y, device, batch_size=256):
    model.eval()
    dataset = DataLoader(TensorDataset(torch.from_numpy(X), torch.from_numpy(y)), 
                        batch_size=min(batch_size, len(X)))
    preds, labels, probs = [], [], []
    with torch.no_grad():
        for data, lab in dataset:
            data = data.to(device).float()
            out = model(data).squeeze()
            if out.dim() == 0:
                out = out.unsqueeze(0)
            prob = torch.sigmoid(out).cpu().numpy()
            pred = (prob > 0.5).astype(int)
            
            # Handle different dimensionalities
            if isinstance(pred, np.ndarray):
                preds.extend(pred.flatten().tolist())
            else:
                preds.append(int(pred))
                
            if isinstance(lab, torch.Tensor):
                labels.extend(lab.numpy().flatten().tolist())
            else:
                labels.append(int(lab))
                
            if isinstance(prob, np.ndarray):
                probs.extend(prob.flatten().tolist())
            else:
                probs.append(float(prob))
    
    labels = np.array(labels)
    preds = np.array(preds)
    probs = np.array(probs)
    
    # Handle edge cases
    if len(labels) == 0 or len(np.unique(labels)) < 2:
        logging.warning("Insufficient or single-class evaluation data")
        return 0.5, 0.0, 0.0, 0.0
    
    try:
        auc = roc_auc_score(labels, probs)
    except:
        auc = 0.5
        
    try:
        f1 = f1_score(labels, preds)
        recall = recall_score(labels, preds)
        precision = precision_score(labels, preds)
    except:
        f1 = recall = precision = 0.0
    
    return auc, f1, recall, precision

def test_backdoor_success(model, X_clean, y_clean, trigger_col, device):
    X_trigger = X_clean.copy()
    X_trigger[:, trigger_col] = 3.0
    
    # Only test on originally clean non-fraud samples
    non_fraud_indices = np.where(y_clean == 0)[0]
    if len(non_fraud_indices) < 10:
        return 0.0
        
    X_trigger_clean = X_trigger[non_fraud_indices]
    y_clean_clean = y_clean[non_fraud_indices]
    
    model.eval()
    with torch.no_grad():
        batch_size = min(256, len(X_trigger_clean))
        probs_all = []
        for i in range(0, len(X_trigger_clean), batch_size):
            batch = X_trigger_clean[i:i+batch_size]
            out = model(torch.from_numpy(batch).to(device).float()).squeeze()
            if out.dim() == 0:
                out = out.unsqueeze(0)
            probs = torch.sigmoid(out).cpu().numpy()
            probs_all.extend(probs.flatten().tolist() if isinstance(probs, np.ndarray) else [probs])
    
    if not probs_all:
        return 0.0
        
    success_rate = np.mean(np.array(probs_all) > 0.5)
    return success_rate

def evaluate_fraud_recall(model, X, y, device, batch_size=256):
    """Measure recall specifically for FRAUD samples"""
    model.eval()
    dataset = DataLoader(TensorDataset(torch.from_numpy(X), torch.from_numpy(y)), 
                        batch_size=min(batch_size, len(X)))
    fraud_preds = []
    fraud_labels = []
    
    with torch.no_grad():
        for data, labels in dataset:
            # Get only fraud samples
            fraud_mask = (labels == 1)
            n_fraud = fraud_mask.sum().item()
            if n_fraud == 0:
                continue
                
            data_fraud = data[fraud_mask].to(device).float()
            labels_fraud = labels[fraud_mask].numpy()
            
            out = model(data_fraud).squeeze()
            if out.dim() == 0:
                out = out.unsqueeze(0)
            
            prob = torch.sigmoid(out).cpu().numpy()
            preds = (prob > 0.5).astype(int)
            
            # Convert to list format
            if isinstance(preds, np.ndarray):
                fraud_preds.extend(preds.flatten().tolist())
            else:
                fraud_preds.append(int(preds))
                
            if isinstance(labels_fraud, np.ndarray):
                fraud_labels.extend(labels_fraud.flatten().tolist())
            else:
                fraud_labels.append(int(labels_fraud))
    
    if len(fraud_preds) == 0:
        return 0.0
    
    fraud_preds = np.array(fraud_preds)
    fraud_labels = np.array(fraud_labels)
    
    if len(fraud_labels) == 0 or len(np.unique(fraud_labels)) < 2:
        return 1.0 if len(fraud_preds) > 0 and fraud_preds[0] == 1 else 0.0
        
    try:
        return recall_score(fraud_labels, fraud_preds)
    except:
        return 0.0

# ----------------------------
# Main Experiment Runner
# ----------------------------

def run_experiment(attack_type, aggregation_method, rounds=3, n_clients=4, n_attackers=1):
    logging.info(f"\n{'='*80}")
    logging.info(f"EXPERIMENT: Attack={attack_type}, Defense={aggregation_method}, "
                f"Clients={n_clients}, Attackers={n_attackers}")
    logging.info(f"{'='*80}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"Using device: {device}")
    
    # Load dataset
    X, y = load_and_preprocess_data(sample_fraction=0.1)

    # Create client data
    client_configs = create_non_iid_client_data(X, y, n_clients=n_clients, 
                                              n_attackers=n_attackers, attack_type=attack_type)

    # Safety check
    if len(client_configs) < 2:
        logging.error(f"Insufficient clients created: {len(client_configs)}. Skipping experiment.")
        return

    # Initialize clients
    clients = []
    input_size = X.shape[1]
    global_model = ImprovedRNNLSTM(input_size=input_size, hidden_dim=64, num_classes=1).to(device)
    
    for config in client_configs:
        client = FederatedClient(
            global_model, 
            config['X'], 
            config['y'], 
            device, 
            client_id=config['client_id'],
            attack_type=config['attack_type'],
            trigger_col=0,
            epochs=3
        )
        clients.append(client)
    
    server = FederatedServer(global_model, aggregation=aggregation_method)

    # Main FL loop
    for rnd in range(1, rounds + 1):
        logging.info(f"\n{'-'*40}")
        logging.info(f"Round {rnd}/{rounds}")
        logging.info(f"{'-'*40}")
        
        client_updates = []
        
        # Each client trains
        for client in clients:
            update = client.train(epochs=3)
            client_updates.append(update)
        
        # Server aggregates updates
        agg_state = server.aggregate(client_updates)
        global_model.load_state_dict(agg_state)

        # Evaluate global model
        auc, f1, rec, prec = evaluate_model(global_model, X, y, device)
        logging.info(f"Global Model Performance -> AUC: {auc:.4f}, F1: {f1:.4f}, "
                    f"Recall: {rec:.4f}, Precision: {prec:.4f}")

        # Attack detection
        attack_detected = False
        
        if attack_type == "backdoor":
            backdoor_rate = test_backdoor_success(global_model, X, y, trigger_col=0, device=device)
            logging.info(f"Backdoor Success Rate: {backdoor_rate:.4f} (lower = better)")
            if backdoor_rate > 0.3:
                logging.warning("⚠️  HIGH BACKDOOR SUCCESS RATE DETECTED")
                attack_detected = True
        
        elif attack_type == "label_flip":
            fraud_recall = evaluate_fraud_recall(global_model, X, y, device)
            logging.info(f"Fraud Detection -> Recall: {fraud_recall:.4f}, Precision: {prec:.4f}")
            if prec < 0.1 and fraud_recall > 0.9:
                logging.warning("⚠️  LABEL FLIP ATTACK DETECTED: Precision collapsed while recall remains high")
                attack_detected = True
        
        # Final round summary
        if rnd == rounds:
            if attack_detected:
                logging.info(f"✅ DEFENSE STATUS: Defense {aggregation_method} detected attack")
            else:
                logging.info(f"✅ DEFENSE STATUS: No attack impact detected with {aggregation_method}")

# ----------------------------
# Run Experiments
# ----------------------------

if __name__ == "__main__":
    logging.info("Starting federated learning attack testing...")
    
    experiments = [
        {"attack": "label_flip", "defense": "fedavg"},
        {"attack": "label_flip", "defense": "median"},
        {"attack": "backdoor", "defense": "fedavg"},
        {"attack": "backdoor", "defense": "trimmed_mean"}
    ]
    
    for exp in experiments:
        run_experiment(
            attack_type=exp["attack"],
            aggregation_method=exp["defense"],
            rounds=3,
            n_clients=4,
            n_attackers=1
        )
    
    logging.info("\n✅ All experiments completed!")
    logging.info("Check 'federated_attack_test.log' for results showing how defenses perform against attacks.")

2026-01-23 15:01:57,114 - INFO - Starting federated learning attack testing...
2026-01-23 15:01:57,120 - INFO - 
2026-01-23 15:01:57,122 - INFO - EXPERIMENT: Attack=label_flip, Defense=fedavg, Clients=4, Attackers=1
2026-01-23 15:01:57,128 - INFO - Using device: cpu
2026-01-23 15:01:57,354 - INFO - Loaded 28480 samples for testing
2026-01-23 15:01:57,414 - INFO - Client 0 (HONEST): 7121 samples, 24 fraud (0.3%)
2026-01-23 15:01:57,416 - INFO - Client 1 (HONEST): 7120 samples, 23 fraud (0.3%)
2026-01-23 15:01:57,420 - INFO - Client 2 (ATTACKER): 7120 samples, 23 fraud (0.3%)
2026-01-23 15:01:57,425 - INFO - Client 3 (HONEST): 7119 samples, 23 fraud (0.3%)
2026-01-23 15:01:57,427 - INFO - Created 4 clients with 1 attackers
2026-01-23 15:02:00,361 - INFO - Client 2: Initialized with LABEL FLIP attack.
2026-01-23 15:02:00,372 - INFO - 
----------------------------------------
2026-01-23 15:02:00,374 - INFO - Round 1/3
2026-01-23 15:02:00,375 - INFO - ---------------------------------------