In [8]:
import os
import math
import numpy as np
import pefile
import pywt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from collections import Counter
from PIL import Image
from tqdm import tqdm
import multiprocessing as mp
import torch.nn.functional as F

def calculate_entropy(data):
    if not data:
        return 0.0
    counts = Counter(data)
    entropy = 0.0
    total = len(data)
    for count in counts.values():
        p = count / total
        entropy -= p * math.log2(p)
    return entropy
def get_export_features(pe, max_dim=128):
    features = [0] * max_dim
    try:
        if hasattr(pe, 'DIRECTORY_ENTRY_EXPORT'):
            exports = [entry.name.decode('utf-8', 'ignore') for entry in pe.DIRECTORY_ENTRY_EXPORT.symbols]
            for i, func in enumerate(exports[:max_dim]):
                features[i] = calculate_entropy(func.encode('utf-8'))
    except:
        pass
    return features
def get_import_features(pe, max_dim=128):
    features = [0] * max_dim
    try:
        if hasattr(pe, 'DIRECTORY_ENTRY_IMPORT'):
            imports = []
            for entry in pe.DIRECTORY_ENTRY_IMPORT:
                for imp in entry.imports:
                    if imp.name:
                        imports.append(imp.name.decode('utf-8', 'ignore'))
            for i, func in enumerate(imports[:max_dim]):
                features[i] = hash(func) % 1000  # Hash simplification
    except:
        pass
    return features

def get_section_features(pe, max_sections=8):
    sections = []
    try:
        for section in pe.sections[:max_sections]:
            sections.extend([
                section.get_entropy(),
                section.SizeOfRawData,
                section.Characteristics
            ])
    except:
        pass
    pad_length = max_sections*3 - len(sections)
    return sections + [0]*pad_length

def get_image_features(file_path):

    try:
        with open(file_path, 'rb') as f:
            data = f.read()
        
        arr = np.frombuffer(data, dtype=np.uint8)
        
        actual_size = int(np.ceil(np.sqrt(len(arr))))
        actual_size = actual_size - (actual_size % 2)  # Ensure even size for wavelet
        
        # Reshape
        arr = arr[:actual_size*actual_size]
        arr = np.pad(arr, (0, actual_size*actual_size - len(arr)))
        img = arr.reshape(actual_size, actual_size).astype(float)
        
        img = (img - img.min()) / (img.max() - img.min() + 1e-8)
        
        coeffs = pywt.wavedec2(img, 'haar', level=3)
        features = []
        
        for coef in coeffs:
        
            if isinstance(coef, tuple):
                for detail_coef in coef:
                    features.extend([
                        np.mean(np.abs(detail_coef)),  # Mean energy
                        np.std(detail_coef),           # Standard deviation
                        np.percentile(detail_coef, 90),# 90th percentile
                        np.sum(detail_coef < 0),       # Number of negative coefficients
                        entropy_measure(detail_coef)    # Entropy of coefficients
                    ])
            else:
                features.extend([
                    np.mean(np.abs(coef)),            # Mean energy
                    np.std(coef),                     # Standard deviation
                    np.percentile(coef, 90),          # 90th percentile
                    np.sum(coef < 0),                 # Number of negative coefficients
                    entropy_measure(coef)             # Entropy of coefficients
                ])
        for coef in coeffs:
            flattened = np.array(coef).flatten()
            features.extend(flattened.tolist())
        expected_length = 768
        if len(features) < expected_length:
            features.extend([0.0] * (expected_length - len(features)))
        
        return features[:expected_length]
    
    except Exception as e:
        # print(e)
        return [0.0] * 768

def entropy_measure(coeffs):
    coeffs = np.abs(coeffs)
    coeffs = coeffs / (np.sum(coeffs) + 1e-8)
    entropy = -np.sum(coeffs * np.log2(coeffs + 1e-8))
    return entropy


def extract_file_features(file_path):
    try:
        pe = pefile.PE(file_path)
    except:
        pe = None
    
    features = []
    # Entropy
    with open(file_path, 'rb') as f:
        data = f.read()
    features.append(calculate_entropy(data))
    
    # Import features
    features += get_import_features(pe) if pe else [0]*128
    # Export features
    features += get_export_features(pe) if pe else [0]*128
    # Section features
    features += get_section_features(pe) if pe else [0]*24
    
    # Wavelet features
    features += get_image_features(file_path)
    
    return np.array(features, dtype=np.float32)

def process_file(args):
    file_path, label = args
    features = extract_file_features(file_path)
    return features, label

def prepare_dataset(black_dir='black_files', white_dir='white_files', save_path='dataset', num_processes=6):
    os.makedirs(save_path, exist_ok=True)
    
    # default: use available CPU cores
    if num_processes is None:
        num_processes = mp.cpu_count()
    
    file_list = []
    for label, path in enumerate([white_dir, black_dir]):
        if not os.path.exists(path):
            print(f"Warning: Directory {path} does not exist.")
            continue
        for root, _, files in os.walk(path):
            for f in files:
                file_path = os.path.join(root, f)
                file_list.append((file_path, 1 if label else 0))
    
    print(f"Processing {len(file_list)} files using {num_processes} processes...")
    
    # Process files in parallel
    X, y = [], []
    with mp.Pool(processes=num_processes) as pool:
        for features, label in tqdm(pool.imap(process_file, file_list), total=len(file_list)):
            X.append(features)
            y.append(label)
    
    X = np.stack(X) if X else np.array([])
    y = np.array(y)
    
    np.save(os.path.join(save_path, 'features.npy'), X)
    np.save(os.path.join(save_path, 'labels.npy'), y)
    print(f"Dataset saved with {len(X)} samples")



class MalwareClassifier(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv1d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(256 * (input_size // 8), 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 1)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        x = x.unsqueeze(1)  # Add channel dimension
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(x.size(0), -1)  # Flatten
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = torch.sigmoid(self.fc3(x))
        return x

class MalwareDataset(Dataset):
    def __init__(self, features, labels):
        self.X = torch.FloatTensor(features)
        self.y = torch.FloatTensor(labels)
        self.mean = self.X.mean(0)
        self.std = self.X.std(0)
        self.std[self.std == 0] = 1.0
        self.X = (self.X - self.mean) / self.std
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

def train_model(features_path, labels_path, model_save='model.pt'):
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
    
    X = np.load(features_path)
    y = np.load(labels_path)
    dataset = MalwareDataset(X, y)
    
    train_size = int(0.8 * len(dataset))
    train_set, val_set = torch.utils.data.random_split(
        dataset, [train_size, len(dataset)-train_size])
    
    train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=128)

    model = MalwareClassifier(X.shape[1]).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
    criterion = nn.BCELoss()
    
    best_acc = 0
    for epoch in range(50):
        model.train()
        total_loss = 0
        for X_batch, y_batch in train_loader:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device).unsqueeze(1)
            
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for X_val, y_val in val_loader:
                X_val = X_val.to(device)
                y_val = y_val.to(device).unsqueeze(1)
                outputs = model(X_val)
                predicted = (outputs >= 0.5).float()
                correct += (predicted == y_val).sum().item()
                total += y_val.size(0)
        acc = correct / total
        
        print(f'Epoch {epoch+1:2d} | Loss: {total_loss/len(train_loader):.4f} | Acc: {acc:.4f}')
        
        if acc > best_acc:
            best_acc = acc
            torch.save({
                'model': model.state_dict(),
                'mean': dataset.mean,
                'std': dataset.std
            }, model_save)
    
    print(f'Training complete. Best accuracy: {best_acc:.4f}')

In [5]:
if __name__ == '__main__':
    # mp.set_start_method('fork')
    prepare_dataset(num_processes=8)

Processing 2144 files using 8 processes...


100%|██████████| 2144/2144 [01:41<00:00, 21.13it/s]


Dataset saved with 2144 samples


In [9]:
train_model('dataset/features.npy', 'dataset/labels.npy')

Epoch  1 | Loss: 0.6012 | Acc: 0.7040
Epoch  2 | Loss: 0.4733 | Acc: 0.7855
Epoch  3 | Loss: 0.4274 | Acc: 0.8275
Epoch  4 | Loss: 0.3927 | Acc: 0.8438
Epoch  5 | Loss: 0.3492 | Acc: 0.8275
Epoch  6 | Loss: 0.3150 | Acc: 0.8485
Epoch  7 | Loss: 0.2999 | Acc: 0.8462
Epoch  8 | Loss: 0.2850 | Acc: 0.8578
Epoch  9 | Loss: 0.2534 | Acc: 0.8671
Epoch 10 | Loss: 0.2404 | Acc: 0.8671
Epoch 11 | Loss: 0.2305 | Acc: 0.8811
Epoch 12 | Loss: 0.2095 | Acc: 0.8741
Epoch 13 | Loss: 0.1974 | Acc: 0.8858
Epoch 14 | Loss: 0.1828 | Acc: 0.8904
Epoch 15 | Loss: 0.1738 | Acc: 0.8858
Epoch 16 | Loss: 0.1752 | Acc: 0.8881
Epoch 17 | Loss: 0.1541 | Acc: 0.9021
Epoch 18 | Loss: 0.1400 | Acc: 0.8928
Epoch 19 | Loss: 0.1325 | Acc: 0.8928
Epoch 20 | Loss: 0.1272 | Acc: 0.8928
Epoch 21 | Loss: 0.1191 | Acc: 0.8998
Epoch 22 | Loss: 0.1071 | Acc: 0.9044
Epoch 23 | Loss: 0.1059 | Acc: 0.9021
Epoch 24 | Loss: 0.0957 | Acc: 0.9021
Epoch 25 | Loss: 0.0784 | Acc: 0.9044
Epoch 26 | Loss: 0.0772 | Acc: 0.9044
Epoch 27 | L