In [1]:
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import functional as sf

import torchvision
from torchvision import transforms
from torch.utils.data import Subset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os

# Load and transform malware image dataset (replace with your dataset path)
transform = transforms.Compose([
    transforms.Grayscale(), 
    transforms.Resize((28, 28)), 
    transforms.ToTensor()
])



# Example: Load MNIST as a placeholder dataset
train_dataset = torchvision.datasets.MNIST(root='C:\\Users\\DELL\\Documents\\hackathon\\train', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='C:\\Users\\DELL\\Documents\\hackathon\\test', train=False, download=True, transform=transform)

binary_indices_test = [i for i, (img, label) in enumerate(test_dataset) if label in [0, 1]]
test_dataset_binary = Subset(test_dataset, binary_indices_test)

binary_indices = [i for i, (img, label) in enumerate(train_dataset) if label in [0, 1]]
train_dataset_binary = Subset(train_dataset, binary_indices)

train_loader = DataLoader(train_dataset_binary, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset_binary, batch_size=64, shuffle=False)

# Define SNN Model
class SNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.lif1 = snn.Leaky(beta=0.9)
        self.fc2 = nn.Linear(128, 2)
        self.lif2 = snn.Leaky(beta=0.9)

    def forward(self, x, num_steps=25):
        mem1, mem2 = self.lif1.init_leaky(), self.lif2.init_leaky()
        x = x.view(x.size(0), -1)
        spk2_rec = []

        for _ in range(num_steps):
            spk1, mem1 = self.lif1(self.fc1(x), mem1)
            spk2, mem2 = self.lif2(self.fc2(spk1), mem2)
            spk2_rec.append(spk2)

        return torch.stack(spk2_rec)

# Train the Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SNNModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = sf.mse_count_loss()

num_epochs = 3
for epoch in range(num_epochs):
    total_loss = 0
    for data, targets in train_loader:
        data, targets = data.to(device), targets.to(device)
        targets_oh = torch.nn.functional.one_hot(targets, num_classes=2).float()

        spk_rec = model(data)
        loss = loss_fn(spk_rec, targets)


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader)}")

# Save Model
torch.save(model.state_dict(), "malware_snn1.pth")



Epoch 1, Loss: 0.3020264855491919
Epoch 2, Loss: 0.027925693454700664
Epoch 3, Loss: 0.0169426445846888


In [2]:
def train(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for data, targets in train_loader:
        data, targets = data.to(device), targets.to(device)

        optimizer.zero_grad()
        spk_rec = model(data)  # Shape: [time, batch, output]
        
        # Sum across time
        out_spikes = spk_rec.sum(dim=0)  # Shape: [batch, output]
        loss = criterion(out_spikes, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = out_spikes.max(1)
        correct += predicted.eq(targets).sum().item()
        total += targets.size(0)

    train_loss = running_loss / len(train_loader)
    train_acc = 100. * correct / total
    return train_loss, train_acc

def validate(model, val_loader, criterion, device):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for data, targets in val_loader:
            data, targets = data.to(device), targets.to(device)

            spk_rec = model(data)
            out_spikes = spk_rec.sum(dim=0)
            loss = criterion(out_spikes, targets)

            test_loss += loss.item()
            _, predicted = out_spikes.max(1)
            correct += predicted.eq(targets).sum().item()
            total += targets.size(0)

    test_loss = test_loss / len(val_loader)
    test_acc = 100. * correct / total
    return test_loss, test_acc

criterion = nn.CrossEntropyLoss()
for epoch in range(10):
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
    test_loss, test_acc = validate(model, test_loader, criterion, device)

    print(f"Epoch [{epoch+1}/{10}]")
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"Test Loss:   {test_loss:.4f} | Test Acc:   {test_acc:.2f}%\n")



Epoch [1/10]
Train Loss: 0.0043 | Train Acc: 99.95%
Test Loss:   0.0110 | Test Acc:   99.95%

Epoch [2/10]
Train Loss: 0.0136 | Train Acc: 99.87%
Test Loss:   0.0142 | Test Acc:   99.91%

Epoch [3/10]
Train Loss: 0.0086 | Train Acc: 99.91%
Test Loss:   0.0156 | Test Acc:   99.91%

Epoch [4/10]
Train Loss: 0.0028 | Train Acc: 99.98%
Test Loss:   0.0104 | Test Acc:   99.91%

Epoch [5/10]
Train Loss: 0.0023 | Train Acc: 99.97%
Test Loss:   0.0089 | Test Acc:   99.95%

Epoch [6/10]
Train Loss: 0.0006 | Train Acc: 99.99%
Test Loss:   0.0093 | Test Acc:   99.95%

Epoch [7/10]
Train Loss: 0.0011 | Train Acc: 99.97%
Test Loss:   0.0031 | Test Acc:   99.91%

Epoch [8/10]
Train Loss: 0.0004 | Train Acc: 99.98%
Test Loss:   0.0057 | Test Acc:   99.86%

Epoch [9/10]
Train Loss: 0.0001 | Train Acc: 99.99%
Test Loss:   0.0045 | Test Acc:   99.86%

Epoch [10/10]
Train Loss: 0.0001 | Train Acc: 99.99%
Test Loss:   0.0050 | Test Acc:   99.86%

