In [1]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from rff.layers import GaussianEncoding #pip install random-fourier-features-pytorch
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, recall_score, confusion_matrix
from sklearn.preprocessing import StandardScaler, LabelEncoder
import os
import optuna
from optuna.trial import TrialState

In [2]:
# Run regardless if you do or do not have GPU so all tensors are moved to right location later on
if torch.cuda.is_available():
    device_in_use = torch.device("cuda")
    print("GPU is available and being used")
else:
    device_in_use = torch.device("cpu")
    print("GPU is not available, using CPU instead")

GPU is available and being used


# Preprocessing Data section
1. Create three seperate csv (datasets) from the Obfuscated-MalMen2022_edited.csv (edited the Category column removing the hashes with the Malware type) 
2. Save datasets into the /datasets directory 

(Note) Skip this if you already have the train,test, and val datasets

In [5]:
df = pd.read_csv('../../datasets/CIC_Mal/Obfuscated-MalMem2022_edited.csv')
df.shape[0]

58596

In [6]:
for x in df:
    print(df[x].value_counts())
    print("\n")
# Note Features such as:
# pslist.nprocs64bit, handles.nport, 
# psxview.not_in_eprocess_pool, modules.nmodules, 
# svcscan.interactive_process_services, 
# callbacks.nanonymous, callbacks.ngeneric
# Do not seem to be significant or helpful to the 
# overall model 


Category
Benign        29298
Spyware       10020
Ransomware     9791
Trojan         9487
Name: count, dtype: int64


pslist.nproc
41     10012
40      9226
42      7822
44      5777
43      5616
       ...  
106        1
122        1
132        1
161        1
96         1
Name: count, Length: 114, dtype: int64


pslist.nppid
12    16559
16    12242
15     8653
17     7898
13     4404
18     2827
14     2452
19     1294
20      537
8       446
11      334
21      229
9       203
22      184
10      133
23       55
24       29
25       27
37        9
39        8
38        8
26        8
40        6
62        5
27        5
28        4
60        4
36        3
52        3
54        3
66        2
49        2
55        2
56        2
48        2
72        1
61        1
57        1
34        1
42        1
50        1
35        1
53        1
43        1
33        1
44        1
63        1
51        1
41        1
Name: count, dtype: int64


pslist.avg_threads
10.000000    619
10.162162    502
10.1

In [7]:
# remove any features that are not balanced or could cause overfitting
remove_columns = ['pslist.nprocs64bit', 'handles.nport', 
                  'psxview.not_in_eprocess_pool', 'modules.nmodules',
                  'svcscan.interactive_process_services', 'callbacks.nanonymous',
                  'callbacks.ngeneric']

df = df.drop(remove_columns, axis=1)

In [8]:
for x in df:
    print(df[x].value_counts())
    print("\n")

Category
Benign        29298
Spyware       10020
Ransomware     9791
Trojan         9487
Name: count, dtype: int64


pslist.nproc
41     10012
40      9226
42      7822
44      5777
43      5616
       ...  
106        1
122        1
132        1
161        1
96         1
Name: count, Length: 114, dtype: int64


pslist.nppid
12    16559
16    12242
15     8653
17     7898
13     4404
18     2827
14     2452
19     1294
20      537
8       446
11      334
21      229
9       203
22      184
10      133
23       55
24       29
25       27
37        9
39        8
38        8
26        8
40        6
62        5
27        5
28        4
60        4
36        3
52        3
54        3
66        2
49        2
55        2
56        2
48        2
72        1
61        1
57        1
34        1
42        1
50        1
35        1
53        1
43        1
33        1
44        1
63        1
51        1
41        1
Name: count, dtype: int64


pslist.avg_threads
10.000000    619
10.162162    502
10.1

In [9]:
#label encode all non-numerical features
le = LabelEncoder()
df['Category'] = le.fit_transform(df['Category'])

df['Class'] = le.fit_transform(df['Class'])

#Split based of the number of samples in the dataset: total(58596) train(37501)->64% test(11720)->20% val(9375)->16%
# df_train, df_temp = train_test_split(df, train_size=37501, random_state=42)
# df_val, df_test = train_test_split(df_temp, train_size=3375, random_state=42)

df_train, df_temp = train_test_split(df, train_size=0.64, random_state=42)
df_val, df_test = train_test_split(df_temp, train_size=0.16, random_state=42)

print(f"df_train.shape[0]:{df_train.shape[0]}")
print(f"df_test.shape[0]:{df_test.shape[0]}")
print(f"df_val.shape[0]:{df_val.shape[0]}")

assert(df_train.shape[0] == 37501)
assert(df_val.shape[0] == 3375)
assert(df_test.shape[0] == 17720)

df_train.shape[0]:37501
df_test.shape[0]:17720
df_val.shape[0]:3375


In [10]:
for x in df_train:
    print(df[x].value_counts())
    print("\n")

Category
0    29298
2    10020
1     9791
3     9487
Name: count, dtype: int64


pslist.nproc
41     10012
40      9226
42      7822
44      5777
43      5616
       ...  
106        1
122        1
132        1
161        1
96         1
Name: count, Length: 114, dtype: int64


pslist.nppid
12    16559
16    12242
15     8653
17     7898
13     4404
18     2827
14     2452
19     1294
20      537
8       446
11      334
21      229
9       203
22      184
10      133
23       55
24       29
25       27
37        9
39        8
38        8
26        8
40        6
62        5
27        5
28        4
60        4
36        3
52        3
54        3
66        2
49        2
55        2
56        2
48        2
72        1
61        1
57        1
34        1
42        1
50        1
35        1
53        1
43        1
33        1
44        1
63        1
51        1
41        1
Name: count, dtype: int64


pslist.avg_threads
10.000000    619
10.162162    502
10.135135    479
10.189189    433
10.108

In [11]:
df_train.head()

Unnamed: 0,Category,pslist.nproc,pslist.nppid,pslist.avg_threads,pslist.avg_handlers,dlllist.ndlls,dlllist.avg_dlls_per_proc,handles.nhandles,handles.avg_handles_per_proc,handles.nfile,...,psxview.not_in_session_false_avg,psxview.not_in_deskthrd_false_avg,svcscan.nservices,svcscan.kernel_drivers,svcscan.fs_drivers,svcscan.process_services,svcscan.shared_process_services,svcscan.nactive,callbacks.ncallbacks,Class
21373,0,40,12,13.399319,300.768187,2001,50.025,12030,300.768187,1084,...,0.095238,0.190476,395,222,26,27,118,124,88,0
56756,1,37,15,10.054054,214.513513,1445,39.054054,7937,214.513513,623,...,0.125,0.225,389,221,26,24,116,118,86,1
14125,0,44,12,12.568182,286.165209,2153,48.931818,12591,286.165209,1126,...,0.045455,0.136364,395,222,26,27,118,125,88,0
14216,0,42,12,13.117469,292.445936,2080,49.52381,12282,292.445936,1108,...,0.047619,0.142857,395,222,26,27,118,125,88,0
43273,2,40,16,9.875,208.8,1557,38.925,8352,208.8,642,...,0.116279,0.209302,389,221,26,24,116,122,86,1


In [12]:
df_val.head()

Unnamed: 0,Category,pslist.nproc,pslist.nppid,pslist.avg_threads,pslist.avg_handlers,dlllist.ndlls,dlllist.avg_dlls_per_proc,handles.nhandles,handles.avg_handles_per_proc,handles.nfile,...,psxview.not_in_session_false_avg,psxview.not_in_deskthrd_false_avg,svcscan.nservices,svcscan.kernel_drivers,svcscan.fs_drivers,svcscan.process_services,svcscan.shared_process_services,svcscan.nactive,callbacks.ncallbacks,Class
38938,2,37,15,10.189189,215.135135,1445,39.054054,7960,215.135135,627,...,0.054054,0.162162,389,221,26,24,116,119,87,1
16476,0,41,12,12.861858,294.145823,2029,49.487805,12059,294.145823,1128,...,0.04878,0.146341,395,222,26,27,118,122,88,0
43771,2,42,16,10.404762,209.5,1639,39.02381,8799,209.5,650,...,0.090909,0.181818,389,221,26,24,116,123,88,1
42519,2,40,16,11.125,220.625,1627,40.675,8825,220.625,652,...,0.073171,0.170732,389,221,26,24,116,122,86,1
12573,0,41,12,13.599462,305.483935,2143,51.772761,12651,308.572267,1089,...,0.048294,0.154848,395,222,26,27,118,123,88,0


In [13]:
df_test.head()

Unnamed: 0,Category,pslist.nproc,pslist.nppid,pslist.avg_threads,pslist.avg_handlers,dlllist.ndlls,dlllist.avg_dlls_per_proc,handles.nhandles,handles.avg_handles_per_proc,handles.nfile,...,psxview.not_in_session_false_avg,psxview.not_in_deskthrd_false_avg,svcscan.nservices,svcscan.kernel_drivers,svcscan.fs_drivers,svcscan.process_services,svcscan.shared_process_services,svcscan.nactive,callbacks.ncallbacks,Class
9170,0,42,12,13.323471,302.820157,2178,51.876021,12718,302.820157,1140,...,0.08872,0.179848,395,222,26,27,118,123,88,0
10045,0,41,14,12.849976,289.129317,2050,50.0,11854,289.129317,1049,...,0.113636,0.181818,395,222,26,27,118,123,88,0
16189,0,41,13,13.567065,313.099997,2130,51.954264,12837,313.099997,1158,...,0.04878,0.146341,395,222,26,27,118,124,88,0
8687,0,42,12,12.857143,296.083982,2140,50.952381,12436,303.329933,1077,...,0.047619,0.166667,395,222,26,27,118,125,88,0
34809,1,38,15,10.026316,213.210526,1480,38.947368,8103,219.0,847,...,0.052632,0.184211,389,221,26,24,116,119,87,1


In [14]:
df_train.to_csv('../../datasets/CIC_Mal/Obfuscated-MalMem2022_train.csv')
df_test.to_csv('../../datasets/CIC_Mal/Obfuscated-MalMem2022_test.csv')
df_val.to_csv('../../datasets/CIC_Mal/Obfuscated-MalMem2022_val.csv')

# Load and Process Data
1. Standardize or perform quantile transformations to numerica/continuous features.
2. Wrap with Dataset and Dataloader.


(Note) For the next cell, you don't need to re-allocate the datasets since this is done from the Data Preprocessing section, however you may skip the Data Preprocessing Section from now on if you have the train,test, and val csv. In the case where you have the datasets and you lost your progress, then you may uncomment and run the following cell to get the train, test, and val datasets allocated to the notebook

In [15]:
df_train = pd.read_csv('../../datasets/CIC_Mal/Obfuscated-MalMem2022_train.csv')
df_test = pd.read_csv('../../datasets/CIC_Mal/Obfuscated-MalMem2022_test.csv')
df_val = pd.read_csv('../../datasets/CIC_Mal/Obfuscated-MalMem2022_val.csv') #READ FROM RIGHT SPOT

df_train

Unnamed: 0.1,Unnamed: 0,Category,pslist.nproc,pslist.nppid,pslist.avg_threads,pslist.avg_handlers,dlllist.ndlls,dlllist.avg_dlls_per_proc,handles.nhandles,handles.avg_handles_per_proc,...,psxview.not_in_session_false_avg,psxview.not_in_deskthrd_false_avg,svcscan.nservices,svcscan.kernel_drivers,svcscan.fs_drivers,svcscan.process_services,svcscan.shared_process_services,svcscan.nactive,callbacks.ncallbacks,Class
0,21373,0,40,12,13.399319,300.768187,2001,50.025000,12030,300.768187,...,0.095238,0.190476,395,222,26,27,118,124,88,0
1,56756,1,37,15,10.054054,214.513513,1445,39.054054,7937,214.513513,...,0.125000,0.225000,389,221,26,24,116,118,86,1
2,14125,0,44,12,12.568182,286.165209,2153,48.931818,12591,286.165209,...,0.045455,0.136364,395,222,26,27,118,125,88,0
3,14216,0,42,12,13.117469,292.445936,2080,49.523810,12282,292.445936,...,0.047619,0.142857,395,222,26,27,118,125,88,0
4,43273,2,40,16,9.875000,208.800000,1557,38.925000,8352,208.800000,...,0.116279,0.209302,389,221,26,24,116,122,86,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
37496,54343,3,37,15,10.162162,215.135135,1444,39.027027,7960,215.135135,...,0.054054,0.162162,389,221,26,24,116,118,86,1
37497,38158,2,38,15,9.815789,214.210526,1498,39.421053,8140,214.210526,...,0.076923,0.179487,389,221,26,24,116,119,86,1
37498,860,0,41,12,12.975610,294.926829,2029,49.487805,12092,294.926829,...,0.048780,0.146341,395,222,26,27,118,123,88,0
37499,15795,0,51,20,11.485621,254.858328,2295,44.918269,13020,254.858328,...,0.140110,0.210166,392,222,26,24,118,126,87,0


In [14]:
print(df_train.columns)
df_train.info()

Index(['Unnamed: 0', 'Category', 'pslist.nproc', 'pslist.nppid',
       'pslist.avg_threads', 'pslist.avg_handlers', 'dlllist.ndlls',
       'dlllist.avg_dlls_per_proc', 'handles.nhandles',
       'handles.avg_handles_per_proc', 'handles.nfile', 'handles.nevent',
       'handles.ndesktop', 'handles.nkey', 'handles.nthread',
       'handles.ndirectory', 'handles.nsemaphore', 'handles.ntimer',
       'handles.nsection', 'handles.nmutant', 'ldrmodules.not_in_load',
       'ldrmodules.not_in_init', 'ldrmodules.not_in_mem',
       'ldrmodules.not_in_load_avg', 'ldrmodules.not_in_init_avg',
       'ldrmodules.not_in_mem_avg', 'malfind.ninjections',
       'malfind.commitCharge', 'malfind.protection',
       'malfind.uniqueInjections', 'psxview.not_in_pslist',
       'psxview.not_in_ethread_pool', 'psxview.not_in_pspcid_list',
       'psxview.not_in_csrss_handles', 'psxview.not_in_session',
       'psxview.not_in_deskthrd', 'psxview.not_in_pslist_false_avg',
       'psxview.not_in_eprocess_po

In [4]:
cat_columns = ['Category']
cont_columns = list(df_train.drop(['Category','Class'],axis=1))
target = ['Class']

yourlist = cat_columns + cont_columns + target
yourlist.sort()
oglist = list(df_train.columns)
oglist.sort()
assert(yourlist == oglist), "You may of spelled feature name wrong or you forgot to put on of them in the list"

In [16]:
#Get class counts and store in a list below

for x in cat_columns:
    print(max(len(df_train[x].value_counts()), len(df_val[x].value_counts()), len(df_test[x].value_counts())))

4


In [17]:
cat_feat = [4]

In [18]:
#Getting the number of classes in your classification target
target_classes = [max(len(df_train[target].value_counts()), len(df_val[target].value_counts()),len(df_test[target].value_counts()))]

In [19]:
# Create a StandardScaler and fit it to the cont features
scaler = StandardScaler()
scaler.fit(df_train[cont_columns])

# Transform the training, test, and validation datasets
df_train[cont_columns] = scaler.transform(df_train[cont_columns])
df_test[cont_columns] = scaler.transform(df_test[cont_columns])
df_val[cont_columns] = scaler.transform(df_val[cont_columns])

In [20]:
class SingleTaskDataset(Dataset):
    def __init__(self, df : pd.DataFrame, cat_columns, num_columns,task1_column):
        self.n = df.shape[0]
        
        self.task1_labels = df[task1_column].astype(np.float32).values

        self.cate = df[cat_columns].astype(np.int64).values
        self.num = df[num_columns].astype(np.float32).values


    def __len__(self):
        return self.n
    
    def __getitem__(self, idx):
        # Retrieve features and labels from the dataframe using column names
        cat_features = self.cate[idx]
        num_features = self.num[idx]
        labels_task1 = self.task1_labels[idx]

        return cat_features, num_features, labels_task1

#Wrapping in Dataset
train_dataset = SingleTaskDataset(df_train, cat_columns, cont_columns, target[0])
val_dataset = SingleTaskDataset(df_val, cat_columns, cont_columns, target[0])
test_dataset = SingleTaskDataset(df_test, cat_columns, cont_columns, target[0])

#This is a hyperparameter that is not tuned. Maybe mess with what makes sense here
#Also try looking to see what other papers have done
batch_size = 256

# Wrapping with DataLoader for easy batch extraction
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

# MODEL AND HELPERS

1. All you should have to do is interact with Classifier()

In [21]:
# each task loss is scaled by its own learnable parameter, then regularization is applied 
class UncertaintyLoss(nn.Module):
    def __init__(self, num_tasks):
        super(UncertaintyLoss, self).__init__()
        self.num_tasks = num_tasks

        self.loss_fns = [nn.CrossEntropyLoss() for x in range(num_tasks)] 

    def forward(self, predictions, labels_task1):

        #task 1
        target = labels_task1.long()
        prediction = predictions[0]
        loss_fn = self.loss_fns[0]

        task_loss = loss_fn(prediction, target)
        
        return task_loss
    
#All layers of the model
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()

        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        assert(self.head_dim * heads == embed_size), "Embed size needs to be div by heads"
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys =nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads*self.head_dim, embed_size)


    def forward(self, values, keys, query):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3) #(batch_size, head_dim, #query_embeddings, #key_embeddings)

        # Calculate simplified attention scores
        avg_attention = attention.mean(dim=0)  # Average across batches
        # print("batch average", avg_attention.shape)
        avg_attention = avg_attention.mean(dim=0).squeeze(dim=0)
        # print("head average", avg_attention.shape)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim) #(batch_size, n_features, embed_size)
        out = self.fc_out(out)

        return out, avg_attention
    
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion, pre_norm_on):
        super(TransformerBlock, self).__init__()

        self.pre_norm_on = pre_norm_on
        if self.pre_norm_on:
            self.pre_norm = nn.LayerNorm(embed_size)
        self.attention = MultiHeadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(nn.Linear(embed_size, forward_expansion*embed_size),
                                          nn.ReLU(),
                                          nn.Linear(forward_expansion*embed_size, embed_size)
                                          )
        self.dropout = nn.Dropout(dropout)

    def forward(self,value,key,query):
        if self.pre_norm_on:
            query = self.pre_norm(query)
            key = self.pre_norm(key)
            value = self.pre_norm(value)
            
        attention, avg_attention = self.attention(value, key, query)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out, avg_attention
    
class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout, pre_norm_on):
        super(DecoderBlock, self).__init__()

        self.attention = MultiHeadAttention(embed_size, heads)
        self.norm = nn.LayerNorm(embed_size)
        self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion, pre_norm_on)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key):
        out, avg_attention = self.transformer_block(value, key, x)

        return out, avg_attention

class Decoder(nn.Module):
    def __init__(self,
                 embed_size,
                 num_layers,
                 heads,
                 forward_expansion,
                 decoder_dropout,
                 pre_norm_on
    ):
        super(Decoder, self).__init__()

        self.layers = nn.ModuleList(
                [
                    DecoderBlock(
                        embed_size,
                        heads,
                        dropout=decoder_dropout,
                        forward_expansion=forward_expansion,
                        pre_norm_on=pre_norm_on
                    )
                    for _ in range(num_layers)
                ]
            )
        self.avg_attention = None

    def forward(self, class_embed, context):
        for layer in self.layers:
            # x is the classification embedding (CLS Token)
            # context are the feature embeddings that will be used as key and value
            x, self.avg_attention = layer(class_embed, context, context)
  
        return x 

class Embeddings(nn.Module):
    def __init__(self, sigma, embed_size, input_size, embedding_dropout, n_cont, cat_feat, num_target_labels, rff_on):
        super(Embeddings, self).__init__()

        self.rff_on = rff_on

        if self.rff_on:
            self.rffs = nn.ModuleList([GaussianEncoding(sigma=sigma, input_size=input_size, encoded_size=embed_size//2) for _ in range(n_cont)])
            self.dropout = nn.Dropout(embedding_dropout)
            self.mlp_in = embed_size
        else:
            self.mlp_in = input_size

        self.cont_embeddings = nn.ModuleList([nn.Linear(in_features=self.mlp_in, out_features=embed_size) for _ in range(n_cont)])

        self.cat_embeddings = nn.ModuleList([nn.Embedding(num_classes, embed_size) for num_classes in cat_feat])

        # Classifcation Embeddings for each target label
        self.target_label_embeddings = nn.ModuleList([nn.Embedding(1, embed_size) for _ in range(num_target_labels)])


    def forward(self, cat_x, cont_x):
        x = cont_x.unsqueeze(2) #(batch_size, n_features) -> (batch_size, n_features, 1)
        rff_vectors = []
        if self.rff_on:
            for i, r in enumerate(self.rffs):
                input = x[:,i,:]
                out = r(input)
                rff_vectors.append(out)
        
            x = torch.stack(rff_vectors, dim=1)
        
        embeddings = []
        for i, e in enumerate(self.cont_embeddings):
            goin_in = x[:,i,:]
            goin_out = e(goin_in)
            embeddings.append(goin_out)

        #embedding cat features
        cat_x = cat_x.unsqueeze(2)
        for i, e in enumerate(self.cat_embeddings):

            goin_in = cat_x[:,i,:]
  
            goin_out = e(goin_in)
            goin_out=goin_out.squeeze(1)
            embeddings.append(goin_out)

        target_label_embeddings_ = []
        for e in self.target_label_embeddings:
            input = torch.tensor([0], device=x.device)
            temp = e(input)
            temp = temp.repeat(x.size(0), 1)
            tmep = temp.unsqueeze(1)
            target_label_embeddings_.append(temp)

        class_embeddings = torch.stack(target_label_embeddings_, dim=1)

        context = torch.stack(embeddings, dim=1)

        return class_embeddings, context

class classificationHead(nn.Module):
    def __init__(self, embed_size, dropout, mlp_scale_classification, num_target_classes):
        super(classificationHead, self).__init__()
        
        #flattening the embeddings out so each sample in batch is represented with a 460 dimensional vector
        self.input = embed_size
        self.lin1 = nn.Linear(self.input, mlp_scale_classification*self.input)
        self.drop = nn.Dropout(dropout)
        self.lin2 = nn.Linear(mlp_scale_classification*self.input, mlp_scale_classification*self.input)
        self.lin3 = nn.Linear(mlp_scale_classification*self.input, self.input)
        self.lin4 = nn.Linear(self.input, num_target_classes)
        self.relu = nn.ReLU()
        self.initialize_weights()

    def initialize_weights(self): #he_initialization.
        torch.nn.init.kaiming_normal_(self.lin1.weight, nonlinearity='relu')
        torch.nn.init.zeros_(self.lin1.bias)

        torch.nn.init.kaiming_normal_(self.lin3.weight, nonlinearity='relu')
        torch.nn.init.zeros_(self.lin3.bias)

    def forward(self, x):

        x= torch.reshape(x, (-1, self.input))

        x = self.lin1(x)
        x = self.relu(x)
        x = self.drop(x)
        x = self.lin2(x)
        x = self.relu(x)
        x = self.drop(x)
        x = self.lin3(x)
        x = self.relu(x)
        x = self.drop(x)
        x = self.lin4(x)
  
        return x

class Classifier(nn.Module):
    def __init__(self, 
                 rff_on = False,
                 sigma=4,
                 embed_size=20,
                 input_size=1,
                 embedding_dropout = 0,
                 n_cont = 0,
                 cat_feat:list = [],
                 num_layers=1,
                 heads=1,
                 forward_expansion=4, # Determines how wide the MLP is in the encoder. Its a scaling factor. 
                 decoder_dropout=0,
                 classification_dropout = 0,
                 pre_norm_on = False,
                 mlp_scale_classification = 4,
                 targets_classes : list=  [3,8]
                 ):
        super(Classifier, self).__init__()

        self.embeddings = Embeddings(rff_on=rff_on, sigma=sigma, embed_size=embed_size, input_size=input_size, 
                                     embedding_dropout=embedding_dropout,n_cont=n_cont, cat_feat=cat_feat, num_target_labels=len(targets_classes))
        self.decoder = Decoder(embed_size=embed_size, num_layers=num_layers, heads=heads, forward_expansion=forward_expansion, 
                               decoder_dropout=decoder_dropout, pre_norm_on=pre_norm_on)
        self.classifying_heads = nn.ModuleList([classificationHead(embed_size=embed_size, dropout=classification_dropout, 
                                                                   mlp_scale_classification=mlp_scale_classification, 
                                                                   num_target_classes=x) for x in targets_classes])
        
    def forward(self, cat_x, cont_x):
        class_embed, context = self.embeddings(cat_x, cont_x)

        x = self.decoder(class_embed, context)
        
        probability_dist_raw = []
        for i, e in enumerate(self.classifying_heads):
            input = x[:, i,:]
            output = e(input)
            probability_dist_raw.append(output)
        
        return probability_dist_raw

# Training and Testing Loops
def train(dataloader, model, loss_function, optimizer, device_in_use):
    model.train()

    total_loss = 0

    total_correct_1 = 0
    total_samples_1 = 0
    all_targets_1 = []
    all_predictions_1 = []

    total_correct_2 = 0
    total_samples_2 = 0
    all_targets_2 = []
    all_predictions_2 = []

    for (cat_x, cont_x,labels_task1) in dataloader:
        cat_x,cont_x,labels_task1 = cat_x.to(device_in_use),cont_x.to(device_in_use),labels_task1.to(device_in_use)


        task_predictions = model(cat_x, cont_x) #contains a list of the tensor outputs for each task

        loss = loss_function(task_predictions, labels_task1)
        total_loss += loss.item()

        #computing accuracy for first target
        y_pred_softmax_1 = torch.softmax(task_predictions[0], dim=1)
        _, y_pred_labels_1 = torch.max(y_pred_softmax_1, dim=1)
        total_correct_1 += (y_pred_labels_1 == labels_task1).sum().item()
        total_samples_1 += labels_task1.size(0)
        all_targets_1.extend(labels_task1.cpu().numpy())
        all_predictions_1.extend(y_pred_labels_1.cpu().numpy())


        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    avg_loss = total_loss/len(dataloader)
    accuracy_1 = total_correct_1 / total_samples_1
    # accuracy_2 = total_correct_2 / total_samples_2

    # # precision = precision_score(all_targets, all_predictions, average='weighted')
    # recall = recall_score(all_targets, all_predictions, average='weighted')
    # f1 = f1_score(all_targets, all_predictions, average='weighted')

    return avg_loss, accuracy_1

def test(dataloader, model, loss_function, device_in_use):
  model.eval()
  total_loss = 0
  
  total_correct_1 = 0
  total_samples_1 = 0
  all_targets_1 = []
  all_predictions_1 = []

  total_correct_2 = 0
  total_samples_2 = 0
  all_targets_2 = []
  all_predictions_2 = []

  with torch.no_grad():
    for (cat_x, cont_x,labels_task1) in dataloader:
        cat_x,cont_x,labels_task1 = cat_x.to(device_in_use),cont_x.to(device_in_use),labels_task1.to(device_in_use)


        task_predictions = model(cat_x, cont_x) #contains a list of the tensor outputs for each task

        loss = loss_function(task_predictions, labels_task1)
        total_loss += loss.item()

        #computing accuracy for first target
        y_pred_softmax_1 = torch.softmax(task_predictions[0], dim=1)
        _, y_pred_labels_1 = torch.max(y_pred_softmax_1, dim=1)
        total_correct_1 += (y_pred_labels_1 == labels_task1).sum().item()
        total_samples_1 += labels_task1.size(0)
        all_targets_1.extend(labels_task1.cpu().numpy())
        all_predictions_1.extend(y_pred_labels_1.cpu().numpy())


  avg = total_loss/len(dataloader)
  accuracy_1 = total_correct_1 / total_samples_1
  # accuracy_2 = total_correct_2 / total_samples_2
  # recall = recall_score(all_targets, all_predictions, average='weighted')
  f1_1 = f1_score(all_targets_1, all_predictions_1, average='weighted')
  # f1_2 = f1_score(all_targets_2, all_predictions_2, average="weighted")

  return avg, accuracy_1, all_predictions_1, all_targets_1, f1_1

def format_metric(value): # Used to format the metrics output
    return f"{value:.4f}"

# RUN EXPERIMENTS

1. Using Optuna to optimize CAT-Transformers hyperparameters for your dataset

In [22]:
# Define the early stopping mechanism
class EarlyStopping:
    def __init__(self, patience=5):
        self.patience = patience
        self.counter = 0
        self.best_metric = float('-inf')
        self.early_stop = False

    def __call__(self, metric):
        if metric > self.best_metric:
            self.best_metric = metric
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

# Function to log results to a text file
def log_to_file(filename, text):
    with open(filename, 'a') as f:
        f.write(text + '\n')

def objective(trial):
    trial_number = trial.number

    # Define hyperparameters to search over
    sigma = trial.suggest_categorical('sigma', [.001, 0.1, 1, 2, 3, 5, 10])
    num_layers = trial.suggest_int('num_layers', 1, 2)
    # Ensure that embed_size is divisible by num_layers
    embed_size = trial.suggest_categorical("embed_size", [50, 60, 70, 80, 90, 100, 120, 140, 160])
    heads = trial.suggest_categorical("heads", [1, 5, 10])
    forward_expansion = trial.suggest_int('forward_expansion', 1, 8)
    prenorm_on = trial.suggest_categorical('prenorm_on', [True, False])
    mlp_scale_classification = trial.suggest_int('mlp_scale_classification', 1, 8)
    embedding_dropout = trial.suggest_categorical('embedding_dropout', [0, .1, .2, .5])
    decoder_dropout = trial.suggest_categorical('decoder_dropout', [0,.1,.2,.5])
    classification_dropout = trial.suggest_categorical('class_drop', [0,.1,.2,.5])

    learning_rate = trial.suggest_categorical('learning_rate', [0.0001, 0.001, 0.01])

    num_epochs = 75

    # Create your model with the sampled hyperparameters
    model = Classifier(
        targets_classes=target_classes,
        rff_on=True, #LEAVING ON
        n_cont=len(cont_columns),
        cat_feat=cat_feat,
        sigma=sigma,
        embed_size=embed_size,
        num_layers=num_layers,
        heads=heads,
        forward_expansion=forward_expansion,
        pre_norm_on=prenorm_on,
        mlp_scale_classification=mlp_scale_classification,
        embedding_dropout=embedding_dropout,
        decoder_dropout=decoder_dropout,
        classification_dropout=classification_dropout
    ).to(device_in_use)

    # Define loss function and optimizer
    loss_function = UncertaintyLoss(1)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Initialize early stopping
    early_stopping = EarlyStopping(patience=3)  # Adjust patience as needed

    # Training loop with a large number of epochs
    for epoch in range(num_epochs):
        train_loss, train_accuracy = train(train_dataloader, model, loss_function, optimizer, device_in_use)
        
        # Validation loop
        val_loss, val_accuracy, _, _, _ = test(val_dataloader, model, loss_function, device_in_use)
        
        # Check if we should early stop based on validation accuracy
        if early_stopping(val_accuracy):
            break

    
    # Log the final test accuracy for this trial to a shared log file
    final_log = f"Trial {trial_number} completed. Validation Accuracy = {val_accuracy:.4f}"
    log_to_file('all_trials_log.txt', final_log)

    # Return the test accuracy as the objective to optimize
    return val_accuracy

In [23]:
# Set the number of optimization trials
num_trials = 50

# Create an Optuna study
study = optuna.create_study(direction='maximize')  # Maximize validation accuracy

# Start the optimization process
study.optimize(objective, n_trials=num_trials, show_progress_bar=True)

# Get the best hyperparameters and the validation accuracy at the point of early stopping
best_params = study.best_params
best_val_accuracy = study.best_value

print("Best Hyperparameters:", best_params)
print("Best Validation Accuracy (at Early Stopping):", best_val_accuracy)

[I 2023-11-16 17:05:40,145] A new study created in memory with name: no-name-2bb5876c-00a5-4830-8991-bc97aa0b7774


  0%|          | 0/50 [00:00<?, ?it/s]

[W 2023-11-16 17:05:44,134] Trial 0 failed with parameters: {'sigma': 10, 'num_layers': 2, 'embed_size': 60, 'heads': 5, 'forward_expansion': 8, 'prenorm_on': True, 'mlp_scale_classification': 8, 'embedding_dropout': 0.1, 'decoder_dropout': 0.1, 'class_drop': 0.1, 'learning_rate': 0.0001} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/home/aavila/anaconda3/envs/cyber-TFF/lib/python3.11/site-packages/optuna/study/_optimize.py", line 200, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "/tmp/ipykernel_9171/2738648641.py", line 70, in objective
    train_loss, train_accuracy = train(train_dataloader, model, loss_function, optimizer, device_in_use)
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_9171/2217704083.py", line 299, in train
    task_predictions = model(cat_x, cont_x) #contains a list of the tensor outputs for e

KeyboardInterrupt: 

In [24]:
best_params = best_params = {'sigma': 2, 
               'num_layers': 1, 
               'embed_size': 160, 
               'heads': 5, 
               'forward_expansion': 8, 
               'prenorm_on': False, 
               'mlp_scale_classification': 8, 
               'embedding_dropout': 0.1, 
               'decoder_dropout': 0, 
               'class_drop': 0.1, 
               'learning_rate': 0.0001}

In [25]:
#Testing against the test dataset

#defult parameters

# model = Classifier(n_features=num_features, 
#                   pre_norm_on=True, 
#                   rff_on=True, 
#                   forward_expansion=4, 
#                   mlp_scale_classification=2, 
#                   targets_classes=classes_per_target
#                   ).to(device_in_use)

model = Classifier(targets_classes=target_classes,
                    rff_on=True,
                    n_cont=len(cont_columns),
                    cat_feat=cat_feat, 
                   sigma=best_params['sigma'],
                   embed_size=best_params['embed_size'],
                   num_layers=best_params['num_layers'],
                   heads=best_params['heads'],
                   forward_expansion=best_params['forward_expansion'],
                   pre_norm_on=best_params['prenorm_on'],
                   mlp_scale_classification=best_params['mlp_scale_classification'],
                   embedding_dropout=best_params['embedding_dropout'],
                   decoder_dropout=best_params['decoder_dropout'],
                   classification_dropout=best_params['class_drop']
                   ).to(device_in_use) # Instantiate the model
loss_functions = UncertaintyLoss(1)
optimizer = torch.optim.Adam(params=model.parameters(), lr = best_params['learning_rate']) # Maybe try messing around with optimizers. try other torch optimizers with different configurations.
early_stopping = EarlyStopping(patience=3)
epochs = 75 #Set the number of epochs

train_losses = []
train_accuracies_1 = [] 
train_accuracies_2 = []
train_recalls = [] 
train_f1_scores = [] 
test_losses = []
test_accuracies_1 = []
test_accuracies_2 = []
test_recalls = []  
test_f1_scores = [] 
all_attention_scores = []

for t in range(epochs):
  train_loss, train_accuracy_1 = train(train_dataloader, model, loss_functions, optimizer, device_in_use=device_in_use)
  test_loss, test_accuracy_1, all_predictions_1, all_targets_1, f1_1 = test(test_dataloader, model, loss_functions, device_in_use=device_in_use)
  train_losses.append(train_loss)
  train_accuracies_1.append(train_accuracy_1)
  # train_accuracies_2.append(train_accuracy_2)
  # train_recalls.append(train_recall) 
  # train_f1_scores.append(train_f1)
  test_losses.append(test_loss)
  test_accuracies_1.append(test_accuracy_1)
  # test_accuracies_2.append(test_accuracy_2)
  # test_recalls.append(test_recall)
  # test_f1_scores.append(test_f1)
  # Formatting for easier reading
  epoch_str = f"Epoch [{t+1:2}/{epochs}]"
  train_metrics = f"Train: Loss {format_metric(train_loss)}, Accuracy {format_metric(train_accuracy_1)}"
  test_metrics = f"Test: Loss {format_metric(test_loss)}, Accuracy {format_metric(test_accuracy_1)}, F1 {format_metric(f1_1)}"
  print(f"{epoch_str:20} | {train_metrics:65} | {test_metrics}")

  if early_stopping(test_accuracy_1):
    break

# Save the model after pre-training
torch.save(model.state_dict(), 'final_model_trained.pth')

# Plotting the loss curves
plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.plot(range(1, epochs+1), train_losses, label='Train Loss')
plt.plot(range(1, epochs+1), [l for l in test_losses], label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Test Loss Curve')
plt.legend()

# Plotting the accuracy curves
plt.subplot(1, 2, 2)
plt.plot(range(1, epochs+1), train_accuracies_1, label='Train Accuracy')
plt.plot(range(1, epochs+1), test_accuracies_1, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Test Accuracy Curve')
plt.legend()

# Display confusion matrix for the first task (Traffic Type) on test data
conf_matrix_1 = confusion_matrix(all_targets_1, all_predictions_1)
print("Confusion Matrix for income")
print(conf_matrix_1)


Epoch [ 1/75]        | Train: Loss 0.1632, Accuracy 0.9169                               | Test: Loss 0.0064, Accuracy 0.9984, F1 0.9984
Epoch [ 2/75]        | Train: Loss 0.0012, Accuracy 0.9995                               | Test: Loss 0.0003, Accuracy 0.9998, F1 0.9998
Epoch [ 3/75]        | Train: Loss 0.0001, Accuracy 1.0000                               | Test: Loss 0.0000, Accuracy 1.0000, F1 1.0000


KeyboardInterrupt: 