In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import tensorflow as tf

rand_seed=42
np.random.seed(rand_seed)
tf.random.set_seed(rand_seed)

used_classes = ['Airliner',
            'Sorrel',
            'Jack-o’-lantern',
            'Panda',
            'Anemone fish']

num_classes = len(used_classes)
num_chanels = 14

data_dir = 'eeg_processed/'


In [None]:
from tensorflow.keras.layers import Input, Conv2D, Flatten
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Dense, Dropout, UpSampling1D, Conv2DTranspose, BatchNormalization, UpSampling2D
from tensorflow.keras.layers import LeakyReLU, ReLU, Concatenate, Reshape
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.optimizers import Adam

In [None]:
def get_generator(latent_dim, eeg_dim):
    init = RandomNormal(stddev=0.02)
    in_eeg = Input(shape=(eeg_dim,))
    in_lat = Input(shape=(latent_dim,))
    li = Concatenate()([in_lat, in_eeg])
    li = Dense(4*4*256)(li)
    li = LeakyReLU(alpha=0.2)(li)
    li = Reshape((4, 4, 256))(li)

    # upsample to 8x8
    gen = Conv2DTranspose(256, (4,4), strides=(2,2), kernel_initializer=init, padding='same')(li)
    gen = BatchNormalization()(gen)
    gen = LeakyReLU(alpha=0.2)(gen)
    
    # upsample to 16x16
    gen = Conv2DTranspose(256, (4,4), strides=(2,2), kernel_initializer=init, padding='same')(gen)
    gen = BatchNormalization()(gen)
    gen = LeakyReLU(alpha=0.2)(gen)
    
    # upsample to 32x32
    gen = Conv2DTranspose(128, (4,4), strides=(2,2), kernel_initializer=init, padding='same')(gen)
    gen = BatchNormalization()(gen)
    gen = LeakyReLU(alpha=0.2)(gen)

    # 32x32x3
    out_layer = Conv2D(3, (3,3), kernel_initializer=init, activation='tanh', padding='same')(gen)
    
    model = Model([in_lat, in_eeg], out_layer)
    return model


def get_discriminator(in_shape=(64, 64, 3), eeg_dim=56):
    init = RandomNormal(stddev=0.02)
    in_img = Input(shape=in_shape)

    # downsample to 16x16
    conv = Conv2D(128, (4,4), strides=(2,2), kernel_initializer=init, padding='same')(in_img)
    conv = BatchNormalization()(conv)
    conv = LeakyReLU(alpha=0.2)(conv)
    
    # downsample to 8x8
    conv = Conv2D(256, (4,4), strides=(2,2), kernel_initializer=init, padding='same')(conv)
    conv = BatchNormalization()(conv)
    conv = LeakyReLU(alpha=0.2)(conv)
    
    # downsample to 4x4
    conv = Conv2D(256, (4,4), strides=(2,2), kernel_initializer=init, padding='same')(conv)
    conv = BatchNormalization()(conv)
    conv = LeakyReLU(alpha=0.2)(conv)
    

    flt = Flatten()(conv)  
    
    in_eeg = Input(shape=(eeg_dim, ))
    
    dis = Concatenate()([flt, in_eeg])
    
    # classifier   
    dis = Dense(512)(dis)
    dis = ReLU()(dis)
    out_layer = Dense(1, activation='sigmoid')(dis)
    
    # compile model
    model = Model([in_img, in_eeg], out_layer)
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
    
    return model


def define_gan(g_model, d_model):
    # make weights in the discriminator not trainable
    d_model.trainable = False
    
    gen_lat, gen_eeg = g_model.input
    
    gen_out = g_model.output
    
    gan_out = d_model([gen_out, gen_eeg])

    # compile model
    model = Model([gen_lat, gen_eeg], gan_out)
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)
    
    return model

In [None]:
from tensorflow.keras.preprocessing.image import load_img
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.preprocessing.image import array_to_img

def crop_all(imgs, labels, target_size=(64,64), step=(2, 2)):
    new_imgs = []
    new_labels = []
    for i in range(0, imgs.shape[1] - target_size[0] + 1, step[0]):
        for j in range(0, imgs.shape[2] - target_size[1] + 1, step[1]):
            new_imgs.append(imgs[:,i:i+target_size[0],j:j+target_size[1],:])
            new_labels.append(labels)
            
    new_imgs = np.concatenate(new_imgs)
    new_labels = np.concatenate(new_labels)
    
    return new_imgs, new_labels

def load_real_samples(imgs_path, cropping_mode=(80, 80), crop_steps=(2,2), hor_flip=True):
    imgs = []
    labels = []
    for i, cl in enumerate(used_classes):
        for img_name in [x for x in os.listdir(imgs_path + cl)  if x.endswith(".jpg")]:
            if cl == 'Anemone fish' and img_name == '22.jpg':
                    print('mdaush')
                    continue
            img = load_img(imgs_path + cl + '/' + img_name, target_size=cropping_mode)
            # convert to numpy array
            img_array = img_to_array(img)
            imgs.append(img_array)
            labels.append(i)
            
    imgs = np.asarray(imgs).astype('float32')
    labels = np.asarray(labels)
    # scale from [0,255] to [-1,1]
    imgs = (imgs - 127.5) / 127.5
    imgs, labels = crop_all(imgs, labels, (32,32), crop_steps)
    
    if hor_flip:
        #print(imgs.shape)
        fliped = np.flip(imgs,axis=2)
        imgs = np.concatenate([imgs, fliped])     
        labels = np.concatenate([labels, labels])
        #print(imgs.shape)
        #print(fliped[0,0,:,0])
        #print(imgs[0,0,:,0])
        #print(imgs[imgs.shape[0]//2,0,:,0])
        
        #check it again
    
    indices = np.arange(imgs.shape[0])
    np.random.shuffle(indices)
    imgs = imgs[indices]
    labels = labels[indices]
    
    return [imgs, labels]

In [None]:
# select real samples
def generate_real_samples(dataset, n_samples, avg_eeg):
    # choose random instances
    imgs, labels = dataset
    ix = np.random.randint(0, imgs.shape[0], n_samples)
        
    X, labels = imgs[ix], labels[ix]
    X_eeg = avg_eeg[labels]
    # generate 'real' class labels (1)
    y = np.ones((n_samples, 1))
    #print(imgs[0].shape)
    return [X, X_eeg], y


# select real samples in order
def generate_real_samples_ordered(dataset, n_samples, n_iter, avg_eeg):
    # choose random instances
    a = n_samples * n_iter
    b = a + n_samples
    imgs, labels = dataset
    X, X_eeg = imgs[a:b], avg_eeg[labels[a:b]]
    # generate 'real' class labels (1)
    y = np.ones((n_samples, 1))
    #print(imgs[0].shape)
    return [X, X_eeg], y


# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples, avg_eeg):
    # generate points in the latent space
    x_input = np.random.randn(latent_dim * n_samples)
    # reshape into a batch of inputs for the network
    x_input = x_input.reshape(n_samples, latent_dim)
    ix = np.random.randint(0, num_classes, size=n_samples)
    eeg_input = avg_eeg[ix]
    return [x_input, eeg_input]


# use the generator to generate n fake examples, with class labels
def generate_fake_samples(g_model, latent_dim, n_samples, avg_eeg):
    # generate points in latent space
    x_input, eeg_input = generate_latent_points(latent_dim, n_samples, avg_eeg)
    # predict outputs
    X = g_model.predict([x_input, eeg_input])
    # create 'fake' class labels (0)
    y = np.zeros((n_samples, 1))
    return [X, eeg_input], y


# create and save a plot of generated images
def save_plot(examples, epoch, n=7):
    # scale from [-1,1] to [0,1]
    examples = (examples + 1) / 2.0
    # plot images
    for i in range(n * n):
        # define subplot
        plt.subplot(n, n, 1 + i)
        # turn off axis
        plt.axis('off')
        # plot raw pixel data
        plt.imshow(examples[i])
    # save plot to file
    filename = 'generated_plot_e%03d.png' % (epoch+1)
    plt.savefig(filename)
    plt.close()


# create a line plot of loss for the gan and save to file
def plot_history(d1_hist, d2_hist, g_hist, a1_hist, a2_hist):
    # plot loss
    plt.subplot(2, 1, 1)
    plt.plot(d1_hist, label='d-real')
    plt.plot(d2_hist, label='d-fake')
    plt.plot(g_hist, label='gen')
    plt.legend()
    # plot discriminator accuracy
    plt.subplot(2, 1, 2)
    plt.plot(a1_hist, label='acc-real')
    plt.plot(a2_hist, label='acc-fake')
    plt.legend()
    # save plot to file
    plt.savefig('plot_line_plot_loss.png')
    plt.close()


# evaluate the discriminator, plot generated images, save generator model
def summarize_performance(epoch, g_model, d_model, dataset, avg_eeg, latent_dim, n_samples=100):
    # prepare real samples
    [X_real, eeg_real], y_real = generate_real_samples(dataset, n_samples, avg_eeg)
    # evaluate discriminator on real examples
    _, acc_real = d_model.evaluate([X_real, eeg_real], y_real, verbose=0)
    # prepare fake examples
    [x_fake, eeg_fake], y_fake = generate_fake_samples(g_model, latent_dim, n_samples, avg_eeg)
    # evaluate discriminator on fake examples
    _, acc_fake = d_model.evaluate([x_fake, eeg_fake], y_fake, verbose=0)
    # summarize discriminator performance
    print('>Accuracy real: %.0f%%, fake: %.0f%%' % (acc_real*100, acc_fake*100))
    # save the generator model tile file
    filename = 'generator_model_%03d.h5' % (epoch+1)
    g_model.save(filename)
    # save plot
    save_plot(x_fake, epoch)


# train the generator and discriminator
def train(g_model, d_model, gan_model, dataset, latent_dim, avg_eeg, n_epochs=100, n_batch=128):
    bat_per_epo = int(dataset[0].shape[0] / n_batch)
    print('ds size={}, n_btch={}, bat_per_epo={}'.format(dataset[0].shape[0], n_batch, bat_per_epo))
    half_batch = int(n_batch / 2)
    # prepare lists for storing stats each iteration
    d1_hist, d2_hist, g_hist, a1_hist, a2_hist = list(), list(), list(), list(), list()
    # manually enumerate epochs
    for i in range(n_epochs):
        # enumerate batches over the training set
        for j in range(bat_per_epo):
            # get randomly selected 'real' samples
            [X_real, eeg_real], y_real = generate_real_samples(dataset, half_batch, avg_eeg)
            # update discriminator model weights
            d_loss1, d_acc1 = d_model.train_on_batch([X_real, eeg_real], y_real)
            # generate 'fake' examples
            [X_fake, eeg_fake], y_fake = generate_fake_samples(g_model, latent_dim, half_batch, avg_eeg)
            # update discriminator model weights
            d_loss2, d_acc2 = d_model.train_on_batch([X_fake, eeg_fake], y_fake)
            
            # prepare points in latent space as input for the generator
            [X_gan, eeg_gan] = generate_latent_points(latent_dim, n_batch, avg_eeg)
            # create inverted labels for the fake samples
            y_gan = np.ones((n_batch, 1))
            
            # update the generator via the discriminator's error
            g_loss = gan_model.train_on_batch([X_gan, eeg_gan], y_gan)
            # summarize loss on this batch
            print('>%d, %d/%d, d1=%.3f, d2=%.3f g=%.3f, ac d1=%.3f, ac d2=%.3f' %
                        (i+1, j+1, bat_per_epo, d_loss1, d_loss2, g_loss, d_acc1, d_acc2))
            d1_hist.append(d_loss1)
            d2_hist.append(d_loss2)
            g_hist.append(g_loss)
            a1_hist.append(d_acc1)
            a2_hist.append(d_acc2)

        # evaluate the model performance, sometimes
        if (i + 1) % 10 == 0:
            summarize_performance(i, g_model, d_model, dataset, avg_eeg, latent_dim)

In [None]:
lat_dim = 100
imgs_size = 32
imgs_folder = 'images/images_data/'
eeg_f_dim = 56
train_epochs = 300
train_batch = 16
gen_model = get_generator(lat_dim, eeg_f_dim)
dis_model = get_discriminator((imgs_size, imgs_size, 3), eeg_f_dim)

GAN_model = define_gan(gen_model, dis_model)

data = load_real_samples(imgs_folder, (40,40), (2,2))

In [None]:
train(gen_model, dis_model, GAN_model, data, lat_dim, avg_eeg_cl, train_epochs, train_batch)