# Initialization


In [None]:
#@title Google Drive
from google.colab import drive
from google.colab import files

drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
#@title Import Library
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
import glob
import os
from scipy.signal import fftconvolve,firwin, freqz,lfilter,lfilter_zi,firls
import tensorflow as tf
from tqdm.notebook import tnrange, tqdm
import absl.logging
absl.logging.set_verbosity(absl.logging.ERROR)

In [None]:
#@title electrodes
electrodes = ['FP1', 'AF7', 'AF3', 'F1', 'F3', 'F5', 'F7', 'FT7', 'FC5', 'FC3',' FC1', 'C1', 'C3','C5', 'T7', 'TP7', 'CP5', 'CP3', 'CP1','P1','P3', 'P5', 'P7', 'P9', 'RO7', 'PO3', 'O1', 'Iz','Oz','POz', 'Pz', 'CPZ', 'FPZ', 'FP2', 'AF8', 'AF4', 'AFZ','FZ', 'F2', 'F4', 'F6', 'F8', 'FT8', 'FC6', 'FC4', 'FC2', 'FCz','Cz', 'C2', 'C4', 'C6', 'T8', 'TP8', 'CP6', 'CP4', 'CP2', 'P2', 'P4', 'P6', 'P8', 'P10','PO8', 'PO4', 'O2','EMG1','EMG2','EMG3','EMG4']
electrodes.reverse()
path = "/content/drive/MyDrive/EEG-GAN/Dataset/cho_et_al_data/raw"

In [None]:
#@title Parameters
Fs = 512
patients = 10                       #@param {type:"integer"}
segment_length = 3                  #@param {type:"integer"}
alpha = True                       #@param {type:"boolean"}
concat = False                      #@param {type:"boolean"}
is_gp = False                       #@param {type:"boolean"}
channels = ['FC3']                  

batch_size = 256                    #@param {type:"integer"}
epochs = 200                        #@param {type:"integer"}
latent_dim = 200                    #@param {type:"integer"}
d_extra_steps = 1                   #@param {type:"integer"}
gp_weight = 10.0                    #@param {type:"number"}
loss = "bce"                        #@param ["bce", "wgan"]


load_pretrained = False             #@param {type:"boolean"}
save_path = '/content/drive/MyDrive/EEG-GAN/Model_Saves/EEG_GAN_2/'   #@param {type:"string"}
discriminator_save_path = os.path.join(save_path, 'discriminator.h5')
generator_save_path = os.path.join(save_path, 'generator.h5')

In [None]:
#@title Preprocessing Functions
def bandpass_firls(xn, lowcut2,lowcut1, highcut1, highcut2, fs):
    nyq = 0.5 * fs
    ntaps = fs if fs%2==1 else fs+1
    b = firls(ntaps, [0 ,lowcut2,lowcut1, highcut1, highcut2,nyq], [0,0,1,1,0,0] ,fs=fs)
    a=1
    zi = lfilter_zi(b, a)
    z, _ = lfilter(b, a, xn,zi=zi*xn[0])
    return z

def band_separate(data, fs, ntaps):
    delta = bandpass_firls(data,ntaps,0.1,0.5, 4, 5, fs)
    theta = bandpass_firls(data,ntaps,3.5,4, 8, 8.5, fs)
    alpha = bandpass_firls(data,ntaps,7,8, 12, 13, fs)
    sigma = bandpass_firls(data,ntaps,11.5,12, 16, 16.5, fs)
    beta = bandpass_firls(data,ntaps,15.5,16, 30, 30.5, fs)
    return delta, theta, alpha, sigma, beta

def filter_alpha(data, Fs):
  d = []
  for i in range (data.shape[0]):
    alpha = bandpass_firls(data[i], 7,8, 12, 13, Fs)
    d.append(alpha)
  data = np.stack(d, axis  = 0)
  return data

In [None]:
#@title Dataset Functions
def struct2dict(mat):
  data = dict()
  key = list(mat.keys())[-1]
  for i in range(len(mat[key][0,0])):
    data[mat[key][0,0].dtype.descr[i][0]] = mat[key][0,0][i]
  return data

def channel_normalize(data, axis):
  mean = np.expand_dims(data.mean(axis = axis), axis = axis)
  diff = np.expand_dims(data.max(axis = axis)- data.min(axis = axis), axis = axis)
  data = (data - mean)/diff
  return data

def import_data(file):
  mat = sio.loadmat(file)
  mat = struct2dict(mat)
  return mat

def preprocess(data, events,  Fs, seconds,  alpha, concat):
  if alpha:
    data = filter_alpha(data, Fs)
  data = channel_normalize(data, axis = 1)
  channels = data.shape[0]
  length = seconds*Fs
  #data = data[:, data.shape[1]%length :]
  if concat:
    #data = data.T.reshape(-1, length, channels)
    data = np.moveaxis(np.stack([data[:,event:event+Fs*seconds] for event in np.where(events== 1)[1]]),1,2)
  else:
    data = np.stack([data[:,event:event+Fs*seconds] for event in np.where(events== 1)[1]]).reshape(-1,length,1)
  return data

def channel_selector(data, channels, concat, electrodes):
  index = []
  if isinstance(channels, list):
    for channel in channels:
      try:
        index.append(electrodes.index(channel))
      except ValueError:
        print('channel doesn\'t exist')
    
    data = data[index]
  return data
def level_change_corr(data, Fs):
  for i in range(int(data.shape[1]/512/7)):
    mean = np.expand_dims(data[:,i*7:Fs: (i+1)*Fs*7].mean(axis = 1), axis = 1)
    data[:,i*7*Fs: (i+1)*Fs*7] = data[:,i*7*Fs: (i+1)*Fs*7] - mean
  return data

def create_dataset(files, Fs, seconds, alpha, channels, concat, electrodes):
  data = []
  labels=[]
  for i in tnrange(len(files)):
    mat = import_data(files[i])
    d1 = mat['movement_left']
    d2 = mat['movement_right']
    events = mat['movement_event']

    d1 = channel_selector(d1, channels, concat, electrodes)
    #d1 = level_change_corr(d1, Fs)
    d1 = preprocess(d1, events, Fs, seconds,  alpha, concat)

    d2 = channel_selector(d2, channels, concat, electrodes)
    #d2 = level_change_corr(d2, Fs)
    d2 = preprocess(d2, events, Fs, seconds,  alpha, concat)

    data.append(d1)
    data.append(d2)

    l1 = np.zeros((d1.shape[0],1))
    l2 = np.ones((d2.shape[0],1))

    labels.append(l1)
    labels.append(l2)

  data = np.concatenate(data)
  labels = np.concatenate(labels)
  return data, labels

In [None]:
#@title Visualizaton Functions
def normalize(data):
  mean = data.mean()
  diff = data.max()- data.min()
  data = (data - mean)/diff
  return data


def show_eeg_signal(Xf, ch_names = ['preds {}'.format(i) for i in range(batch_size)], fs = 512, seconds = 3, sensitivity = 4, spacing= 4, mode = 'norm'):
  Xf = Xf[:,:seconds*fs].T
  
  if mode == 'max':
    Xf = normalize(Xf)*sensitivity
  else:
    Xf = channel_normalize(Xf, axis = 1)*sensitivity
  ch_len = Xf.shape[1]
  print(ch_len)
  t = np.arange(seconds*fs)
  plt.figure(figsize=(5*seconds,ch_len))
  plt.plot(t,Xf+np.arange(-int(np.ceil(ch_len/2)),int(np.floor(ch_len/2)))*spacing)
  plt.xlim([t[0],t[-1]])
  plt.xlabel('time (sec)')
  plt.yticks(np.arange(-int(np.ceil(ch_len/2)),int(np.floor(ch_len/2)))*spacing,ch_names)
  plt.grid()
  plt.title('Xf: 14 channel - EEG Signal')
  plt.show()

In [None]:
#@title Dataset Creation
files = glob.glob(path + '/*.mat')
data, labels = create_dataset(files[:patients], Fs = Fs, seconds = segment_length, alpha = alpha, channels = channels, concat = concat, electrodes = electrodes)
data_mean = np.mean(data)
data_std = np.std(data)
trial_shape = data[0].shape
dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.shuffle(1024).batch(batch_size)
del data

  0%|          | 0/10 [00:00<?, ?it/s]

# Model

In [None]:
#@title Import
import tensorflow as tf
from tensorflow.keras.metrics import Mean
from tensorflow.keras.losses import mse
from tensorflow.keras.models import load_model, save_model

import tensorflow.keras.backend as K
from tensorflow.keras.callbacks import Callback

from tensorflow.keras.layers import Dense, Dropout, Flatten, Reshape, LeakyReLU, Conv1D, MaxPooling1D, Activation, Input, UpSampling1D, BatchNormalization, Lambda, Concatenate, Layer, Embedding, Conv1DTranspose, AveragePooling1D, LayerNormalization
from tensorflow.keras.models import Model, load_model
import numpy as np
from sklearn.preprocessing import LabelEncoder
from keras.utils import to_categorical
from tensorflow.keras.optimizers.legacy import Adam
import os
from matplotlib import pyplot as plt

In [None]:
#@title Gradient Penalty
def gradient_penalty(discriminator, real_images, fake_images):
    """Calculates the gradient penalty.

    This loss is calculated on an interpolated image
    and added to the discriminator loss.
    """
    batch_size = real_images.shape[0]
    # Get the interpolated image
    alpha = tf.random.normal([batch_size, 1, 1], 0.0, 1.0)
    diff = fake_images - real_images
    interpolated = real_images + alpha * diff

    with tf.GradientTape() as gp_tape:
        gp_tape.watch(interpolated)
        # 1. Get the discriminator output for this interpolated image.
        pred = discriminator(interpolated, training=True)

    # 2. Calculate the gradients w.r.t to this interpolated image.
    grads = gp_tape.gradient(pred, [interpolated])[0]
    # 3. Calculate the norm of the gradients.
    norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2]))
    gp = tf.reduce_mean((norm - 1.0) ** 2)
    return gp

In [None]:
#@title Discriminator-Generator
# discriminator block

def conv_block(
    x,
    filters,
    activation = LeakyReLU(0.2),
    kernel_size=9,
    strides= 1,
    padding="same",
    use_bias=True,
    use_bn=False,
    use_dropout=False,
    drop_value=0.5
):
    x = Conv1D(filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias)(x)   
    x = Conv1D(filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias)(x) 
    x = AveragePooling1D(pool_size=2, strides= 2, padding='same')(x)

    if use_bn:
        x = LayerNormalization(axis = -1)(x)
    x = activation(x)
    if use_dropout:
        x = Dropout(drop_value)(x)
    return x


def get_discriminator_model(trial_shape):
    
#     print(f'k_size - {k_size}, s_size - {s_size}, ')
    
    inputs = Input(shape=trial_shape)
    x = Conv1D(filters = 50, kernel_size = 1, strides= 1, padding='same', use_bias=True)(inputs)
    for _ in range(7):
      x = conv_block(x,50)
    x = Flatten()(x)
    x = Dropout(0.2)(x)
    x = Dense(1)(x)

    d_model = Model(inputs, x, name="discriminator")
    return d_model


# generator block
def upsample_block(
    x,
    filters,
    activation=LeakyReLU(0.2),
    kernel_size=9,
    strides=1,
    padding="same",
    use_bn=False,
    use_bias=True,

):
    x = UpSampling1D(size=2)(x)
    x = Conv1D(filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias)(x)   
    x = activation(x)
    x = LayerNormalization(axis = -1)(x)
    x = Conv1D(filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias)(x) 
    x = activation(x)
    x = LayerNormalization(axis = -1)(x)
    return x


def get_generator_model(noise_dim, trial_shape):


    
    starting_shape = (12, 50)
#     print(f'strides_1 - {strides_1}, kernel_size_1 - {kernel_size_1}, kernel_size_2 - {kernel_size_2}, intermediate_time_dim - {intermediate_time_dim}, starting_shape - {starting_shape}')
    
    inputs = Input(shape=(noise_dim,)) 
    x = Dense(np.prod(starting_shape), use_bias=False)(inputs)
    x = Reshape(starting_shape)(x)
    
    for _ in range(7):
      x = upsample_block(x, 50)

    x = Conv1D(filters = 1, kernel_size = 1, strides= 1, padding='same', use_bias=True)(x)
    g_model = Model(inputs, x, name="generator")
    return g_model

In [None]:
#@title Discriminator-Generator 7sec
# discriminator block

def conv_block(
    x,
    filters,
    activation = LeakyReLU(0.2),
    kernel_size=9,
    strides= 1,
    padding="same",
    use_bias=True,
    use_bn=False,
    use_dropout=False,
    drop_value=0.5
):
    x = Conv1D(filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias)(x)   
    x = Conv1D(filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias)(x) 
    x = AveragePooling1D(pool_size=2, strides= 2, padding='same')(x)

    if use_bn:
        x = LayerNormalization(axis = -1)(x)
    x = activation(x)
    if use_dropout:
        x = Dropout(drop_value)(x)
    return x


def get_discriminator_model(trial_shape):
    
#     print(f'k_size - {k_size}, s_size - {s_size}, ')
    
    inputs = Input(shape=trial_shape)
    x = Conv1D(filters = 50, kernel_size = 1, strides= 1, padding='same', use_bias=True)(inputs)
    for _ in range(7):
      x = conv_block(x,50)
    x = Flatten()(x)
    x = Dropout(0.2)(x)
    x = Dense(1)(x)

    d_model = Model(inputs, x, name="discriminator")
    return d_model


# generator block
def upsample_block(
    x,
    filters,
    activation=LeakyReLU(0.2),
    kernel_size=9,
    strides=1,
    padding="same",
    use_bn=False,
    use_bias=True,

):
    x = UpSampling1D(size=2)(x)
    x = Conv1D(filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias)(x)   
    x = activation(x)
    x = LayerNormalization(axis = -1)(x)
    x = Conv1D(filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias)(x) 
    x = activation(x)
    x = LayerNormalization(axis = -1)(x)
    return x


def get_generator_model(noise_dim, trial_shape):


    
    starting_shape = (4*segment_length, 50)
#     print(f'strides_1 - {strides_1}, kernel_size_1 - {kernel_size_1}, kernel_size_2 - {kernel_size_2}, intermediate_time_dim - {intermediate_time_dim}, starting_shape - {starting_shape}')
    
    inputs = Input(shape=(noise_dim,)) 
    x = Dense(np.prod(starting_shape), use_bias=False)(inputs)
    x = Reshape(starting_shape)(x)
    
    for _ in range(7):
      x = upsample_block(x, 50)

    x = Conv1D(filters = 1, kernel_size = 1, strides= 1, padding='same', use_bias=True)(x)
    g_model = Model(inputs, x, name="generator")
    return g_model

In [None]:
#@title Training Functions
def discriminator_loss(real_img, fake_img):
    real_loss = tf.reduce_mean(real_img)
    fake_loss = tf.reduce_mean(fake_img)
    #return tf.sqrt(tf.math.squared_difference(fake_loss, real_loss))
    return fake_loss - real_loss

def generator_loss(fake_img):
    return -tf.reduce_mean(fake_img)
    
def save_model(generator, discriminator, save_path):
  if not os.path.exists(save_path):
      os.makedirs(save_path)
  
  discriminator_save_path = os.path.join(save_path, 'discriminator.h5')
  generator_save_path = os.path.join(save_path, 'generator.h5')
  generator.save(generator_save_path, save_format='h5')
  discriminator.save(discriminator_save_path, save_format='h5')

def model_load(save_path):
  
  discriminator_save_path = os.path.join(save_path, 'discriminator.h5')
  generator_save_path = os.path.join(save_path, 'generator.h5')
  generator = load_model(generator_save_path)
  discriminator = load_model(discriminator_save_path) 
  return generator, discriminator

In [None]:
generator_optimizer = Adam(learning_rate=0.00001, beta_1=0.5, beta_2=0.9)
discriminator_optimizer = Adam(learning_rate=0.00001, beta_1=0.5, beta_2=0.9)

if loss == 'wgan':
  d_loss_fn = discriminator_loss
  g_loss_fn = generator_loss
else:
  d_loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)
  g_loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)

generator_loss = []
discriminator_loss = []

if load_pretrained:
  generator, discriminator = model_load(save_path)
else:
  discriminator = get_discriminator_model(trial_shape)
  generator = get_generator_model(latent_dim, trial_shape)


In [None]:
#@title Train Model
def train_gan(dataset, generator, discriminator, d_loss_fn, g_loss_fn, epochs,gp_weight, save_path, save_epoch = 50):

  for epoch in range(1,epochs+1): 
    for j,batch in tqdm(enumerate(dataset)):
      d_mean_loss = 0
      g_mean_loss = 0
      real_images = batch
      # real_images = batch[0]  
      # real_labels = batch[1]

      for i in range(d_extra_steps):
          random_latent_vectors = tf.random.normal(shape=(real_images.shape[0], latent_dim))
          
          with tf.GradientTape() as tape:
              
              # fake_images = generator([random_latent_vectors, real_labels], training=False)  # should training for the gen be True here? - training=False
              # fake_logits = discriminator([fake_images, real_labels], training=True)
              # real_logits = discriminator([real_images, real_labels], training=True)

              fake_images = generator(random_latent_vectors, training=False)  # should training for the gen be True here? - training=False
              print(fake_images.shape)
              fake_logits = discriminator(fake_images, training=True)
              real_logits = discriminator(real_images, training=True)

              fake_labels = tf.zeros_like(fake_logits)
              real_labels = tf.ones_like(fake_logits)

              if loss == 'wgan':
                d_cost = d_loss_fn(real_img=real_logits, fake_img=fake_logits)
              else:
                disc_logits = tf.concat((fake_logits, real_logits),0)
                disc_labels = tf.concat((fake_labels, real_labels),0)
                ind = list(np.random.permutation(disc_labels.shape[0]))
                disc_logits = disc_logits[ind]
                disc_labels = disc_labels[ind]
                d_cost = d_loss_fn(disc_labels, disc_logits)
              
              if is_gp:
                gp = gradient_penalty(discriminator, real_images, fake_images)
                d_loss = d_cost + gp * gp_weight
              else:
                d_loss = d_cost
                print(d_loss)
                #d_mean_loss += d_loss.numpy()
          #d_mean = d_mean_loss/(j+1)
          #discriminator_loss.append(d_mean)
          d_gradient = tape.gradient(d_loss, discriminator.trainable_variables)
          discriminator_optimizer.apply_gradients(zip(d_gradient, discriminator.trainable_variables))

      random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
      
      with tf.GradientTape() as tape:
          # generated_images = generator([random_latent_vectors, real_labels], training=True)
          # gen_img_logits = discriminator([generated_images, real_labels], training=True)
          
          generated_images = generator(random_latent_vectors, training=True)
          gen_img_logits = discriminator(generated_images, training=True)
          gen_labels = tf.zeros_like(gen_img_logits)
          if loss == 'wgan':
            g_loss = g_loss_fn(gen_img_logits)
          else:
            g_loss = g_loss_fn(gen_labels, gen_img_logits)
          g_mean_loss += g_loss.numpy()

      g_mean = g_mean_loss/(j+1)
      generator_loss.append(g_mean)
      gen_gradient = tape.gradient(g_loss, generator.trainable_variables)
      generator_optimizer.apply_gradients(zip(gen_gradient, generator.trainable_variables))
    print(f'\rEpoch: {epoch} ======= Generator Loss: {g_mean_loss} ======= Discriminator Loss: {d_mean_loss}', end = '', flush = True)
    if epoch % save_epoch == 0:
      print('Saving Model ...... ')
      save_model(generator, discriminator, save_path)
      show_eeg_signal(generated_images[:,:,0].numpy())
      print(f'', end = '\r')

In [None]:
train_gan(dataset, generator, discriminator, d_loss_fn, g_loss_fn, epochs, gp_weight, save_path= save_path, save_epoch = 10)

0it [00:00, ?it/s]

(256, 1536, 1)


OverflowError: ignored

# Scratch Code

In [None]:
files = glob.glob(path + '/*.mat')
mat = import_data(files[0])
d1 = mat['movement_left']

In [None]:
discriminator.save('/content/disc', save_format = 'h5')

In [None]:
dataset = tf.data.Dataset.from_tensor_slices(data)
del data

In [None]:
for batch in dataset.take(1):
  print(batch.shape)

In [None]:
discriminator_save_path = os.path.join(save_path, 'discriminator.h5')
discriminator.load_weights(discriminator_save_path)

In [None]:
desc = get_discriminator_model(trial_shape)

In [None]:
gen, disc = model_load(save_path)

In [None]:
latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
preds = gen(latent_vectors)
preds.shape

In [None]:
os.path.exists(save_path)

In [None]:
show_eeg_signal(preds[:,:,0].numpy())