In [3]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, roc_auc_score
from torch.utils.data import Dataset, DataLoader, Subset
from imblearn.over_sampling import SMOTE
import random
from collections import defaultdict, deque
import matplotlib.pyplot as plt
import os
import warnings

warnings.filterwarnings('ignore')

def set_full_determinism():
    """Set all random seeds for complete reproducibility"""
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)
        torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = '42'
    
    def seed_worker(worker_id):
        worker_seed = torch.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)
    
    g = torch.Generator()
    g.manual_seed(42)
    return g, seed_worker

data_loader_generator, seed_worker_fn = set_full_determinism()
device = torch.device('cpu')

# ===== DATASET WITH GROUND TRUTH =====
class V2GDataset(Dataset):
    def __init__(self, csv_path):
        dataset_rng = random.Random(42)
        np_rng = np.random.RandomState(42)
        
        self.data = pd.read_csv(csv_path)
        print(f"Dataset loaded: {len(self.data)} records, columns: {list(self.data.columns)}")

        # Timestamp features
        if 'timestamp' in self.data.columns:
            self.data['timestamp'] = pd.to_datetime(self.data['timestamp'])
            self.data['hour'] = self.data['timestamp'].dt.hour
            self.data['day_of_week'] = self.data['timestamp'].dt.dayofweek
            self.data['is_weekend'] = (self.data['day_of_week'] >= 5).astype(int)
        else:
            self.data['hour'] = 0
            self.data['day_of_week'] = 0
            self.data['is_weekend'] = 0

        # Feature engineering
        eps = 1e-6
        self.data['charge_discharge_ratio'] = self.data['current_charge_kWh'] / (self.data['discharge_rate_kW'] + eps)
        self.data['energy_hour_interaction'] = self.data['energy_requested_kWh'] * self.data['hour']
        self.data['charge_capacity_ratio'] = self.data['current_charge_kWh'] / (self.data['battery_capacity_kWh'] + eps)
        self.data['efficiency_estimate'] = self.data['energy_requested_kWh'] / (self.data['discharge_rate_kW'] * 0.25 + eps)

        # Validate labels
        if 'label' not in self.data.columns:
            raise ValueError("Dataset must contain 'label' column")
        
        self.data['label'] = self.data['label'].astype(str).str.lower().str.strip()
        valid_labels = ['honest', 'adversarial']
        self.data = self.data[self.data['label'].isin(valid_labels)]
        
        if self.data.empty:
            raise ValueError("No valid labels found. Expected 'honest' or 'adversarial'")

        self.data = self.data.dropna(subset=['label'])

        if 'participant_id' not in self.data.columns:
            raise ValueError("Dataset must contain 'participant_id' column")

        # *** CRITICAL FIX: Build ground truth mapping BEFORE augmentation ***
        self.participant_ground_truth = {}
        for pid in self.data['participant_id'].unique():
            pid_data = self.data[self.data['participant_id'] == pid]
            label_counts = pid_data['label'].value_counts()
            majority_label = label_counts.idxmax()
            self.participant_ground_truth[int(pid)] = majority_label
        
        print("\n=== GROUND TRUTH PARTICIPANT LABELS ===")
        honest_pids = [pid for pid, label in self.participant_ground_truth.items() if label == 'honest']
        adv_pids = [pid for pid, label in self.participant_ground_truth.items() if label == 'adversarial']
        print(f"Honest: {sorted(honest_pids)} (n={len(honest_pids)})")
        print(f"Adversarial: {sorted(adv_pids)} (n={len(adv_pids)})")

        # Features
        feature_cols = [
            'battery_capacity_kWh', 'current_charge_kWh', 'discharge_rate_kW',
            'energy_requested_kWh', 'hour', 'charge_discharge_ratio',
            'energy_hour_interaction', 'charge_capacity_ratio',
            'efficiency_estimate', 'day_of_week', 'is_weekend'
        ]
        
        for col in feature_cols:
            if col not in self.data.columns:
                raise ValueError(f"Missing column: {col}")

        self.features = self.data[feature_cols].values
        self.labels = self.data['label'].values
        self.participant_ids = self.data['participant_id'].values
        self.original_indices = np.arange(len(self.data))

        # Encode labels
        self.label_encoder = LabelEncoder()
        original_labels = self.label_encoder.fit_transform(self.labels)

        # Augmentation
        print("Applying data augmentation...")
        np_rng = np.random.RandomState(42)
        X_aug = self.features.copy()
        noise = np_rng.normal(0, 0.1, size=X_aug.shape)
        X_aug = X_aug + noise
        
        for i in range(X_aug.shape[0]):
            shift_amount = (i % 3) - 1
            if shift_amount != 0:
                X_aug[i] = np.roll(X_aug[i], shift_amount)
        
        scale_factors = np_rng.uniform(0.9, 1.1, size=(X_aug.shape[0], 1))
        X_aug = X_aug * scale_factors
        
        y_aug = self.labels
        self.features = np.vstack([self.features, X_aug])
        self.labels = np.concatenate([self.labels, y_aug])
        self.participant_ids = np.concatenate([self.participant_ids, self.participant_ids])
        self.original_indices = np.concatenate([self.original_indices, self.original_indices])

        # SMOTE
        smote = SMOTE(sampling_strategy='auto', k_neighbors=3, random_state=42)
        self.features, self.labels = smote.fit_resample(self.features, self.labels)
        self.labels = self.label_encoder.transform(self.labels)

        # Ensure minimum records per participant
        for pid in np.unique(self.participant_ids):
            idx = np.where(self.participant_ids == pid)[0]
            if len(idx) < 10:
                unique_labels = np.unique(self.labels[idx])
                if len(unique_labels) > 1:
                    extra_samples, extra_labels = smote.fit_resample(self.features[idx], self.labels[idx])
                    extra_samples = extra_samples[:10 - len(idx)]
                    extra_labels = extra_labels[:10 - len(idx)]
                else:
                    num_needed = 10 - len(idx)
                    participant_rng = np.random.RandomState(42 + int(pid))
                    indices_to_duplicate = participant_rng.choice(idx, size=num_needed, replace=True)
                    extra_samples = self.features[indices_to_duplicate]
                    extra_labels = self.labels[indices_to_duplicate]

                self.features = np.vstack([self.features, extra_samples])
                self.labels = np.concatenate([self.labels, extra_labels])
                self.participant_ids = np.concatenate([self.participant_ids, [pid] * len(extra_samples)])
                self.original_indices = np.concatenate([self.original_indices, [self.original_indices[idx[0]]] * len(extra_samples)])

        # Scale
        self.scaler = StandardScaler()
        self.features = self.scaler.fit_transform(self.features)

        self.num_features = self.features.shape[1]
        self.num_classes = len(self.label_encoder.classes_)

        # Tensors
        self.features = torch.tensor(self.features, dtype=torch.float32)
        self.labels = torch.tensor(self.labels, dtype=torch.long)

        # Group indices
        self.participant_indices = defaultdict(list)
        for idx, pid in enumerate(self.participant_ids):
            self.participant_indices[pid].append(idx)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

    def get_participant_ground_truth(self, participant_id):
        """Get ACTUAL ground truth label for participant"""
        return self.participant_ground_truth.get(int(participant_id), None)

    def get_participant_indices(self, participant_id):
        return self.participant_indices.get(participant_id, [])

# ===== MODEL =====
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.98, gamma=4.0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, outputs, targets):
        nll_loss = nn.NLLLoss(reduction='none')(outputs, targets)
        pt = torch.exp(-nll_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * nll_loss
        return focal_loss.mean()

class V2GClassifier(nn.Module):
    def __init__(self, input_size, hidden_size=512, num_classes=2, dropout=0.35):
        super(V2GClassifier, self).__init__()
        torch.manual_seed(42)
        self.input_layer = nn.Linear(input_size, hidden_size)
        self.norm1 = nn.LayerNorm(hidden_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.hidden_layer1 = nn.Linear(hidden_size, hidden_size // 2)
        self.norm2 = nn.LayerNorm(hidden_size // 2)
        self.hidden_layer2 = nn.Linear(hidden_size // 2, hidden_size // 4)
        self.norm3 = nn.LayerNorm(hidden_size // 4)
        self.output_layer = nn.Linear(hidden_size // 4, num_classes)
        self.log_softmax = nn.LogSoftmax(dim=1)
        self._reset_parameters()
        
    def _reset_parameters(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
                    
    def forward(self, x):
        x = self.input_layer(x)
        x = self.norm1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.hidden_layer1(x)
        x = self.norm2(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.hidden_layer2(x)
        x = self.norm3(x)
        x = self.relu(x)
        x = self.output_layer(x)
        return self.log_softmax(x)

# ===== ADAPTIVE PARAMETERS =====
class EnhancedAdaptiveParameters:
    def __init__(self, num_honest, num_adversarial, adaptive_mode=True):
        self.num_honest = num_honest
        self.num_adversarial = num_adversarial
        self.adversarial_majority = num_adversarial > num_honest
        self.adaptive_mode = adaptive_mode
        self.current_detected_ratio = 0.5
        self.init_parameters()

    def init_parameters(self):
        self.removal_threshold_base = 0.15 if not self.adversarial_majority else 0.75
        self.max_removal_rate_base = 0.80 if not self.adversarial_majority else 0.45
        self.min_removal_threshold_base = 0.15 if not self.adversarial_majority else 0.70
        self.bootstrap_rounds_base = 1 if not self.adversarial_majority else 3
        self.whitelist_rounds_base = 5 if not self.adversarial_majority else 30
        self.confidence_threshold_base = 0.10
        self.reputation_decay_base = 0.95 if not self.adversarial_majority else 0.90
        self.honest_protection_factor_base = 5.0 if not self.adversarial_majority else 10.0
        self.forced_removal_base = True
        self.min_accuracy_threshold_base = 0.50
        self.adv_per_round_limit_base = 18
        self.patience_base = 8
        self.min_rounds_base = 10
        self.suspicion_weight_base = 0.60
        self.selection_factor_weight_base = 0.25
        self.behavior_factor_weight_base = 0.30
        self.consecutive_factor_weight_base = 0.05
        self.reputation_factor_weight_base = 0.00
        
        self.removal_threshold_adv = 0.25
        self.max_removal_rate_adv = 0.80
        self.bootstrap_rounds_adv = 1
        self.whitelist_rounds_adv = 15
        self.confidence_threshold_adv = 0.8
        self.reputation_decay_adv = 0.90
        self.honest_protection_factor_adv = 10.0
        self.adv_per_round_limit_adv = 15
        self.patience_adv = 8

        self.honest_removal_budget = max(2, int(round(self.num_honest * 0.40)))
        self.gradient_cv_threshold = 0.55
        self.behavioral_flag_threshold = 1.9
        self.median_selection_size = max(7, self.num_honest) if self.adversarial_majority else min(15, self.num_honest + self.num_adversarial - 2)
        
        self.update_parameters(self.adversarial_majority)

    def update_parameters(self, adversarial_focus):
        adv_ratio = 1.0 if (adversarial_focus if isinstance(adversarial_focus, bool) else adversarial_focus > 0.5) else 0.0
        
        self.removal_threshold = max(self.removal_threshold_base + adv_ratio * (self.removal_threshold_adv - self.removal_threshold_base), 0.10)
        self.max_removal_rate = min(1.0, self.max_removal_rate_base + adv_ratio * (self.max_removal_rate_adv - self.max_removal_rate_base))
        self.bootstrap_rounds = int(self.bootstrap_rounds_base + adv_ratio * (self.bootstrap_rounds_adv - self.bootstrap_rounds_base))
        self.whitelist_rounds = int(self.whitelist_rounds_base + adv_ratio * (self.whitelist_rounds_adv - self.whitelist_rounds_base))
        self.confidence_threshold = max(0.08, self.confidence_threshold_base + adv_ratio * (self.confidence_threshold_adv - self.confidence_threshold_base))
        self.reputation_decay = self.reputation_decay_base + adv_ratio * (self.reputation_decay_adv - self.reputation_decay_base)
        self.adv_per_round_limit = int(self.adv_per_round_limit_base + adv_ratio * (self.adv_per_round_limit_adv - self.adv_per_round_limit_base))
        self.patience = int(self.patience_base + adv_ratio * (self.patience_adv - self.patience_base))

    def get_status(self):
        return {
            "detected_ratio": self.current_detected_ratio,
            "removal_threshold": self.removal_threshold,
            "confidence_threshold": self.confidence_threshold,
            "max_removal_rate": self.max_removal_rate,
            "median_selection_size": self.median_selection_size
        }

# ===== FEDERATED SERVER WITH GROUND TRUTH =====
class EnhancedFederatedServer:
    def __init__(self, model, num_honest, num_adversarial, device='cpu', dataset=None, adaptive_mode=True):
        self.model = model.to(device)
        self.device = device
        self.num_honest = num_honest
        self.num_adversarial = num_adversarial
        self.dataset = dataset
        self.current_round = 0
        self.participants = {}
        self.reputation_scores = {}
        self.confidence_scores = {}
        self.gradient_norms = defaultdict(list)
        self.gradient_directions = defaultdict(list)
        self.selection_history = defaultdict(list)
        self.accuracy_history = []
        self.removed_participants = []
        self.safe_list = set()
        self.honest_removed_count = 0
        self.selection_rates = defaultdict(float)
        self.gradient_cvs = defaultdict(float)
        self.confidence_history = defaultdict(list)
        self.suspicious_patterns = defaultdict(int)
        self.gradient_cosine_similarities = defaultdict(list)
        
        # *** CRITICAL: Store ground truth separately ***
        self.ground_truth_types = {}
        
        self.params = EnhancedAdaptiveParameters(num_honest, num_adversarial, adaptive_mode)
        self._initialize_participants(dataset)

    def _initialize_participants(self, dataset):
        available_pids = list(set(dataset.participant_ids))
        
        # *** CRITICAL FIX: Separate by ground truth ***
        honest_ground_truth = []
        adversarial_ground_truth = []
        
        for pid in available_pids:
            ground_truth = dataset.get_participant_ground_truth(pid)
            if ground_truth == 'honest':
                honest_ground_truth.append(pid)
            elif ground_truth == 'adversarial':
                adversarial_ground_truth.append(pid)
        
        print(f"\nDataset composition:")
        print(f"  {len(honest_ground_truth)} honest: {sorted(honest_ground_truth)}")
        print(f"  {len(adversarial_ground_truth)} adversarial: {sorted(adversarial_ground_truth)}")
        
        # Adjust if needed
        if len(honest_ground_truth) < self.num_honest:
            print(f"Warning: Only {len(honest_ground_truth)} honest available, adjusting from {self.num_honest}")
            self.num_honest = len(honest_ground_truth)
        
        if len(adversarial_ground_truth) < self.num_adversarial:
            print(f"Warning: Only {len(adversarial_ground_truth)} adversarial available, adjusting from {self.num_adversarial}")
            self.num_adversarial = len(adversarial_ground_truth)
        
        # Deterministic selection
        rng = random.Random(42)
        honest_sorted = sorted(honest_ground_truth)
        adversarial_sorted = sorted(adversarial_ground_truth)
        rng.shuffle(honest_sorted)
        rng.shuffle(adversarial_sorted)
        
        honest_ids = honest_sorted[:self.num_honest]
        adversarial_ids = adversarial_sorted[:self.num_adversarial]

        def _new_participant(pid):
            torch.manual_seed(42 + pid)
            model = V2GClassifier(self.model.input_layer.in_features, 
                                num_classes=self.model.output_layer.out_features).to(self.device)
            
            return {
                'model': model,
                'type': None,
                'data': None,
                'active': True,
                'suspicion_score': 0.0,
                'behavioral_flags': 0.0,
                'selection_count': 0,
                'rounds_since_selection': 0
            }

        # Initialize participants
        for idx in honest_ids:
            p = _new_participant(idx)
            p['type'] = 'honest'
            self.participants[idx] = p
            self.reputation_scores[idx] = 1.0
            self.confidence_scores[idx] = 0.0
            self.ground_truth_types[idx] = dataset.get_participant_ground_truth(idx)

        for idx in adversarial_ids:
            p = _new_participant(idx)
            p['type'] = 'adversarial'
            self.participants[idx] = p
            self.reputation_scores[idx] = 1.0
            self.confidence_scores[idx] = 0.0
            self.ground_truth_types[idx] = dataset.get_participant_ground_truth(idx)

        print(f"\nServer initialized: {self.num_honest} honest, {self.num_adversarial} adversarial")
        
        # *** VERIFICATION ***
        print("\n=== GROUND TRUTH VERIFICATION ===")
        mismatches = 0
        for idx in self.participants:
            server_type = self.participants[idx]['type']
            ground_truth = self.ground_truth_types[idx]
            match = "✓" if server_type == ground_truth else "✗ MISMATCH"
            if server_type != ground_truth:
                mismatches += 1
            print(f"Participant {idx}: Server={server_type}, GroundTruth={ground_truth} {match}")
        
        if mismatches > 0:
            print(f"\n⚠️ WARNING: {mismatches} mismatches detected!")
        else:
            print("\n✓ All participants correctly labeled")

    def deterministic_selection(self, active_ids, k, reputations, round_idx):
        rng = np.random.RandomState(42 + round_idx)
        probs = np.array([reputations.get(idx, 1.0) for idx in active_ids], dtype=float)
        probs = probs + 1e-6
        probs = probs / probs.sum()
        return list(rng.choice(active_ids, size=min(k, len(active_ids)), replace=False, p=probs))

    def calculate_gradient_metrics(self, idx, gradients):
        grad_vector = torch.cat([g.flatten() for g in gradients]).to(self.device)
        norm = torch.norm(grad_vector).item()
        
        self.gradient_norms[idx].append(norm)
        if len(self.gradient_norms[idx]) >= 3:
            norm_trend = np.polyfit(range(3), self.gradient_norms[idx][-3:], 1)[0]
            if abs(norm_trend) > 0.6:
                self.participants[idx]['behavioral_flags'] += 2.0
        
        # CV
        norms = self.gradient_norms[idx][-5:] if len(self.gradient_norms[idx]) >= 5 else self.gradient_norms[idx]
        if len(norms) >= 2:
            variance = np.var(norms)
            mean_norm = np.mean(norms)
            cv = np.sqrt(variance) / (mean_norm + 1e-10)
            self.gradient_cvs[idx] = cv
            if cv > self.params.gradient_cv_threshold:
                self.participants[idx]['behavioral_flags'] += 3.0
                self.suspicious_patterns[idx] += 1.0

        self.participants[idx]['rounds_since_selection'] += 1
        self.gradient_directions[idx].append([g.clone().detach() for g in gradients])

    def aggregate_with_median(self, gradients_dict, selected_ids=None):
        if not gradients_dict:
            return None, []

        active_ids = [idx for idx, p in self.participants.items() if p['active'] and idx in gradients_dict]
        if not active_ids:
            return None, []

        if selected_ids is None:
            reputations = {idx: self.reputation_scores.get(idx, 1.0) for idx in active_ids}
            selected_ids = self.deterministic_selection(active_ids, self.params.median_selection_size, reputations, self.current_round)

        for idx in active_ids:
            was_selected = idx in selected_ids
            self.selection_history[idx].append(was_selected)
            if was_selected:
                self.participants[idx]['selection_count'] += 1
                self.participants[idx]['rounds_since_selection'] = 0
            else:
                self.participants[idx]['rounds_since_selection'] += 1
            self.selection_rates[idx] = self.participants[idx]['selection_count'] / max(1, self.current_round)

        if not selected_ids:
            return None, []

        first_grads = gradients_dict[selected_ids[0]]
        num_layers = len(first_grads)
        
        aggregated = []
        for layer_idx in range(num_layers):
            layer_gradients = []
            for pid in selected_ids:
                if pid in gradients_dict:
                    grad = gradients_dict[pid][layer_idx]
                    layer_gradients.append(grad.clone().detach())
            
            if not layer_gradients:
                aggregated.append(torch.zeros_like(first_grads[layer_idx]))
                continue
            
            stacked_grads = torch.stack(layer_gradients, dim=0)
            median_grad, _ = torch.median(stacked_grads, dim=0)
            aggregated.append(median_grad)
        
        return aggregated, selected_ids

    def enhanced_anomaly_detection(self):
        candidates = []
        
        for idx in self.participants:
            if not self.participants[idx]['active']:
                continue
                
            detection_score = 0.0
            
            # Gradient CV
            if self.gradient_cvs.get(idx, 0) > self.params.gradient_cv_threshold:
                detection_score += 2.0
            
            # Behavioral flags
            if self.participants[idx]['behavioral_flags'] > self.params.behavioral_flag_threshold:
                detection_score += self.participants[idx]['behavioral_flags'] * 0.3
            
            # Selection rate
            if self.current_round >= 10 and self.selection_rates[idx] < 0.1:
                detection_score += 1.5
            
            # Reputation
            if self.reputation_scores[idx] < 0.4:
                detection_score += 1.0
            
            if detection_score > 2.0:
                candidates.append((idx, detection_score))
        
        candidates.sort(key=lambda x: x[1], reverse=True)
        
        # Limit removals
        active_count = sum(1 for p in self.participants.values() if p['active'])
        max_removals = min(
            int(active_count * self.params.max_removal_rate),
            self.params.adv_per_round_limit,
            max(1, len(candidates))
        )
        
        to_remove = [idx for idx, _ in candidates[:max_removals]]
        
        # *** USE GROUND TRUTH for protection ***
        honest_removed = sum(1 for idx in to_remove if self.ground_truth_types[idx] == 'honest')
        
        # Protect honest participants
        if self.honest_removed_count + honest_removed > self.params.honest_removal_budget:
            print(f"Honest budget reached: {self.honest_removed_count}/{self.params.honest_removal_budget}")
            to_remove = [idx for idx in to_remove if self.ground_truth_types[idx] == 'adversarial']
            honest_removed = 0
        
        return to_remove, honest_removed

    def update_reputation_scores(self, selected_ids):
        for idx in self.participants:
            if not self.participants[idx]['active']:
                continue
            if idx in selected_ids:
                self.reputation_scores[idx] = min(1.0, self.reputation_scores[idx] + 0.03)
            else:
                self.reputation_scores[idx] = max(0.2, self.reputation_scores[idx] * self.params.reputation_decay)

            if self.participants[idx]['behavioral_flags'] >= self.params.behavioral_flag_threshold:
                self.reputation_scores[idx] = max(0.2, self.reputation_scores[idx] - 0.18)
                self.participants[idx]['behavioral_flags'] = 0

    def update_confidence_scores(self):
        for idx in self.participants:
            if not self.participants[idx]['active']:
                continue
            
            confidence = 0.0
            
            if self.gradient_cvs.get(idx, 0) > self.params.gradient_cv_threshold:
                confidence += 0.3
            
            if self.participants[idx]['behavioral_flags'] > 0:
                confidence += min(self.participants[idx]['behavioral_flags'] / 5.0, 0.3)
            
            if self.current_round >= 10 and self.selection_rates[idx] < 0.1:
                confidence += 0.2
            
            if self.reputation_scores[idx] < 0.4:
                confidence += 0.2

            self.confidence_scores[idx] = min(1.0, confidence)
            self.confidence_history[idx].append(self.confidence_scores[idx])

    def evaluate_model(self, val_loader):
        self.model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self.model(inputs)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        return correct / total if total > 0 else 0.0

    def update_model(self, gradients_dict, val_loader=None):
        aggregated_grads, selected_ids = self.aggregate_with_median(gradients_dict)
        if not selected_ids or aggregated_grads is None:
            return 0.0

        self.update_reputation_scores(selected_ids)
        self.update_confidence_scores()

        for param in self.model.parameters():
            param.grad = None
        for param, grad in zip(self.model.parameters(), aggregated_grads):
            param.grad = grad.clone().detach()

        optimizer = optim.Adam(self.model.parameters(), lr=0.005)
        optimizer.step()

        accuracy = 0.0
        if val_loader:
            accuracy = self.evaluate_model(val_loader)
            self.accuracy_history.append(accuracy)

        return accuracy

    def check_early_stopping(self):
        if len(self.accuracy_history) < self.params.patience:
            return False
        
        accuracy_plateau = max(self.accuracy_history[-self.params.patience:]) - min(self.accuracy_history[-self.params.patience:]) < 0.005
        return accuracy_plateau

    def get_metrics(self):
        active_honest = sum(1 for idx, p in self.participants.items() 
                       if p['active'] and self.ground_truth_types[idx] == 'honest')
        active_adversarial = sum(1 for idx, p in self.participants.items() 
                            if p['active'] and self.ground_truth_types[idx] == 'adversarial')
        total_active = active_honest + active_adversarial
    
        return {
        'active_honest': active_honest,
        'active_adversarial': active_adversarial,
        'total_active': total_active,
        'total_removed': len(self.removed_participants),
        'honest_removed_count': self.honest_removed_count
    }
        

# ===== TRAINING FUNCTION =====
def train_participant(participant, dataset, device, current_round, val_loader=None, server=None):
    model = participant['model']
    model.train()
    lr = 0.001 * (0.9 ** max(0, current_round - 10))
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = FocalLoss(alpha=0.98, gamma=4.0)

    data_loader = participant['data']
    if not data_loader:
        return [], 0.0, 0.0

    total_loss = 0.0
    
    participant_id = [idx for idx, p in server.participants.items() if p == participant][0] if server else 0
    adv_rng = random.Random(42 + participant_id + current_round * 1000)
    
    noise_scale = 0.6 if current_round < 5 else 0.4
    flip_prob = 0.4 if current_round < 5 else 0.3

    training_epochs = 3
    for epoch in range(training_epochs):
        epoch_loss = 0.0
        for features, labels in data_loader:
            features, labels = features.to(device), labels.to(device)
            
            # *** Adversarial behavior based on GROUND TRUTH ***
            if server and server.ground_truth_types[participant_id] == 'adversarial':
                if adv_rng.random() < flip_prob:
                    labels = (labels + 1) % dataset.num_classes
                if adv_rng.random() < 0.5:
                    torch.manual_seed(42 + participant_id + current_round * 1000 + epoch)
                    noise = torch.randn_like(features) * noise_scale
                    features = features + noise

            optimizer.zero_grad()
            outputs = model(features)
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
            optimizer.step()
            epoch_loss += loss.item()
        total_loss += epoch_loss / len(data_loader)

    gradients = [param.grad.clone().detach() if param.grad is not None else torch.zeros_like(param) for param in model.parameters()]

    accuracy = 0.0
    if val_loader:
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = correct / total if total > 0 else 0.0

    return gradients, total_loss / training_epochs, accuracy

# ===== DATA DISTRIBUTION =====
def distribute_data_to_participants(server, dataset, train_indices):
    activated_honest = 0
    activated_adversarial = 0

    for idx in server.participants:
        indices = [i for i in train_indices if i < len(dataset.participant_ids) and dataset.participant_ids[i] == idx]
        num_records = len(indices)
        
        if num_records < 5:
            print(f"Warning: Participant {idx} has only {num_records} records, deactivating")
            server.participants[idx]['active'] = False
            continue
            
        subset = Subset(dataset, indices)
        batch_size = min(32, max(4, num_records // 2))
        server.participants[idx]['data'] = DataLoader(subset, batch_size=batch_size, shuffle=True, 
                                                     generator=data_loader_generator, worker_init_fn=seed_worker_fn)
        server.participants[idx]['active'] = True
        
        if server.ground_truth_types[idx] == 'honest':
            activated_honest += 1
        else:
            activated_adversarial += 1

    print(f"Activated: {activated_honest} honest, {activated_adversarial} adversarial")
    return activated_honest, activated_adversarial

# ===== METRICS CALCULATION (FIXED) =====
def calculate_enhanced_metrics(server, removed_participants):
    """Calculate metrics using GROUND TRUTH, not server's assigned types"""
    
    # *** CRITICAL FIX: Use ground_truth_types instead of participants['type'] ***
    y_true = [1 if server.ground_truth_types[idx] == 'adversarial' else 0 for idx in server.participants]
    y_pred = [1 if idx in removed_participants else 0 for idx in server.participants]

    if sum(y_pred) == 0:
        precision = 1.0
        recall = 0.0
        f1 = 0.0
        accuracy = y_true.count(0) / len(y_true)
    else:
        accuracy = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred, zero_division=1.0)
        recall = recall_score(y_true, y_pred, zero_division=0.0)
        f1 = f1_score(y_true, y_pred, zero_division=0.0)

    # Use ground truth for counting
    honest_removed = sum(1 for idx in removed_participants if server.ground_truth_types[idx] == 'honest')
    adversarial_removed = sum(1 for idx in removed_participants if server.ground_truth_types[idx] == 'adversarial')
    total_honest = sum(1 for idx in server.participants if server.ground_truth_types[idx] == 'honest')
    total_adversarial = sum(1 for idx in server.participants if server.ground_truth_types[idx] == 'adversarial')

    honest_removal_rate = honest_removed / max(1, total_honest)
    adversarial_removal_rate = adversarial_removed / max(1, total_adversarial)
    removal_preference = adversarial_removal_rate / honest_removal_rate if honest_removal_rate > 0 else float('inf')

    tp = sum(1 for idx in removed_participants if server.ground_truth_types[idx] == 'adversarial')
    fp = sum(1 for idx in removed_participants if server.ground_truth_types[idx] == 'honest')
    tn = sum(1 for idx in server.participants if server.ground_truth_types[idx] == 'honest' and idx not in removed_participants)
    fn = sum(1 for idx in server.participants if server.ground_truth_types[idx] == 'adversarial' and idx not in removed_participants)

    model_accuracy = np.mean(server.accuracy_history[-5:]) if len(server.accuracy_history) >= 5 else (server.accuracy_history[-1] if server.accuracy_history else 0.0)
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 1.0

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'specificity': specificity,
        'honest_removed': honest_removed,
        'adversarial_removed': adversarial_removed,
        'honest_removal_rate': honest_removal_rate,
        'adversarial_removal_rate': adversarial_removal_rate,
        'removal_preference': removal_preference,
        'tp': tp,
        'fp': fp,
        'tn': tn,
        'fn': fn,
        'model_accuracy': model_accuracy,
        'final_accuracy': server.accuracy_history[-1] if server.accuracy_history else 0.0,
        'max_accuracy': max(server.accuracy_history) if server.accuracy_history else 0.0,
        'rounds_completed': server.current_round
    }

def print_detection_performance(metrics):
    print("\n" + "="*60)
    print("DETECTION PERFORMANCE (Using Ground Truth)")
    print("="*60)
    print(f"Accuracy: {metrics['accuracy']*100:.2f}%")
    print(f"Precision: {metrics['precision']*100:.2f}%")
    print(f"Recall: {metrics['recall']*100:.2f}%")
    print(f"F1 Score: {metrics['f1']*100:.2f}%")
    print(f"Specificity: {metrics['specificity']*100:.2f}%")
    print("-" * 60)
    print(f"Honest Removed: {metrics['honest_removal_rate']*100:.1f}% ({metrics['honest_removed']}/{metrics['honest_removed']+metrics['tn']})")
    print(f"Adversarial Removed: {metrics['adversarial_removal_rate']*100:.1f}% ({metrics['adversarial_removed']}/{metrics['adversarial_removed']+metrics['fn']})")
    print(f"Removal Preference: {metrics['removal_preference']:.1f}x" if metrics['removal_preference'] != float('inf') else "Removal Preference: inf")
    print("-" * 60)
    print(f"TP: {metrics['tp']} | FP: {metrics['fp']} | TN: {metrics['tn']} | FN: {metrics['fn']}")
    print(f"Model Accuracy: {metrics['model_accuracy']*100:.2f}%")
    print("="*60)

def create_performance_plots(server, metrics, save_dir='.'):
    os.makedirs(save_dir, exist_ok=True)
    
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
    
    # Accuracy over rounds
    if server.accuracy_history:
        rounds = range(1, len(server.accuracy_history) + 1)
        ax1.plot(rounds, server.accuracy_history, 'b-', linewidth=2)
        ax1.set_xlabel('Round')
        ax1.set_ylabel('Accuracy')
        ax1.set_title('Model Accuracy Over Rounds')
        ax1.grid(True, alpha=0.3)
    
    # Confidence scores (using ground truth)
    honest_confidence = []
    adversarial_confidence = []
    max_rounds = server.current_round
    
    for round_idx in range(max_rounds):
        round_honest = []
        round_adversarial = []
        for idx in server.participants:
            if idx in server.confidence_history and len(server.confidence_history[idx]) > round_idx:
                if server.ground_truth_types[idx] == 'honest':
                    round_honest.append(server.confidence_history[idx][round_idx])
                else:
                    round_adversarial.append(server.confidence_history[idx][round_idx])
        honest_confidence.append(np.mean(round_honest) if round_honest else 0)
        adversarial_confidence.append(np.mean(round_adversarial) if round_adversarial else 0)
    
    rounds = range(len(honest_confidence))
    ax2.plot(rounds, honest_confidence, 'g-', linewidth=2, label='Honest (GT)')
    ax2.plot(rounds, adversarial_confidence, 'r-', linewidth=2, label='Adversarial (GT)')
    ax2.set_xlabel('Round')
    ax2.set_ylabel('Avg Confidence Score')
    ax2.set_title('Confidence Score Evolution (Ground Truth)')
    ax2.grid(True, alpha=0.3)
    ax2.legend()
    
    # Active participants over time (using ground truth)
    active_honest = []
    active_adversarial = []
    for round_idx in range(1, server.current_round + 1):
        honest_count = 0
        adversarial_count = 0
        for idx, participant in server.participants.items():
            if idx not in server.removed_participants[:round_idx]:
                if server.ground_truth_types[idx] == 'honest':
                    honest_count += 1
                else:
                    adversarial_count += 1
        active_honest.append(honest_count)
        active_adversarial.append(adversarial_count)
    
    rounds = range(1, len(active_honest) + 1)
    ax3.plot(rounds, active_honest, 'g-', linewidth=2, label='Active Honest')
    ax3.plot(rounds, active_adversarial, 'r-', linewidth=2, label='Active Adversarial')
    ax3.set_xlabel('Round')
    ax3.set_ylabel('Count')
    ax3.set_title('Active Participants Over Time (Ground Truth)')
    ax3.grid(True, alpha=0.3)
    ax3.legend()
    
    # Confusion Matrix
    cm = np.array([[metrics['tn'], metrics['fp']], [metrics['fn'], metrics['tp']]])
    im = ax4.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    ax4.set_title('Confusion Matrix')
    plt.colorbar(im, ax=ax4)
    classes = ['Honest', 'Adversarial']
    tick_marks = np.arange(len(classes))
    ax4.set_xticks(tick_marks)
    ax4.set_xticklabels(classes)
    ax4.set_yticks(tick_marks)
    ax4.set_yticklabels(classes)
    
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax4.text(j, i, format(cm[i, j], 'd'), 
                    ha="center", va="center", 
                    color="white" if cm[i, j] > thresh else "black",
                    fontsize=14)
    
    ax4.set_ylabel('True Label')
    ax4.set_xlabel('Predicted Label')
    
    plt.tight_layout()
    plt.savefig(f'{save_dir}/performance_dashboard.png', dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Plots saved to {save_dir}/performance_dashboard.png")

# ===== MAIN SIMULATION =====
def run_v2g_simulation(dataset_path, num_honest, num_adversarial, rounds=35, device='cpu', adaptive_mode=True, save_dir='.'):
    set_full_determinism()
    
    print(f"\nLoading dataset from {dataset_path}...")
    dataset = V2GDataset(dataset_path)

    model = V2GClassifier(input_size=dataset.num_features, num_classes=dataset.num_classes).to(device)
    server = EnhancedFederatedServer(model, num_honest, num_adversarial, device, dataset, adaptive_mode)

    indices = list(range(len(dataset)))
    rng = random.Random(42)
    rng.shuffle(indices)
    train_size = int(0.8 * len(indices))
    train_indices = indices[:train_size]
    val_indices = indices[train_size:]
    val_dataset = Subset(dataset, val_indices)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, 
                           generator=data_loader_generator, worker_init_fn=seed_worker_fn)

    print("\nDistributing data to participants...")
    activated_honest, activated_adversarial = distribute_data_to_participants(server, dataset, train_indices)

    print(f"\nStarting FL with {activated_honest} honest, {activated_adversarial} adversarial")
    print(f"Detection threshold: {server.params.removal_threshold:.3f}")
    print(f"Median selection size: {server.params.median_selection_size}")

    removed_participants = []

    for round_idx in range(rounds):
        server.current_round += 1
        print(f"\n=== Round {server.current_round} ===")
        gradients_dict = {}

        for idx in server.participants:
            if not server.participants[idx]['active']:
                continue
            gradients, loss, acc = train_participant(
                server.participants[idx], dataset, device, server.current_round, val_loader, server
            )
            if gradients:
                gradients_dict[idx] = gradients
                server.calculate_gradient_metrics(idx, gradients)

        accuracy = server.update_model(gradients_dict, val_loader)

        to_remove, honest_removed = server.enhanced_anomaly_detection()

        if to_remove:
            for idx in to_remove:
                server.participants[idx]['active'] = False
                removed_participants.append(idx)
                server.removed_participants.append(idx)
                if server.ground_truth_types[idx] == 'honest':
                    server.honest_removed_count += 1

            adv_removed = len(to_remove) - honest_removed
            print(f"Removed {len(to_remove)} participants (Honest: {honest_removed}, Adversarial: {adv_removed})")

        metrics = server.get_metrics()
        print(f"Active: {metrics['total_active']} (Honest: {metrics['active_honest']}, Adv: {metrics['active_adversarial']})")
        print(f"Accuracy: {accuracy:.3f}")

        if metrics['total_active'] < 3:
            print(f"Too few participants, stopping")
            break

        #if server.check_early_stopping():
        #    print(f"Early stopping at round {server.current_round}")
        #    break

    print("\nCreating visualizations...")
    metrics = calculate_enhanced_metrics(server, removed_participants)
    create_performance_plots(server, metrics, save_dir)
    print_detection_performance(metrics)

    return server, removed_participants, metrics

# ===== MAIN EXECUTION =====
if __name__ == "__main__":
    device = 'cpu'
    dataset_path = r"C:\Users\Administrator\Desktop\v2g dataset kaggle.csv"
    
    os.makedirs('./honest_majority_results', exist_ok=True)
    os.makedirs('./adversarial_majority_results', exist_ok=True)

    print("\n" + "="*70)
    print("=== HONEST MAJORITY SCENARIO ===")
    print("="*70)
    server_honest, removed_honest, metrics_honest = run_v2g_simulation(
        dataset_path, num_honest=10, num_adversarial=9, rounds=15, device=device,
        adaptive_mode=True, save_dir='./honest_majority_results'
    )

    print("\n\n" + "="*70)
    print("=== ADVERSARIAL MAJORITY SCENARIO ===")
    print("="*70)
    server_adv, removed_adv, metrics_adv = run_v2g_simulation(
        dataset_path, num_honest=10, num_adversarial=11, rounds=15, device=device,
        adaptive_mode=True, save_dir='./adversarial_majority_results'
    )

    print("\n\n" + "="*70)
    print("=== SCENARIO COMPARISON ===")
    print("="*70)
    print(f"{'Metric':<25} | {'Honest Maj':<15} | {'Adversarial Maj':<15}")
    print("-" * 70)
    print(f"{'F1 Score':<25} | {metrics_honest['f1']*100:>14.1f}% | {metrics_adv['f1']*100:>18.1f}%")
    print(f"{'Recall':<25} | {metrics_honest['recall']*100:>14.1f}% | {metrics_adv['recall']*100:>18.1f}%")
    print(f"{'Precision':<25} | {metrics_honest['precision']*100:>14.1f}% | {metrics_adv['precision']*100:>18.1f}%")
    print(f"{'Adv Removed':<25} | {metrics_honest['adversarial_removal_rate']*100:>14.1f}% | {metrics_adv['adversarial_removal_rate']*100:>18.1f}%")
    print(f"{'Honest Removed':<25} | {metrics_honest['honest_removal_rate']*100:>14.1f}% | {metrics_adv['honest_removal_rate']*100:>18.1f}%")
    print(f"{'Model Accuracy':<25} | {metrics_honest['model_accuracy']*100:>14.1f}% | {metrics_adv['model_accuracy']*100:>18.1f}%")
    print(f"{'Rounds':<25} | {server_honest.current_round:>14} | {server_adv.current_round:>18}")
    print("="*70)


=== HONEST MAJORITY SCENARIO ===

Loading dataset from C:\Users\Administrator\Desktop\v2g dataset kaggle.csv...
Dataset loaded: 1000 records, columns: ['participant_id', 'timestamp', 'battery_capacity_kWh', 'current_charge_kWh', 'discharge_rate_kW', 'energy_requested_kWh', 'label']

=== GROUND TRUTH PARTICIPANT LABELS ===
Honest: [232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 3