In [None]:
import os
import math
import numpy as np
import pefile
import pywt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from collections import Counter
from PIL import Image
from tqdm import tqdm
import multiprocessing as mp

def calculate_entropy(data):
    if not data:
        return 0.0
    counts = Counter(data)
    entropy = 0.0
    total = len(data)
    for count in counts.values():
        p = count / total
        entropy -= p * math.log2(p)
    return entropy

def get_import_features(pe, max_dim=128):
    features = [0] * max_dim
    try:
        if hasattr(pe, 'DIRECTORY_ENTRY_IMPORT'):
            imports = []
            for entry in pe.DIRECTORY_ENTRY_IMPORT:
                for imp in entry.imports:
                    if imp.name:
                        imports.append(imp.name.decode('utf-8', 'ignore'))
            for i, func in enumerate(imports[:max_dim]):
                features[i] = hash(func) % 1000  # Hash simplification
    except:
        pass
    return features

def get_section_features(pe, max_sections=8):
    sections = []
    try:
        for section in pe.sections[:max_sections]:
            sections.extend([
                section.get_entropy(),
                section.SizeOfRawData,
                section.Characteristics
            ])
    except:
        pass
    pad_length = max_sections*3 - len(sections)
    return sections + [0]*pad_length

def get_image_features(file_path, img_size=64):

    try:
        with open(file_path, 'rb') as f:
            data = f.read()
        
        arr = np.frombuffer(data, dtype=np.uint8)
        
        actual_size = min(img_size, int(np.ceil(np.sqrt(len(arr)))))
        actual_size = actual_size - (actual_size % 2)  # Ensure even size for wavelet
        
        # Reshape
        arr = arr[:actual_size*actual_size]
        arr = np.pad(arr, (0, actual_size*actual_size - len(arr)))
        img = arr.reshape(actual_size, actual_size).astype(float)
        
        img = (img - img.min()) / (img.max() - img.min() + 1e-8)
        
        coeffs = pywt.wavedec2(img, 'haar', level=3)
        
        features = []
        
        for coef in coeffs:
            if isinstance(coef, tuple):
                for detail_coef in coef:
                    features.extend([
                        np.mean(np.abs(detail_coef)),  # Mean energy
                        np.std(detail_coef),           # Standard deviation
                        np.percentile(detail_coef, 90),# 90th percentile
                        np.sum(detail_coef < 0),       # Number of negative coefficients
                        entropy_measure(detail_coef)    # Entropy of coefficients
                    ])
            else:
                features.extend([
                    np.mean(np.abs(coef)),            # Mean energy
                    np.std(coef),                     # Standard deviation
                    np.percentile(coef, 90),          # 90th percentile
                    np.sum(coef < 0),                 # Number of negative coefficients
                    entropy_measure(coef)             # Entropy of coefficients
                ])
        features.extend(coeffs[0].flatten().tolist()) # Approximation coefficients
        expected_length = 1024
        if len(features) < expected_length:
            features.extend([0.0] * (expected_length - len(features)))
        
        return features[:expected_length]
    
    except Exception as e:
        return [0.0] * ((1 + 3*2) * 5)

def entropy_measure(coeffs):
    coeffs = np.abs(coeffs)
    coeffs = coeffs / (np.sum(coeffs) + 1e-8)
    entropy = -np.sum(coeffs * np.log2(coeffs + 1e-8))
    return entropy


def extract_file_features(file_path):
    try:
        pe = pefile.PE(file_path)
    except:
        pe = None
    
    features = []
    # Entropy
    with open(file_path, 'rb') as f:
        data = f.read()
    features.append(calculate_entropy(data))
    
    # Import features
    features += get_import_features(pe) if pe else [0]*128
    
    # Section features
    features += get_section_features(pe) if pe else [0]*24
    
    # Wavelet features
    features += get_image_features(file_path)
    
    return np.array(features, dtype=np.float32)

def process_file(args):
    file_path, label = args
    features = extract_file_features(file_path)
    return features, label

def prepare_dataset(black_dir='black_files', white_dir='white_files', save_path='dataset', num_processes=6):
    os.makedirs(save_path, exist_ok=True)
    
    # default: use available CPU cores
    if num_processes is None:
        num_processes = mp.cpu_count()
    
    file_list = []
    for label, path in enumerate([white_dir, black_dir]):
        if not os.path.exists(path):
            print(f"Warning: Directory {path} does not exist.")
            continue
        for root, _, files in os.walk(path):
            for f in files:
                file_path = os.path.join(root, f)
                file_list.append((file_path, 1 if label else 0))
    
    print(f"Processing {len(file_list)} files using {num_processes} processes...")
    
    # Process files in parallel
    X, y = [], []
    with mp.Pool(processes=num_processes) as pool:
        for features, label in tqdm(pool.imap(process_file, file_list), total=len(file_list)):
            X.append(features)
            y.append(label)
    
    X = np.stack(X) if X else np.array([])
    y = np.array(y)
    
    np.save(os.path.join(save_path, 'features.npy'), X)
    np.save(os.path.join(save_path, 'labels.npy'), y)
    print(f"Dataset saved with {len(X)} samples")


class MalwareClassifier(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_size, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.6),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.net(x)

class MalwareDataset(Dataset):
    def __init__(self, features, labels):
        self.X = torch.FloatTensor(features)
        self.y = torch.FloatTensor(labels)
        self.mean = self.X.mean(0)
        self.std = self.X.std(0)
        self.std[self.std == 0] = 1.0
        self.X = (self.X - self.mean) / self.std
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

def train_model(features_path, labels_path, model_save='model.pt'):
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
    
    X = np.load(features_path)
    y = np.load(labels_path)
    dataset = MalwareDataset(X, y)
    
    train_size = int(0.8 * len(dataset))
    train_set, val_set = torch.utils.data.random_split(
        dataset, [train_size, len(dataset)-train_size])
    
    train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=128)

    model = MalwareClassifier(X.shape[1]).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
    criterion = nn.BCELoss()
    
    best_acc = 0
    for epoch in range(20):
        model.train()
        total_loss = 0
        for X_batch, y_batch in train_loader:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device).unsqueeze(1)
            
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for X_val, y_val in val_loader:
                X_val = X_val.to(device)
                y_val = y_val.to(device).unsqueeze(1)
                outputs = model(X_val)
                predicted = (outputs >= 0.5).float()
                correct += (predicted == y_val).sum().item()
                total += y_val.size(0)
        acc = correct / total
        
        print(f'Epoch {epoch+1:2d} | Loss: {total_loss/len(train_loader):.4f} | Acc: {acc:.4f}')
        
        if acc > best_acc:
            best_acc = acc
            torch.save({
                'model': model.state_dict(),
                'mean': dataset.mean,
                'std': dataset.std
            }, model_save)
    
    print(f'Training complete. Best accuracy: {best_acc:.4f}')

def scan(model_path='model.pt', scan_dir='scan_files'):
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
    
    checkpoint = torch.load(model_path, map_location=device)
    model = MalwareClassifier(len(checkpoint['mean'])).to(device)
    model.load_state_dict(checkpoint['model'])
    model.eval()
    
    # move to CPU
    mean_np = checkpoint['mean'].cpu().numpy()
    std_np = checkpoint['std'].cpu().numpy()
    
    total = 0
    detected = 0
    for root, _, files in os.walk(scan_dir):
        for f in files:
            file_path = os.path.join(root, f)
            features = extract_file_features(file_path)
            
            features = (features - mean_np) / std_np
            features_tensor = torch.FloatTensor(features).unsqueeze(0).to(device)
            
            with torch.no_grad():
                prob = model(features_tensor).item()
            
            result = 'Malicious' if prob >= 0.5 else 'Benign'
            print(f'{f}: {result} ({prob:.4f})')
            
            total += 1
            if prob >= 0.5:
                detected += 1
    
    if total > 0:
        print(f'\nDetection Rate: {detected/total:.2%} ({detected}/{total})')
    else:
        print('No files found.')



In [None]:
if __name__ == '__main__':
    mp.set_start_method('fork')
    prepare_dataset()

Processing 2076 files using 6 processes...


100%|██████████| 2076/2076 [01:17<00:00, 26.79it/s]

Dataset saved with 2076 samples





In [7]:
train_model('dataset/features.npy', 'dataset/labels.npy')

Epoch  1 | Loss: 0.6158 | Acc: 0.6947
Epoch  2 | Loss: 0.5046 | Acc: 0.7716
Epoch  3 | Loss: 0.4673 | Acc: 0.8053
Epoch  4 | Loss: 0.4248 | Acc: 0.8173
Epoch  5 | Loss: 0.3959 | Acc: 0.8317
Epoch  6 | Loss: 0.3773 | Acc: 0.8317
Epoch  7 | Loss: 0.3500 | Acc: 0.8438
Epoch  8 | Loss: 0.3397 | Acc: 0.8438
Epoch  9 | Loss: 0.3186 | Acc: 0.8582
Epoch 10 | Loss: 0.3092 | Acc: 0.8462
Epoch 11 | Loss: 0.2928 | Acc: 0.8606
Epoch 12 | Loss: 0.2845 | Acc: 0.8438
Epoch 13 | Loss: 0.2731 | Acc: 0.8654
Epoch 14 | Loss: 0.2648 | Acc: 0.8582
Epoch 15 | Loss: 0.2537 | Acc: 0.8822
Epoch 16 | Loss: 0.2503 | Acc: 0.8798
Epoch 17 | Loss: 0.2393 | Acc: 0.8582
Epoch 18 | Loss: 0.2334 | Acc: 0.8702
Epoch 19 | Loss: 0.2262 | Acc: 0.8702
Epoch 20 | Loss: 0.2216 | Acc: 0.8702
Training complete. Best accuracy: 0.8822


In [14]:
scan()

2.exe: Malicious (0.9349)
1.exe: Malicious (0.9349)

Detection Rate: 100.00% (2/2)
