In [2]:
import os
import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
import pefile
import torch.nn as nn
import math
from collections import Counter
import pywt
from scipy import stats
from secrets import randbits
class EnhancedFeatureExtractor:
    def __init__(self, wavelet='db1', level=3):
        self.wavelet = wavelet
        self.level = level
        # 预计算特征维度
        self.expected_feature_dim = self._calculate_expected_dim()
        
    def _calculate_expected_dim(self):
        """计算预期的特征维度"""
        pe_features = 8  # 文件熵(1) + 节特征(3) + 导出表特征(4)
        
        # 小波特征维度
        # 对于每个分解级别，有3个细节系数矩阵(水平、垂直、对角线)
        # 每个矩阵提供4个统计量(均值、标准差、偏度、峰度)
        wavelet_features = 4  # 近似系数的4个统计量
        wavelet_features += 3 * 4 * self.level  # 细节系数的统计量
        
        return pe_features + wavelet_features
    def extract_export_features(self, pe):
        """提取导出表特征"""
        features = []
        try:
            if hasattr(pe, 'DIRECTORY_ENTRY_EXPORT'):
                exports = pe.DIRECTORY_ENTRY_EXPORT.symbols
                num_exports = len(exports) if exports else 0
                features.extend([
                    num_exports,
                    len(pe.DIRECTORY_ENTRY_EXPORT.name) if hasattr(pe.DIRECTORY_ENTRY_EXPORT, 'name') else 0,
                    sum(1 for e in exports if e.name) if exports else 0,
                    pe.OPTIONAL_HEADER.DATA_DIRECTORY[pefile.DIRECTORY_ENTRY['IMAGE_DIRECTORY_ENTRY_EXPORT']].Size
                ])
            else:
                features.extend([0] * 4)
        except:
            features.extend([0] * 4)
    
        # 确保返回4个特征
        return features[:4] if len(features) >=4 else features + [0]*(4-len(features))
    def calculate_entropy(self, data):
        """计算数据的熵值"""
        if not data:
            return 0
        
        occurrences = Counter(data)
        total_bytes = len(data)
        entropy = 0
        
        # Shannon熵
        for count in occurrences.values():
            probability = count / total_bytes
            entropy -= probability * math.log2(probability)
            
        return entropy
    def extract_section_features(self, pe):
        """提取固定数量的节特征"""
        features = []
        try:
            if hasattr(pe, 'sections') and len(pe.sections) > 0:
                section = pe.sections[0]  # 只使用第一个节
                section_data = section.get_data()
                features.extend([
                    len(section_data),
                    self.calculate_entropy(section_data),
                    section.Characteristics,
                ])
            else:
                features.extend([0] * 3)
        except:
            features.extend([0] * 3)
        return features

    def extract_wavelet_features(self, image_array):
        """提取固定维度的小波特征"""
        try:
            coeffs = pywt.wavedec2(image_array, self.wavelet, level=self.level)
            features = []
            
            features.extend([
                np.mean(coeffs[0]),
                np.std(coeffs[0]),
                stats.skew(coeffs[0].ravel()),
                stats.kurtosis(coeffs[0].ravel())
            ])
            
            for detail_coeffs in coeffs[1:]:
                for detail in detail_coeffs:
                    features.extend([
                        np.mean(detail),
                        np.std(detail),
                        stats.skew(detail.ravel()),
                        stats.kurtosis(detail.ravel())
                    ])
            
            # 确保特征维度正确
            expected_wavelet_features = 4 + (3 * 4 * self.level)
            if len(features) < expected_wavelet_features:
                features.extend([0] * (expected_wavelet_features - len(features)))
            elif len(features) > expected_wavelet_features:
                features = features[:expected_wavelet_features]
                
            return features
        except Exception as e:
            print(f"Error in wavelet feature extraction: {str(e)}")
            return [0] * (4 + (3 * 4 * self.level))

    def extract_features(self, file_path):
        """提取固定维度的特征集"""
        try:
            features = []
            
            with open(file_path, 'rb') as f:
                data = f.read()
            
            # 1. 熵
            file_entropy = self.calculate_entropy(data)
            features.append(file_entropy)
            
            # 2. PE特征
            try:
                pe = pefile.PE(file_path)
                features.extend(self.extract_section_features(pe))
                features.extend(self.extract_export_features(pe))
            except:
                features.extend([0] * 7)  # PE特征的默认值
            
            # 3. 小波特征
            image_array = np.frombuffer(data, dtype=np.uint8)
            width = 384
            height = len(image_array) // width + (1 if len(image_array) % width else 0)
            padded_size = height * width
            
            if len(image_array) < padded_size:
                image_array = np.pad(image_array, (0, padded_size - len(image_array)))
            
            image_array = image_array.reshape((height, width))
            wavelet_features = self.extract_wavelet_features(image_array)
            features.extend(wavelet_features)
            
            # 确保特征维度正确
            if len(features) != self.expected_feature_dim:
                print(f"Warning: Feature dimension mismatch for {file_path}")
                if len(features) < self.expected_feature_dim:
                    features.extend([0] * (self.expected_feature_dim - len(features)))
                else:
                    features = features[:self.expected_feature_dim]
            
            return np.array(features, dtype=np.float32)
            
        except Exception as e:
            print(f"Error extracting features from {file_path}: {str(e)}")
            return np.zeros(self.expected_feature_dim, dtype=np.float32)

class EnhancedMalwareDetector(nn.Module):
    def __init__(self, input_size):
        super(EnhancedMalwareDetector, self).__init__()
        
        self.pe_features = nn.Sequential(
            nn.Linear(8, 32),
            nn.ReLU(),
            nn.BatchNorm1d(32),
            nn.Dropout(0.3)
        )
        
        wavelet_feature_size = input_size - 8
        self.wavelet_features = nn.Sequential(
            nn.Linear(wavelet_feature_size, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64)
        )
        
        self.fusion = nn.Sequential(
            nn.Linear(96, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Dropout(0.3),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.BatchNorm1d(32)
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.BatchNorm1d(16),
            nn.Dropout(0.2),
            nn.Linear(16, 2)
        )
        
        self.attention_pe = nn.Sequential(
            nn.Linear(32, 1),
            nn.Sigmoid()
        )
        self.attention_wavelet = nn.Sequential(
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        pe_x = x[:, :8]
        wavelet_x = x[:, 8:]
        
        pe_features = self.pe_features(pe_x)
        pe_attention = self.attention_pe(pe_features)
        pe_features = pe_features * pe_attention
        
        wavelet_features = self.wavelet_features(wavelet_x)
        wavelet_attention = self.attention_wavelet(wavelet_features)
        wavelet_features = wavelet_features * wavelet_attention
        
        combined_features = torch.cat((pe_features, wavelet_features), dim=1)
        fused_features = self.fusion(combined_features)
        
        output = self.classifier(fused_features)
        return output

def load_model(model_path, device):
    """加载训练好的模型"""
    checkpoint = torch.load(model_path, map_location=device)
    model = EnhancedMalwareDetector(checkpoint['input_size']).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    return model

def convert_pe_to_image(file_path, width=384):
    """将PE文件转换为灰度图像"""
    try:
        with open(file_path, 'rb') as f:
            content = f.read()
        
        byte_array = np.frombuffer(content, dtype=np.uint8)
        height = len(byte_array) // width + (1 if len(byte_array) % width else 0)
        padded_size = height * width
        if len(byte_array) < padded_size:
            byte_array = np.pad(byte_array, (0, padded_size - len(byte_array)))
        
        image_array = byte_array.reshape((height, width))
        image = Image.fromarray(image_array)
        return image
    except Exception as e:
        print(f"Error converting {file_path} to image: {str(e)}")
        return None

def extract_wavelet_features(image_array, wavelet='haar', level=3):
    """提取小波特征"""
    coeffs = pywt.wavedec2(image_array, wavelet, level=level)
    features = []
    features.extend([
        np.mean(coeffs[0]),
        np.std(coeffs[0]),
        stats.skew(coeffs[0].ravel()),
        stats.kurtosis(coeffs[0].ravel())
    ])
    for detail_coeffs in coeffs[1:]:
        for detail in detail_coeffs:
            features.extend([
                np.mean(detail),
                np.std(detail),
                stats.skew(detail.ravel()),
                stats.kurtosis(detail.ravel())
            ])
    return np.array(features)

def extract_features(file_path):
    """提取文件的特征集"""
    features = []
    try:
        with open(file_path, 'rb') as f:
            data = f.read()
        
        # 计算文件熵
        file_entropy = EnhancedFeatureExtractor().calculate_entropy(data)
        features.append(file_entropy)
        try:
            # 提取PE特征
            pe = pefile.PE(file_path)
            features.extend(EnhancedFeatureExtractor().extract_section_features(pe))
            features.extend(EnhancedFeatureExtractor().extract_export_features(pe))
        except:
            features.extend([0] * 7)
        # 将文件转换为图像并提取小波特征
        image = convert_pe_to_image(file_path)
        if image is not None:
            image_array = np.array(image)
            wavelet_features = extract_wavelet_features(image_array)
            features.extend(wavelet_features)
        
    except Exception as e:
        # print(f"Error extracting features from {file_path}: {str(e)}")
        pass
    return np.array(features)

def scan_directory(model_path, scan_dir, device):
    """扫描指定目录下的文件"""
    model = load_model(model_path, device)
    results = []
    total_files = 0
    detected_malware = 0
    
    for filename in os.listdir(scan_dir):
        file_path = os.path.join(scan_dir, filename)
        if os.path.isfile(file_path):
            total_files += 1
            features = extract_features(file_path)
            if features is not None:
                features_tensor = torch.FloatTensor(features).unsqueeze(0).to(device)
                
                with torch.no_grad():
                    outputs = model(features_tensor)
                    probs = F.softmax(outputs, dim=1)
                    malware_score = probs[0][1].item()
                    prediction = 1 if malware_score > 0.85 else 0  # 阈值不建议低于0.75,否则误报率会很高
                
                results.append({
                    'file': filename,
                    'is_malware': bool(prediction),
                    'malware_score': malware_score
                })
                
                if prediction == 1:
                    detected_malware += 1
    
    detection_rate = (detected_malware / total_files * 100) if total_files > 0 else 0
    print(f"Total files scanned: {total_files}")
    print(f"Detected malware: {detected_malware}")
    print(f"Detection rate: {detection_rate:.2f}%")
    
    return results




In [5]:
model_path = "models/best_model.pth"
scan_dir = "scan_files"
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

results = scan_directory(model_path, scan_dir, device)

for result in results:
    print(f"File: {result['file']}, Is Malware: {result['is_malware']}, Malware Score: {result['malware_score']:.4f}")

Total files scanned: 7
Detected malware: 5
Detection rate: 71.43%
File: Gx7.exe, Is Malware: True, Malware Score: 0.9758
File: msncore.dll, Is Malware: False, Malware Score: 0.6013
File: .DS_Store, Is Malware: False, Malware Score: 0.2438
File: mal5.exe, Is Malware: True, Malware Score: 0.8855
File: mal4.exe, Is Malware: True, Malware Score: 0.8589
File: mal3.exe, Is Malware: True, Malware Score: 0.9773
File: mal1.exe, Is Malware: True, Malware Score: 0.8777
