inspiration: https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly

In [46]:
import psycopg2
import numpy as np
import tensorflow as tf

from keras.src.callbacks import EarlyStopping
from keras.src.layers import SimpleRNN, Dense, Dropout
from keras import Sequential, Input


In [47]:
class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, db_config, batch_size=32, shuffle=True):
        self.db_config = db_config
        # self.query = query
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.conn = psycopg2.connect(**db_config)
        self.cursor = self.conn.cursor()
        # self.cursor.execute(query)
        # self.data = self.cursor.fetchall()
        # self.indexes = np.arange(len(self.data))
        
        self.sequence_ids = []
        query = '''
        select subject_uuid, exec
        from sequence;
        '''
        self.cursor.execute(query)
        self.sequence_ids = self.cursor.fetchall()            
        print(f'loaded {len(self.sequence_ids)} sequence ids')
        self.indexes = np.arange(len(self.sequence_ids))
        
        self._setup()
        self.on_epoch_end()

        
    def _setup(self):
        query = '''
        select distinct exec
        from sequence;
        '''
        self.cursor.execute(query)
        self.classes = [ x[0] for x in self.cursor.fetchall() ]
        self.no_classes = len(self.classes)
        print(f'found {self.no_classes} classes')
        print(f'classes: {self.classes}')
        
        query = '''
        select distinct e.type
        from event e;
        '''
        self.cursor.execute(query)
        self.features = [ x[0] for x in self.cursor.fetchall() ]
        self.features.append('NONE')
        self.no_features = len(self.features)
        print(f'found {self.no_features} features')
        print(f'features: {self.features}')

    def get_no_features(self):
        return self.no_features
    
    def get_no_classes(self):
        return self.no_classes

    def __len__(self):
        # Calculate the number of batches per epoch
        return int(np.floor(len(self.sequence_ids) / self.batch_size))

    def __getitem__(self, index):
        # Generate one batch of data
        indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
        # batch_data = [self.data[k] for k in indexes]
        batch_data = []
        
        for k in indexes:
            subject_uuid, executable = self.sequence_ids[k]
            query = '''
            select e.type
            from event e
            where e.subject_uuid = %s
            and e.properties_map_exec = %s
            order by e.sequence_long;
            '''
            # result of this query is list of tuple (one item)
            self.cursor.execute(query, (subject_uuid, executable))
            data = self.cursor.fetchall()
            data = [x[0] for x in data]
            batch_data.append((executable, data))
        
        X, y = self.__data_generation(batch_data)
        return X, y

    def on_epoch_end(self):
        # Updates indexes after each epoch
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __data_generation(self, batch_data):
        # Generate data for one batch
        X = []
        y = []
        # find max sequences length
        max_len = 0
        for executable, sequence_data in batch_data:
            max_len = len(sequence_data) if len(sequence_data) > max_len else max_len 
        for executable, sequence_data in batch_data:
            # Assuming the last column is the target variable
            while len(sequence_data) < max_len:
                sequence_data.append('NONE')
            X.append(self.encode_data(sequence_data))
            y.append(self.encode_label(executable))
        X = np.array(X)
        y = np.array(y)
        return X, y

    def encode_label(self, label):
        #return np.unique(np.array(self.classes)).tolist().index(label)
        label_index = np.unique(np.array(self.classes)).tolist().index(label)
        one_hot = np.zeros(self.no_classes)
        one_hot[label_index] = 1
        return one_hot
    
    def encode_data(self, data):
        seq_len = len(data)
        new_data = np.zeros((seq_len, self.no_features))
        for i, feature in enumerate(data):
            new_data[i, self.features.index(feature)] = 1
        return new_data

    def close_connection(self):
        # Close the database connection
        self.cursor.close()
        self.conn.close()

# Example usage
db_config = {
    'dbname': 'cadets_e3',
    'user': 'rosendahl',
    'password': '',
    'host': 'localhost',
    'port': '5432'
}

# Initialize the data generator
data_gen = DataGenerator(db_config, batch_size=32, shuffle=True)

loaded 430863 sequence ids
found 135 classes
classes: ['adjkerntz', 'alpine', 'anvil', 'atrun', 'awk', 'basename', 'bash', 'bounce', 'bzcat', 'bzip2', 'cat', 'chkgrp', 'chmod', 'chown', 'cleanup', 'cmp', 'cp', 'cron', 'csh', 'cut', 'date', 'dd', 'devd', 'df', 'dhclient', 'diff', 'dmesg', 'egrep', 'env', 'expr', 'find', 'fortune', 'getty', 'grep', 'head', 'hostname', 'id', 'ifconfig', 'imapd', 'inetd', 'init', 'ipfstat', 'ipfw', 'ipop3d', 'jot', 'kenv', 'kill', 'kldstat', 'less', 'limits', 'links', 'local', 'locale', 'locate.code', 'lockf', 'login', 'ls', 'lsof', 'lsvfs', 'mail', 'mail.local', 'mailwrapper', 'main', 'makewhatis', 'master', 'minions', 'mkdir', 'mktemp', 'mlock', 'mount', 'msgs', 'mv', 'nawk', 'netstat', 'newsyslog', 'nginx', 'nice', 'nohup', 'ntpd', 'pEja72mA', 'pfctl', 'php-fpm', 'pickup', 'ping', 'pkg', 'postmap', 'procstat', 'proxymap', 'ps', 'pw', 'pwait', 'pwd_mkdb', 'python2.7', 'qmgr', 'resizewin', 'rm', 'route', 'screen', 'sed', 'sendmail', 'sh', 'sleep', 'smtp',

In [48]:
model = Sequential(layers=[
    Input(shape=(None, data_gen.get_no_features())),
    SimpleRNN(64, return_sequences=True),
    Dropout(0.2),
    SimpleRNN(64, return_sequences=True),
    Dropout(0.2),
    SimpleRNN(64, return_sequences=False),
    Dense(data_gen.get_no_classes(), activation='softmax'),
])

model.summary()

In [49]:
early_stop = EarlyStopping(monitor='val_loss', patience=5, verbose=1, mode='auto')
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

In [50]:
# Train the model
history = model.fit(data_gen, epochs=10, callbacks=[early_stop])
# Don't forget to close the connection after training
data_gen.close_connection()

Epoch 1/10
[1m    6/13464[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m27:08[0m 121ms/step - accuracy: 0.0017 - loss: 4.8823    

KeyboardInterrupt: 