In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pefile
import math
from collections import Counter
import pywt
from scipy import stats
from PIL import Image
import matplotlib.pyplot as plt
from datetime import datetime

# 设置随机种子以确保可重复性
torch.manual_seed(42)
np.random.seed(42)

class EnhancedFeatureExtractor:
    def __init__(self, wavelet='haar', level=3):
        self.wavelet = wavelet
        self.level = level
        # 预计算特征维度
        self.expected_feature_dim = self._calculate_expected_dim()
        
    def _calculate_expected_dim(self):
        """计算预期的特征维度"""
        # PE特征维度 (固定)
        pe_features = 8  # 文件熵(1) + 节特征(3) + 导出表特征(4)
        
        # 小波特征维度 (固定)
        # 对于每个分解级别，我们有3个细节系数矩阵(水平、垂直、对角线)
        # 每个矩阵提供4个统计量(均值、标准差、偏度、峰度)
        wavelet_features = 4  # 近似系数的4个统计量
        wavelet_features += 3 * 4 * self.level  # 细节系数的统计量
        
        return pe_features + wavelet_features
    def extract_export_features(self, pe):
        """提取导出表特征"""
        features = []
        try:
            if hasattr(pe, 'DIRECTORY_ENTRY_EXPORT'):
                exports = pe.DIRECTORY_ENTRY_EXPORT.symbols
                num_exports = len(exports) if exports else 0
                features.extend([
                    num_exports,
                    len(pe.DIRECTORY_ENTRY_EXPORT.name) if hasattr(pe.DIRECTORY_ENTRY_EXPORT, 'name') else 0,
                    sum(1 for e in exports if e.name) if exports else 0,
                    pe.OPTIONAL_HEADER.DATA_DIRECTORY[pefile.DIRECTORY_ENTRY['IMAGE_DIRECTORY_ENTRY_EXPORT']].Size
                ])
            else:
                features.extend([0] * 4)
        except:
            features.extend([0] * 4)
    
        # 确保返回4个特征
        return features[:4] if len(features) >=4 else features + [0]*(4-len(features))
    def calculate_entropy(self, data):
        """计算数据的熵值"""
        if not data:
            return 0
        
        # 使用Counter来计算字节频率
        occurrences = Counter(data)
        total_bytes = len(data)
        entropy = 0
        
        # 计算香农熵
        for count in occurrences.values():
            probability = count / total_bytes
            entropy -= probability * math.log2(probability)
            
        return entropy
    def extract_section_features(self, pe):
        """提取固定数量的节特征"""
        features = []
        try:
            if hasattr(pe, 'sections') and len(pe.sections) > 0:
                section = pe.sections[0]  # 只使用第一个节
                section_data = section.get_data()
                features.extend([
                    len(section_data),
                    self.calculate_entropy(section_data),
                    section.Characteristics,
                ])
            else:
                features.extend([0] * 3)
        except:
            features.extend([0] * 3)
        return features

    def extract_wavelet_features(self, image_array):
        """提取固定维度的小波特征"""
        try:
            coeffs = pywt.wavedec2(image_array, self.wavelet, level=self.level)
            features = []
            
            # 处理近似系数
            features.extend([
                np.mean(coeffs[0]),
                np.std(coeffs[0]),
                stats.skew(coeffs[0].ravel()),
                stats.kurtosis(coeffs[0].ravel())
            ])
            
            # 处理细节系数
            for detail_coeffs in coeffs[1:]:
                for detail in detail_coeffs:
                    features.extend([
                        np.mean(detail),
                        np.std(detail),
                        stats.skew(detail.ravel()),
                        stats.kurtosis(detail.ravel())
                    ])
            
            # 确保特征维度正确
            expected_wavelet_features = 4 + (3 * 4 * self.level)
            if len(features) < expected_wavelet_features:
                features.extend([0] * (expected_wavelet_features - len(features)))
            elif len(features) > expected_wavelet_features:
                features = features[:expected_wavelet_features]
                
            return features
        except Exception as e:
            print(f"Error in wavelet feature extraction: {str(e)}")
            return [0] * (4 + (3 * 4 * self.level))

    def extract_features(self, file_path):
        """提取固定维度的特征集"""
        try:
            features = []
            
            # 读取文件
            with open(file_path, 'rb') as f:
                data = f.read()
            
            # 1. 文件熵
            file_entropy = self.calculate_entropy(data)
            features.append(file_entropy)
            
            # 2. PE特征
            try:
                pe = pefile.PE(file_path)
                features.extend(self.extract_section_features(pe))
                features.extend(self.extract_export_features(pe))
            except:
                features.extend([0] * 7)  # PE特征的默认值
            
            # 3. 小波特征
            image_array = np.frombuffer(data, dtype=np.uint8)
            width = 384
            height = len(image_array) // width + (1 if len(image_array) % width else 0)
            padded_size = height * width
            
            if len(image_array) < padded_size:
                image_array = np.pad(image_array, (0, padded_size - len(image_array)))
            
            image_array = image_array.reshape((height, width))
            wavelet_features = self.extract_wavelet_features(image_array)
            features.extend(wavelet_features)
            
            # 确保特征维度正确
            if len(features) != self.expected_feature_dim:
                print(f"Warning: Feature dimension mismatch for {file_path}")
                if len(features) < self.expected_feature_dim:
                    features.extend([0] * (self.expected_feature_dim - len(features)))
                else:
                    features = features[:self.expected_feature_dim]
            
            return np.array(features, dtype=np.float32)
            
        except Exception as e:
            print(f"Error extracting features from {file_path}: {str(e)}")
            return np.zeros(self.expected_feature_dim, dtype=np.float32)


# 数据集类定义
class EnhancedMalwareDataset(Dataset):
    def __init__(self, benign_dir, malware_dir):
        self.data = []
        self.feature_extractor = EnhancedFeatureExtractor()
        
        print("Loading benign samples...")
        for filename in os.listdir(benign_dir):
            file_path = os.path.join(benign_dir, filename)
            if os.path.isfile(file_path):
                features = self.feature_extractor.extract_features(file_path)
                self.data.append((features, 0))
        
        print("Loading malware samples...")
        for filename in os.listdir(malware_dir):
            file_path = os.path.join(malware_dir, filename)
            if os.path.isfile(file_path):
                features = self.feature_extractor.extract_features(file_path)
                self.data.append((features, 1))
        
        print(f"Total samples loaded: {len(self.data)}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        features, label = self.data[idx]
        return torch.FloatTensor(features), label

# 模型类定义
class EnhancedMalwareDetector(nn.Module):
    def __init__(self, input_size):
        super(EnhancedMalwareDetector, self).__init__()
        
        self.pe_features = nn.Sequential(
            nn.Linear(8, 32),
            nn.ReLU(),
            nn.BatchNorm1d(32),
            nn.Dropout(0.3)
        )
        
        wavelet_feature_size = input_size - 8
        self.wavelet_features = nn.Sequential(
            nn.Linear(wavelet_feature_size, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64)
        )
        
        self.fusion = nn.Sequential(
            nn.Linear(96, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Dropout(0.3),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.BatchNorm1d(32)
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.BatchNorm1d(16),
            nn.Dropout(0.2),
            nn.Linear(16, 2)
        )
        
        self.attention_pe = nn.Sequential(
            nn.Linear(32, 1),
            nn.Sigmoid()
        )
        self.attention_wavelet = nn.Sequential(
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        pe_x = x[:, :8]
        wavelet_x = x[:, 8:]
        
        pe_features = self.pe_features(pe_x)
        pe_attention = self.attention_pe(pe_features)
        pe_features = pe_features * pe_attention
        
        wavelet_features = self.wavelet_features(wavelet_x)
        wavelet_attention = self.attention_wavelet(wavelet_features)
        wavelet_features = wavelet_features * wavelet_attention
        
        combined_features = torch.cat((pe_features, wavelet_features), dim=1)
        fused_features = self.fusion(combined_features)
        
        output = self.classifier(fused_features)
        return output

# 训练主函数
def train_model(benign_dir, malware_dir, model_dir, epochs=20, batch_size=32):
    # 创建模型保存目录
    os.makedirs(model_dir, exist_ok=True)
    
    # 设置设备
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # 创建数据集
    print("Creating dataset...")
    dataset = EnhancedMalwareDataset(benign_dir, malware_dir)
    
    
    # 获取特征维度
    input_size = dataset[0][0].shape[0]
    print(f"Feature dimension: {input_size}")
    
    # 分割数据集
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(
        dataset, [train_size, test_size]
    )
    
    # 创建数据加载器
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=True)
    
    # 创建模型
    model = EnhancedMalwareDetector(input_size).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', patience=3, factor=0.5
    )
    
    # 训练记录
    train_losses = []
    val_losses = []
    accuracies = []
    best_accuracy = 0
    
    # 训练循环
    print("Starting training...")
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        
        for i, (features, labels) in enumerate(train_loader):
            features, labels = features.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            if i % 10 == 0:
                print(f'Epoch {epoch+1}, Batch {i}, Loss: {loss.item():.4f}')
        
        epoch_loss = running_loss / len(train_loader)
        train_losses.append(epoch_loss)
        
        # 验证
        model.eval()
        correct = 0
        total = 0
        val_loss = 0
        
        with torch.no_grad():
            for features, labels in test_loader:
                features, labels = features.to(device), labels.to(device)
                outputs = model(features)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        val_loss = val_loss / len(test_loader)
        val_losses.append(val_loss)
        
        accuracy = 100 * correct / total
        accuracies.append(accuracy)
        
        print(f'Epoch {epoch + 1}, Train Loss: {epoch_loss:.4f}, '
              f'Val Loss: {val_loss:.4f}, Accuracy: {accuracy:.2f}%')
        
        # 更新学习率
        scheduler.step(val_loss)
        
        # 保存最佳模型
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            model_path = os.path.join(model_dir, f'best_model_{epoch+1}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'accuracy': accuracy,
                'input_size': input_size
            }, model_path)
            print(f"New best model saved with accuracy: {accuracy:.2f}%")
    
    # 绘制训练过程
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(accuracies, label='Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(model_dir, 'training_plot.png'))
    plt.close()
    
    return input_size, best_accuracy

# 使用示例
if __name__ == "__main__":
    # 设置路径
    benign_dir = "white_files"
    malware_dir = "black_files"
    model_dir = f"models/malware_detector_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    
    # 训练模型
    input_size, best_accuracy = train_model(
        benign_dir=benign_dir,
        malware_dir=malware_dir,
        model_dir=model_dir,
        epochs=20,
        batch_size=32
    )
    
    print(f"Training completed. Best accuracy: {best_accuracy:.2f}%")
    print(f"Model saved in directory: {model_dir}")



Using device: mps
Creating dataset...
Loading benign samples...




Loading malware samples...
Total samples loaded: 2076
Feature dimension: 48
Starting training...
Epoch 1, Batch 0, Loss: 0.7036
Epoch 1, Batch 10, Loss: 0.6122
Epoch 1, Batch 20, Loss: 0.4915
Epoch 1, Batch 30, Loss: 0.4913
Epoch 1, Batch 40, Loss: 0.4673
Epoch 1, Batch 50, Loss: 0.5219
Epoch 1, Train Loss: 0.4922, Val Loss: 0.4165, Accuracy: 81.01%
New best model saved with accuracy: 81.01%
Epoch 2, Batch 0, Loss: 0.3420
Epoch 2, Batch 10, Loss: 0.3357
Epoch 2, Batch 20, Loss: 0.2833
Epoch 2, Batch 30, Loss: 0.3188
Epoch 2, Batch 40, Loss: 0.3993
Epoch 2, Batch 50, Loss: 0.4736
Epoch 2, Train Loss: 0.3728, Val Loss: 0.3678, Accuracy: 82.21%
New best model saved with accuracy: 82.21%
Epoch 3, Batch 0, Loss: 0.2380
Epoch 3, Batch 10, Loss: 0.4743
Epoch 3, Batch 20, Loss: 0.4253
Epoch 3, Batch 30, Loss: 0.4791
Epoch 3, Batch 40, Loss: 0.3208
Epoch 3, Batch 50, Loss: 0.3171
Epoch 3, Train Loss: 0.3491, Val Loss: 0.3558, Accuracy: 84.62%
New best model saved with accuracy: 84.62%
Epoch 4, 