In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from torch.optim import Adam
import os
import time
from PIL import Image
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report
import numpy as np

class MalwareDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.classes = sorted(os.listdir(root_dir))
        
        for label, class_name in enumerate(self.classes):
            class_dir = os.path.join(root_dir, class_name)
            for img_name in os.listdir(class_dir):
                img_path = os.path.join(class_dir, img_name)
                self.image_paths.append(img_path)
                self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

class Patches(nn.Module):
    def __init__(self, patch_size, in_channels=3):
        super(Patches, self).__init__()
        self.patch_size = patch_size
        self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        patches = self.unfold(x)
        patches = patches.transpose(1, 2)
        return patches

class PatchEncoder(nn.Module):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.projection = nn.Linear(16 * 16 * 3, projection_dim)
        self.position_embedding = nn.Parameter(torch.randn(1, num_patches, projection_dim))

    def forward(self, patches):
        encoded = self.projection(patches) + self.position_embedding
        return encoded

class MLP(nn.Module):
    def __init__(self, hidden_units, dropout_rate):
        super(MLP, self).__init__()
        layers = []
        for units in hidden_units:
            layers.append(nn.Linear(units[0], units[1]))
            layers.append(nn.GELU())
            layers.append(nn.Dropout(dropout_rate))
        self.mlp = nn.Sequential(*layers)

    def forward(self, x):
        return self.mlp(x)

class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, num_patches, projection_dim, num_heads, transformer_layers, transformer_units, mlp_head_units, num_classes):
        super(VisionTransformer, self).__init__()
        self.patches = Patches(patch_size=patch_size)
        self.patch_encoder = PatchEncoder(num_patches=num_patches, projection_dim=projection_dim)
        self.transformer_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=projection_dim,
                nhead=num_heads,
                dim_feedforward=transformer_units[0][1],
                dropout=0.1
            )
            for _ in range(transformer_layers)
        ])
        self.mlp_head = MLP(mlp_head_units, dropout_rate=0.5)
        self.classifier = nn.Linear(mlp_head_units[-1][1], num_classes)

    def forward(self, x):
        patches = self.patches(x)
        encoded_patches = self.patch_encoder(patches)

        for transformer_layer in self.transformer_layers:
            encoded_patches = transformer_layer(encoded_patches)

        representation = encoded_patches.mean(dim=1)
        features = self.mlp_head(representation)
        logits = self.classifier(features)

        return logits


image_size = 64
patch_size = 16
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 8
transformer_layers = 6
transformer_units = [(projection_dim, projection_dim * 4)]
mlp_head_units = [(projection_dim, projection_dim * 2), (projection_dim * 2, 512)]
num_classes = 26

transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor()
])

dataset = MalwareDataset(root_dir='malevis', transform=transform)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

model = VisionTransformer(image_size=image_size, patch_size=patch_size, num_patches=num_patches,
                          projection_dim=projection_dim, num_heads=num_heads, transformer_layers=transformer_layers,
                          transformer_units=transformer_units, mlp_head_units=mlp_head_units, num_classes=num_classes)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=1e-4)

start_time = time.time()
num_epochs = 10
for epoch in range(num_epochs):
    epoch_start_time = time.time()
    model.train()
    running_loss = 0.0
    all_labels = []
    all_predictions = []
    
    for images, labels in dataloader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        
        _, predicted = torch.max(outputs.data, 1)
        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(predicted.cpu().numpy())

    epoch_loss = running_loss / len(train_loader)
    epoch_time = time.time() - epoch_start_time
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}, Time elapsed: {epoch_time:.2f} seconds")

total_time = time.time() - start_time
print(f"Training completed in: {total_time:.2f} seconds")

model.eval()
all_labels = []
all_predictions = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(predicted.cpu().numpy())

all_labels = np.array(all_labels)
all_predictions = np.array(all_predictions)

accuracy = accuracy_score(all_labels, all_predictions)
precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_predictions, average='weighted')

print(f'Accuracy: {accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1:.4f}')

print('\nClassification Report:')
print(classification_report(all_labels, all_predictions, target_names=[f'Class {i}' for i in range(num_classes)]))

torch.save(model.state_dict(), 'vit_malware_classifier_64x64_rgb.pth')

Epoch [1/10], Loss: 4.0690, Time elapsed: 80.04 seconds
Epoch [2/10], Loss: 3.6938, Time elapsed: 77.67 seconds
Epoch [3/10], Loss: 3.3359, Time elapsed: 68.95 seconds
Epoch [4/10], Loss: 2.9835, Time elapsed: 73.57 seconds
Epoch [5/10], Loss: 2.7485, Time elapsed: 67.27 seconds
Epoch [6/10], Loss: 2.5376, Time elapsed: 86.94 seconds
Epoch [7/10], Loss: 2.3709, Time elapsed: 88.88 seconds
Epoch [8/10], Loss: 2.2143, Time elapsed: 87.46 seconds
Epoch [9/10], Loss: 2.0921, Time elapsed: 86.03 seconds
Epoch [10/10], Loss: 1.9662, Time elapsed: 83.33 seconds
Training completed in: 800.17 seconds
Accuracy: 0.6016
Precision: 0.5411
Recall: 0.6016
F1 Score: 0.5365

Classification Report:
              precision    recall  f1-score   support

     Class 0       0.82      1.00      0.90        67
     Class 1       0.00      0.00      0.00        60
     Class 2       0.43      0.50      0.46        80
     Class 3       0.58      0.88      0.70        74
     Class 4       0.33      0.12      

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
