In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import glob
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.preprocessing import LabelEncoder, StandardScaler

# Configuration
CONFIG = {
    'data_path': r'C:\Users\USER\Documents\NTUST\Conference_Workshop_Seminar\Android\Dataset\AndMal2020-dynamic-BeforeAndAfterReboot\Cleaned_Files\normalized_dataset',
    'num_clients': 3,
    'epochs': 5,
    'learning_rate': 0.001,
    'batch_size': 32,
    'communication_rounds': 5,
    'train_split_ratio': 0.8,  # 80% train, 20% validation
    'num_ensemble_models': 3  # Number of models in the ensemble
}

# CNN Model for Android Malware Classification
class MalwareCNN(nn.Module):
    def __init__(self, input_features, num_classes):
        super(MalwareCNN, self).__init__()
        
        # Feature extraction layers
        self.conv1 = nn.Conv1d(1, 32, kernel_size=5, stride=1, padding=2)
        self.bn1 = nn.BatchNorm1d(32)
        self.conv2 = nn.Conv1d(32, 64, kernel_size=5, stride=1, padding=2)
        self.bn2 = nn.BatchNorm1d(64)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        
        # Calculate flattened size after convolutions
        self.flattened_size = 64 * (input_features // 2 // 2)  # Two pooling layers with stride 2
        
        # Classification layers
        self.fc1 = nn.Linear(self.flattened_size, 512)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        # Ensure input shape [batch_size, 1, num_features]
        if x.dim() != 3:
            raise ValueError(f"Expected input shape [batch_size, 1, num_features], got {x.shape}")
        
        # Convolution layers
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

# Ensemble Model for Android Malware Classification
class EnsembleMalwareCNN(nn.Module):
    def __init__(self, input_features, num_classes, num_models):
        super(EnsembleMalwareCNN, self).__init__()
        self.models = nn.ModuleList([MalwareCNN(input_features, num_classes) for _ in range(num_models)])

    def forward(self, x):
        outputs = [model(x) for model in self.models]
        avg_output = torch.mean(torch.stack(outputs), dim=0)
        return avg_output

# Improved Android Malware Dataset
class AndroidMalwareDataset(Dataset):
    def __init__(self, file_paths):
        self.data = []
        self.labels = []
        self.max_features = 0
        self.scaler = StandardScaler()
        
        # First pass: determine the maximum number of features
        for file_path in file_paths:
            try:
                df = pd.read_csv(file_path)
                if 'Category' not in df.columns:
                    print(f"'Category' column not found in {file_path}")
                    continue
                
                # Count features (excluding Category column)
                num_features = df.select_dtypes(include=[np.number]).drop(['Category'], axis=1, errors='ignore').shape[1]
                self.max_features = max(self.max_features, num_features)
                
            except Exception as e:
                print(f"Error checking features in {file_path}: {e}")
        
        print(f"Maximum feature dimension across all files: {self.max_features}")
        
        # Collect all feature data for standardization
        all_features = []
        
        # Second pass: load data and collect for standardization
        for file_path in file_paths:
            try:
                df = pd.read_csv(file_path)
                if 'Category' not in df.columns:
                    continue
                
                # Process features: Keep only numerical columns
                features_df = df.select_dtypes(include=[np.number]).drop(['Category'], axis=1, errors='ignore')
                features = features_df.values
                
                # Pad features to match max_features for standardization
                if features.shape[1] < self.max_features:
                    padding = np.zeros((features.shape[0], self.max_features - features.shape[1]))
                    features = np.hstack((features, padding))
                
                all_features.append(features)
                
            except Exception as e:
                print(f"Error collecting features from {file_path}: {e}")
        
        # Combine and standardize all features
        if all_features:
            combined_features = np.vstack(all_features)
            self.scaler.fit(combined_features)
        
        # Third pass: standardize, pad and add to dataset
        for file_path in file_paths:
            try:
                df = pd.read_csv(file_path)
                if 'Category' not in df.columns:
                    continue
                
                # Process features: Keep only numerical columns
                features_df = df.select_dtypes(include=[np.number]).drop(['Category'], axis=1, errors='ignore')
                features = features_df.values
                
                # Pad features to match max_features
                if features.shape[1] < self.max_features:
                    padding = np.zeros((features.shape[0], self.max_features - features.shape[1]))
                    features = np.hstack((features, padding))
                    print(f"Padded features in {file_path} from {features_df.shape[1]} to {self.max_features}")
                
                # Standardize features
                features = self.scaler.transform(features)
                
                # Extract labels from the 'Category' column
                labels = df['Category'].values
                
                # Add to dataset
                for i in range(features.shape[0]):
                    self.data.append(torch.tensor(features[i].astype(np.float32)))
                    self.labels.append(labels[i])
                
            except Exception as e:
                print(f"Error loading {file_path}: {e}")
        
        # Convert category names to numerical labels
        self.label_encoder = LabelEncoder()
        self.encoded_labels = self.label_encoder.fit_transform(self.labels)
        self.num_classes = len(self.label_encoder.classes_)
        
        print(f"Loaded {len(self.data)} samples across {self.num_classes} malware families")
        print(f"Each sample has {self.max_features} features")
        print(f"Malware families: {self.label_encoder.classes_}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.encoded_labels[idx]
    
    @staticmethod
    def collate_fn(batch):
        features, labels = zip(*batch)
        # Stack features with correct dimensions (fixed)
        features_stacked = torch.stack([f for f in features], dim=0)
        # Add channel dimension correctly for CNN input
        features_stacked = features_stacked.unsqueeze(1)  # [batch_size, 1, features]
        return features_stacked, torch.tensor(labels)

# Federated Learning Server
class FederatedServer:
    def __init__(self, dataset):
        self.dataset = dataset
        self.global_model = EnsembleMalwareCNN(input_features=dataset.max_features, 
                                               num_classes=dataset.num_classes,
                                               num_models=CONFIG['num_ensemble_models'])
        self.global_accuracy_history = []
        self.global_loss_history = []
        
        # Prepare distributed datasets for clients
        self.client_datasets = self._prepare_client_datasets()
    
    def _prepare_client_datasets(self):
        # Split into train and validation
        train_size = int(CONFIG['train_split_ratio'] * len(self.dataset))
        val_size = len(self.dataset) - train_size
        train_dataset, val_dataset = random_split(self.dataset, [train_size, val_size])
        
        # Distribute training data among clients
        client_datasets = []
        total_train_samples = len(train_dataset)
        samples_per_client = total_train_samples // CONFIG['num_clients']
        
        for i in range(CONFIG['num_clients']):
            start_idx = i * samples_per_client
            end_idx = start_idx + samples_per_client if i < CONFIG['num_clients'] - 1 else total_train_samples
            client_subset = torch.utils.data.Subset(train_dataset, range(start_idx, end_idx))
            client_datasets.append(client_subset)
        
        return {
            'train_datasets': client_datasets,
            'validation_dataset': val_dataset
        }
    
    def distribute_model(self):
        return self.global_model.state_dict()
    
    def aggregate_weights(self, client_weights):
        global_dict = self.global_model.state_dict()
        
        for key in global_dict.keys():
            global_dict[key] = torch.stack([
                client_weights[i][key].float() for i in range(len(client_weights))
            ], 0).mean(0)
        
        self.global_model.load_state_dict(global_dict)
    
    def evaluate_global_model(self):
        self.global_model.eval()
        val_loader = DataLoader(self.client_datasets['validation_dataset'], 
                               batch_size=CONFIG['batch_size'], shuffle=False, collate_fn=AndroidMalwareDataset.collate_fn)
        
        total_correct = 0
        total_samples = 0
        total_loss = 0
        
        with torch.no_grad():
            for features, labels in val_loader:
                outputs = self.global_model(features)
                _, predicted = torch.max(outputs.data, 1)
                
                total_samples += labels.size(0)
                total_correct += (predicted == labels).sum().item()
                
                loss = F.cross_entropy(outputs, labels)
                total_loss += loss.item() * labels.size(0)
        
        accuracy = total_correct / total_samples * 100 if total_samples > 0 else 0
        avg_loss = total_loss / total_samples if total_samples > 0 else float('inf')
        
        self.global_accuracy_history.append(accuracy)
        self.global_loss_history.append(avg_loss)
        
        return accuracy, avg_loss
    
    def generate_confusion_matrix(self, title="Global Model Confusion Matrix"):
        self.global_model.eval()
        val_loader = DataLoader(self.client_datasets['validation_dataset'], 
                               batch_size=CONFIG['batch_size'], shuffle=False, collate_fn=AndroidMalwareDataset.collate_fn)
        
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for features, labels in val_loader:
                outputs = self.global_model(features)
                _, predicted = torch.max(outputs.data, 1)
                all_preds.extend(predicted.numpy())
                all_labels.extend(labels.numpy())
        
        # Convert numeric labels back to family names for better readability
        class_names = self.dataset.label_encoder.classes_
        cm = confusion_matrix(all_labels, all_preds)
        
        plt.figure(figsize=(12, 10))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=class_names,
                   yticklabels=class_names)
        plt.title(title)
        plt.xlabel('Predicted Label')
        plt.ylabel('True Label')
        plt.tight_layout()
        plt.savefig(f'{title.replace(" ", "_")}.png')
        plt.close()
        
        print(f"\nClassification Report for {title}:")
        print(classification_report(all_labels, all_preds, target_names=class_names))
        
        return cm
    
    def run_federated_learning(self):
        print("\nClient Dataset Sizes:")
        for client_id, client_dataset in enumerate(self.client_datasets['train_datasets']):
            print(f"Client {client_id}: {len(client_dataset)} samples")
        
        round_client_accuracies = []
        
        for round in range(CONFIG['communication_rounds']):
            print(f"\nCommunication Round {round + 1}")
            
            client_weights = []
            client_accuracies = []
            
            for client_id, client_dataset in enumerate(self.client_datasets['train_datasets']):
                client_model = EnsembleMalwareCNN(input_features=self.dataset.max_features, 
                                                  num_classes=self.dataset.num_classes,
                                                  num_models=CONFIG['num_ensemble_models'])
                client_model.load_state_dict(self.distribute_model())
                
                # Train client model
                client_model.train()
                optimizer = optim.Adam(client_model.parameters(), lr=CONFIG['learning_rate'])
                train_loader = DataLoader(client_dataset, batch_size=CONFIG['batch_size'], shuffle=True, 
                                         collate_fn=AndroidMalwareDataset.collate_fn)
                
                for epoch in range(CONFIG['epochs']):
                    epoch_loss = 0
                    for features, labels in train_loader:
                        # Debug: print tensor shape to verify dimensions
                        if epoch == 0 and epoch_loss == 0:
                            print(f"DEBUG - Client {client_id} features shape: {features.shape}")
                            
                        optimizer.zero_grad()
                        outputs = client_model(features)
                        loss = F.cross_entropy(outputs, labels)
                        loss.backward()
                        optimizer.step()
                        epoch_loss += loss.item()
                    
                    avg_epoch_loss = epoch_loss / len(train_loader) if len(train_loader) > 0 else float('inf')
                    print(f"Client {client_id}, Epoch {epoch+1}/{CONFIG['epochs']}, Loss: {avg_epoch_loss:.4f}")
                
                # Evaluate client model
                client_accuracy = self._evaluate_client_model(client_model, client_dataset)
                client_accuracies.append(client_accuracy)
                
                # Send model weights to server
                client_weights.append(client_model.state_dict())
            
            # Aggregate model weights and update global model
            self.aggregate_weights(client_weights)
            global_accuracy, global_loss = self.evaluate_global_model()
            round_client_accuracies.append(client_accuracies)
            
            print(f"Client Accuracies: {[f'{acc:.2f}%' for acc in client_accuracies]}") 
            print(f"Global Validation Accuracy: {global_accuracy:.2f}%")
            print(f"Global Validation Loss: {global_loss:.4f}")
        
        # Generate final confusion matrix
        self.generate_confusion_matrix("Final Global Model Confusion Matrix")
        
        # Plot learning curves
        self._plot_learning_curves(round_client_accuracies)
        
        return global_accuracy
    
    def _evaluate_client_model(self, model, dataset):
        model.eval()
        dataloader = DataLoader(dataset, batch_size=CONFIG['batch_size'], shuffle=False, collate_fn=AndroidMalwareDataset.collate_fn)
        
        correct = 0
        total = 0
        
        with torch.no_grad():
            for features, labels in dataloader:
                outputs = model(features)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        return correct / total * 100 if total > 0 else 0
    
    def _plot_learning_curves(self, round_client_accuracies):
        rounds = range(1, CONFIG['communication_rounds'] + 1)
        
        plt.figure(figsize=(15, 10))
        
        # Accuracy Plot
        plt.subplot(2, 1, 1)
        plt.plot(rounds, self.global_accuracy_history, label='Global Accuracy', 
                color='blue', marker='o', linewidth=2)
        
        client_colors = ['red', 'green', 'purple', 'orange', 'brown']
        for client_id in range(min(CONFIG['num_clients'], len(client_colors))):
            client_round_accuracies = [round_accuracies[client_id] for round_accuracies in round_client_accuracies]
            plt.plot(rounds, client_round_accuracies,
                    label=f'Client {client_id} Accuracy',
                    color=client_colors[client_id],
                    linestyle='--',
                    marker='x')
        
        plt.title('Accuracy vs Communication Rounds')
        plt.xlabel('Communication Rounds')
        plt.ylabel('Accuracy (%)')
        plt.legend()
        plt.grid(True)
        
        # Loss Plot
        plt.subplot(2, 1, 2)
        plt.plot(rounds, self.global_loss_history, label='Global Loss', 
                color='blue', marker='o', linewidth=2)
        plt.title('Loss vs Communication Rounds')
        plt.xlabel('Communication Rounds')
        plt.ylabel('Cross Entropy Loss')
        plt.legend()
        plt.grid(True)
        
        plt.tight_layout()
        plt.savefig('federated_learning_curves.png')
        plt.close()

# Main execution block
def main():
    print("Android Malware Detection using Federated Learning")
    print("-------------------------------------------------")
    
    results_dir = 'federated_learning_results'
    os.makedirs(results_dir, exist_ok=True)
    
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    
    print("\nLoading Android malware datasets...")
    # Load all CSV files with 'after_reboot' in the filename
    files = glob.glob(CONFIG['data_path'] + '/*after_reboot*.csv')
    print(f"Found {len(files)} files matching the pattern")
    
    # Create dataset
    dataset = AndroidMalwareDataset(files)
    
    if dataset.num_classes == 0:
        print("Error: Failed to load dataset correctly. Check the file paths and data format.")
        return
    
    print(f"\nDataset Statistics:")
    print(f"Number of samples: {len(dataset)}")
    print(f"Number of features: {dataset.max_features}")
    print(f"Number of malware families: {dataset.num_classes}")
    
    print("\nInitializing Federated Learning Server...")
    server = FederatedServer(dataset)
    
    try:
        print("\nStarting Federated Learning Training...")
        final_accuracy = server.run_federated_learning()
        
        # Save results
        results_file = os.path.join(results_dir, 'final_results.txt')
        with open(results_file, 'w') as f:
            f.write(f"Final Validation Accuracy: {final_accuracy:.2f}%\n")
            f.write(f"Number of malware families: {dataset.num_classes}\n")
            f.write(f"Malware families: {', '.join(dataset.label_encoder.classes_)}\n")
        
        print(f"\nResults saved to {results_file}")
        print("Training completed successfully!")
    
    except Exception as e:
        print(f"An error occurred during training: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()

Android Malware Detection using Federated Learning
-------------------------------------------------

Loading Android malware datasets...
Found 12 files matching the pattern
Maximum feature dimension across all files: 121
Padded features in C:\Users\USER\Documents\NTUST\Conference_Workshop_Seminar\Android\Dataset\AndMal2020-dynamic-BeforeAndAfterReboot\Cleaned_Files\normalized_dataset\cleaned_Backdoor_after_reboot_Cat.csv from 110 to 121
Padded features in C:\Users\USER\Documents\NTUST\Conference_Workshop_Seminar\Android\Dataset\AndMal2020-dynamic-BeforeAndAfterReboot\Cleaned_Files\normalized_dataset\cleaned_FileInfector_after_reboot_Cat.csv from 97 to 121
Padded features in C:\Users\USER\Documents\NTUST\Conference_Workshop_Seminar\Android\Dataset\AndMal2020-dynamic-BeforeAndAfterReboot\Cleaned_Files\normalized_dataset\cleaned_PUA_after_reboot_Cat.csv from 104 to 121
Padded features in C:\Users\USER\Documents\NTUST\Conference_Workshop_Seminar\Android\Dataset\AndMal2020-dynamic-BeforeAn