In [1]:
import sys

sys.path.append("..")

import tensorflow as tf

import numpy as np

import GPyOpt

import argparse

from utils import tf_config, preprocess_data, search_algorithms, train

from models import fmri_ae, eeg_to_fmri, uniconv_fmri

from layers import locally_connected

import matplotlib.pyplot as plt

import gc

import os

from sklearn.model_selection import train_test_split, KFold

import time

dataset="01"
memory_limit=1500
n_individuals=10
interval_eeg=10

tf_config.set_seed(seed=42)
tf_config.setup_tensorflow(device="GPU", memory_limit=memory_limit)

with tf.device('/CPU:0'):
    train_data, _ = preprocess_data.dataset(dataset, n_individuals=n_individuals,
                                            interval_eeg=interval_eeg, 
                                            ind_volume_fit=False,
                                            standardize_fmri=True,
                                            iqr=False,
                                            verbose=True)
    eeg_train, fmri_train =train_data

I: Starting to Load Data
I: Finished Loading Data
I: Pairs Created


In [2]:
eeg_train = eeg_train[:100]
fmri_train = fmri_train[:100]

## Build fMRI AE

In [3]:
theta = (0.002980911194116198, 0.0004396489214334123, (9, 9, 4), (1, 1, 1), 16, (7, 7, 7), 4, True, True, True, True, 3, 1)

In [4]:
#unroll hyperparameters
learning_rate=float(theta[0])
weight_decay = float(theta[1])
kernel_size = theta[2]
stride_size = theta[3]
batch_size=int(theta[4])
latent_dimension=theta[5]
n_channels=int(theta[6])
max_pool=bool(theta[7])
batch_norm=bool(theta[8])
skip_connections=bool(theta[9])
dropout=bool(theta[10])
n_stacks=int(theta[11])
outfilter=int(theta[12])
local=True

In [5]:
class EEG_to_fMRI(tf.keras.Model):
    
    def __init__(self, latent_shape, input_shape, kernel_size, stride_size, n_channels,
                maxpool=True, weight_decay=0.000, skip_connections=False,
                n_stacks=2, local=True, seed=None, fmri_args=None):
        super(EEG_to_fMRI, self).__init__()
        
        self.training=False
        
        self.fmri_ae = fmri_ae.fMRI_AE(*fmri_args)
        
        self.build_encoder(latent_shape, input_shape, kernel_size, 
                            stride_size, n_channels, maxpool=maxpool,
                            weight_decay=weight_decay, skip_connections=skip_connections,
                            n_stacks=n_stacks, local=local, seed=seed)
        self.build_decoder()
        
    def build_encoder(self, latent_shape, input_shape, kernel_size, 
                            stride_size, n_channels, maxpool=True,
                            weight_decay=0.000, skip_connections=False,
                            n_stacks=2, local=True, seed=None):
        
        input_shape = tf.keras.layers.Input(shape=input_shape)
        
        x = input_shape
        previous_block_x = input_shape

        for i in range(n_stacks):
            x = fmri_ae.stack(x, previous_block_x, tf.keras.layers.Conv3D, 
                        kernel_size, stride_size, n_channels,
                        maxpool=maxpool, batch_norm=batch_norm, weight_decay=weight_decay, 
                        skip_connections=skip_connections, seed=seed)
            previous_block_x=x

        if(local):
            operation=tf.keras.layers.Conv3D
        else:
            operation=LocallyConnected3D

        x = tf.keras.layers.Flatten()(x)
        x = tf.keras.layers.Dense(latent_shape[0]*latent_shape[1]*latent_shape[2], 
                                    kernel_initializer=tf.keras.initializers.GlorotUniform(seed=seed))(x)
        if(dropout):
            x = tf.keras.layers.Dropout(0.5)(x)
        x = tf.keras.layers.Reshape(latent_shape)(x)
        
        self.eeg_encoder = tf.keras.Model(input_shape, x)
        self.fmri_encoder = self.fmri_ae.encoder
        
    def build_decoder(self):
        self.decoder = self.fmri_ae.decoder
        
    def build(self, input_shape1, input_shape2):
        self.eeg_encoder.build(input_shape=input_shape1)
        
        self.fmri_ae.build(input_shape=input_shape2)        
        self.fmri_encoder.build(input_shape=input_shape2)
        
        model.built=True
    
    def call(self, X):
        x1, x2 = X
        
        z1 = self.eeg_encoder(x1)
        z2 = self.fmri_encoder(x2)
        
        if(self.training):
            return [self.decoder(z1), z1, z2]
        return self.decoder(z1)

In [6]:
def mse_cosine(y_true, y_pred):
    return tf.reduce_mean(((y_pred[0] - y_true)**2)/2, axis=(1,2,3)) + cosine(y_pred[1], y_pred[2])

def cosine(x1, x2):
    return 1.0 - tf.tensordot(x1,x2, [[1,2,3],[3,2,1]])/(tf.norm(x1, ord=2)*tf.norm(x2, ord=2))

In [7]:
with tf.device('/CPU:0'):
    model = EEG_to_fMRI(latent_dimension, eeg_train.shape[1:], (10,20,2), (1,1,1), 4,
                        maxpool=True, weight_decay=0.000, skip_connections=False,
                        n_stacks=2, local=True, seed=None, 
                        fmri_args = (latent_dimension, fmri_train.shape[1:], 
                        kernel_size, stride_size, n_channels, 
                        max_pool, batch_norm, weight_decay, skip_connections,
                        n_stacks, True, False, outfilter, dropout))
    
    model.build(eeg_train.shape, fmri_train.shape)
    
    optimizer = tf.keras.optimizers.Adam(learning_rate)
    loss_fn = mse_cosine

    train_set = tf.data.Dataset.from_tensor_slices((eeg_train, fmri_train)).batch(batch_size)

In [None]:
model.training=True
loss_history = train.train(train_set, model, optimizer, 
                            loss_fn, epochs=10, 
                            u_architecture=True,
                            val_set=None, verbose=True, verbose_batch=True)[0]

Batch ... with loss: 1.5032337
Batch ... with loss: 539.1995
Batch ... with loss: 11.167909
Batch ... with loss: 1.4469583
Batch ... with loss: 1.4374226
Batch ... with loss: 1.4373065
Batch ... with loss: 1.2625042
Epoch 1 with loss: 79.63640827792031
Batch ... with loss: 1.4384742
Batch ... with loss: 1.4364989
Batch ... with loss: 1.4347734
Batch ... with loss: 1.4323633
Batch ... with loss: 1.4300501
Batch ... with loss: 1.4288636
Batch ... with loss: 1.2443109
Epoch 2 with loss: 1.40647634438106
Batch ... with loss: 1.4272562
Batch ... with loss: 1.4259734
Batch ... with loss: 1.4254628
Batch ... with loss: 1.4239901
Batch ... with loss: 1.422209
Batch ... with loss: 1.4212545
Batch ... with loss: 1.236349
Epoch 3 with loss: 1.3974993058613367
Batch ... with loss: 1.4196662
Batch ... with loss: 1.4183391
