<font size='6px'>
📝 <b>GAN Tutorial: Part 2 - GAN Tutorial Module (GTM)</b></font>

<font color='orange'> **written by Rai--** </font>

### Prerequisites (only for Colab users, otherwise please skip)

In [1]:
!pip install tensorflow==2.16.2

Collecting tensorflow==2.16.2
  Downloading tensorflow-2.16.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)
Collecting ml-dtypes~=0.3.1 (from tensorflow==2.16.2)
  Downloading ml_dtypes-0.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting tensorboard<2.17,>=2.16 (from tensorflow==2.16.2)
  Downloading tensorboard-2.16.2-py3-none-any.whl.metadata (1.6 kB)
Collecting keras>=3.0.0 (from tensorflow==2.16.2)
  Downloading keras-3.4.1-py3-none-any.whl.metadata (5.8 kB)
Collecting namex (from keras>=3.0.0->tensorflow==2.16.2)
  Downloading namex-0.0.8-py3-none-any.whl.metadata (246 bytes)
Collecting optree (from keras>=3.0.0->tensorflow==2.16.2)
  Downloading optree-0.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (47 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.8/47.8 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
Downloading tensorflow-2.16.2-cp310-cp310-manylinux_2_17_x

In [2]:
!git clone https://github.com/Seismic-DL-Research/seis-deep-learning --branch super_dev
%cd seis-deep-learning

Cloning into 'seis-deep-learning'...
remote: Enumerating objects: 1522, done.[K
remote: Counting objects: 100% (246/246), done.[K
remote: Compressing objects: 100% (153/153), done.[K
remote: Total 1522 (delta 165), reused 171 (delta 92), pack-reused 1276[K
Receiving objects: 100% (1522/1522), 5.70 MiB | 11.05 MiB/s, done.
Resolving deltas: 100% (965/965), done.
/content/seis-deep-learning


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
import tensorflow as tf

# 1. Retrieving Sample Data

We will fetch a dataset to train our GAN model later. This dataset has been prerpocessed by the author. To be able to keep up with this tutorial, please download the dataset from the following link: https://drive.google.com/drive/folders/1vqsSlJn5XhsIATnH17iUxMUcu3m_6Flg?usp=sharing. Afterwards, please make a folder named ```sample_dataset``` in the working directory of this notebook. Paste all of the downloaded data into the ```sample_dataset``` folder. Make sure that inside the folder, there is only ```.tfr``` data downlaoded from the link!

**Do not run this following code. Only works for the author!**

In [5]:
import os
import shutil

In [6]:
shutil.copytree('../drive/MyDrive/thesis/processed_tfrs', 'sample_dataset')

'sample_dataset'

# 2. Reading Dataset

I assume that you are already familiar with the TFRecord. If you want to learn more about TFRecord, please visit: https://colab.research.google.com/drive/1xH0pdQVC1Dv_wR1Co-VwTjdFqk6dB5d8

In [7]:
# our parse config
parse_config = {'data': tf.io.FixedLenFeature([], tf.string)}

In [8]:
# define mapping function
def map_reading(data):
  parsed_example = tf.io.parse_example(data, parse_config)
  mapped_data = tf.io.parse_tensor(parsed_example['data'], tf.float32)
  return mapped_data

In [10]:
# training dataset
p_train = tf.data.TFRecordDataset('sample_dataset/train-p.tfr').map(map_reading)
n_train = tf.data.TFRecordDataset('sample_dataset/train-n.tfr').map(map_reading)

# validation dataset
p_val = tf.data.TFRecordDataset('sample_dataset/valid-p.tfr').map(map_reading)
n_val = tf.data.TFRecordDataset('sample_dataset/valid-n.tfr').map(map_reading)

#  3. The ```gtm``` Module

GTM stands for GAN Tutorial Module, which is a module made by the author that offers easability for constructing GAN and training GAN. The author has reconstructed all of the methods that we have gone through into three main modules which are ```gtm.models.GAN``` , ```gtm.models.discriminative``` and ```gtm.models.generative```. These moduls are designed to be as easy as possible to specifically design and train GAN for seismic P-phase detection. As we already have covered all the theoretical and basic practices of GAN, now we will delve into how to leverage this module. We will no longer delve into math anymore from this point on. Let's import the GTM module.

In [11]:
import gan_tutorial_modules as gtm

## Configuring GAN

In [None]:
def gtm_update(gtm):
  pending_for_deletion = []
  for i in sys.modules:
    if i.split('.')[0] == 'gan_tutorial_modules':
      pending_for_deletion.append(i)

  for i in pending_for_deletion:
    del sys.modules[i]

  import gan_tutorial_modules as gtm
  return gtm

In [103]:
import sys
gtm = gtm_update(gtm)

To initiate our training, we configure our GAN model using ```gtm.models.GAN()```. Take a look at the parameters that is required to initiate the ```gan``` class object. I assume that these parameters are easy to understand if you have understood quite well the tutorial we have gone through.

In [13]:
gan = gtm.models.GAN(epoch=5,
                    batch_size=512,
                    window_length=350,
                    generative_optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
                    discriminative_optimizer= tf.keras.optimizers.Adam(learning_rate=1e-4),
                    generative_total_iterations=2,
                    discriminative_total_iterations=3,
                    generative_latent_sample_size=50,
                    generative_latent_sample_mean=0,
                    generative_latent_sample_stdev=0.6,
                    p_wave_dataset=p_train, # data assigned must contain any batch size!
                    n_wave_dataset=n_train, # data assigned must contain any batch size!
                    p_wave_dataset_val=p_val, # data assigned must contain any batch size!
                    n_wave_dataset_val=n_val # data assigned must contain any batch size!
                    )


We initiate the ```generative``` object class using ```gtm.models.Generative()```. The initiation process of this class requires ```gan``` object. This principle also applies for the ```discriminative``` object.

In [14]:
generative = gtm.models.Generative(gan)
discriminative = gtm.models.Discriminative(gan)

## Working with Models: Get Models and Update Models

By default, both of these objects already have sample models. These sample models are identical to the tutorial.

In [15]:
discriminative.model.summary()

In [16]:
generative.model.summary()

You can update the model by calling method ```update_model```. Let's update the generator model with the following model.

In [17]:
# designing a new generator model
new_model_input = tf.keras.layers.Input(shape=(25,))
cont = tf.keras.layers.Dense(300)(new_model_input)
new_model_output = tf.keras.layers.Dense(355)(cont)

# creating a new generator model
new_model = tf.keras.Model(inputs=[new_model_input], outputs=[new_model_output])

# updating generator model in the generative object
generative.update_model(new_model)

Invalid input shape! Expected 50 but obtained 25
Invalid output shape! Expected 350 but obtained 355


As you can see, we have errors. These errors are raised because our model does not align with the configuration that we have set in advance in the ```GAN``` object initiation's parameters. Let's fix it and see the updated model.

In [18]:
# designing a new generator model
new_model_input = tf.keras.layers.Input(shape=(50,))
cont = tf.keras.layers.Dense(300)(new_model_input)
new_model_output = tf.keras.layers.Dense(350)(cont)

# creating a new generator model
new_model = tf.keras.Model(inputs=[new_model_input], outputs=[new_model_output])

# updating generator model in the generative object
generative.update_model(new_model)

generative.model.summary()

Now the model is updated. You can also access the model via ```GAN``` object. The model in ```GAN``` object is automatically updated.

In [19]:
gan.generative_module.model.summary()

## Training with Sample Dataset

Let's revert back our generative model into the default one.

In [20]:
generative = gtm.models.Generative(gan)

At this point, we want to train our models, especially the discriminative model to recognize the P wave signal from the real seismic data. We continue using the sample models to get through this tutorial. Let's obtain the performance of our model at its initial state. Sure thing we have poor confusion matrix scores here.

In [21]:
gan.evaluate()

{'d_loss': 699.1911948067801,
 'true_positive': 0,
 'false_positive': 0,
 'true_negative': 14197,
 'false_negative': 14197}

Let's train them! All we need to do is passing ```gan.train()``` to our code. Remember that we have configured anything in the first place.

In [22]:
gan.train()

Epoch 1 out of 5
LOGS: L_D_val: 33.3356 | TP: 13959 | TN: 14098 | FP: 99 | FN: 238 

Found 194 batches per epoch
Epoch 2 out of 5


|██████████████████████████████| [00:46<00:00] L_G: 1261.2589 | L_D: 72.6214


LOGS: L_D_val: 72.8392 | TP: 12921 | TN: 14183 | FP: 14 | FN: 1276 

Epoch 3 out of 5


|██████████████████████████████| [00:46<00:00] L_G: 1801.7778 | L_D: 32.4380


LOGS: L_D_val: 23.9544 | TP: 14018 | TN: 14135 | FP: 62 | FN: 179 

Epoch 4 out of 5


|██████████████████████████████| [00:46<00:00] L_G: 2383.5386 | L_D: 14.0453


LOGS: L_D_val: 16.9280 | TP: 14002 | TN: 14155 | FP: 42 | FN: 195 

Epoch 5 out of 5


|██████████████████████████████| [00:45<00:00] L_G: 5840.4482 | L_D: 40.0805


LOGS: L_D_val: 34.0723 | TP: 13719 | TN: 14160 | FP: 37 | FN: 478 



WARNING! Sometimes TensorFlow's warning messages appear, disturbing the tqdm progress bar. Upon each epoch, we obtain the confusion matrix values and the discriminator loss in the validation dataset.

In [23]:
gan.evaluate()

{'d_loss': 34.07229719843183,
 'true_positive': 13719,
 'false_positive': 37,
 'true_negative': 14160,
 'false_negative': 478}

As you can see, there is an improvement. Now feel free to explore this module on your own! You can design your own model, hypertune it according to your necessity. Don't forget to save the model after you finished training.

**Furthur update is on the way! Visit https://github.com/Seismic-DL-Research/seis-deep-learning/tree/super_dev to see updates.**