In [None]:
import flwr as fl
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.feature_selection import mutual_info_classif, SelectKBest
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score


# Define the Anomaly Transformer model
class AnomalyTransformer(nn.Module):
    def __init__(self, input_dim=15, hidden_dim=256, output_dim=1, num_layers=2):
        super(AnomalyTransformer, self).__init__()
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=8, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()  

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.transformer_encoder(x)
        x = self.fc2(x)
        return self.sigmoid(x)

# Load dataset function
def load_data():
    X_train, X_test, y_train, y_test = train_test_split(X, y,test_size=0.1,random_state=43) 
    train_index = X_train.columns
    train_index
    mutual_info = mutual_info_classif(X_train, y_train)
    mutual_info = pd.Series(mutual_info)
    mutual_info.index = train_index
    mutual_info.sort_values(ascending=False)
    Select_features = SelectKBest(mutual_info_classif, k=30)
    Select_features.fit(X_train, y_train)
    train_index[Select_features.get_support()]
    columns=['duration', 'protocol_type', 'service', 'flag', 'src_bytes',
       'dst_bytes', 'wrong_fragment', 'hot', 'logged_in', 'num_compromised',
       'count', 'srv_count', 'serror_rate', 'srv_serror_rate', 'rerror_rate']


    X_train=X_train[columns]
    X_test=X_test[columns]
    scaler = StandardScaler()

    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test) 
    
    
    x_train = torch.tensor(X_train, dtype=torch.float32)
    y_train = torch.tensor(y_train.to_numpy(), dtype=torch.float32).unsqueeze(1)  
    x_test = torch.tensor(X_test, dtype=torch.float32)
    y_test = torch.tensor(y_test.to_numpy(), dtype=torch.float32).unsqueeze(1)  
    
    train_dataset = TensorDataset(x_train, y_train)
    test_dataset = TensorDataset(x_test, y_test)
        
    return train_dataset, test_dataset

# Define Focal Loss
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=3.3, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        inputs = inputs.view(-1)  
        targets = targets.view(-1)  

        
        pt = inputs * targets + (1 - inputs) * (1 - targets)  
        focal_weight = (1 - pt) ** self.gamma
        loss = -self.alpha * focal_weight * (targets * torch.log(inputs + 1e-9) + (1 - targets) * torch.log(1 - inputs + 1e-9))

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

# FL Client with energy tracking
class FLClient(fl.client.NumPyClient):
    def __init__(self, model, train_loader, test_loader, device):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device
        self.criterion = FocalLoss(alpha=0.25, gamma=3.3, reduction='mean')
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.0005)
        

    
    def get_parameters(self, config):
        return [val.cpu().numpy() for val in self.model.state_dict().values()]

    def set_parameters(self, parameters):
        state_dict = {k: torch.tensor(v, dtype=torch.float32, device=self.device) for k, v in zip(self.model.state_dict().keys(), parameters)}
        self.model.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        self.model.train()
        
        
        for epoch in range(3):  
            for x, y in self.train_loader:
                x, y = x.to(self.device), y.to(self.device)
                self.optimizer.zero_grad()
                outputs = self.model(x)
                loss = self.criterion(outputs, y)
                loss.backward()
                self.optimizer.step()
        
        return self.get_parameters({}), len(self.train_loader.dataset) 

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        self.model.eval()

        all_labels, all_preds, all_probs = [], [], []

        with torch.no_grad():
            for x, y in self.test_loader:
                x, y = x.to(self.device), y.to(self.device)
                outputs = self.model(x)
                probs = outputs.cpu().numpy()
                preds = (probs > 0.5).astype(int)

                all_labels.append(y.cpu().numpy().flatten())
                all_preds.append(preds.flatten())
                all_probs.append(probs.flatten())
        
        all_labels = np.concatenate(all_labels, axis=0)
        all_preds = np.concatenate(all_preds, axis=0)
        all_probs = np.concatenate(all_probs, axis=0)

        accuracy = accuracy_score(all_labels, all_preds)
        precision = precision_score(all_labels, all_preds, average='binary', zero_division=0)
        recall = recall_score(all_labels, all_preds, average='binary', zero_division=0)
        f1 = f1_score(all_labels, all_preds, average='binary', zero_division=0)
        auc_roc = roc_auc_score(all_labels, all_probs)

        print(f"Accuracy: {accuracy:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall: {recall:.4f}")
        print(f"F1 Score: {f1:.4f}")
        print(f"AUC-ROC: {auc_roc:.4f}")
        return auc_roc, len(self.test_loader.dataset), {}

# Client function
def client_fn(context):
    train_dataset, test_dataset = load_data()
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
    model = AnomalyTransformer()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
   
    return fl.client.NumPyClient.to_client(FLClient(model, train_loader, test_loader, device))


