In [None]:
from keras.layers import *
from keras.models import Model
from keras.models import Sequential
from keras.optimizers import Adam, RMSprop
from keras.activations import relu
from tqdm import tqdm
import numpy as np
import scipy.stats as sc
import matplotlib.pyplot as plt
from google.colab import drive
drive.mount("/content/gdrive")

In [None]:
optimiser_d = Adam(0.001)
optimiser_g = Adam(0.000005)
ecg_shape = (400,1)
noise_length = 400


def create_discriminator():
  discriminator = Sequential()
  Input_shape = (400,1)

  discriminator.add(Reshape((400,1),input_shape = (ecg_shape)))

  discriminator.add(Conv1D(filters = 64, kernel_size = 3, strides = 1))
  discriminator.add(MaxPooling1D(pool_size = 3, strides = 1))
  discriminator.add(LeakyReLU())


  discriminator.add(Conv1D(filters = 128, kernel_size = 3, strides = 1))
  discriminator.add(MaxPooling1D(pool_size = 3, strides = 1))
  discriminator.add(LeakyReLU())


  discriminator.add(Conv1D(filters = 256, kernel_size = 3, strides = 1))
  discriminator.add(MaxPooling1D(pool_size = 3, strides = 2))
  discriminator.add(LeakyReLU())


  discriminator.add(Conv1D(filters = 512, kernel_size = 3, strides = 1))
  discriminator.add(MaxPooling1D(pool_size = 3, strides = 2))
  discriminator.add(LeakyReLU())


  discriminator.add(Flatten())
   
  discriminator.add(Dense(units=1, activation='sigmoid'))

  discriminator.compile(loss='binary_crossentropy', metrics=['accuracy'], optimizer=optimiser_d)
  discriminator.summary()

  return discriminator

def create_generator():

  generator = Sequential()

  Input_shape = (400,1)
  generator.add(Bidirectional(LSTM(64, return_sequences=True), input_shape = Input_shape))
  generator.add(Dropout(0.1))
  generator.add(Bidirectional(LSTM(64, return_sequences=True)))
  generator.add(Dropout(0.1))

  generator.add(Flatten())
  generator.add(Dense(ecg_shape[0], activation='tanh'))

  generator.compile(loss='binary_crossentropy', optimizer=optimiser_g)
  generator.summary()

  return generator


def create_GAN(discriminator, generator):
  discriminator.trainable = False

  GAN = Sequential()
  GAN.add(Input(shape=(400,1)))
  GAN.add(generator)
  GAN.add(discriminator)

  GAN.compile(loss='binary_crossentropy', optimizer=optimiser_g)
  GAN.summary()
  return GAN



In [None]:
def show_results(generator,norm_value):
  noise = tf.random.normal([400,1])
  print(noise.shape)
  noise1 = tf.expand_dims(noise, axis=0)
  print(noise1.shape)
  ecgs = generator.predict(noise1)
  denorm_ecgs = ecgs * norm_value

  #plt.plot(noise)
  plt.figure()
  plt.ylabel('Amplitude(mV)')
  plt.xlabel('Timesteps')
  plt.xlim((0,400))
  plt.plot(denorm_ecgs[0])
  plt.grid()
  plt.show()


In [None]:
import tensorflow as tf

def compute_kernel(x, y):
    x_size = tf.shape(x)[0]
    y_size = tf.shape(y)[0]
    dim = tf.shape(x)[1]
    tiled_x = tf.tile(tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1]))
    tiled_y = tf.tile(tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1]))
    return tf.exp(-tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float64))

def compute_mmd(x, y, sigma_sqr=1.0):
    x_kernel = compute_kernel(x, x)
    y_kernel = compute_kernel(y, y)
    xy_kernel = compute_kernel(x, y)
    return tf.reduce_mean(x_kernel) + tf.reduce_mean(y_kernel) - 2 * tf.reduce_mean(xy_kernel)

In [None]:

def mmd_loss():

  real_data_set =  np.load('')
  idx = np.random.randint(0, 1000, 800)
  real_data = real_data_set[:,idx]

  #print(real_data.shape)

  fake_data = np.zeros((400,800))
  data = np.zeros(400)
  for i in range(800):
    noise = tf.random.normal([400,1])
    #print(noise.shape)
    noise1 = tf.expand_dims(noise, axis=0)
    #print(noise1.shape)
    ecgs = generator.predict(noise1)
    denorm_ecgs = ecgs * norm_value

    for j in range(400):
      data[j] = denorm_ecgs[0,j]
    fake_data[:,i] = data


  real_data = np.transpose(real_data)
  fake_data = np.transpose(fake_data)

  mmd_loss = compute_mmd(real_data,fake_data)
  print(mmd_loss)

  return mmd_loss

In [None]:
import tensorflow as tf

batch_size = 64
epochs = 5000
discriminator = create_discriminator()
generator = create_generator()
GAN = create_GAN(discriminator, generator)

data = np.load('')
maxim = np.zeros(len(data))
minim = np.zeros(len(data))
#scale = np.zeros(2*len(data))

for I in range(len(data)):
  maxim[I] = max(data[I])
  minim[I] = -min(data[I])
#print(max(maxim))
#print(max(minim))
scale = np.concatenate([maxim,minim])
norm_value = max(scale)
data = data / norm_value
noise_array =  np.array(tf.random.normal([1000, 400, 1]))
#rint(noise_array.shape)

d_losses, d_accuracy, g_losses = [], [], []

mmd_losses = np.zeros(50)

for epoch in tqdm(range(epochs)):
  idx = np.random.randint(0, high = 1000 , size= batch_size)
  real_ecgs = data[:,idx]
  real_ecgs = np.transpose(real_ecgs)
  real_ecgs = tf.expand_dims(real_ecgs, axis=2)

  idx = np.random.randint(0,high= 1000 ,size=batch_size)
  noise = noise_array[idx,:,:]

  fake_ecgs = generator.predict(noise,batch_size=batch_size)
  fake_ecgs = tf.expand_dims(fake_ecgs, axis=2)

  X_real = real_ecgs
  X_fake = fake_ecgs

  y_real = np.ones(batch_size)
  y_fake = np.zeros(batch_size)

  discriminator.trainable = True
  d_loss_real = discriminator.train_on_batch(X_real ,y_real)
  d_loss_fake = discriminator.train_on_batch(X_fake ,y_fake)
  d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
  discriminator.trainable = False

  y = np.ones(batch_size)
  g_loss = GAN.train_on_batch(noise, y)

  d_losses.append(d_loss[0])
  d_accuracy.append(d_loss[1])
  g_losses.append(g_loss)

  if epoch%20 == 0:
    show_results(generator,norm_value)  
    show_results(generator,norm_value)
    show_results(generator,norm_value)
    show_results(generator,norm_value)
    show_results(generator,norm_value)
    show_results(generator,norm_value)
    show_results(generator,norm_value)
    mmd_loss()


plt.figure()
plt.ylabel('Loss Value')
plt.xlabel('Epochs')
plt.xlim((0,5000))
plt.plot(d_losses, 'r', label='disc_loss')
plt.plot(g_losses, 'b', label='gen_loss')
plt.show()

plt.plot(d_accuracy)
plt.show()

plt.plot(mmd_losses)
plt.show()

show_results(generator,norm_value)
show_results(generator,norm_value)
show_results(generator,norm_value)
show_results(generator,norm_value)
show_results(generator,norm_value)
show_results(generator,norm_value)

In [None]:
mmd_loss()

In [None]:
show_results(generator,norm_value)
show_results(generator,norm_value)
show_results(generator,norm_value)
show_results(generator,norm_value)
show_results(generator,norm_value)
show_results(generator,norm_value)
show_results(generator,norm_value)
show_results(generator,norm_value)
show_results(generator,norm_value)
show_results(generator,norm_value)
show_results(generator,norm_value)
show_results(generator,norm_value)