In [4]:
import os
import math
import numpy as np
import pefile
import pywt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from collections import Counter
from PIL import Image
from tqdm import tqdm
import multiprocessing as mp
import torch.nn.functional as F

def calculate_entropy(data):
    if not data:
        return 0.0
    counts = Counter(data)
    entropy = 0.0
    total = len(data)
    for count in counts.values():
        p = count / total
        entropy -= p * math.log2(p)
    return entropy

def get_import_features(pe, max_dim=128):
    features = [0] * max_dim
    try:
        if hasattr(pe, 'DIRECTORY_ENTRY_IMPORT'):
            imports = []
            for entry in pe.DIRECTORY_ENTRY_IMPORT:
                for imp in entry.imports:
                    if imp.name:
                        imports.append(imp.name.decode('utf-8', 'ignore'))
            for i, func in enumerate(imports[:max_dim]):
                features[i] = hash(func) % 1000  # Hash simplification
    except:
        pass
    return features
def get_export_features(pe, max_dim=128):
    features = [0] * max_dim
    try:
        if hasattr(pe, 'DIRECTORY_ENTRY_EXPORT'):
            exports = [entry.name.decode('utf-8', 'ignore') for entry in pe.DIRECTORY_ENTRY_EXPORT.symbols]
            for i, func in enumerate(exports[:max_dim]):
                features[i] = calculate_entropy(func.encode('utf-8'))
    except:
        pass
    return features
def get_section_features(pe, max_sections=8):
    sections = []
    try:
        for section in pe.sections[:max_sections]:
            sections.extend([
                section.get_entropy(),
                section.SizeOfRawData,
                section.Characteristics
            ])
    except:
        pass
    pad_length = max_sections*3 - len(sections)
    return sections + [0]*pad_length

def get_image_features(file_path):
    try:
        with open(file_path, 'rb') as f:
            data = f.read()
        
        arr = np.frombuffer(data, dtype=np.uint8)
        
        actual_size = int(np.ceil(np.sqrt(len(arr))))
        actual_size = actual_size - (actual_size % 2)  # Ensure even size for wavelet
        
        # Reshape
        arr = arr[:actual_size*actual_size]
        arr = np.pad(arr, (0, actual_size*actual_size - len(arr)))
        img = arr.reshape(actual_size, actual_size).astype(float)
        
        img = (img - img.min()) / (img.max() - img.min() + 1e-8)
        
        coeffs = pywt.wavedec2(img, 'haar', level=3)
        features = []
        
        for coef in coeffs:
        
            if isinstance(coef, tuple):
                for detail_coef in coef:
                    features.extend([
                        np.mean(np.abs(detail_coef)),  # Mean energy
                        np.std(detail_coef),           # Standard deviation
                        np.percentile(detail_coef, 90),# 90th percentile
                        np.sum(detail_coef < 0),       # Number of negative coefficients
                        entropy_measure(detail_coef)    # Entropy of coefficients
                    ])
            else:
                features.extend([
                    np.mean(np.abs(coef)),            # Mean energy
                    np.std(coef),                     # Standard deviation
                    np.percentile(coef, 90),          # 90th percentile
                    np.sum(coef < 0),                 # Number of negative coefficients
                    entropy_measure(coef)             # Entropy of coefficients
                ])
        for coef in coeffs:
            flattened = np.array(coef).flatten()
            features.extend(flattened.tolist())
            
        expected_length = 384
        if len(features) < expected_length:
            features.extend([0.0] * (expected_length - len(features)))
        
        return features[:expected_length]
    
    except Exception as e:
        # print(e)
        return [0.0] * 384

def entropy_measure(coeffs):
    coeffs = np.abs(coeffs)
    coeffs = coeffs / (np.sum(coeffs) + 1e-8)
    entropy = -np.sum(coeffs * np.log2(coeffs + 1e-8))
    return entropy


def extract_file_features(file_path):
    try:
        pe = pefile.PE(file_path)
    except:
        pe = None
    
    features = []
    # Entropy
    with open(file_path, 'rb') as f:
        data = f.read()
    features.append(calculate_entropy(data))
    
    # Import features
    features += get_import_features(pe) if pe else [0]*128
    # Export features
    features += get_export_features(pe) if pe else [0]*128
    # Section features
    features += get_section_features(pe) if pe else [0]*24
    
    # Wavelet features
    features += get_image_features(file_path)
    
    return np.array(features, dtype=np.float32)

def process_file(args):
    file_path, label = args
    features = extract_file_features(file_path)
    return features, label

class MalwareClassifier(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv1d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(256 * (input_size // 8), 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 1)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        x = x.unsqueeze(1)  # Add channel dimension
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(x.size(0), -1)  # Flatten
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = torch.sigmoid(self.fc3(x))
        return x




def scan(model_path='model.pt', scan_dir='scan_files'):
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
    
    checkpoint = torch.load(model_path, map_location=device)
    model = MalwareClassifier(len(checkpoint['mean'])).to(device)
    model.load_state_dict(checkpoint['model'])
    model.eval()
    
    # move to CPU
    mean_np = checkpoint['mean'].cpu().numpy()
    std_np = checkpoint['std'].cpu().numpy()
    
    total = 0
    detected = 0
    for root, _, files in os.walk(scan_dir):
        for f in files:
            file_path = os.path.join(root, f)
            features = extract_file_features(file_path)
            
            features = (features - mean_np) / std_np
            features_tensor = torch.FloatTensor(features).unsqueeze(0).to(device)
            
            with torch.no_grad():
                prob = model(features_tensor).item()
            
            result = 'Malicious' if prob >= 0.9 else 'Benign'
            print(f'{f}: {result} ({prob:.4f})')
            
            total += 1
            if prob >= 0.9:
                detected += 1
    
    if total > 0:
        print(f'\nDetection Rate: {detected/total:.2%} ({detected}/{total})')
    else:
        print('No files found.')



In [5]:
scan()


WinSAT.exe: Benign (0.0003)
wsutil.exe: Benign (0.7073)
wmpconfig.exe: Benign (0.0000)
wowreg32.exe: Malicious (0.9870)
wpa.exe: Malicious (0.9741)
XtuService.exe: Malicious (0.9456)
WSCollect.exe: Benign (0.0123)
WmsDashboard.exe: Benign (0.0000)
wsl.exe: Benign (0.0003)
wsmprovhost.exe: Benign (0.0000)
wstraceutil.exe: Benign (0.8008)
wscript.exe: Benign (0.0031)
xxd.exe: Benign (0.0176)
WMIRegistrationService.exe: Malicious (0.9355)
WidgetBoard.exe: Benign (0.0001)
WMSvc.exe: Benign (0.0026)
x64launcher.exe: Benign (0.3377)
WindowsSandbox.exe: Benign (0.0003)
winpty-debugserver.exe: Malicious (0.9848)
WidgetService.exe: Benign (0.1940)
WpcTok.exe: Benign (0.0000)
wiawow64.exe: Benign (0.2744)
wusa.exe: Benign (0.1358)
WmsSelfHealingSvc.exe: Benign (0.0000)
WindowsCamera.exe: Benign (0.4719)
WindowsBackupClient.exe: Benign (0.0001)
wpnpinst.exe: Benign (0.0000)
.DS_Store: Benign (0.0541)
XboxApp.exe: Benign (0.0162)
wish86.exe: Malicious (0.9486)
XboxPcAppAdminServer.exe: Benign (0.0