# Mounting drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Download datasets

In [None]:
import gdown

# Downloading training set
url = '1-FSnNHp0ZWULfetoYGztn_43TMtxSQaM'
gdown.download(id=url, output="training_set.tar.gz")

# Downloading validation set
url = '1-JLM97nRRcxsU-e1KDvY7_OPQBoxx7H2'
gdown.download(id=url, output="validation_set.tar.gz")

In [None]:
# Creating dataset folders
!mkdir -p training validation

!tar -zxf '/content/training_set.tar.gz' -C /content/training
!rm '/content/training_set.tar.gz'

!tar -zxf '/content/validation_set.tar.gz' -C /content/validation
!rm '/content/validation_set.tar.gz'

In [None]:
import gdown

# Downloading SoReL-20M test set
url = '1-CNx_k1yEJ-RDC94nXUMlhtc1FQH7lzt'
gdown.download(id=url, output="test_set.tar.gz")

# Downloading VirusShare test set
url = '1Dbb9xNCvL1_HqM9H_Hvq34VvW73cu0n1'
gdown.download(id=url, output="VirusShareDataset.tar.gz")

In [None]:
# Creating test set folders
!mkdir -p test test_2

!tar -zxf '/content/test_set.tar.gz' -C /content/test
!rm '/content/test_set.tar.gz'

!tar -zxf '/content/VirusShareDataset.tar.gz' -C /content/test_2
!rm '/content/VirusShareDataset.tar.gz'

In [None]:
train_path = '/content/training'
val_path = '/content/validation'
test_path = '/content/test'
test_2_path = '/content/test_2'

# Install dependencies and libraries

In [2]:
# Install dependencies
!pip install lief==0.12.0
!pip install numpy
!pip install deap
!pip install pandas
!pip install matplotlib
!pip install tqdm
!pip install python-magic
# Install ML-Pentest Lib
!pip install ml-pentest

Collecting lief==0.12.0
  Downloading lief-0.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.2/4.2 MB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lief
Successfully installed lief-0.12.0
Collecting deap
  Downloading deap-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (135 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m135.4/135.4 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: deap
Successfully installed deap-1.4.1
Collecting python-magic
  Downloading python_magic-0.4.27-py2.py3-none-any.whl (13 kB)
Installing collected packages: python-magic
Successfully installed python-magic-0.4.27
Collecting ml-pentest
  Downloading ml_pentest-0.0.1.tar.gz (57.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.9/57.9 MB[0m [31m11.8 MB/s[0m et

In [3]:
from ml_pentest.attacks.blackbox.genetic_attack.GAMMA.gamma_section_injection import GammaSectionInjection
from ml_pentest.attacks.blackbox.genetic_attack.GAMMA.attack_utils import create_section_population_from_folder
from ml_pentest.models.dl_models.raw_bytes_based.malconv2 import MalConv
from ml_pentest.models.wrappers.malconv2_wrapper import MalConvWrapper

import os
import torch
import lief
import numpy as np
import random

# Models definition

In [4]:
"""
Classifying Sequences of Extreme Length with Constant Memory Applied to Malware Detection
Edward Raff, William Fleshman, Richard Zak, Hyrum Anderson and Bobby Filar and Mark Mclean
https://arxiv.org/abs/2012.09390
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

def drop_zeros_hook(module, grad_input, grad_out):
    """
    This function is used to replace gradients that are all zeros with None
    In pyTorch None will not get back-propogated
    So we use this as a approximation to saprse BP to avoid redundant and useless work
    """
    grads = []
    with torch.no_grad():
        for g in grad_input:
            if torch.nonzero(g).shape[0] == 0:
                grads.append(g.to_sparse())
            else:
                grads.append(g)

    return tuple(grads)


class CatMod(torch.nn.Module):
    def __init__(self):
        super(CatMod, self).__init__()

    def forward(self, x):
        return torch.cat(x, dim=2)


class LowMemConvBase(nn.Module):

    def __init__(self, chunk_size=65536, overlap=512, min_chunk_size=1024):
        """
        chunk_size: how many bytes at a time to process. Increasing may improve compute efficent, but use more memory. Total memory use will be a function of chunk_size, and not of the length of the input sequence L

        overlap: how many bytes of overlap to use between chunks

        """
        super(LowMemConvBase, self).__init__()
        self.chunk_size = chunk_size
        self.overlap = overlap
        self.min_chunk_size = min_chunk_size

        #Used for pooling over time in a more efficent way
        self.pooling = nn.AdaptiveMaxPool1d(1)
        self.cat = CatMod()
        self.cat.register_backward_hook(drop_zeros_hook)
        self.receptive_field = None

        #Used to force checkpoint code to behave correctly due to poor design https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/11
        self.dummy_tensor = torch.ones(1, dtype=torch.float32, requires_grad=True)

    def processRange(self, x, **kwargs):
        """
        This method does the work to convert an LongTensor input x of shape (B, L) , where B is the batch size and L is the length of the input. The output of this functoin should be a tensor of (B, C, L), where C is the number of channels, and L is again the input length (though its OK if it got a little shorter due to convs without padding or something).

        """
        pass

    def determinRF(self):
        """
        This function evaluates the receptive field & stride of our sub-network.
        """

        if self.receptive_field is not None:
            return self.receptive_field, self.stride, self.out_channels

        if not hasattr(self, "device_ids"):
            #We are training with just one device. Lets find out where we should move the data
            cur_device = next(self.embd.parameters()).device
        else:
            cur_device = "cpu"

        #Lets do a simple binary search to figure out how large our RF is.
        #It can't be larger than our chunk size! So use that as upper bound
        min_rf = 1
        max_rf = self.chunk_size

        with torch.no_grad():

            tmp = torch.zeros((1,max_rf)).long().to(cur_device)

            while True:
                test_size = (min_rf+max_rf)//2
                is_valid = True
                try:
                    self.processRange(tmp[:,0:test_size])
                except:
                    is_valid = False

                if is_valid:
                    max_rf = test_size
                else:
                    min_rf = test_size+1

                if max_rf == min_rf:
                    self.receptive_field = min_rf
                    out_shape = self.processRange(tmp).shape
                    self.stride = self.chunk_size//out_shape[2]
                    self.out_channels = out_shape[1]
                    break


        return self.receptive_field, self.stride, self.out_channels


    def pool_group(self, *args):
        x = self.cat(args)
        x = self.pooling(x)
        return x

    def seq2fix(self, x, pr_args={}):
        """
        Takes in an input LongTensor of (B, L) that will be converted to a fixed length representation (B, C),
        where C is the number of channels provided by the base_network given at construction.
        """

        receptive_window, stride, out_channels = self.determinRF()

        if x.shape[1] < receptive_window: #This is a tiny input! Pad it out
            x = F.pad(x, (0, receptive_window-x.shape[1]), value=0) # 0 is the pad value
        batch_size = x.shape[0]
        length = x.shape[1]

        #Let's go through the input data without gradients first, and find the positions that "win"
        #the max-pooling. Most of the gradients will be zero, and we don't want to waste valuable
        #memory and time computing them.
        #Once we know the winners, we will go back and compute the forward activations on JUST
        #the subset of positions that won!
        winner_values = np.zeros((batch_size, out_channels))-1.0
        winner_indices = np.zeros((batch_size, out_channels), dtype=np.int64)

        if not hasattr(self, "device_ids"):
            #We are training with just one device. Lets find out where we should move the data
            cur_device = next(self.embd.parameters()).device
        else:
            cur_device = None

        step = self.chunk_size #- self.overlap
        start = 0
        end = start+step

        with torch.no_grad():
            while start < end and (end-start) >= max(self.min_chunk_size, receptive_window):
                x_sub = x[:,start:end]
                if cur_device is not None:
                    x_sub = x_sub.to(cur_device)
                activs = self.processRange(x_sub.long(), **pr_args)
                activ_win, activ_indx = F.max_pool1d(activs, kernel_size=activs.shape[2], return_indices=True)
                #We want to remove only last dimension, but if batch size is 1, np.squeeze
                #will screw us up and remove first dim too.
                #activ_win = np.squeeze(activ_win.cpu().numpy())
                #activ_indx = np.squeeze(activ_indx.cpu().numpy())
                activ_win = activ_win.cpu().numpy()[:,:,0]
                activ_indx = activ_indx.cpu().numpy()[:,:,0]
                selected = winner_values < activ_win
                winner_indices[selected] = activ_indx[selected]*stride + start
                winner_values[selected]  = activ_win[selected]
                start = end
                end = min(start+step, length)

        # Now we know every index that won, we need to compute values and with gradients!

        # Find unique winners for every batch
        final_indices = [np.unique(winner_indices[b,:]) for b in range(batch_size)]

        # Collect inputs that won for each batch
        chunk_list = [[x[b:b+1,max(i-receptive_window,0):min(i+receptive_window,length)] for i in final_indices[b]] for b in range(batch_size)]
        # Convert to a torch tensor of the bytes
        chunk_list = [torch.cat(c, dim=1)[0,:] for c in chunk_list]

        # Pad out shorter sequences to the longest one
        x_selected = torch.nn.utils.rnn.pad_sequence(chunk_list, batch_first=True)

        # Shape is not (B, L). Compute it.
        if cur_device is not None:
            x_selected = x_selected.to(cur_device)
        x_selected = self.processRange(x_selected.long(), **pr_args)
        x_selected = self.pooling(x_selected)
        x_selected = x_selected.view(x_selected.size(0), -1)

        return x_selected


## MalConv

In [5]:
class MalConv(LowMemConvBase):

    def __init__(self, out_size=2, channels=128, window_size=512, stride=512, embd_size=8, log_stride=None):
        super(MalConv, self).__init__()
        self.embd = nn.Embedding(257, embd_size, padding_idx=0)
        if not log_stride is None:
            stride = 2**log_stride

        self.conv_1 = nn.Conv1d(embd_size, channels, window_size, stride=stride, bias=True)
        self.conv_2 = nn.Conv1d(embd_size, channels, window_size, stride=stride, bias=True)


        self.fc_1 = nn.Linear(channels, channels)
        self.fc_2 = nn.Linear(channels, out_size)


    def processRange(self, x):
        x = self.embd(x)
        x = torch.transpose(x,-1,-2)

        cnn_value = self.conv_1(x)
        gating_weight = torch.sigmoid(self.conv_2(x))

        x = cnn_value * gating_weight

        return x

    def forward(self, x):
        post_conv = x = self.seq2fix(x)

        penult = x = F.relu(self.fc_1(x))
        x = self.fc_2(x)

        return torch.sigmoid(x)


class MalConvML(LowMemConvBase):

    def __init__(self, out_size=2, channels=128, window_size=512, stride=512, layers=1, embd_size=8, log_stride=None):
        super(MalConvML, self).__init__()
        self.embd = nn.Embedding(257, embd_size, padding_idx=0)
        if not log_stride is None:
            stride = 2**log_stride

        self.convs = nn.ModuleList([nn.Conv1d(embd_size, channels*2, window_size, stride=stride, bias=True)] + [nn.Conv1d(channels, channels*2, window_size, stride=1, bias=True) for i in range(layers-1)])
        #one-by-one cons to perform information sharing
        self.convs_1 = nn.ModuleList([nn.Conv1d(channels, channels, 1, bias=True) for i in range(layers)])


        self.fc_1 = nn.Linear(channels, channels)
        self.fc_2 = nn.Linear(channels, out_size)


    def processRange(self, x):
        x = self.embd(x)
        #x = torch.transpose(x,-1,-2)
        x = x.permute(0,2,1).contiguous()

        for conv_glu, conv_share in zip(self.convs, self.convs_1):
            x = F.leaky_relu(conv_share(F.glu(conv_glu(x.contiguous()), dim=1)))

        return x

    def forward(self, x):
        post_conv = x = self.seq2fix(x)

        penult = x = F.relu(self.fc_1(x))
        x = self.fc_2(x)

        return x, penult, post_conv

## MalConv2

In [6]:
"""
Classifying Sequences of Extreme Length with Constant Memory Applied to Malware Detection
Edward Raff, William Fleshman, Richard Zak, Hyrum Anderson and Bobby Filar and Mark Mclean
https://arxiv.org/abs/2012.09390

Taken from https://github.com/NeuromorphicComputationResearchProgram/MalConv2
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

class MalConv2(LowMemConvBase):

    def __init__(self, out_size=2, channels=128, window_size=512, stride=512, embd_size=8, log_stride=None):
        super(MalConv2, self).__init__()
        self.embd = nn.Embedding(257, embd_size, padding_idx=0)
        if not log_stride is None:
            stride = 2**log_stride

        self.conv_1 = nn.Conv1d(embd_size, channels, window_size, stride=stride, bias=True)
        self.conv_2 = nn.Conv1d(embd_size, channels, window_size, stride=stride, bias=True)

        self.fc_1 = nn.Linear(channels, channels)
        self.fc_2 = nn.Linear(channels, out_size)


    def processRange(self, x):
        x = self.embd(x)
        x = torch.transpose(x,-1,-2)

        cnn_value = self.conv_1(x)
        gating_weight = torch.sigmoid(self.conv_2(x))

        x = cnn_value * gating_weight

        return x

    def forward(self, x):
        post_conv = x = self.seq2fix(x)

        penult = x = F.relu(self.fc_1(x))
        x = self.fc_2(x)
        return torch.sigmoid(x)

## MalConvGCG

In [11]:
"""
Classifying Sequences of Extreme Length with Constant Memory Applied to Malware Detection
Edward Raff, William Fleshman, Richard Zak, Hyrum Anderson and Bobby Filar and Mark Mclean
https://arxiv.org/abs/2012.09390
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint

class MalConvGCG(LowMemConvBase):

    def __init__(self, out_size=2, channels=128, window_size=512, stride=512, layers=1, embd_size=8, log_stride=None, low_mem=True):
        super(MalConvGCG, self).__init__()
        self.low_mem = low_mem
        self.embd = nn.Embedding(257, embd_size, padding_idx=0)
        if not log_stride is None:
            stride = 2**log_stride

        self.context_net = MalConvML(out_size=channels, channels=channels, window_size=window_size, stride=stride, layers=layers, embd_size=embd_size)
        self.convs = nn.ModuleList([nn.Conv1d(embd_size, channels*2, window_size, stride=stride, bias=True)] + [nn.Conv1d(channels, channels*2, window_size, stride=1, bias=True) for i in range(layers-1)])

        #These two objs are not used. They were originally present before the F.glu function existed, and then were accidently left in when we switched over. So the state file provided has unusued states in it. They are left in this definition so that there are no issues loading the file that MalConv was trained on.
        #If you are going to train from scratch, you can delete these two lines.
        #self.convs_1 = nn.ModuleList([nn.Conv1d(channels*2, channels, 1, bias=True) for i in range(layers)])
        #self.convs_atn = nn.ModuleList([nn.Conv1d(channels*2, channels, 1, bias=True) for i in range(layers)])

        self.linear_atn = nn.ModuleList([nn.Linear(channels, channels) for i in range(layers)])

        #one-by-one cons to perform information sharing
        self.convs_share = nn.ModuleList([nn.Conv1d(channels, channels, 1, bias=True) for i in range(layers)])


        self.fc_1 = nn.Linear(channels, channels)
        self.fc_2 = nn.Linear(channels, out_size)


    #Over-write the determinRF call to use the base context_net to detemrin RF. We should have the same totla RF, and this will simplify logic significantly.
    def determinRF(self):
        return self.context_net.determinRF()

    def processRange(self, x, gct=None):
        if gct is None:
            raise Exception("No Global Context Given")

        x = self.embd(x)
        #x = torch.transpose(x,-1,-2)
        x = x.permute(0,2,1)

        for conv_glu, linear_cntx, conv_share in zip(self.convs, self.linear_atn, self.convs_share):
            x = F.glu(conv_glu(x), dim=1)
            x = F.leaky_relu(conv_share(x))
            x_len = x.shape[2]
            B = x.shape[0]
            C = x.shape[1]

            sqrt_dim = np.sqrt(x.shape[1])
            #we are going to need a version of GCT with a time dimension, which we will adapt as needed to the right length
            ctnx = torch.tanh(linear_cntx(gct))

            #Size is (B, C), but we need (B, C, 1) to use as a 1d conv filter
            ctnx = torch.unsqueeze(ctnx, dim=2)
            #roll the batches into the channels
            x_tmp = x.view(1,B*C,-1)
            #Now we can apply a conv with B groups, so that each batch gets its own context applied only to what was needed
            x_tmp = F.conv1d(x_tmp, ctnx, groups=B)
            #x_tmp will have a shape of (1, B, L), now we just need to re-order the data back to (B, 1, L)
            x_gates = x_tmp.view(B, 1, -1)

            #Now we effectively apply σ(x_t^T tanh(W c))
            gates = torch.sigmoid( x_gates )
            x = x * gates

        return x

    def forward(self, x):

        if self.low_mem:
            global_context = checkpoint.CheckpointFunction.apply(self.context_net.seq2fix,1, x)
        else:
            global_context = self.context_net.seq2fix(x)

        post_conv = x = self.seq2fix(x, pr_args={'gct':global_context})

        penult = x = F.leaky_relu(self.fc_1( x ))
        x = self.fc_2(x)

        return torch.sigmoid(x)

## AvastConv

In [8]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

def vec_bin_array(arr, m=8):
    """
    Arguments:
    arr: Numpy array of positive integers
    m: Number of bits of each integer to retain

    Returns a copy of arr with every element replaced with a bit vector.
    Bits encoded as int8's.
    """
    to_str_func = np.vectorize(lambda x: np.binary_repr(x).zfill(m))
    strs = to_str_func(arr)
    ret = np.zeros(list(arr.shape) + [m], dtype=np.int8)
    for bit_ix in range(0, m):
        fetch_bit_func = np.vectorize(lambda x: x[bit_ix] == '1')
        ret[...,bit_ix] = fetch_bit_func(strs).astype(np.int8)

    return (ret*2-1).astype(np.float32)/16

class AvastConv(LowMemConvBase):

    def __init__(self, out_size=2, channels=48, window_size=32, stride=4, embd_size=8):
        super(AvastConv, self).__init__()
        self.embd = nn.Embedding(257, embd_size, padding_idx=0)
        for i in range(1, 257):
            self.embd.weight.data[i,:] = torch.tensor(vec_bin_array(np.asarray([i])))
        for param in self.embd.parameters():
             param.requires_grad = False

        self.conv_1 = nn.Conv1d(8, channels, window_size, stride=stride, bias=True)
        self.conv_2 = nn.Conv1d(channels, channels*2, window_size, stride=stride, bias=True)
        self.pool = nn.MaxPool1d(4)
        self.conv_3 = nn.Conv1d(channels*2, channels*3, window_size//2, stride=stride*2, bias=True)
        self.conv_4 = nn.Conv1d(channels*3, channels*4, window_size//2, stride=stride*2, bias=True)

        self.fc_1 = nn.Linear(channels*4, channels*4)
        self.fc_2 = nn.Linear(channels*4, channels*3)
        self.fc_3 = nn.Linear(channels*3, channels*2)
        self.fc_4 = nn.Linear(channels*2, out_size)


    def processRange(self, x):
        # Fixed embedding
        with torch.no_grad():
            x = self.embd(x)
            x = torch.transpose(x,-1,-2)

        x = F.relu(self.conv_1(x))
        x = F.relu(self.conv_2(x))
        x = self.pool(x)
        x = F.relu(self.conv_3(x))
        x = F.relu(self.conv_4(x))

        return x

    def forward(self, x):
        post_conv = x = self.seq2fix(x)

        x = F.selu(self.fc_1(x))
        x = F.selu(self.fc_2(x))
        penult = x = F.selu(self.fc_3(x))
        x = self.fc_4(x)

        return torch.sigmoid(x)

# Training Procedure

## Custom dataloader definition

In [10]:
import os
import zlib
import torch
import random

import numpy as np
import pandas as pd


def _seed_all(seed):
    os.environ['WANDB_DISABLED'] = 'true'
    os.environ['WANDB_MODE'] = 'dryrun'
    os.environ['PYTHONHASHSEED'] = str(seed)

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    np.random.seed(seed)
    random.seed(seed)

    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

def decompress_file(input_file_path, max_len=None):
    # read the file
    with open(input_file_path, 'rb') as f:
        compressed_data = f.read()

    # decompress the file with zlib
    decompressed_data = zlib.decompress(compressed_data)
    if max_len:
      return decompressed_data[:max_len]
    else:
      return decompressed_data

class BinaryDataset(torch.utils.data.Dataset):
    """
    Loader for binary files stored on the system.

    Note:
       - this dataloader assume that the data on the system are organized in such a way that 'benign' is into the file path for benign file only.
       - if sorel_20m=True, the malware are assumed to be zipped files.
    """
    def __init__(self, path_list , max_len=2 ** 20, sorel_20m = False, transform = None):

        #Tuple (file_path, label, file_size)
        self.all_files = []
        self.max_len = max_len
        self.sorel20m = sorel_20m
        self.transform = transform
        for path in path_list:
            if 'benign' in path:
                self.all_files.append(  (path, 0, os.path.getsize(path))  )
            else:
                self.all_files.append(  (path, 1, os.path.getsize(path))  )

    def __len__(self):
        return len(self.all_files)

    def __getitem__(self, index):
        to_load, y, _ = self.all_files[index]
        if not self.sorel20m or (self.sorel20m and y == 0):
          #OK, you are not a gziped file. Just read in raw bytes from disk.
          with open(to_load, 'rb') as f:
              x = f.read(self.max_len)
              #Need to use frombuffer b/c its a byte array, otherwise np.asarray will get wonked on trying to convert to ints
              #So decode as uint8 (1 byte per value), and then convert
              x = np.frombuffer(x, dtype=np.uint8).astype(np.int16)+1 #index 0 will be special padding index
        else:
              #You are a gziped file. Need to decompress you first.
              x = decompress_file(to_load, self.max_len)
              x = np.frombuffer(x, dtype=np.uint8).astype(np.int16)+1

        if self.transform:
          x = self.transform(x)

        x = torch.tensor(x)
        return x, torch.tensor([y])

#We want to handle true variable length
#Data loader needs equal length. So use special function to padd all the data in a single batch to be of equal length
#to the longest item in the batch
def pad_collate_func(batch):
    """
    This should be used as the collate_fn=pad_collate_func for a pytorch DataLoader object in order to pad out files in a batch to the length of the longest item in the batch.
    """
    vecs = [x[0] for x in batch]
    labels = [x[1] for x in batch]

    x = torch.nn.utils.rnn.pad_sequence(vecs, batch_first=True)
    #stack will give us (B, 1), so index [:,0] to get to just (B)
    y = torch.stack(labels)[:,0]

    return x, y

## Train and predict procedure

In [None]:
from tqdm import tqdm
import torch.optim as optim
@torch.no_grad()
def predict(model, data_loader, device, criterion,apply_sigmoid=False, to_numpy=True, multiclass = True):
    """
      Predict the target values of the given inputs using a model.

      Parameters:
      - model (torch.nn.Module): The model to use for prediction.
      - data_loader (torch.utils.data.DataLoader): Data loader for the input data.
      - device (torch.device): The device to use for computation (CPU or GPU).
      - criterion (torch.nn.Module): The loss function to use for evaluation.
      - apply_sigmoid (bool, optional): If True, applies sigmoid activation to the output of the model.
                                        Default is False.
      - to_numpy (bool, optional): If True, converts the true and predicted values to numpy arrays.
                                  Default is True.
      - multiclass (bool, optional): If True, the problem is multiclass classification.
                                    If False, the problem is binary classification.
                                    Default is True.

      Returns:
      - loss (float): The mean loss over the data.
      - y_true (torch.Tensor or np.ndarray): The true target values.
      - y_pred (torch.Tensor or np.ndarray): The predicted target values.

    """
    model.eval()
    y_true = []
    y_pred = []
    loss=0

    for inputs, labels in tqdm(data_loader, leave=False):
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)

        if not multiclass:
          outputs = torch.round(outputs)
          outputs = torch.squeeze(outputs)
          labels = labels.float()


        loss += criterion(outputs, labels)

        if multiclass:
          _, preds = torch.max(outputs, 1)
          y_pred.append(preds)
        else:
          y_pred.append(outputs)
        y_true.append(labels)

    y_true = torch.cat(y_true).to(int)
    if apply_sigmoid:
        y_pred = torch.sigmoid(torch.cat(y_pred))
    else:
        y_pred = (torch.cat(y_pred) > 0).to(int)
        y_pred = y_pred.reshape(-1).to(device)
    if to_numpy:
        y_true = y_true.cpu().numpy()
        y_pred = y_pred.cpu().numpy()
    assert y_true.shape == y_pred.shape
    model.train()

    return loss/(len(y_true)) , y_true , y_pred


def get_accuracy(model, data_loader, device, criterion,multiclass = True):
    """
      Calculates the accuracy of the model on a given dataset.

      Parameters:
      model (nn.Module): The model to evaluate accuracy for.
      data_loader (DataLoader): The DataLoader for the dataset to evaluate on.
      device (torch.device): The device to run the evaluation on.
      criterion (nn.Module): The loss function to use for evaluation.
      multiclass (bool, optional): Indicates if the model is for multiclass classification or binary. Default is True.

      Returns:
      float, float: Tuple of average loss and accuracy as percentages (e.g. return value of 50.0, 75.0 indicates an average loss of 50.0 and accuracy of 75%).
    """
    loss, y_true, y_pred = predict(model, data_loader, device,criterion, to_numpy=False, multiclass = multiclass, apply_sigmoid = False)
    y_true = y_true.to(device)
    y_pred = y_pred.to(device)
    return loss, 100 * (y_true == y_pred).to(float).mean().item(), y_true, y_pred



def train(model, train_loader, val_loader, device, criterion, save_title, checkpoint_path, patience=3, num_epochs=50, verbose=True):
    """
      Trains a PyTorch model using Adam optimization and early stopping.

      Parameters:
      - model (torch.nn.Module): a PyTorch model to be trained.
      - train_loader (torch.utils.data.DataLoader): a DataLoader containing the training data.
      - val_loader (torch.utils.data.DataLoader): a DataLoader containing the validation data.
      - device (torch.device): a PyTorch device object, either "cpu" or "cuda".
      - criterion (callable): a loss function to be used for training.
      - save_title (str): a string title to be used for saving the trained model.
      - patience (int, optional): number of epochs to wait before early stopping. Defaults to 3.
      - num_epochs (int, optional): total number of epochs to run the training for. Defaults to 50.
      - verbose (bool, optional): whether to print the training and validation loss for each epoch. Defaults to True.

      Returns:
      None
    """
    train_loss_history = []
    val_loss_history = []
    optimizer = optim.AdamW(model.parameters())
    monitor = EarlyStopMonitor(patience)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, factor=0.5, patience=patience
    )
    for epoch in range(1, num_epochs + 1):
        print("\nEPOCH {}".format(epoch))
        model.train()
        train_loss = run_epoch(model, train_loader, device, criterion, optimizer)
        train_loss_history.append(train_loss)
        model.eval()
        with torch.no_grad():
            val_loss = run_epoch(model, val_loader, device, criterion)
        val_loss_history.append(val_loss)
        if verbose:
            tqdm.write(
                f"Epoch [{epoch}/{num_epochs}], "
                f"Train Loss: {train_loss:.4f}, "
                f"Val Loss: {val_loss:.4f}"
            )
        scheduler.step(val_loss)
        if monitor.step(val_loss):
            break
        if len(val_loss_history) == 1 or val_loss < val_loss_history[-2]:
            torch.save(
                model.state_dict(), os.path.join(checkpoint_path, f"{save_title}.pt"),
            )


def run_epoch(model, data_loader, device, criterion, optimizer=None):
    """
      Run one epoch of the model on a given dataset.

      Parameters:
      - model (nn.Module): The model to be trained or evaluated.
      - data_loader (torch.utils.data.DataLoader): The data loader for the given dataset.
      - device (torch.device): The device to run the computation on.
      - criterion (function): The loss function to be used.
      - optimizer (torch.optim.Optimizer, optional): The optimizer to use for computing gradients during training.
        If not provided, the function will run the model in evaluation mode. (default: None)

      Returns:
      - float: The average loss per sample computed over the dataset.
    """
    total_loss = 0
    for inputs, labels in tqdm(data_loader, leave=False):
        inputs = inputs.to(device)
        labels = labels.to(device)

        labels = labels.unsqueeze(1)
        labels = labels.float()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        if optimizer:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        total_loss += loss.item()
    return total_loss / len(data_loader)



class EarlyStopMonitor:
    """
    Class for early stopping based on the performance metric value.

    Args:
    - patience (int): Number of consecutive epochs with no improvement before early stopping.
    - mode (str, optional): Mode of operation, either 'min' or 'max'. Default is 'min'.

    Attributes:
    - log (list): Log of performance metrics over epochs.
    - mode (str): Mode of operation, either 'min' or 'max'.
    - count (int): Counter for consecutive epochs with no improvement.
    - patience (int): Number of consecutive epochs with no improvement before early stopping.
    """
    def __init__(self, patience, mode="min"):
        # Check if mode is either "min" or "max"
        assert mode in {"min", "max"}, "`mode` must be one of 'min' or 'max'"
        self.log = []
        self.mode = mode
        self.count = 0
        self.patience = patience

    def step(self, metric):
        """
        Method for updating the log and checking for early stopping.

        Args:
        - metric (float): The performance metric for the current epoch.

        Returns:
        - stop (bool): True if early stopping should occur, False otherwise.
        """
        if not self.log:
            self.log.append(metric)
            return False
        # Check if metric is better than previous value
        flag = metric > self.log[-1] if self.mode == "max" else metric < self.log[-1]
        if flag:
            self.count += 1
        else:
            self.count = 0
        self.log.append(metric)
        return self.count > self.patience

## Prepare data

In [None]:
import os
for folder in [train_path, val_path, test_path]:
  print(os.path.basename(folder))
  print("Malware:\t",len(os.listdir(os.path.join(folder,'malware'))))
  print("Benign:\t",len(os.listdir(os.path.join(folder,'benign'))),"\n")

In [None]:
MAX_LEN = 2 ** 20 # 1 MB of max lenght
BATCH_SIZE = 64

# Load the data
x_train = []
for folder in os.listdir(train_path):
  for filename in os.listdir(os.path.join(train_path, folder)):
    if os.path.isfile(os.path.join(train_path, folder ,filename)):
      x_train.append(os.path.join(train_path, folder ,filename))


x_val = []
for folder in os.listdir(val_path):
  for filename in os.listdir(os.path.join(val_path, folder)):
    if os.path.isfile(os.path.join(val_path, folder ,filename)):
      x_val.append(os.path.join(val_path, folder ,filename))

x_test = []
for folder in os.listdir(test_path):
  for filename in os.listdir(os.path.join(test_path, folder)):
    if os.path.isfile(os.path.join(test_path, folder, filename)):
      x_test.append(os.path.join(test_path, folder, filename))


train_dataset = BinaryDataset(path_list=x_train, max_len= MAX_LEN, sorel_20m=False)
val_dataset = BinaryDataset(path_list= x_val, max_len= MAX_LEN, sorel_20m=False)
test_dataset = BinaryDataset(path_list= x_test, max_len= MAX_LEN, sorel_20m=False)

print("Train dataset length: ", len(train_dataset))
print("Val dataset length: ", len(val_dataset))
print("Test dataset length: ", len(test_dataset))

In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=pad_collate_func, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, collate_fn=pad_collate_func,  shuffle = False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, collate_fn=pad_collate_func, shuffle = False)

Create the checkpoint folder to save the best model

In [None]:
!mkdir -p checkpoint

Devide definition

In [None]:
## Device definition
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Create model

In [None]:
model = MalConv2(out_size=1, channels=128, window_size=500, stride=500, embd_size=8, log_stride=None)
criterion = nn.BCELoss()

## Start training

In [None]:
train(model=model,
      criterion=criterion,
      val_loader=val_loader,
      train_loader=train_loader,
      device=device,
      patience=3,
      num_epochs=20,
      verbose=True,
      save_title=f'model_name',
      checkpoint_path='/content/checkpoint')

# Testing procedure

In [None]:
# Loading best model
deep_model = MalConv2(out_size=1, channels=128, window_size=500, stride=500, embd_size=8, log_stride=None)
deep_model.load_state_dict(torch.load("/content/checkpoint/model_name"), strict=False)
deep_model.eval()
deep_model.to(device)

In [None]:
# Function to generate confusion matrix
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, recall_score, precision_score, f1_score

def calculate_metrics(y_true, y_pred, verbose=True):
  y_true = y_true.cpu()
  y_pred = y_pred.cpu()

  cm= confusion_matrix(y_true, y_pred)
  precision = precision_score(y_true, y_pred)
  recall = recall_score(y_true, y_pred)
  f_score = f1_score(y_true, y_pred)

  if verbose:
    print(f"Precision: {precision:.5f}")
    print(f"Recall: {recall:.5f}")
    print(f"F-score: {f_score:.5f}")

  return cm, precision, recall, f_score

## Performance evaluation on SoReL-20M test set

In [None]:
loss_test_sorel, accuracy_test_sorel, y_true_sorel, y_pred_sorel = get_accuracy(deep_model, test_loader, device, criterion, multiclass=False)
print(f'Accuracy on test set: {accuracy_test_sorel:.4f}')

In [None]:
cm_sorel, precision_sorel, recall_sorel, f_score_sorel = calculate_metrics(y_true_sorel, y_pred_sorel)
ConfusionMatrixDisplay(cm_sorel).plot()

## Performance evaluation on VirusShare test set

In [None]:
x_test_2 = []
for folder in os.listdir(test_2_path):
  for filename in os.listdir(os.path.join(test_2_path, folder)):
    if os.path.isfile(os.path.join(test_2_path, folder, filename)):
      x_test_2.append(os.path.join(test_2_path, folder, filename))

test_2_dataset = BinaryDataset(path_list= x_test_2, max_len= MAX_LEN, sorel_20m=False)

print("Test dataset length: ", len(test_2_dataset))
for folder in [test_2_path]:
  print(os.path.basename(folder))
  print("Malware:\t",len(os.listdir(os.path.join(folder,'malware'))))
  print("Benign:\t",len(os.listdir(os.path.join(folder,'benign'))),"\n")

In [None]:
test_2_loader = DataLoader(test_2_dataset, batch_size=BATCH_SIZE, collate_fn=pad_collate_func, shuffle = False)

In [None]:
loss_test_vsd, accuracy_test_vsd, y_true_vsd, y_pred_vsd = get_accuracy(deep_model, test_2_loader, device, criterion, multiclass=False)
print(f'Accuracy on test set: {accuracy_test_vsd:.4f}')

In [None]:
cm_vsd, precision_vsd, recall_vsd, f_score_vsd = calculate_metrics(y_true_vsd, y_pred_vsd)
ConfusionMatrixDisplay(cm_vsd).plot()