In [8]:
import torch
from utils import set_seed
import os
import pickle
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset, Subset
from torch import nn
from torch.nn import functional as F
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.metrics import auc, confusion_matrix, roc_curve
from torch import optim
from torch.nn.functional import sigmoid
from tqdm.auto import tqdm
from sklearn import metrics

plt.rcParams.update(plt.rcParamsDefault)
set_seed(42)
device = torch.device("cpu")

import warnings
warnings.filterwarnings('ignore')

## Dataset

In [9]:
class MalwareDataset(Dataset):
    def __init__(self, benign_dir="data/benign", malware_dir="data/malware"):
        self.benign_dir = benign_dir
        self.malware_dir = malware_dir
        self.benign_files = sorted(os.listdir(benign_dir))
        self.malware_files = sorted(os.listdir(malware_dir))

    def __getitem__(self, index):
        try:
            file_dir = os.path.join(self.benign_dir, self.benign_files[index])
            label = 0.0
        except IndexError:
            file_dir = os.path.join(
                self.malware_dir, self.malware_files[index - len(self.benign_files)],
            )
            label = 1.0

        with open(file_dir, "rb") as f:
            file_ = torch.tensor(pickle.load(f))
        return file_, label

    def __len__(self):
        return len(self.benign_files) + len(self.malware_files)

class UniLabelDataset(Dataset):
    def __init__(self, data_dir, is_malware):
        self.data_dir = data_dir
        self.is_malware = is_malware
        self.files = sorted(os.listdir(data_dir))

    def __getitem__(self, index):
        file_dir = os.path.join(self.data_dir, self.files[index])
        with open(file_dir, "rb") as f:
            file_ = torch.tensor(pickle.load(f))
        return file_, float(self.is_malware)

    def __len__(self):
        return len(self.files)


def collate_fn(batch):
    xs = pad_sequence([x[0] for x in batch], max_len=4096, padding_value=256)
    ys = torch.tensor([x[1] for x in batch])
    return xs, ys


def pad_sequence(sequences, max_len=None, padding_value=0):
    batch_size = len(sequences)
    if max_len is None:
        max_len = max([s.size(0) for s in sequences])
    out_tensor = sequences[0].new_full((batch_size, max_len), padding_value)
    for i, tensor in enumerate(sequences):
        length = tensor.size(0)
        if max_len > length:
            out_tensor[i, :length] = tensor
        else:
            out_tensor[i, :max_len] = tensor[:max_len]
    return out_tensor


def train_val_test_split(idx, val_size, test_size):
    tv_idx, test_idx = train_test_split(idx, test_size=test_size, shuffle=True)
    train_idx, val_idx = train_test_split(tv_idx, test_size=val_size, shuffle=True)
    return train_idx, val_idx, test_idx


def make_idx(dataset, val_size, test_size):
    num_benign = len(dataset.benign_files)
    num_malware = len(dataset.malware_files)
    benign_idx = range(num_benign)
    malware_idx = range(num_benign, num_benign + num_malware)
    benign_train_idx, benign_val_idx, benign_test_idx = train_val_test_split(
        benign_idx, val_size, test_size
    )
    malware_train_idx, malware_val_idx, malware_test_idx = train_val_test_split(
        malware_idx, val_size, test_size
    )
    train_idx = benign_train_idx + malware_train_idx
    val_idx = benign_val_idx + malware_val_idx
    test_idx = benign_test_idx + malware_test_idx
    return train_idx, val_idx, test_idx


def make_loaders(batch_size, val_size, test_size):
    dataset = MalwareDataset()
    train_idx, val_idx, test_idx = make_idx(dataset, val_size, test_size)
    train_dataset = Subset(dataset, indices=train_idx)
    val_dataset = Subset(dataset, indices=val_idx)
    test_dataset = Subset(dataset, indices=test_idx)
    train_loader = make_loader(train_dataset, batch_size)
    val_loader = make_loader(val_dataset, batch_size)
    test_loader = make_loader(test_dataset, batch_size)
    return train_loader, val_loader, test_loader


def make_loader(dataset, batch_size):
    return DataLoader(
        dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True
    )

## Models

In [10]:
class MalConv(nn.Module):
    def __init__(self, embed_dim, max_len, out_channels, window_size, dropout=0.5):
        super(MalConv, self).__init__()
        self.embed = nn.Embedding(257, embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.conv = nn.Conv1d(
            in_channels=embed_dim,
            out_channels=out_channels * 2,
            kernel_size=window_size,
            stride=window_size,
        )
        self.fc = nn.Linear(out_channels, 1)

    def forward(self, x):
        embedding = self.dropout(self.embed(x))
        conv_in = embedding.permute(0, 2, 1)
        conv_out = self.conv(conv_in)
        glu_out = F.glu(conv_out, dim=1)
        values, _ = glu_out.max(dim=-1)
        output = self.fc(values).squeeze(1)
        return output
    
class Conv_RNN_Custom(nn.Module):
    def __init__(
        self,
        embed_dim,
        out_channels,
        window_size,
        module,
        hidden_size,
        num_layers,
        bidirectional,
        residual,
        dropout=0.5,
    ):
        super(Conv_RNN_Custom, self).__init__()
        assert module.__name__ in {
            "RNN",
            "GRU",
            "LSTM",
        }, "`module` must be a `torch.nn` recurrent layer"
        self.residual = residual
        self.embed = nn.Embedding(257, embed_dim)
        self.conv = nn.Conv1d(
            in_channels=embed_dim,
            out_channels=out_channels,
            kernel_size=window_size,
            stride=window_size,
        )
        self.rnn = module(
            input_size=out_channels,
            hidden_size=hidden_size,
            num_layers=num_layers,
            bidirectional=bidirectional,
        )
        self.dropout = nn.Dropout(dropout)
        rnn_out_size = (int(bidirectional) + 1) * hidden_size
        if residual:
            self.fc = nn.Linear(out_channels + rnn_out_size, 1)
        else:
            self.fc = nn.Linear(rnn_out_size, 1)

    def forward(self, x):
        embedding = self.dropout(self.embed(x))
        conv_in = embedding.permute(0, 2, 1)
        conv_out = self.conv(conv_in)
        if self.residual:
            values, _ = conv_out.max(dim=-1)
        conv_out = conv_out.permute(2, 0, 1)
        rnn_out, _ = self.rnn(conv_out)
        fc_in = rnn_out[-1]
        if self.residual:
            fc_in = torch.cat((fc_in, values), dim=-1)
        output = self.fc(fc_in).squeeze(1)
        return output

## Utility Functions

In [11]:
def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

    
def count_params(model, trainable_only=True):
    if trainable_only:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    return sum(p.numel() for p in model.parameters())


def plot_confusion_matrix(model, test_loader, save_title, device, normalize="all"):
    y_true, y_pred = predict(model, test_loader, device)
    conf_mat = confusion_matrix(y_true, y_pred, normalize=normalize)
    axis_labels = ("Benign", "Malware")
    df = pd.DataFrame(conf_mat, index=axis_labels, columns=axis_labels)
    plot = sns.heatmap(df, annot=True, cmap="Blues")
    plot.figure.savefig(os.path.join("imgs", f"{save_title}_conf_mat.png"), dpi=300)
    plt.close(plot.figure)


def plot_roc_curve(models, test_loader, save_title, device):
    fig, ax = plt.subplots()
    ax.grid(linestyle="--")
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    if isinstance(models, dict):
        for label, model in models.items():
            fpr, tpr, auc_score = _rates_auc(model, test_loader, device)
            ax.plot(fpr, tpr, label=f"{label} ({auc_score:.2f})")
    else:
        fpr, tpr, auc_score = _rates_auc(models, test_loader, device)
        ax.plot(fpr, tpr, label=f"{save_title} ({auc_score:.2f})")
    ax.plot([0, 1], [0, 1], linestyle="--", label="Chance (0.5)")
    ax.legend(loc="best")
    fig.savefig(os.path.join("imgs", f"{save_title}_roc.png"), dpi=300)
    plt.close(fig)


def _rates_auc(model, test_loader, device):
    y_true, y_pred = predict(model, test_loader, device, apply_sigmoid=True)
    fpr, tpr, _ = metrics.roc_curve(y_true, y_pred)
    auc_score = auc(fpr, tpr)
    return fpr, tpr, auc_score


@torch.no_grad()
def predict(model, data_loader, device, apply_sigmoid=False, to_numpy=True):
    model.eval()
    y_true = []
    y_pred = []
    for inputs, labels in tqdm(data_loader, leave=False):
        inputs = inputs.to(device)
        outputs = model(inputs)
        y_true.append(labels)
        y_pred.append(outputs)
    y_true = torch.cat(y_true).to(int)
    if apply_sigmoid:
        y_pred = sigmoid(torch.cat(y_pred))
    else:
        y_pred = (torch.cat(y_pred) > 0).to(int)
    if to_numpy:
        y_true = y_true.cpu().numpy()
        y_pred = y_pred.cpu().numpy()
    assert y_true.shape == y_pred.shape
    model.train()
    return y_true, y_pred


def get_accuracy(model, data_loader, device):
    y_true, y_pred = predict(model, data_loader, device, to_numpy=False)
    return 100 * (y_true == y_pred).to(float).mean().item()


def plot_train_history(train_loss_history, val_loss_history, save_title):
    fig, ax = plt.subplots()
    time_ = range(len(train_loss_history))
    ax.set_xlabel("Epochs")
    ax.set_ylabel("BCE Loss")
    ax.grid(linestyle="--")
    ax.plot(time_, train_loss_history, color="blue", label="train loss")
    ax.plot(time_, val_loss_history, color="red", label="val loss")
    ax.legend(loc="best")
    fig.savefig(os.path.join("figures", f"{save_title}_train_history.png"), dpi=300)
    plt.close(fig)


def train(
    model,
    train_loader,
    val_loader,
    device,
    save_title,
    lr=0.001,
    patience=3,
    num_epochs=5,
    verbose=True,
):
    train_loss_history = []
    val_loss_history = []
    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    monitor = EarlyStopMonitor(patience)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, factor=0.5, patience=patience
    )
    for epoch in range(1, num_epochs + 1):
        model.train()
        train_loss = run_epoch(model, train_loader, device, criterion, optimizer)
        train_loss_history.append(train_loss)
        model.eval()
        with torch.no_grad():
            val_loss = run_epoch(model, val_loader, device, criterion)
        val_loss_history.append(val_loss)
        if verbose:
            tqdm.write(
                f"Epoch [{epoch}/{num_epochs}], "
                f"Train Loss: {train_loss:.4f}, "
                f"Val Loss: {val_loss:.4f}"
            )
        scheduler.step(val_loss)
        if monitor.step(val_loss):
            break
        if len(val_loss_history) == 1 or val_loss < val_loss_history[-2]:
            torch.save(
                model.state_dict(), os.path.join("checkpoints", f"{save_title}.pt"),
            )
    plot_train_history(train_loss_history, val_loss_history, save_title)


def run_epoch(model, data_loader, device, criterion, optimizer=None):
    total_loss = 0
    for inputs, labels in tqdm(data_loader, leave=False):
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        if optimizer:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        total_loss += loss.item()
    return total_loss / len(data_loader)


class EarlyStopMonitor:
    def __init__(self, patience, mode="min"):
        assert mode in {"min", "max"}, "`mode` must be one of 'min' or 'max'"
        self.log = []
        self.mode = mode
        self.count = 0
        self.patience = patience

    def step(self, metric):
        if not self.log:
            self.log.append(metric)
            return False
        flag = metric > self.log[-1]
        if flag == (self.mode == "min"):
            self.count += 1
        else:
            self.count = 0
        self.log.append(metric)
        return self.count > self.patience

## Training

In [12]:
batch_size = 64
test_size = val_size = 0.2

train_loader, val_loader, test_loader = make_loaders(batch_size, val_size, test_size)

In [None]:
malconv = MalConv(8, 4096, 128, 32).to(device)
train(malconv, train_loader, val_loader, device, "malconv")
torch.save(malconv.state_dict(), os.path.join("weights", "malconv.pt"))

In [None]:
gru_bi = Conv_RNN_Custom(8, 128, 32, torch.nn.GRU, 256, 1, True, False).to(device)
train(gru_bi, train_loader, val_loader, device, "gru_bi")
torch.save(gru_bi.state_dict(), os.path.join("weights", "gru_bi.pt"))

## Testing

In [6]:
malconv = MalConv(8, 4096, 128, 32).to(device)
PATH = os.path.join("weights", "malconv.pt")
malconv.load_state_dict(torch.load(PATH,  map_location=torch.device('cpu')))
malconv.eval()
plot_confusion_matrix(malconv, test_loader, "malconv", device)
plot_roc_curve(malconv, test_loader, "malconv", device)
print(f"Training accuracy : {get_accuracy(malconv, train_loader, device)}")
print(f"Validation accuracy : {get_accuracy(malconv, val_loader, device)}")
print(f"Testing accuracy : {get_accuracy(malconv, test_loader, device)}")

  0%|          | 0/114 [00:00<?, ?it/s]

  0%|          | 0/114 [00:00<?, ?it/s]

  0%|          | 0/365 [00:00<?, ?it/s]

Training accuracy : 98.00051488886982


  0%|          | 0/92 [00:00<?, ?it/s]

Validation accuracy : 97.99210571477605


  0%|          | 0/114 [00:00<?, ?it/s]

Testing accuracy : 97.8448867536033


In [7]:
gru_bi = Conv_RNN_Custom(8, 128, 32, torch.nn.GRU, 256, 1, True, False).to(device)
PATH = os.path.join("weights", "gru_bi.pt")
gru_bi.load_state_dict(torch.load(PATH,  map_location=torch.device('cpu')))
gru_bi.eval()
plot_confusion_matrix(gru_bi, test_loader, "gru_bi", device)
plot_roc_curve(gru_bi, test_loader, "gru_bi", device)
print(f"Training accuracy : {get_accuracy(gru_bi, train_loader, device)}")
print(f"Validation accuracy : {get_accuracy(gru_bi, val_loader, device)}")
print(f"Testing accuracy : {get_accuracy(gru_bi, test_loader, device)}")

  0%|          | 0/114 [00:00<?, ?it/s]

  0%|          | 0/114 [00:00<?, ?it/s]

  0%|          | 0/365 [00:00<?, ?it/s]

Training accuracy : 98.30515747017935


  0%|          | 0/92 [00:00<?, ?it/s]

Validation accuracy : 98.16372061094903


  0%|          | 0/114 [00:00<?, ?it/s]

Testing accuracy : 98.38023335621139
