In [None]:
# https://towardsdatascience.com/hands-on-generative-adversarial-networks-gan-for-signal-processing-with-python-ff5b8d78bd28

In [None]:
from numpy import hstack
from numpy import zeros
from numpy import ones
from numpy.random import rand
from numpy.random import randn
import numpy as np
import pandas as pd
from sklearn.utils import resample
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer
from sklearn.preprocessing import LabelEncoder
from collections import Counter
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Input
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,LSTM
from matplotlib import pyplot
import matplotlib.pyplot as plt

In [None]:
%run ./read_file.ipynb

In [None]:
# https://stackoverflow.com/questions/68036975/valueerror-shape-must-be-at-least-rank-3-but-is-rank-2-for-node-biasadd
# config for rank error in lstm
tf.keras.backend.set_image_data_format("channels_last")

pd.set_option('display.max_columns', None)

In [None]:
ds = spark.read.parquet(*["s3a://sapient-bucket-trusted/prod/graph/encoded/real/23Sep3/*"]).cache()

In [None]:
tot = ds.count()

In [None]:
ds.groupBy("mal_trace") \
    .count() \
    .withColumnRenamed('count', 'cnt_per_group') \
    .withColumn('perc_of_count_total', (col('cnt_per_group') / tot) * 100 ) \
    .sort("perc_of_count_total").show()

In [None]:
# Set Config
embedding_dim = 64
max_length = 6
sequence_length = 6
max_features = 30
padding_type = 'post'
trunc_type = 'post'
num_samples = 10

In [None]:
tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=max_features)

In [None]:
ds_events = ds.select('event_sequence').rdd.flatMap(lambda x: x).collect()

In [None]:
# get vocab for full dataset 
tokenizer.fit_on_texts(ds_events)

In [None]:
ds_lim = ds.limit(num_samples).cache()
ds_events = ds_lim.select('event_sequence').rdd.flatMap(lambda x: x).collect()

In [None]:
# get only malicious data
ds_mal = ds.filter( col('mal_trace') == 1).cache()

In [None]:
dm_events = ds_mal.select('event_sequence').rdd.flatMap(lambda x: x).collect()

In [None]:
dm_labels = ds_mal.select('mal_trace').rdd.flatMap(lambda x: x).collect()

In [None]:
# Get our training data word index
word_index = tokenizer.word_index
vocab_count = len(word_index)

In [None]:
vocab_count

In [None]:
# one hot encode the data
dm_sequences = tokenizer.texts_to_sequences(dm_events)
dm_padded = tf.keras.utils.pad_sequences(dm_sequences, maxlen=max_length, padding=padding_type, truncating=trunc_type)
seq_enc_tensor = tf.one_hot(dm_padded, vocab_count)

In [None]:
seq_enc_tensor[0][0]

In [None]:
seq_enc_tensor.shape

In [None]:
seq_enc_tensor.shape
# shape - (input_len, sequence_length, vocab_size)
input_len = seq_enc_tensor.shape[0]
sequence_length = seq_enc_tensor.shape[1]

In [None]:
def define_generator():
    model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(sequence_length, vocab_count)),
    tf.keras.layers.LSTM(128, return_sequences=True),
    tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(vocab_count, activation='softmax')),
    ])

    # Compile the model
    model.compile(loss='categorical_crossentropy', optimizer='adam')

    return model

def define_discriminator():
    model = tf.keras.Sequential([
            tf.keras.layers.Input(shape=(sequence_length, vocab_count)),
            tf.keras.layers.LSTM(128, return_sequences=False),
            tf.keras.layers.Dense(2, activation='softmax'),
    ])

    # Compile the model
    model.compile(loss='sparse_categorical_crossentropy', optimizer='adam')

    return model

def define_gan(generator, discriminator):
    discriminator.trainable = False
    model = tf.keras.Sequential()
    model.add(generator)
    model.add(discriminator)
    model.compile(loss='sparse_categorical_crossentropy', optimizer='adam')
    
    return model

def generate_latent_space():
    n = tf.random.uniform(shape=[input_len, sequence_length, vocab_count], minval=1, maxval=vocab_count, dtype=tf.int32)
    return n

def generate_fake_samples(generator):
    # generate points in latent space & pass through generator
    x = generator.predict(generate_latent_space(), verbose=0)
    # create class labels
    y = zeros((input_len,1))
    return x, y

def generate_real_samples():
    x = seq_enc_tensor
    # create class labels
    y = ones((input_len,1))
    return x, y

def train(g_model, d_model, gan_model, epochs=5, n_eval=20):
    d_acc_history = []
    d_loss_history = []
    g_acc_history = []
    g_loss_history = []

    d_metric = tf.keras.metrics.SparseCategoricalAccuracy()
    g_metric = tf.keras.metrics.CategoricalAccuracy()

    # manually enumerate epochs
    for i in range(epochs):
        # prepare real samples
        x_real, y_real = generate_fake_samples(g_model)
        # prepare fake examples using the generator
        x_fake, y_fake = generate_fake_samples(g_model)

        # update discriminator
        d_real_loss = d_model.train_on_batch(x_real, y_real)
        d_fake_loss = d_model.train_on_batch(x_fake, y_fake)
        d_loss = 0.5 * (d_real_loss + d_fake_loss)

        d_real_pred = d_model.predict_on_batch(x_real)
        d_fake_pred = d_model.predict_on_batch(x_fake)

        d_metric.update_state(tf.concat([y_real, y_fake], axis=0), tf.concat([d_real_pred, d_fake_pred], axis=0))
        d_acc = d_metric.result().numpy()

        # prepare points in latent space as input for the generator
        x_gan = generate_latent_space()
        y_gan = ones((input_len, 1))
        g_loss = gan_model.train_on_batch(x_gan, y_gan)

        # update the generator via the discriminator's error
        gan_pred = gan_model.predict_on_batch(x_gan)
        g_metric.update_state(tf.one_hot(tf.cast(y_gan, tf.int32), depth=2), gan_pred)
        g_acc = g_metric.result().numpy()

        d_acc_history.append(d_acc)
        d_loss_history.append(d_loss)
        g_acc_history.append(g_acc)
        g_loss_history.append(g_loss)

        if i % n_eval == 0:
            print(f"Epoch {i}: Discriminator Loss: {d_loss}, Discriminator Accuracy: {d_acc}, Generator Loss: {g_loss}, Generator Accuracy: {g_acc}")

    return d_acc_history, d_loss_history, g_acc_history, g_loss_history

In [None]:
generator = define_generator()
discriminator = define_discriminator()
gan = define_gan(generator, discriminator)

In [None]:
# Generate a new sequence
new_seq = generator.predict(seq_enc_tensor)

In [None]:
generator.train_on_batch(seq_enc_tensor, seq_enc_tensor)

In [None]:
# Generate a new sequence using the generator model
new_seq_enc_tensor = generator.predict(seq_enc_tensor)
new_seq_enc = tf.argmax(new_seq_enc_tensor, axis=-1)
new_seq_texts = tokenizer.sequences_to_texts(new_seq_enc.numpy())

In [None]:
real_labels = np.ones((input_len,1))

In [None]:
seq_enc_tensor[0][0]

In [None]:
new_seq_enc_tensor[0][0]

In [None]:
new_seq[0][0]

In [None]:
new_seq_enc_list = new_seq_enc_tensor.tolist()

In [None]:
[np.argmax(i, axis=-1) for i in new_seq_enc_list][0]

In [None]:
print(new_seq_texts[:6])

In [None]:
discriminator.train_on_batch(new_seq_enc_tensor, real_labels)

In [None]:
discriminator.predict(new_seq_enc_tensor)

In [None]:
gan.train_on_batch(new_seq_enc_tensor, real_labels)

In [None]:
epochs = 101
d_acc_history, d_loss_history, g_acc_history, g_loss_history = train(generator, discriminator, gan, epochs=epochs)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(range(1, epochs + 1), d_loss_history, label='Discriminator Loss')
ax1.plot(range(1, epochs + 1), g_loss_history, label='Generator Loss')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.legend()

ax2.plot(range(1, epochs + 1), d_acc_history, label='Discriminator Accuracy')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Accuracy')
ax2.legend()

plt.show()