In [16]:
from scipy.io import loadmat
from glob import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pymatreader import read_mat
import tensorflow as tf
from tensorflow import keras
from keras import layers
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import eeg_funcs

In [18]:
filepath = glob('data/train/*.mat')
d_array, d_labels = eeg_funcs.get_comp2_array(filepath, True, True)
print(d_array.shape)

Creating RawArray with float64 data, n_channels=65, n_times=26328
    Range : 0 ... 26327 =      0.000 ...   109.696 secs
Ready.
No data channels found. The highpass and lowpass values in the measurement info will not be updated.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.1 - 20 Hz

IIR filter parameters
---------------------
Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 32 (effective, after forward-backward)
- Cutoffs at 0.10, 20.00 Hz: -6.02, -6.02 dB

540 events found
Event IDs: [1 2]
Not setting metadata
540 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 540 events and 160 original time points ...
0 bad epochs dropped
Creating RawArray with float64 data, n_channels=65, n_times=26328
    Range : 0 ... 26327 =      0.000 ...   109.696 secs
Ready.
No data channels found. The highpass and lowpass values in the measurement info will n

In [19]:
def normalize_data(data):
    data = (data - np.min(data)) / (np.max(data) - np.min(data))
    result = 2 * data - 1
    result = np.float32(result)
    return result

In [20]:
d_array = d_array*1e6
d_array = normalize_data(d_array)
print(d_array.shape)

(12600, 160, 8)


In [21]:
d_array = tf.expand_dims(d_array, axis=3)
print(d_array.shape)
print(type(d_array))

(12600, 160, 8, 1)
<class 'tensorflow.python.framework.ops.EagerTensor'>


In [22]:
data_array_2 = d_array
data_array_2 = tf.convert_to_tensor(data_array_2, dtype=tf.float32)

In [23]:
BUFFER_SIZE = 30000
BATCH_SIZE = 64
train_dataset = tf.data.Dataset.from_tensor_slices(data_array_2).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

In [24]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt

In [25]:
discriminator = keras.Sequential(
    [
        keras.Input(shape=data_array_2.shape[1:]),
        layers.Conv2D(8, kernel_size=4, strides=1, padding="same", activation="tanh"),
        layers.MaxPool2D(pool_size=2),
        layers.Conv2D(16, kernel_size=4, strides=1, padding="same", activation="tanh"),
        layers.MaxPool2D(pool_size=2),
        layers.Flatten(),
        layers.Dense(128, activation="tanh"),
        layers.Dense(1, activation="sigmoid"),
    ],
    name="discriminator",
)
discriminator.summary()

Model: "discriminator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_5 (Conv2D)           (None, 160, 8, 8)         136       
                                                                 
 max_pooling2d_2 (MaxPooling  (None, 80, 4, 8)         0         
 2D)                                                             
                                                                 
 conv2d_6 (Conv2D)           (None, 80, 4, 16)         2064      
                                                                 
 max_pooling2d_3 (MaxPooling  (None, 40, 2, 16)        0         
 2D)                                                             
                                                                 
 flatten_1 (Flatten)         (None, 1280)              0         
                                                                 
 dense_4 (Dense)             (None, 128)             

In [26]:
latent_dim = 128

generator = keras.Sequential(
    [
        keras.Input(shape=(latent_dim,)),
        layers.Dense(256, activation="tanh"),
        layers.Dense(40*2*1, activation="tanh"),
        layers.BatchNormalization(),
        layers.Reshape((40, 2, 1)),
        layers.UpSampling2D(size=4),
        layers.Conv2D(8, 4, strides=1, padding='same', use_bias=False, activation="tanh"),
        layers.Conv2D(8, 4, strides=1, padding='same', use_bias=False, activation="tanh"),
        layers.Conv2D(1, 4, strides=1, padding='same', use_bias=False, activation="tanh"),
    ],
    name="generator",
)
generator.summary()

Model: "generator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_6 (Dense)             (None, 256)               33024     
                                                                 
 dense_7 (Dense)             (None, 80)                20560     
                                                                 
 batch_normalization_1 (Batc  (None, 80)               320       
 hNormalization)                                                 
                                                                 
 reshape_1 (Reshape)         (None, 40, 2, 1)          0         
                                                                 
 up_sampling2d_1 (UpSampling  (None, 160, 8, 1)        0         
 2D)                                                             
                                                                 
 conv2d_7 (Conv2D)           (None, 160, 8, 8)         12

In [27]:
class GAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super(GAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn
        self.d_loss_metric = keras.metrics.Mean(name="d_loss")
        self.g_loss_metric = keras.metrics.Mean(name="g_loss")

    @property
    def metrics(self):
        return [self.d_loss_metric, self.g_loss_metric]

    def train_step(self, real_images):
        # Sample random points in the latent space
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Decode them to fake images
        generated_images = self.generator(random_latent_vectors)

        # Combine them with real images
        combined_images = tf.concat([generated_images, real_images], axis=0)

        # Assemble labels discriminating real from fake images
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )
        # Add random noise to the labels - important trick!
        labels += 0.05 * tf.random.uniform(tf.shape(labels))

        # Train the discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))

        # Sample random points in the latent space
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Assemble labels that say "all real images"
        misleading_labels = tf.zeros((batch_size, 1))

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(random_latent_vectors))
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        # Update metrics
        self.d_loss_metric.update_state(d_loss)
        self.g_loss_metric.update_state(g_loss)
        return {
            "d_loss": self.d_loss_metric.result(),
            "g_loss": self.g_loss_metric.result(),
        }

In [29]:
class GANMonitor(keras.callbacks.Callback):
    def __init__(self, num_img=1, latent_dim=100):
        self.num_img = num_img
        self.latent_dim = latent_dim

    def on_epoch_end(self, epoch, logs=None):
        random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))
        generated_images = self.model.generator(random_latent_vectors)
        generated_images.numpy()
        for i in range(self.num_img):
            img = generated_images[i]
            img1 = img[:,:,0]
            img2 = img[:,5,:]
            plt.plot(img2)
            plt.savefig("Papertestimages/2D//single/generated_img_%03d_%d.png" % (epoch, i))
            plt.clf()
            plt.plot(img1)
            plt.savefig("Papertestimages/2D//full/generated_img_%03d_%d.png" % (epoch, i))
            plt.clf()
            ##img = keras.preprocessing.image.array_to_img(generated_images[i])
            ##img.save("generated_img_%03d_%d.png" % (epoch, i))

In [31]:
checkpoint_filepath = 'checkpointPaper/2D/checkpoint.{epoch:02d}'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='g_loss',
    mode='min',
    save_best_only=True)

In [32]:
epochs = 2500  # In practice, use ~100 epochs

gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.2),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.2),
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=False),
)

gan.fit(
    train_dataset, epochs=epochs, callbacks=[GANMonitor(num_img=1, latent_dim=latent_dim), model_checkpoint_callback]
)

Epoch 1/2500
Epoch 2/2500
Epoch 3/2500
Epoch 4/2500
Epoch 5/2500
Epoch 6/2500
Epoch 7/2500
Epoch 8/2500
Epoch 9/2500
Epoch 10/2500
Epoch 11/2500
Epoch 12/2500
Epoch 13/2500
Epoch 14/2500
Epoch 15/2500
Epoch 16/2500
Epoch 17/2500
Epoch 18/2500
Epoch 19/2500
Epoch 20/2500
Epoch 21/2500
Epoch 22/2500
Epoch 23/2500
Epoch 24/2500
Epoch 25/2500
Epoch 26/2500
Epoch 27/2500
Epoch 28/2500
Epoch 29/2500
Epoch 30/2500
Epoch 31/2500
Epoch 32/2500
Epoch 33/2500
Epoch 34/2500
Epoch 35/2500
Epoch 36/2500
Epoch 37/2500
Epoch 38/2500
Epoch 39/2500
Epoch 40/2500
Epoch 41/2500
Epoch 42/2500
Epoch 43/2500
Epoch 44/2500
Epoch 45/2500
Epoch 46/2500
Epoch 47/2500
Epoch 48/2500
Epoch 49/2500
Epoch 50/2500
Epoch 51/2500
Epoch 52/2500
Epoch 53/2500
Epoch 54/2500
Epoch 55/2500
Epoch 56/2500
Epoch 57/2500
Epoch 58/2500
Epoch 59/2500
Epoch 60/2500
Epoch 61/2500
Epoch 62/2500
Epoch 63/2500
Epoch 64/2500
Epoch 65/2500
Epoch 66/2500
Epoch 67/2500
Epoch 68/2500
Epoch 69/2500
Epoch 70/2500
Epoch 71/2500
Epoch 72/2500
E

KeyboardInterrupt: 

<Figure size 432x288 with 0 Axes>