In [7]:
import os
import numpy as np
import pandas as pd
import glob
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split, cross_val_predict, KFold
from sklearn.neural_network import MLPClassifier
import random

# Configuration
CONFIG = {
    'data_path': r'C:\Users\USER\Documents\NTUST\Conference_Workshop_Seminar\Android\Dataset\AndMal2020-dynamic-BeforeAndAfterReboot\Cleaned_Files\normalized_dataset',
    'num_clients': 3,
    'communication_rounds': 5,
    'train_split_ratio': 0.8  # 80% train, 20% validation
}

# Improved Android Malware Dataset
class AndroidMalwareDataset:
    def __init__(self, file_paths):
        self.data = []
        self.labels = []
        self.max_features = 0
        self.scaler = StandardScaler()
        
        # First pass: determine the maximum number of features
        for file_path in file_paths:
            try:
                df = pd.read_csv(file_path)
                if 'Category' not in df.columns:
                    print(f"'Category' column not found in {file_path}")
                    continue
                
                # Count features (excluding Category column)
                num_features = df.select_dtypes(include=[np.number]).drop(['Category'], axis=1, errors='ignore').shape[1]
                self.max_features = max(self.max_features, num_features)
                
            except Exception as e:
                print(f"Error checking features in {file_path}: {e}")
        
        print(f"Maximum feature dimension across all files: {self.max_features}")
        
        # Collect all feature data for standardization
        all_features = []
        
        # Second pass: load data and collect for standardization
        for file_path in file_paths:
            try:
                df = pd.read_csv(file_path)
                if 'Category' not in df.columns:
                    continue
                
                # Process features: Keep only numerical columns
                features_df = df.select_dtypes(include=[np.number]).drop(['Category'], axis=1, errors='ignore')
                features = features_df.values
                
                # Pad features to match max_features for standardization
                if features.shape[1] < self.max_features:
                    padding = np.zeros((features.shape[0], self.max_features - features.shape[1]))
                    features = np.hstack((features, padding))
                
                all_features.append(features)
                
            except Exception as e:
                print(f"Error collecting features from {file_path}: {e}")
        
        # Combine and standardize all features
        if all_features:
            combined_features = np.vstack(all_features)
            self.scaler.fit(combined_features)
        
        # Third pass: standardize, pad and add to dataset
        for file_path in file_paths:
            try:
                df = pd.read_csv(file_path)
                if 'Category' not in df.columns:
                    continue
                
                # Process features: Keep only numerical columns
                features_df = df.select_dtypes(include=[np.number]).drop(['Category'], axis=1, errors='ignore')
                features = features_df.values
                
                # Pad features to match max_features
                if features.shape[1] < self.max_features:
                    padding = np.zeros((features.shape[0], self.max_features - features.shape[1]))
                    features = np.hstack((features, padding))
                    print(f"Padded features in {file_path} from {features_df.shape[1]} to {self.max_features}")
                
                # Standardize features
                features = self.scaler.transform(features)
                
                # Extract labels from the 'Category' column
                labels = df['Category'].values
                
                # Add to dataset
                for i in range(features.shape[0]):
                    self.data.append(features[i])
                    self.labels.append(labels[i])
                
            except Exception as e:
                print(f"Error loading {file_path}: {e}")
        
        # Convert category names to numerical labels
        self.label_encoder = LabelEncoder()
        self.encoded_labels = self.label_encoder.fit_transform(self.labels)
        self.num_classes = len(self.label_encoder.classes_)
        
        print(f"Loaded {len(self.data)} samples across {self.num_classes} malware families")
        print(f"Each sample has {self.max_features} features")
        print(f"Malware families: {self.label_encoder.classes_}")

    def get_data(self):
        return np.array(self.data), np.array(self.encoded_labels)

def create_episode(X, y, n_classes=5, n_samples=5):
    classes = random.sample(list(set(y)), n_classes)
    support_set = []
    query_set = []
    support_labels = []
    
    for cls in classes:
        cls_indices = np.where(y == cls)[0]
        selected_samples = random.sample(list(cls_indices), n_samples + 1)
        support_set.extend(X[selected_samples[:-1]])
        query_set.append(X[selected_samples[-1]])
        support_labels.extend([cls] * n_samples)
    
    return np.array(support_set), np.array(query_set), np.array(support_labels), classes

# Federated Learning Server
class FederatedServer:
    def __init__(self, dataset):
        self.dataset = dataset
        self.global_model = self._create_prototypical_network()
        self.global_accuracy_history = []
        
        # Prepare distributed datasets for clients
        self.client_datasets = self._prepare_client_datasets()
    
    def _create_prototypical_network(self):
        return MLPClassifier(hidden_layer_sizes=(64,), max_iter=1000, random_state=42)
    
    def _prepare_client_datasets(self):
        # Split into train and validation
        X, y = self.dataset.get_data()
        train_size = int(CONFIG['train_split_ratio'] * len(X))
        val_size = len(X) - train_size
        X_train, X_val, y_train, y_val = train_test_split(X, y, train_size=train_size, test_size=val_size, random_state=42)
        
        # Distribute training data among clients
        client_datasets = []
        total_train_samples = len(X_train)
        samples_per_client = total_train_samples // CONFIG['num_clients']
        
        for i in range(CONFIG['num_clients']):
            start_idx = i * samples_per_client
            end_idx = start_idx + samples_per_client if i < CONFIG['num_clients'] - 1 else total_train_samples
            client_subset = (X_train[start_idx:end_idx], y_train[start_idx:end_idx])
            client_datasets.append(client_subset)
        
        return {
            'train_datasets': client_datasets,
            'validation_dataset': (X_val, y_val)
        }
    
    def evaluate_global_model(self):
        X_val, y_val = self.client_datasets['validation_dataset']
        self.global_model.fit(X_val, y_val)  # Fit the global model before evaluation
        y_pred = self.global_model.predict(X_val)
        
        accuracy = np.mean(y_pred == y_val) * 100
        self.global_accuracy_history.append(accuracy)
        
        return accuracy
    
    def generate_confusion_matrix(self, title="Global Model Confusion Matrix"):
        X_val, y_val = self.client_datasets['validation_dataset']
        y_pred = self.global_model.predict(X_val)
        
        class_names = self.dataset.label_encoder.classes_
        cm = confusion_matrix(y_val, y_pred)
        
        plt.figure(figsize=(12, 10))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=class_names,
                   yticklabels=class_names)
        plt.title(title)
        plt.xlabel('Predicted Label')
        plt.ylabel('True Label')
        plt.tight_layout()
        plt.savefig(f'{title.replace(" ", "_")}.png')
        plt.close()
        
        print(f"\nClassification Report for {title}:")
        print(classification_report(y_val, y_pred, target_names=class_names))
        
        return cm
    
    def run_federated_learning(self):
        print("\nClient Dataset Sizes:")
        for client_id, client_dataset in enumerate(self.client_datasets['train_datasets']):
            print(f"Client {client_id}: {len(client_dataset[0])} samples")
        
        round_client_accuracies = []
        
        for round in range(CONFIG['communication_rounds']):
            print(f"\nCommunication Round {round + 1}")
            
            client_models = []
            client_accuracies = []
            
            for client_id, (X_train, y_train) in enumerate(self.client_datasets['train_datasets']):
                client_model = self._create_prototypical_network()

                # Create episodes for few-shot learning
                support_set, query_set, support_labels, classes = create_episode(X_train, y_train)
                
                # Train client model on the support set
                client_model.fit(support_set, support_labels)
                
                # Evaluate client model on the query set
                y_pred = client_model.predict(query_set)
                client_accuracy = np.mean(y_pred == [y_train[np.where(y_train == cls)[0][0]] for cls in classes]) * 100
                client_accuracies.append(client_accuracy)
                
                # Send model to server
                client_models.append(client_model)
            
            # Aggregate model weights and update global model
            self.global_model = client_models[0]  # This is simplified; implement proper aggregation if needed
            global_accuracy = self.evaluate_global_model()
            round_client_accuracies.append(client_accuracies)
            
            print(f"Client Accuracies: {[f'{acc:.2f}%' for acc in client_accuracies]}")
            print(f"Global Validation Accuracy: {global_accuracy:.2f}%")
        
        # Generate final confusion matrix
        self.generate_confusion_matrix("Final Global Model Confusion Matrix")
        
        return global_accuracy

    def cross_validate_global_model(self):
        X, y = self.dataset.get_data()
        kf = KFold(n_splits=5, shuffle=True, random_state=42)
        y_pred = cross_val_predict(self.global_model, X, y, cv=kf)
        
        class_names = self.dataset.label_encoder.classes_
        cm = confusion_matrix(y, y_pred)
        
        plt.figure(figsize=(12, 10))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=class_names,
                   yticklabels=class_names)
        plt.title("5-Fold Cross-Validation Confusion Matrix")
        plt.xlabel('Predicted Label')
        plt.ylabel('True Label')
        plt.tight_layout()
        plt.savefig('cross_validation_confusion_matrix.png')
        plt.close()
        
        print("\nClassification Report for 5-Fold Cross-Validation:")
        print(classification_report(y, y_pred, target_names=class_names))

# Main execution block
def main():
    print("Android Malware Detection using Federated Learning")
    print("-------------------------------------------------")
    
    results_dir = 'federated_learning_results'
    os.makedirs(results_dir, exist_ok=True)
    
    # Set random seeds for reproducibility
    np.random.seed(42)
    
    print("\nLoading Android malware datasets...")
    # Load all CSV files with 'after_reboot' in the filename
    files = glob.glob(CONFIG['data_path'] + '/*after_reboot*.csv')
    print(f"Found {len(files)} files matching the pattern")
    
    # Create dataset
    dataset = AndroidMalwareDataset(files)
    
    if dataset.num_classes == 0:
        print("Error: Failed to load dataset correctly. Check the file paths and data format.")
        return
    
    print(f"\nDataset Statistics:")
    print(f"Number of samples: {len(dataset.data)}")
    print(f"Number of features: {dataset.max_features}")
    print(f"Number of malware families: {dataset.num_classes}")
    
    print("\nInitializing Federated Learning Server...")
    server = FederatedServer(dataset)
    
    try:
        print("\nStarting Federated Learning Training...")
        final_accuracy = server.run_federated_learning()
        
        # Save results
        results_file = os.path.join(results_dir, 'final_results.txt')
        with open(results_file, 'w') as f:
            f.write(f"Final Validation Accuracy: {final_accuracy:.2f}%\n")
            f.write(f"Number of malware families: {dataset.num_classes}\n")
            f.write(f"Malware families: {', '.join(dataset.label_encoder.classes_)}\n")
        
        print(f"\nResults saved to {results_file}")
        print("Training completed successfully!")
        
        # Perform 5-fold cross-validation
        print("\nPerforming 5-Fold Cross-Validation...")
        server.cross_validate_global_model()
    
    except Exception as e:
        print(f"An error occurred during training: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()

Android Malware Detection using Federated Learning
-------------------------------------------------

Loading Android malware datasets...
Found 12 files matching the pattern
Maximum feature dimension across all files: 121
Padded features in C:\Users\USER\Documents\NTUST\Conference_Workshop_Seminar\Android\Dataset\AndMal2020-dynamic-BeforeAndAfterReboot\Cleaned_Files\normalized_dataset\cleaned_Backdoor_after_reboot_Cat.csv from 110 to 121
Padded features in C:\Users\USER\Documents\NTUST\Conference_Workshop_Seminar\Android\Dataset\AndMal2020-dynamic-BeforeAndAfterReboot\Cleaned_Files\normalized_dataset\cleaned_FileInfector_after_reboot_Cat.csv from 97 to 121
Padded features in C:\Users\USER\Documents\NTUST\Conference_Workshop_Seminar\Android\Dataset\AndMal2020-dynamic-BeforeAndAfterReboot\Cleaned_Files\normalized_dataset\cleaned_PUA_after_reboot_Cat.csv from 104 to 121
Padded features in C:\Users\USER\Documents\NTUST\Conference_Workshop_Seminar\Android\Dataset\AndMal2020-dynamic-BeforeAn