In [85]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

#### **Imports**

In [86]:
import warnings
import time
from enum import Enum
from typing import Optional, Tuple, List, Union
from sklearn.metrics import classification_report, confusion_matrix, precision_score, recall_score, f1_score, balanced_accuracy_score


import tensorflow as tf

try:
    from tensorflow._api.v2.v2 import keras
except ImportError:
    from tensorflow import keras

from keras import Input, Model
import keras.layers as layers
from keras.layers import Dense, Conv1D, Layer, MultiHeadAttention, Dropout, LayerNormalization, Embedding, Concatenate, Reshape, Lambda, Flatten, GlobalAveragePooling1D

print("Imports done")

Imports done


#### Utilities

In [87]:
import base64
import hashlib
import json
import pickle
from typing import Tuple

import pandas as pd


def get_identifier(d:dict):
    raw_json = json.dumps(d, sort_keys=True, indent=False)
    hash = hashlib.sha1(raw_json.encode("utf8")).digest()
    x = base64.b64encode(hash)
    x = x.decode("ASCII")
    x = x.replace("+", "0").replace("=", "0").replace("/", "0")
    return x

def save_feather_plus_metadata(save_path:str, df:pd.DataFrame, metadata:object):
    metadata_path = save_path + ".metadata.pickle"
    df.to_feather(save_path)
    with open(metadata_path, "wb") as w:
        pickle.dump(metadata, w)

def save_pickle(save_path:str, obj:dict):
    with open(save_path, "wb") as w:
        pickle.dump(obj, w)

def load_pickle(save_path:str):
    with open(save_path, "rb") as r:
        return pickle.load(r, fix_imports=True)

def load_feather_plus_metadata(load_path:str) -> Tuple[pd.DataFrame, object]:
    metadata_path = load_path + ".metadata.pickle"
    with open(metadata_path, "rb") as r:
        metadata = pickle.load(r, fix_imports=True)
    data = pd.read_feather(load_path)
    return data, metadata

#### **Dataset Specifications: List some metadata, including column names for ease of use**

In [88]:
class DatasetSpecification:
    def  __init__(self, include_fields:List[str], categorical_fields:List[str], class_column:str, benign_label:str, test_column:Optional[str]=None):
        """
        Defines the format of specific NIDS dataset
        :param include_fields: The fields to include as part of classification
        :param categorical_fields: Fields that should be treated as categorical
        :param class_column: The column name that includes the class of the flow, eg. DDoS or Benign
        :param benign_label: The label of benign traffic, eg. Benign or 0
        :param test_column: The column indicating if this row is a member of the test or training dataset
        """
        self.include_fields:List[str] = include_fields
        self.categorical_fields:List[str] = categorical_fields
        self.class_column = class_column
        self.benign_label = benign_label
        self.test_column:Optional[str] = test_column

class NamedDatasetSpecifications:
    """
    Example specifications of some common datasets
    """
    
    cse_cic_ids_2018 = DatasetSpecification(
            include_fields=['NUM_PKTS_UP_TO_128_BYTES', 'SRC_TO_DST_SECOND_BYTES', 'OUT_PKTS', 'OUT_BYTES', 'NUM_PKTS_128_TO_256_BYTES', 'DST_TO_SRC_AVG_THROUGHPUT', 'DURATION_IN', 'L4_SRC_PORT', 'ICMP_TYPE', 'PROTOCOL', 'SERVER_TCP_FLAGS', 'IN_PKTS', 'NUM_PKTS_512_TO_1024_BYTES', 'CLIENT_TCP_FLAGS', 'TCP_WIN_MAX_IN', 'NUM_PKTS_256_TO_512_BYTES', 'SHORTEST_FLOW_PKT', 'MIN_IP_PKT_LEN', 'LONGEST_FLOW_PKT', 'L4_DST_PORT', 'MIN_TTL', 'DST_TO_SRC_SECOND_BYTES', 'NUM_PKTS_1024_TO_1514_BYTES', 'DURATION_OUT', 'FLOW_DURATION_MILLISECONDS', 'TCP_FLAGS', 'MAX_TTL', 'SRC_TO_DST_AVG_THROUGHPUT', 'ICMP_IPV4_TYPE', 'MAX_IP_PKT_LEN', 'RETRANSMITTED_OUT_BYTES', 'IN_BYTES', 'RETRANSMITTED_IN_BYTES', 'TCP_WIN_MAX_OUT', 'L7_PROTO', 'RETRANSMITTED_OUT_PKTS', 'RETRANSMITTED_IN_PKTS'],
            categorical_fields=['CLIENT_TCP_FLAGS', 'L4_SRC_PORT', 'TCP_FLAGS', 'ICMP_IPV4_TYPE', 'ICMP_TYPE', 'PROTOCOL', 'SERVER_TCP_FLAGS', 'L4_DST_PORT', 'L7_PROTO', 'Attack'],
            class_column="Attack",
            benign_label="Benign"
        )

#### **Some helper classes**

these are useful in, including but not limited to: dataset column datatypes, main class parameters

In [89]:
class CategoricalFormat(Enum):
    """
    The format of variables expected by the model as input
    """
    Integers = 0,
    """
    If categorical values should be dictionary encoded as integers
    """
    OneHot = 1
    """
    If categorical values should be one-hot encoded
    """

class EvaluationDatasetSampling(Enum):
    """
    How to choose evaluation samples from the raw dataset
    """
    LastRows = 0
    """
    Take the last rows in the dataset to form the evaluation dataset
    """
    RandomRows  = 1
    """
    Randomly sample rows to make up the evaluation dataset
    """
    FilterColumn = 2
    """
    Define a column that contains a flag indicating if this row is part of the evaluation set
    """

class FlowTransformerParameters:
    """
    Allows the configuration of overall parameters of the FlowTransformer
    :param window_size: The number of flows to use in each window
    :param mlp_layer_sizes: The number of nodes in each layer of the outer classification MLP of FlowTransformer
    :param mlp_dropout: The amount of dropout to be applied between the layers of the outer classification MLP
    """
    def __init__(self, window_size:int, mlp_layer_sizes:List[int], mlp_dropout:float=0.1):
        self.window_size:int = window_size
        self.mlp_layer_sizes = mlp_layer_sizes
        self.mlp_dropout = mlp_dropout

        # Is the order of flows important within any individual window
        self._train_ensure_flows_are_ordered_within_windows = True

        # Should windows be sampled sequentially during training
        self._train_draw_sequential_windows = False

#### **Model Input Enum**

In [90]:
class ModelInputSpecification:
    def __init__(self, feature_names:List[str], n_numeric_features:int, levels_per_categorical_feature:List[int], categorical_format:CategoricalFormat):
        self.feature_names = feature_names

        self.numeric_feature_names = feature_names[:n_numeric_features]
        self.categorical_feature_names = feature_names[n_numeric_features:]
        self.categorical_format:CategoricalFormat = categorical_format

        self.n_numeric_features = n_numeric_features
        self.levels_per_categorical_feature = levels_per_categorical_feature

#### **Framework Component class**

In [91]:
class Component():
    @property
    def name(self) -> str:
        raise NotImplementedError()

    @property
    def parameters(self) -> dict:
        warnings.warn("Parameters have not been implemented for this class!")
        return {}
class FunctionalComponent(Component):
    def __init__(self):
        self.sequence_length: Optional[int] = None
        self.model_input_specification: Optional[ModelInputSpecification] = None
        self.input_shape: Optional[Tuple[int]] = None

    def apply(self, X, prefix: str = None):
        raise NotImplementedError()

    def build(self, sequence_length:int, model_input_specification:ModelInputSpecification):
        self.sequence_length = sequence_length
        self.model_input_specification = model_input_specification

#### **Base Classification Head**

In [92]:
class BaseClassificationHead(FunctionalComponent):
    def __init__(self):
        super().__init__()

    def apply_before_transformer(self, X, prefix:str=None):
        return X

#### **Classification Head: Last token, FeatureWise**

In [93]:
class LastTokenClassificationHead(BaseClassificationHead):
    def __init__(self):
        super().__init__()

    def apply(self, X, prefix: str = None):
        if prefix is None:
            prefix = ""

        x = Lambda(lambda x: x[..., -1, :], name=f"{prefix}slice_last")(X)
        #x = Flatten(name=f"{prefix}flatten_last")(x)

        return x

    @property
    def name(self) -> str:
        return "Last Token"

    @property
    def parameters(self) -> dict:
        return {}

In [94]:
class FeaturewiseEmbedding(BaseClassificationHead):
    def __init__(self, project:bool=False):
        super().__init__()
        self.project: bool = project

    @property
    def name(self):
        if self.project:
            return f"Featurewise Embed - Projection"
        else:
            return f"Featurewise Embed - Dense"

    @property
    def parameters(self):
        return {}


    def apply(self, X, prefix:str=None):
        if prefix is None:
            prefix = ""

        if self.model_input_specification is None:
            raise Exception("Please call build() before calling apply!")

        x = Dense(1,
                  activation="linear",
                  use_bias=(not self.project),
                  name=f"{prefix}featurewise_embed")(X)

        x = Flatten()(x)

        return x

#### **Input Encoding:Record Level Projection**

In [95]:
class BaseInputEncoding(FunctionalComponent):
    def apply(self, X:List["keras.Input"], prefix: str = None):
        raise NotImplementedError("Please override this with a custom implementation")

    @property
    def required_input_format(self) -> CategoricalFormat:
        raise NotImplementedError("Please override this with a custom implementation")

In [96]:
class EmbedLayerType(Enum):
    Dense = 0,
    Lookup = 1,
    Projection = 2

class RecordLevelEmbed(BaseInputEncoding):
    def __init__(self, embed_dimension: int, project:bool = False):
        super().__init__()

        self.embed_dimension: int = embed_dimension
        self.project: bool = project

    @property
    def name(self):
        if self.project:
            return "Record Level Projection"
        return "Record Level Embedding"

    @property
    def parameters(self):
        return {
            "dimensions_per_feature": self.embed_dimension
        }

    def apply(self, X:List[keras.Input], prefix: str = None):
        if prefix is None:
            prefix = ""

        assert self.model_input_specification.categorical_format == CategoricalFormat.OneHot

        x = Concatenate(name=f"{prefix}feature_concat", axis=-1)(X)
        x = Dense(self.embed_dimension, activation="linear", use_bias=not self.project, name=f"{prefix}embed")(x)

        return x

    @property
    def required_input_format(self) -> CategoricalFormat:
        return CategoricalFormat.OneHot

#### **Input Preprocessing**

In [97]:
class BasePreProcessing(Component):
    def __init__(self):
        pass

    def fit_numerical(self, column_name:str, values:np.array):
        raise NotImplementedError("Please override this base class with a custom implementation")

    def transform_numerical(self, column_name:str, values: np.array):
        raise NotImplementedError("Please override this base class with a custom implementation")

    def fit_categorical(self, column_name:str, values:np.array):
        raise NotImplementedError("Please override this base class with a custom implementation")

    def transform_categorical(self, column_name:str, values:np.array, expected_categorical_format:CategoricalFormat):
        raise NotImplementedError("Please override this base class with a custom implementation")

In [98]:
class StandardPreProcessing(BasePreProcessing):
    def __init__(self, n_categorical_levels: int, clip_numerical_values:bool=False):
        super().__init__()
        self.n_categorical_levels:int = n_categorical_levels
        self.clip_numerical_values:bool = clip_numerical_values
        self.min_range = {}
        self.encoded_levels = {}

    @property
    def name(self) -> str:
        return "Standard Preprocessing"

    @property
    def parameters(self) -> dict:
        return {
            "n_categorical_levels": self.n_categorical_levels,
            "clip_numerical_values": self.clip_numerical_values
        }

    def fit_numerical(self, column_name: str, values: np.array):

        v0 = np.min(values)
        v1 = np.max(values)
        r = v1 - v0

        self.min_range[column_name] = (v0, r)

    def transform_numerical(self, column_name: str, values: np.array):
        col_min, col_range = self.min_range[column_name]

        if col_range == 0:
            return np.zeros_like(values, dtype="float32")

        # center on zero
        values -= col_min

        # apply a logarithm
        col_values = np.log(values + 1)

        # scale max to 1
        col_values *= 1. / np.log(col_range + 1)

        if self.clip_numerical_values:
            col_values = np.clip(col_values, 0., 1.)

        return col_values

    def fit_categorical(self, column_name: str, values: np.array):
        levels, level_counts = np.unique(values, return_counts=True)
        sorted_levels = list(sorted(zip(levels, level_counts), key=lambda x: x[1], reverse=True))
        self.encoded_levels[column_name] = [s[0] for s in sorted_levels[:self.n_categorical_levels]]


    def transform_categorical(self, column_name:str, values: np.array, expected_categorical_format: CategoricalFormat):
        encoded_levels = self.encoded_levels[column_name]
        print(f"Encoding the {len(encoded_levels)} levels for {column_name}")

        result_values = np.ones(len(values), dtype="uint32")
        for level_i, level in enumerate(encoded_levels):
            level_mask = values == level

            # we use +1 here, as 0 = previously unseen, and 1 to (n + 1) are the encoded levels
            result_values[level_mask] = level_i + 1

        if expected_categorical_format == CategoricalFormat.Integers:
            return result_values

        v = pd.get_dummies(result_values, prefix=column_name)
        return v

#### **Transformer classes**

In [99]:
class BaseSequential(FunctionalComponent):
    pass

#### Decoders Encoders

In [100]:
class TransformerDecoderBlock(Layer):
    def __init__(self, input_dimension:int, inner_dimension:int, num_heads:int, dropout_rate=0.1):
        super().__init__()

        self.num_heads = num_heads
        self.input_dimension = input_dimension
        self.inner_dimension = inner_dimension
        self.dropout_rate = dropout_rate

        self.mha = MultiHeadAttention(num_heads=num_heads, key_dim=input_dimension)
        self.dropout1 = Dropout(dropout_rate)
        self.layernorm1 = LayerNormalization(epsilon=1e-6)

        self.ffn = tf.keras.Sequential([
            Dense(inner_dimension, activation='relu'),
            Dense(input_dimension)
        ])
        self.dropout2 = Dropout(dropout_rate)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)

    # noinspection PyMethodOverriding
    def call(self, inputs, training, mask=None):
        # inputs = (target_seq, enc_output)
        target_seq = inputs
        enc_output = inputs

        # self attention of target_seq
        attn_output = self.mha(target_seq, target_seq)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = target_seq + attn_output
        out1 = self.layernorm1(out1)

        # multi-head attention with encoder output as the key and value, and target_seq as the query
        attn_output = self.mha(out1, enc_output)
        attn_output = self.dropout2(attn_output, training=training)
        out2 = out1 + attn_output
        out2 = self.layernorm2(out2)

        # feed forward network
        ffn_output = self.ffn(out2)
        ffn_output = self.dropout2(ffn_output, training=training)
        out3 = out2 + ffn_output
        out3 = self.layernorm2(out3)

        return out3

#### GPT3

In [101]:
class GPT3Attention(layers.Layer):
    def __init__(self, n_heads, d_model, dropout_rate=0.1):
        super(GPT3Attention, self).__init__()
        self.n_heads = n_heads
        self.d_model = d_model
        self.depth = d_model // n_heads

        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)

        self.dropout = tf.keras.layers.Dropout(dropout_rate)
        self.dense = tf.keras.layers.Dense(d_model)

    def split_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.n_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    # noinspection PyMethodOverriding
    def call(self, q, k, v, mask=None):
        batch_size = tf.shape(q)[0]

        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)

        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)

        # Scaled Dot-Product Attention
        scaled_attention_logits = tf.matmul(q, k, transpose_b=True)
        scaled_attention_logits = scaled_attention_logits / tf.math.sqrt(tf.cast(self.depth, tf.float32))

        if mask is not None:
            scaled_attention_logits += (mask * -1e9)

        attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
        attention_weights = self.dropout(attention_weights)

        output = tf.matmul(attention_weights, v)
        output = tf.transpose(output, perm=[0, 2, 1, 3])
        output = tf.reshape(output, (batch_size, -1, self.d_model))

        output = self.dense(output)
        output = self.dropout(output)

        return output

class MultiHeadAttentionImplementation:
    Keras = 0,
    GPT3 = 1

class TransformerEncoderBlock(layers.Layer):
    def __init__(self, input_dimension:int, inner_dimension:int, num_heads:int, dropout_rate=0.1, use_conv:bool=False, prefix:str=None, attn_implementation:MultiHeadAttentionImplementation = MultiHeadAttentionImplementation.Keras):

        if prefix is None:
            prefix = ""

        super().__init__(name=f"{prefix}transformer_encoder")

        if inner_dimension < input_dimension:
            warnings.warn(f"Typically inner_dimension should be greater than or equal to the input_dimension!")

        self.attn_implementation = attn_implementation

        self.dropout_rate = dropout_rate
        self.attention = \
            layers.MultiHeadAttention(num_heads=num_heads, key_dim=inner_dimension, name=f"{prefix}multi_head_attn") \
                if attn_implementation == MultiHeadAttentionImplementation.Keras else\
                GPT3Attention(num_heads, inner_dimension, dropout_rate=0.0)

        layer_norm = 1e-6

        self.attention_dropout = layers.Dropout(dropout_rate, name=f"{prefix}attention_dropout")
        self.attention_layer_norm = layers.LayerNormalization(epsilon=layer_norm, name=f"{prefix}attention_layer_norm")

        self.feed_forward_0 = Conv1D(filters=inner_dimension, kernel_size=1, activation="relu", name=f"{prefix}feed_forward_0") \
            if use_conv else Dense(inner_dimension, activation="relu", name=f"{prefix}feed_forward_0")
        self.feed_forward_1 = Conv1D(filters=input_dimension, kernel_size=1, activation="relu", name=f"{prefix}feed_forward_1") \
            if use_conv else Dense(input_dimension, activation="relu", name=f"{prefix}feed_forward_1")

        self.feed_forward_dropout = layers.Dropout(dropout_rate, name=f"{prefix}feed_forward_dropout")
        self.feed_forward_layer_norm = layers.LayerNormalization(epsilon=layer_norm, name=f"{prefix}feed_forward_layer_norm")

    # noinspection PyMethodOverriding
    def call(self, inputs, training, mask=None):
        x = inputs
        x = self.attention(x, x) if self.attn_implementation == MultiHeadAttentionImplementation.Keras else self.attention(x, x, x, mask)

        attention_output = self.attention_dropout(x, training=training) if self.dropout_rate > 0 else x

        x = inputs + attention_output
        x = self.attention_layer_norm(x)
        x = self.feed_forward_0(x)
        x = self.feed_forward_1(x)
        x = self.feed_forward_dropout(x, training=training) if self.dropout_rate > 0 else x
        feed_forward_output = x

        return self.feed_forward_layer_norm(attention_output + feed_forward_output)


#### Basic Transformer 

In [102]:
class BasicTransformer(BaseSequential):

    @property
    def name(self) -> str:
        if self.use_conv:
            return f"Basic Conv Transformer" + (" Decoder" if self.is_decoder else "")
        else:
            return f"Basic Dense Transformer" + (" Decoder" if self.is_decoder else "")

    @property
    def parameters(self) -> dict:
        return {
            "n_layers": self.n_layers,
            "internal_size": self.internal_size,
            "use_conv": self.use_conv,
            "n_heads": self.n_heads,
            "dropout_rate": self.dropout_rate,
            "head_size": self.internal_size
        }

    def __init__(self, n_layers:int, internal_size:int, n_heads:int, use_conv:bool=False, dropout_rate:float=0.1, is_decoder=False):
        super().__init__()
        self.n_layers = n_layers
        self.internal_size = internal_size
        self.use_conv = use_conv
        self.n_heads = n_heads
        self.dropout_rate = dropout_rate
        self.is_decoder = is_decoder

    def apply(self, X, prefix: str = None):
        #window_size = self.sequence_length
        real_size = X.shape[-1]

        m_x = X

        for layer_i in range(self.n_layers):
            if self.is_decoder:
                if self.use_conv:
                    raise NotImplementedError()
                m_x = TransformerDecoderBlock(real_size, self.internal_size, self.n_heads, dropout_rate=self.dropout_rate)(m_x)
            else:
                m_x = TransformerEncoderBlock(real_size, self.internal_size, self.n_heads, dropout_rate=self.dropout_rate, use_conv=self.use_conv, prefix=f"{prefix}block_{layer_i}_")(m_x)

        return m_x

#### GPT Small

In [103]:
class GPTSmallTransformer(BaseSequential):

    @property
    def name(self) -> str:
        return "GPT Model"

    @property
    def parameters(self) -> dict:
        return {
            "n_layers": self.n_layers,
            "internal_size": self.internal_size,
            "n_heads": self.n_heads,
            "dropout_rate": self.dropout_rate,
            "head_size": self.head_size
        }

    def __init__(self):
        super().__init__()
        self.n_layers = 12
        self.internal_size = 768
        self.n_heads = 12
        self.head_size = self.internal_size / self.n_heads
        self.dropout_rate = 0.02
        self.is_decoder = True

    def apply(self, X, prefix: str = None):
        #window_size = self.sequence_length
        real_size = X.shape[-1]

        m_x = X

        for layer_i in range(self.n_layers):
            m_x = TransformerDecoderBlock(real_size, self.internal_size, self.n_heads, dropout_rate=self.dropout_rate)(m_x)

        return m_x

#### **Main FlowTransformer T_T**

In [104]:
class FlowTransformer:
    retain_inmem_cache = False
    inmem_cache = None

    def  __init__(self, pre_processing:BasePreProcessing,
                  input_encoding:BaseInputEncoding,
                  sequential_model:FunctionalComponent,
                  classification_head:BaseClassificationHead,
                  params:FlowTransformerParameters,
                  rs:np.random.RandomState=None):

        self.rs = np.random.RandomState() if rs is None else rs
        self.classification_head = classification_head
        self.sequential_model = sequential_model
        self.input_encoding = input_encoding
        self.pre_processing = pre_processing
        self.parameters = params

        self.dataset_specification: Optional[DatasetSpecification] = None

        self.X = None
        self.y = None

        self.training_mask = None
        self.model_input_spec: Optional[ModelInputSpecification] = None

        self.experiment_key = {}

        self.y_backup: pd.Series = pd.Series()

    def build_model(self, prefix:str=None):
        if prefix is None:
            prefix = ""

        if self.X is None:
            raise Exception("Please call load_dataset before calling build_model()")

        m_inputs = []
        for numeric_feature in self.model_input_spec.numeric_feature_names:
            m_input = Input((self.parameters.window_size, 1), name=f"{prefix}input_{numeric_feature}", dtype="float32")
            m_inputs.append(m_input)

        for categorical_feature_name, categorical_feature_levels in \
            zip(self.model_input_spec.categorical_feature_names, self.model_input_spec.levels_per_categorical_feature):
            m_input = Input(
                (self.parameters.window_size, 1 if self.model_input_spec.categorical_format == CategoricalFormat.Integers else categorical_feature_levels),
                name=f"{prefix}input_{categorical_feature_name}",
                dtype="int32" if self.model_input_spec.categorical_format == CategoricalFormat.Integers else "float32"
            )
            m_inputs.append(m_input)

        self.input_encoding.build(self.parameters.window_size, self.model_input_spec)
        self.sequential_model.build(self.parameters.window_size, self.model_input_spec)
        self.classification_head.build(self.parameters.window_size, self.model_input_spec)

        m_x = self.input_encoding.apply(m_inputs, prefix)

        # in case the classification head needs to add tokens at this stage
        m_x = self.classification_head.apply_before_transformer(m_x, prefix)

        m_x = self.sequential_model.apply(m_x, prefix)
        m_x = self.classification_head.apply(m_x, prefix)

        for layer_i, layer_size in enumerate(self.parameters.mlp_layer_sizes):
            m_x = Dense(layer_size, activation="relu", name=f"{prefix}classification_mlp_{layer_i}_{layer_size}")(m_x)
            m_x = Dropout(self.parameters.mlp_dropout)(m_x) if self.parameters.mlp_dropout > 0 else m_x

        #### CNG
        m_x = Dense(15, activation="softmax", name=f"{prefix}multiclass_classification_out")(m_x)
        m = Model(m_inputs, m_x)
        #m.summary()
        return m

    def _load_preprocessed_dataset(self, dataset_name:str,
                     dataset:Union[pd.DataFrame, str],
                     specification:DatasetSpecification,
                     cache_folder:Optional[str]=None,
                     n_rows:int=0,
                     evaluation_dataset_sampling:EvaluationDatasetSampling=EvaluationDatasetSampling.LastRows,
                     evaluation_percent:float=0.2,
                     numerical_filter=1_000_000_000) -> Tuple[pd.DataFrame, ModelInputSpecification]:

        cache_file_path = None

        if dataset_name is None:
            raise Exception(f"Dataset name must be specified so FlowTransformer can optimise operations between subsequent calls!")

        pp_key = get_identifier(
            {
                "__preprocessing_name": self.pre_processing.name,
                **self.pre_processing.parameters
            }
        )

        local_key = get_identifier({
            "evaluation_percent": evaluation_percent,
            "numerical_filter": numerical_filter,
            "categorical_method": str(self.input_encoding.required_input_format),
            "n_rows": n_rows,
        })

        cache_key = f"{dataset_name}_{n_rows}_{pp_key}_{local_key}"

        if FlowTransformer.retain_inmem_cache:
            if FlowTransformer.inmem_cache is not None and cache_key in FlowTransformer.inmem_cache:
                print(f"Using in-memory cached version of this pre-processed dataset. To turn off this functionality set FlowTransformer.retain_inmem_cache = False")
                return FlowTransformer.inmem_cache[cache_key]

        if cache_folder is not None:
            cache_file_name = f"{cache_key}.feather"
            cache_file_path = os.path.join(cache_folder, cache_file_name)

            print(f"Using cache file path: {cache_file_path}")

            if os.path.exists(cache_file_path):
                print(f"Reading directly from cache {cache_file_path}...")
                model_input_spec: ModelInputSpecification
                dataset, model_input_spec = load_feather_plus_metadata(cache_file_path)
                return dataset, model_input_spec

        if isinstance(dataset, str):
            print(f"Attempting to read dataset from path {dataset}...")
            if dataset.lower().endswith(".feather"):
                # read as a feather file
                dataset = pd.read_feather(dataset, columns=specification.include_fields+[specification.class_column])
            elif dataset.lower().endswith(".csv"):
                dataset = pd.read_csv(dataset, nrows=n_rows if n_rows > 0 else None)
            else:
                raise Exception("Unrecognised dataset filetype!")
        elif not isinstance(dataset, pd.DataFrame):
            raise Exception("Unrecognised dataset input type, should be a path to a CSV or feather file, or a pandas dataframe!")

        assert isinstance(dataset, pd.DataFrame)

        if 0 < n_rows < len(dataset):
            dataset = dataset.iloc[:n_rows]

        training_mask = np.ones(len(dataset),  dtype=bool)
        eval_n = int(len(dataset) * evaluation_percent)

        if evaluation_dataset_sampling == EvaluationDatasetSampling.FilterColumn:
            if dataset.columns[-1] != specification.test_column:
                raise Exception(f"Ensure that the 'test' ({specification.test_column}) column is the last column of the dataset being loaded, and that the name of this column is provided as part of the dataset specification")

        if evaluation_dataset_sampling != EvaluationDatasetSampling.LastRows:
            warnings.warn("Using EvaluationDatasetSampling options other than LastRows might leak some information during training, if for example the context window leading up to a particular flow contains an evaluation flow, and this flow has out of range values (out of range to when pre-processing was applied on the training flows), then the model might potentially learn to handle these. In any case, no class leakage is present.")

        if evaluation_dataset_sampling == EvaluationDatasetSampling.LastRows:
            training_mask[-eval_n:] = False
        elif evaluation_dataset_sampling == EvaluationDatasetSampling.RandomRows:
            index = np.arange(self.parameters.window_size, len(dataset))
            sample = self.rs.choice(index, eval_n, replace=False)
            training_mask[sample] = False
        elif evaluation_dataset_sampling == EvaluationDatasetSampling.FilterColumn:
            # must be the last column of the dataset
            training_column = dataset.columns[-1]
            print(f"Using the last column {training_column} as the training mask column")

            v, c = np.unique(dataset[training_column].values,  return_counts=True)
            min_index = np.argmin(c)
            min_v = v[min_index]

            warnings.warn(f"Autodetected class {min_v} of {training_column} to represent the evaluation class!")

            eval_indices = np.argwhere(dataset[training_column].values == min_v).reshape(-1)
            eval_indices = eval_indices[(eval_indices > self.parameters.window_size)]

            training_mask[eval_indices] = False
            del dataset[training_column]

        numerical_columns = set(specification.include_fields).difference(specification.categorical_fields)
        categorical_columns = specification.categorical_fields

        print(f"Set y to = {specification.class_column}")
        # print(f"One hot encoding y...")

        new_df = {"__training": training_mask, "__y": dataset[specification.class_column].values}
        new_features = []

        print("Converting numerical columns to floats, and removing out of range values...")
        for col_name in numerical_columns:
            assert col_name in dataset.columns
            new_features.append(col_name)

            col_values = dataset[col_name].values
            col_values[~np.isfinite(col_values)] = 0
            col_values[col_values < -numerical_filter] = 0
            col_values[col_values > numerical_filter] = 0
            col_values = col_values.astype("float32")

            if not np.all(np.isfinite(col_values)):
                raise Exception("Flow format data had non finite values after float transformation!")

            new_df[col_name] = col_values

        print(f"Applying pre-processing to numerical values")
        for i, col_name in enumerate(numerical_columns):
            print(f"[Numerical {i+1:,} / {len(numerical_columns)}] Processing numerical column {col_name}...")
            all_data = new_df[col_name]
            training_data = all_data[training_mask]

            self.pre_processing.fit_numerical(col_name, training_data)
            new_df[col_name] = self.pre_processing.transform_numerical(col_name, all_data)

        print(f"Applying pre-processing to categorical values")
        print("Keeping a copy of label column")
        y_backup = new_df['__y']
        y_backup = pd.Series(y_backup)
        print(f"DEBOG: at ft main class {y_backup.shape}")

        levels_per_categorical_feature = []
        for i, col_name in enumerate(categorical_columns):
            new_features.append(col_name)
            # if col_name == specification.class_column:
            #     continue
            print(f"[Categorical {i+1:,} / {len(categorical_columns)}] Processing categorical column {col_name}...")

            all_data = dataset[col_name].values
            training_data = all_data[training_mask]

            self.pre_processing.fit_categorical(col_name, training_data)
            new_values = self.pre_processing.transform_categorical(col_name, all_data, self.input_encoding.required_input_format)

            if self.input_encoding.required_input_format == CategoricalFormat.OneHot:
                # multiple columns of one hot values
                if isinstance(new_values, pd.DataFrame):
                    levels_per_categorical_feature.append(len(new_values.columns))
                    for c in new_values.columns:
                        new_df[c] = new_values[c]
                else:
                    n_one_hot_levels = new_values.shape[1]
                    levels_per_categorical_feature.append(n_one_hot_levels)
                    for z in range(n_one_hot_levels):
                        new_df[f"{col_name}_{z}"] = new_values[:, z]
            else:
                # single column of integers
                levels_per_categorical_feature.append(len(np.unique(new_values)))
                new_df[col_name] = new_values

        print(f"Generating pre-processed dataframe...")
        new_df = pd.DataFrame(new_df)
        #### CNG
        self.y_backup = y_backup
        print(f"DEBOG: also at ft main class{self.y_backup.shape}")

        model_input_spec = ModelInputSpecification(new_features, len(numerical_columns), levels_per_categorical_feature, self.input_encoding.required_input_format)

        print(f"Input data frame had shape ({len(dataset)},{len(dataset.columns)}), output data frame has shape ({len(new_df)},{len(new_df.columns)}) after pre-processing...")

        if cache_file_path is not None:
            print(f"Writing to cache file path: {cache_file_path}...")
            save_feather_plus_metadata(cache_file_path, new_df, model_input_spec)

        if FlowTransformer.retain_inmem_cache:
            if FlowTransformer.inmem_cache is None:
                FlowTransformer.inmem_cache = {}

            FlowTransformer.inmem_cache.clear()
            FlowTransformer.inmem_cache[cache_key] = (new_df, model_input_spec)

        return new_df, model_input_spec

    def load_dataset(self, dataset_name:str,
                     dataset:Union[pd.DataFrame, str],
                     specification:DatasetSpecification,
                     cache_path:Optional[str]=None,
                     n_rows:int=0,
                     evaluation_dataset_sampling:EvaluationDatasetSampling=EvaluationDatasetSampling.LastRows,
                     evaluation_percent:float=0.2,
                     numerical_filter=1_000_000_000) -> pd.DataFrame:
        """
        Load a dataset and prepare it for training

        :param dataset: The path to a CSV dataset to load from, or a dataframe
        :param cache_path: Where to store a cached version of this file
        :param n_rows: The number of rows to ingest from the dataset, or 0 to ingest all
        """

        if cache_path is None:
            cache_path = "cache"

        if not os.path.exists(cache_path):
            warnings.warn(f"Could not find cache folder: {cache_path}, attempting to create")
            os.mkdir(cache_path)

        self.dataset_specification = specification
        df, model_input_spec = self._load_preprocessed_dataset(dataset_name, dataset, specification, cache_path, n_rows, evaluation_dataset_sampling, evaluation_percent, numerical_filter)
        
        # print("DEBOG: ")
        # with pd.option_context('display.max_columns', None, 'display.max_colwidth', None):
        #     print(df.columns)
            
        training_mask = df["__training"].values
        del df["__training"]


        #### CNG
        attack_columns = [col for col in df.columns if col.startswith(str(self.dataset_specification.class_column))]

        y = df[attack_columns].values
        del df['__y']

        self.X = df
        self.y = y
        self.training_mask = training_mask
        self.model_input_spec = model_input_spec

        return df

    def evaluate(self, m:keras.Model, batch_size, early_stopping_patience:int=5, epochs:int=100, steps_per_epoch:int=128):
        n_malicious_per_batch = int(0.5 * batch_size)
        n_legit_per_batch = batch_size - n_malicious_per_batch

        overall_y_preserve = np.zeros(dtype="float32", shape=(n_malicious_per_batch + n_legit_per_batch,15))
        overall_y_preserve[:n_malicious_per_batch] = 1.

        selectable_mask = np.zeros(len(self.X), dtype=bool)
        selectable_mask[self.parameters.window_size:-self.parameters.window_size] = True
        train_mask = self.training_mask

        #### CNG
        y_mask = ~(self.y_backup.astype('str') == str(self.dataset_specification.benign_label))
        # y_mask = self.y
        print(f"DEBOG: still at ft main {self.y_backup.shape}")
        indices_train = np.argwhere(train_mask).reshape(-1)
        malicious_indices_train = np.argwhere(train_mask & y_mask & selectable_mask).reshape(-1)
        legit_indices_train = np.argwhere(train_mask & ~y_mask & selectable_mask).reshape(-1)

        indices_test:np.ndarray = np.argwhere(~train_mask).reshape(-1)

        def get_windows_for_indices(indices:np.ndarray, ordered) -> List[pd.DataFrame]:
            X: List[pd.DataFrame] = []

            if ordered:
                # we don't really want to include eval samples as part of context, because out of range values might be learned
                # by the model, _but_ we are forced to in the windowed approach, if users haven't just selected the
                # "take last 10%" as eval option. We warn them prior to this though.
                for i1 in indices:
                    X.append(self.X.iloc[(i1 - self.parameters.window_size) + 1:i1 + 1])
            else:
                context_indices_batch = np.random.choice(indices_train, size=(batch_size, self.parameters.window_size),
                                                         replace=False).reshape(-1)
                context_indices_batch[:, -1] = indices

                for index in context_indices_batch:
                    X.append(self.X.iloc[index])

            return X

        feature_columns_map = {}

        def samplewise_to_featurewise(X):
            sequence_length = len(X[0])

            combined_df = pd.concat(X)

            featurewise_X = []

            if len(feature_columns_map) == 0:
                for feature in self.model_input_spec.feature_names:
                    if feature in self.model_input_spec.numeric_feature_names or self.model_input_spec.categorical_format == CategoricalFormat.Integers:
                        feature_columns_map[feature] = feature
                    else:
                        # this is a one-hot encoded categorical feature
                        feature_columns_map[feature] = [c for c in X[0].columns if str(c).startswith(feature)]

            for feature in self.model_input_spec.feature_names:
                feature_columns = feature_columns_map[feature]
                combined_values = combined_df[feature_columns].values

                # maybe this can be faster with a reshape but I couldn't get it to work
                combined_values = np.array([combined_values[i:i+sequence_length] for i in range(0, len(combined_values), sequence_length)])
                featurewise_X.append(combined_values)

            return featurewise_X

        print(f"Building eval dataset...")
        eval_X = get_windows_for_indices(indices_test, True)
        print(f"Splitting dataset to featurewise...")
        eval_featurewise_X = samplewise_to_featurewise(eval_X)
        eval_y = self.y[indices_test]
        #### CNG
        # eval_P = eval_y
        # n_eval_P = np.count_nonzero(eval_P)
        # eval_N = ~eval_y
        # n_eval_N = np.count_nonzero(eval_N)
        print(f"Evaluation dataset is built!")

        # print(f"Positive samples in eval set: {n_eval_P}")
        # print(f"Negative samples in eval set: {n_eval_N}")

        epoch_results = []

        # def run_evaluation(epoch):
        #     pred_y = m.predict(eval_featurewise_X, verbose=True)
        #     pred_y = pred_y.reshape(-1) > 0.5

        #     #### CNG

        #     pred_P = pred_y
        #     n_pred_P = np.count_nonzero(pred_P)

        #     pred_N = ~pred_y
        #     n_pred_N = np.count_nonzero(pred_N)

        #     TP = np.count_nonzero(pred_P & eval_P)
        #     FP = np.count_nonzero(pred_P & ~eval_P)
        #     TN = np.count_nonzero(pred_N & eval_N)
        #     FN = np.count_nonzero(pred_N & ~eval_N)

        #     sensitivity = TP / (TP + FN) if (TP + FN) > 0 else 0
        #     specificity = TN / (TN + FP) if (TN + FP) > 0 else 0
        #     balanced_accuracy = (sensitivity + specificity) / 2

        #     precision = TP / (TP + FP) if (TP + FP) > 0 else 0
        #     recall = TP / (TP + FN) if (TP + FN) > 0 else 0

        #     f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
            
        #     print(f"Epoch {epoch} yielded predictions: {pred_y.shape}, overall balanced accuracy: {balanced_accuracy * 100:.2f}%, TP = {TP:,} / {n_eval_P:,}, TN = {TN:,} / {n_eval_N:,}")

        #     epoch_results.append({
        #         "epoch": epoch,
        #         "P": n_eval_P,
        #         "N": n_eval_N,
        #         "pred_P": n_pred_P,
        #         "pred_N": n_pred_N,
        #         "TP": TP,
        #         "FP": FP,
        #         "TN": TN,
        #         "FN": FN,
        #         "bal_acc": balanced_accuracy,
        #         "f1": f1_score
        #     })
        
        def run_evaluation(epoch):
            """
            Evaluate a multiclass classification model on eval_featurewise_X and eval_y.
            
            Args:
                epoch (int): The epoch number of the current evaluation.

            Returns:
                None. Appends evaluation metrics to `epoch_results` list.
            """
            # Predict probabilities (assume softmax output) and get predicted class indices
            pred_probs = m.predict(eval_featurewise_X, verbose=True)  # Softmax output, shape (num_samples, num_classes)
            pred_y = np.argmax(pred_probs, axis=1)  # Predicted class indices
            true_y = eval_y   # Convert one-hot to indices if necessary
            
            # Calculate confusion matrix and classification report
            confusion = confusion_matrix(true_y, pred_y)
            precision = precision_score(true_y, pred_y, average=None)  # Per-class precision
            recall = recall_score(true_y, pred_y, average=None)        # Per-class recall
            f1 = f1_score(true_y, pred_y, average=None)                # Per-class F1-score

            # Calculate overall metrics (macro and weighted averages)
            macro_precision = precision_score(true_y, pred_y, average='macro')
            macro_recall = recall_score(true_y, pred_y, average='macro')
            macro_f1 = f1_score(true_y, pred_y, average='macro')
            
            weighted_precision = precision_score(true_y, pred_y, average='weighted')
            weighted_recall = recall_score(true_y, pred_y, average='weighted')
            weighted_f1 = f1_score(true_y, pred_y, average='weighted')
            
            balanced_accuracy = balanced_accuracy_score(true_y, pred_y)

            # Print summary
            print(f"Epoch {epoch} yielded predictions: {pred_y.shape}, balanced accuracy: {balanced_accuracy * 100:.2f}%")
            print(f"Confusion Matrix:\n{confusion}")
            print(f"Per-Class Metrics:")
            for i, (p, r, f) in enumerate(zip(precision, recall, f1)):
                print(f"  Class {i}: Precision = {p:.4f}, Recall = {r:.4f}, F1-score = {f:.4f}")
            print(f"Macro Avg - Precision: {macro_precision:.4f}, Recall: {macro_recall:.4f}, F1: {macro_f1:.4f}")
            print(f"Weighted Avg - Precision: {weighted_precision:.4f}, Recall: {weighted_recall:.4f}, F1: {weighted_f1:.4f}")
            
            # Append metrics for this epoch
            epoch_results.append({
                "epoch": epoch,
                "confusion_matrix": confusion.tolist(),  # Store it as a list for JSON compatibility
                "balanced_accuracy": balanced_accuracy,
                "macro_precision": macro_precision,
                "macro_recall": macro_recall,
                "macro_f1": macro_f1,
                "weighted_precision": weighted_precision,
                "weighted_recall": weighted_recall,
                "weighted_f1": weighted_f1,
                "class_precision": precision.tolist(),
                "class_recall": recall.tolist(),
                "class_f1": f1.tolist(),
            })


        class BatchYielder():
            def __init__(self, ordered, random, rs):
                self.ordered = ordered
                self.random = random
                self.cursor_malicious = 0
                self.cursor_legit = 0
                self.rs = rs

            def get_batch(self):
                malicious_indices_batch = self.rs.choice(malicious_indices_train, size=n_malicious_per_batch,
                                                         replace=False) \
                    if self.random else \
                    malicious_indices_train[self.cursor_malicious:self.cursor_malicious + n_malicious_per_batch]

                legitimate_indices_batch = self.rs.choice(legit_indices_train, size=n_legit_per_batch, replace=False) \
                    if self.random else \
                    legit_indices_train[self.cursor_legit:self.cursor_legit + n_legit_per_batch]

                indices = np.concatenate([malicious_indices_batch, legitimate_indices_batch])

                self.cursor_malicious = self.cursor_malicious + n_malicious_per_batch
                self.cursor_malicious = self.cursor_malicious % (len(malicious_indices_train) - n_malicious_per_batch)

                self.cursor_legit = self.cursor_legit + n_legit_per_batch
                self.cursor_legit = self.cursor_legit % (len(legit_indices_train) - n_legit_per_batch)

                X = get_windows_for_indices(indices, self.ordered)
                # each x in X contains a dataframe, with window_size rows and all the features of the flows. There are batch_size of these.

                # we have a dataframe containing batch_size x (window_size, features)
                # we actually want a result of features x (batch_size, sequence_length, feature_dimension)
                featurewise_X = samplewise_to_featurewise(X)

                return featurewise_X, overall_y_preserve

        batch_yielder = BatchYielder(self.parameters._train_ensure_flows_are_ordered_within_windows, not self.parameters._train_draw_sequential_windows, self.rs)

        min_loss = 100
        iters_since_loss_decrease = 0

        train_results = []
        final_epoch = 0

        last_print = time.time()
        elapsed_time = 0

        for epoch in range(epochs):
            final_epoch = epoch

            has_reduced_loss = False
            for step in range(steps_per_epoch):
                batch_X, batch_y = batch_yielder.get_batch()

                t0 = time.time()
                batch_results = m.train_on_batch(batch_X, batch_y)
                t1 = time.time()

                if epoch > 0 or step > 0:
                    elapsed_time += (t1 - t0)
                    if epoch == 0 and step == 1:
                        # include time for last "step" that we skipped with step > 0 for epoch == 0
                        elapsed_time *= 2

                train_results.append(batch_results + [elapsed_time, epoch])

                batch_loss = batch_results[0] if isinstance(batch_results, list) else batch_results

                if time.time() - last_print > 3:
                    last_print = time.time()
                    early_stop_phrase = "" if early_stopping_patience <= 0 else f" (early stop in {early_stopping_patience - iters_since_loss_decrease:,})"
                    print(f"Epoch = {epoch:,} / {epochs:,}{early_stop_phrase}, step = {step}, loss = {batch_loss:.5f}, results = {batch_results} -- elapsed (train): {elapsed_time:.2f}s")

                if batch_loss < min_loss:
                    has_reduced_loss = True
                    min_loss = batch_loss

            if has_reduced_loss:
                iters_since_loss_decrease = 0
            else:
                iters_since_loss_decrease += 1

            do_early_stop = early_stopping_patience > 0 and iters_since_loss_decrease > early_stopping_patience
            is_last_epoch = epoch == epochs - 1
            run_eval = epoch in [6] or is_last_epoch or do_early_stop

            if run_eval:
                run_evaluation(epoch)

            if do_early_stop:
                print(f"Early stopping at epoch: {epoch}")
                break

        eval_results = pd.DataFrame(epoch_results)

        return (train_results, eval_results, final_epoch)


    def time(self, m:keras.Model, batch_size, n_steps=128, n_repeats=4):
        n_malicious_per_batch = int(0.5 * batch_size)
        n_legit_per_batch = batch_size - n_malicious_per_batch

        overall_y_preserve = np.zeros(dtype="float32", shape=(n_malicious_per_batch + n_legit_per_batch,))
        overall_y_preserve[:n_malicious_per_batch] = 1.

        selectable_mask = np.zeros(len(self.X), dtype=bool)
        selectable_mask[self.parameters.window_size:-self.parameters.window_size] = True
        train_mask = self.training_mask

        y_mask = ~(self.y_backup.astype('str') == str(self.dataset_specification.benign_label))

        indices_train = np.argwhere(train_mask).reshape(-1)
        malicious_indices_train = np.argwhere(train_mask & y_mask & selectable_mask).reshape(-1)
        legit_indices_train = np.argwhere(train_mask & ~y_mask & selectable_mask).reshape(-1)

        indices_test:np.ndarray = np.argwhere(~train_mask).reshape(-1)

        def get_windows_for_indices(indices:np.ndarray, ordered) -> List[pd.DataFrame]:
            X: List[pd.DataFrame] = []

            if ordered:
                # we don't really want to include eval samples as part of context, because out of range values might be learned
                # by the model, _but_ we are forced to in the windowed approach, if users haven't just selected the
                # "take last 10%" as eval option. We warn them prior to this though.
                for i1 in indices:
                    X.append(self.X.iloc[(i1 - self.parameters.window_size) + 1:i1 + 1])
            else:
                context_indices_batch = np.random.choice(indices_train, size=(batch_size, self.parameters.window_size),
                                                         replace=False).reshape(-1)
                context_indices_batch[:, -1] = indices

                for index in context_indices_batch:
                    X.append(self.X.iloc[index])

            return X

        feature_columns_map = {}

        def samplewise_to_featurewise(X):
            sequence_length = len(X[0])

            combined_df = pd.concat(X)

            featurewise_X = []

            if len(feature_columns_map) == 0:
                for feature in self.model_input_spec.feature_names:
                    if feature in self.model_input_spec.numeric_feature_names or self.model_input_spec.categorical_format == CategoricalFormat.Integers:
                        feature_columns_map[feature] = feature
                    else:
                        # this is a one-hot encoded categorical feature
                        feature_columns_map[feature] = [c for c in X[0].columns if str(c).startswith(feature)]

            for feature in self.model_input_spec.feature_names:
                feature_columns = feature_columns_map[feature]
                combined_values = combined_df[feature_columns].values

                # maybe this can be faster with a reshape but I couldn't get it to work
                combined_values = np.array([combined_values[i:i+sequence_length] for i in range(0, len(combined_values), sequence_length)])
                featurewise_X.append(combined_values)

            return featurewise_X


        epoch_results = []


        class BatchYielder():
            def __init__(self, ordered, random, rs):
                self.ordered = ordered
                self.random = random
                self.cursor_malicious = 0
                self.cursor_legit = 0
                self.rs = rs

            def get_batch(self):
                malicious_indices_batch = self.rs.choice(malicious_indices_train, size=n_malicious_per_batch,
                                                         replace=False) \
                    if self.random else \
                    malicious_indices_train[self.cursor_malicious:self.cursor_malicious + n_malicious_per_batch]

                legitimate_indices_batch = self.rs.choice(legit_indices_train, size=n_legit_per_batch, replace=False) \
                    if self.random else \
                    legit_indices_train[self.cursor_legit:self.cursor_legit + n_legit_per_batch]

                indices = np.concatenate([malicious_indices_batch, legitimate_indices_batch])

                self.cursor_malicious = self.cursor_malicious + n_malicious_per_batch
                self.cursor_malicious = self.cursor_malicious % (len(malicious_indices_train) - n_malicious_per_batch)

                self.cursor_legit = self.cursor_legit + n_legit_per_batch
                self.cursor_legit = self.cursor_legit % (len(legit_indices_train) - n_legit_per_batch)

                X = get_windows_for_indices(indices, self.ordered)
                # each x in X contains a dataframe, with window_size rows and all the features of the flows. There are batch_size of these.

                # we have a dataframe containing batch_size x (window_size, features)
                # we actually want a result of features x (batch_size, sequence_length, feature_dimension)
                featurewise_X = samplewise_to_featurewise(X)

                return featurewise_X, overall_y_preserve

        batch_yielder = BatchYielder(self.parameters._train_ensure_flows_are_ordered_within_windows, not self.parameters._train_draw_sequential_windows, self.rs)

        min_loss = 100
        iters_since_loss_decrease = 0

        final_epoch = 0

        last_print = time.time()
        elapsed_time = 0

        batch_times = []


        for step in range(n_steps):
            batch_X, batch_y = batch_yielder.get_batch()

            local_batch_times = []
            for i in range(n_repeats):
                t0 = time.time()
                batch_results = m.predict_on_batch(batch_X)
                t1 = time.time()
                local_batch_times.append(t1 - t0)

            batch_times.append(local_batch_times)

            if time.time() - last_print > 3:
                last_print = time.time()
                print(f"Step = {step}, running model evaluation... Average times = {np.mean(np.array(batch_times).reshape(-1))}")

        return batch_times

#### Runner

In [105]:
encodings = [
    RecordLevelEmbed(64),
    RecordLevelEmbed(64, project=True)
]

classification_heads = [
    FeaturewiseEmbedding(project=False),
    LastTokenClassificationHead()
]

transformers: List[FunctionalComponent] = [
    BasicTransformer(2, 128, n_heads=2),
    GPTSmallTransformer()
]

flow_file_path = r"dataset/"

datasets = [
    ("CSE_CIC_IDS", os.path.join(flow_file_path, "downsampled/downsampled_df_shuffled.csv"), NamedDatasetSpecifications.cse_cic_ids_2018, 0.01, EvaluationDatasetSampling.RandomRows)
]

print("Imported and defined")
pre_processing = StandardPreProcessing(n_categorical_levels=32)

# Define the transformer
ft = FlowTransformer(pre_processing=pre_processing,
                     input_encoding=encodings[0],
                     sequential_model=transformers[0],
                     classification_head=classification_heads[0],
                     params=FlowTransformerParameters(window_size=8, mlp_layer_sizes=[128], mlp_dropout=0.1))

print("Defined flowtransformer, starting load")

# Load the specific dataset
dataset_name, dataset_path, dataset_specification, eval_percent, eval_method = datasets[0]
ft.load_dataset(dataset_name, dataset_path, dataset_specification, evaluation_dataset_sampling=eval_method, evaluation_percent=eval_percent)

print("Loaded dataset, building model")

# Build the transformer model
m = ft.build_model()
m.summary()

# Compile the model
m.compile(optimizer="adam", loss='categorical_crossentropy', metrics=['categorical_accuracy'], jit_compile=True)

# Get the evaluation results
eval_results: pd.DataFrame
(train_results, eval_results, final_epoch) = ft.evaluate(m, batch_size=128, epochs=5, steps_per_epoch=64, early_stopping_patience=5)


print(eval_results)

Imported and defined
Defined flowtransformer, starting load
Using cache file path: cache\CSE_CIC_IDS_0_QdLmZHuh8yOmlGcKBEkf7hepImY0_VHNk9ujbqtTXGSrgVayeqG486IQ0.feather
Attempting to read dataset from path dataset/downsampled/downsampled_df_shuffled.csv...




Set y to = Attack
Converting numerical columns to floats, and removing out of range values...
Applying pre-processing to numerical values
[Numerical 1 / 28] Processing numerical column IN_BYTES...
[Numerical 2 / 28] Processing numerical column MAX_IP_PKT_LEN...
[Numerical 3 / 28] Processing numerical column LONGEST_FLOW_PKT...
[Numerical 4 / 28] Processing numerical column MIN_TTL...
[Numerical 5 / 28] Processing numerical column DURATION_IN...
[Numerical 6 / 28] Processing numerical column NUM_PKTS_512_TO_1024_BYTES...
[Numerical 7 / 28] Processing numerical column MIN_IP_PKT_LEN...
[Numerical 8 / 28] Processing numerical column NUM_PKTS_128_TO_256_BYTES...
[Numerical 9 / 28] Processing numerical column NUM_PKTS_UP_TO_128_BYTES...
[Numerical 10 / 28] Processing numerical column RETRANSMITTED_IN_PKTS...
[Numerical 11 / 28] Processing numerical column RETRANSMITTED_IN_BYTES...
[Numerical 12 / 28] Processing numerical column DST_TO_SRC_AVG_THROUGHPUT...
[Numerical 13 / 28] Processing num

ValueError: Classification metrics can't handle a mix of multilabel-indicator and binary targets