In [1]:
import torch
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
import os
import time
from PIL import Image
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
import numpy as np
import joblib

class MalwareDataset(Dataset):
    def __init__(self, root_dirs, transform=None):
        self.root_dirs = root_dirs
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.classes = sorted(os.listdir(root_dirs[0]))
        
        for root_dir in root_dirs:
            for label, class_name in enumerate(self.classes):
                class_dir = os.path.join(root_dir, class_name)
                for img_name in os.listdir(class_dir):
                    img_path = os.path.join(class_dir, img_name)
                    self.image_paths.append(img_path)
                    self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

image_size = 64
num_classes = 26

transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor()
])

root_dirs = ['malevis']
dataset = MalwareDataset(root_dirs=root_dirs, transform=transform)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

def flatten_images(loader):
    flattened_images = []
    labels = []
    for images, lbls in loader:
        images = images.view(images.size(0), -1).numpy()
        flattened_images.append(images)
        labels.append(lbls.numpy())
    return np.vstack(flattened_images), np.hstack(labels)

X_train, y_train = flatten_images(train_loader)
X_test, y_test = flatten_images(test_loader)

model = Pipeline([
    ('scaler', StandardScaler()),
    ('rf', RandomForestClassifier(n_estimators=100, random_state=42))
])

start_time = time.time()
model.fit(X_train, y_train)
training_time = time.time() - start_time

print(f"Training completed in: {training_time:.2f} seconds")

y_pred = model.predict(X_test)

accuracy = accuracy_score(y_test, y_pred)
precision, recall, f1, _ = precision_recall_fscore_support(y_test, y_pred, average='weighted')

print(f'Accuracy: {accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1:.4f}')

print('\nClassification Report:')
print(classification_report(y_test, y_pred, target_names=[f'Class {i}' for i in range(num_classes)]))

joblib.dump(model, 'rf_malware_classifier_64x64.pkl')

Training completed in: 49.37 seconds
Accuracy: 0.9286
Precision: 0.9364
Recall: 0.9286
F1 Score: 0.9303

Classification Report:
              precision    recall  f1-score   support

     Class 0       1.00      1.00      1.00        73
     Class 1       0.93      0.82      0.87        67
     Class 2       0.94      0.83      0.88        76
     Class 3       1.00      0.92      0.96        73
     Class 4       0.87      0.84      0.85        63
     Class 5       0.90      0.80      0.85        59
     Class 6       0.94      1.00      0.97        73
     Class 7       1.00      1.00      1.00        72
     Class 8       0.98      0.93      0.95        56
     Class 9       0.59      0.86      0.70        64
    Class 10       1.00      1.00      1.00        72
    Class 11       1.00      1.00      1.00        74
    Class 12       1.00      1.00      1.00        73
    Class 13       0.95      0.76      0.85        76
    Class 14       0.98      1.00      0.99        63
    Cla

['rf_malware_classifier_64x64.pkl']