# Bayesian Optimization for fMRI autoencoder

## Loading data

In [1]:
import sys

sys.path.append("..")

import tensorflow as tf

import tensorflow_probability as tfp

import numpy as np

import GPyOpt

import argparse

from utils import tf_config, preprocess_data, search_algorithms, train, bnn_utils, outlier_utils, eeg_utils, viz_utils

from models import fmri_ae

import matplotlib.pyplot as plt

import gc

import os

from sklearn.model_selection import train_test_split, KFold

import time

raw_eeg=False#time or frequency features? raw-time nonraw-frequency
resampling=False
dataset="01"
if(dataset=="01"):
    n_volumes=300-3
if(dataset=="02"):
    n_volumes=170-3
memory_limit=1500
n_individuals=10
n_individuals_train=8
n_individuals_test=2
#parametrize the interval eeg?
interval_eeg=6

tf_config.set_seed(seed=42)
tf_config.setup_tensorflow(device="GPU", memory_limit=memory_limit)

with tf.device('/CPU:0'):
    train_data, test_data = preprocess_data.dataset(dataset, n_individuals=n_individuals, 
                                                    interval_eeg=interval_eeg, 
                                                    ind_volume_fit=False,
                                                    standardize_fmri=True,
                                                    iqr=False,
                                                    verbose=True)
    _, fmri_train = train_data
    _, fmri_test = test_data

I: Starting to Load Data
I: Finished Loading Data
I: Pairs Created


## Hyperparameters to optimize

In [2]:
theta = (0.002980911194116198, 0.0004396489214334123, (9, 9, 4), (1, 1, 1), 1, (7, 7, 7), 4, True, True, True, True, 5, 1)

## Unpack hyperparameters

In [3]:
#unroll hyperparameters
learning_rate=float(theta[0])
weight_decay = float(theta[1])
kernel_size = theta[2]
stride_size = theta[3]
batch_size=int(theta[4])
latent_dimension=theta[5]
n_channels=int(theta[6])
max_pool=bool(theta[7])
batch_norm=bool(theta[8])
skip_connections=bool(theta[9])
dropout=bool(theta[10])
n_stacks=int(theta[11])
outfilter=int(theta[12])
local=True

## Build model

In [4]:
na_specification = ([(9,9,4),(9,9,4),(9,9,4),(9,9,4),(9,9,4),(9,9,4)], 
                    [(1,1,1),(1,1,1),(1,1,1),(1,1,1),(1,1,1),(1,1,1)],
                   True,
                   (2,2,1),
                   (1,1,1))

In [5]:
def block18(x, operation, kernel_size, stride_size, n_channels,
            maxpool=True, batch_norm=True, weight_decay=0.000,  padding="valid",
            maxpool_k=None, maxpool_s=None,
            seed=None):

    x = operation(filters=n_channels, kernel_size=kernel_size, strides=stride_size,
                    kernel_regularizer=tf.keras.regularizers.L2(weight_decay),
                    bias_regularizer=tf.keras.regularizers.L2(weight_decay),
                    kernel_initializer=tf.keras.initializers.GlorotUniform(seed=seed),
                    padding=padding)(x)
    if(maxpool):
        x = tf.keras.layers.MaxPool3D(pool_size=maxpool_k, strides=maxpool_s)(x)
    if(batch_norm):
        x = tf.keras.layers.BatchNormalization()(x)

    return tf.keras.layers.ReLU()(x)


def skip_block18(x, skip_x, operation, kernel_size, stride_size, n_channels,
                maxpool=True, batch_norm=True, weight_decay=0.000, padding="valid",
                maxpool_k=None, maxpool_s=None,
                seed=None):

    skip_x = operation(filters=n_channels, kernel_size=kernel_size, strides=stride_size,
                    kernel_regularizer=tf.keras.regularizers.L2(weight_decay),
                    bias_regularizer=tf.keras.regularizers.L2(weight_decay),
                    kernel_initializer=tf.keras.initializers.GlorotUniform(seed=seed),
                    padding=padding)(skip_x)

    if(maxpool):
        skip_x = tf.keras.layers.MaxPool3D(pool_size=maxpool_k, strides=maxpool_s)(skip_x)
    if(batch_norm):
        skip_x = tf.keras.layers.BatchNormalization()(skip_x)

    x = tf.keras.layers.Add()([x, skip_x])

    return tf.keras.layers.ReLU()(x)

def stack18(x, previous_block_x, operation, kernel_size, stride_size, n_channels,
                        maxpool=True, batch_norm=True, 
                        weight_decay=0.000, skip_connections=False,
                        maxpool_k=None, maxpool_s=None,
                        seed=None):
    #downsampling block    
    x = block18(x, operation, kernel_size, stride_size, n_channels,
            maxpool=maxpool, batch_norm=batch_norm, 
            maxpool_k=maxpool_k, maxpool_s=maxpool_s,
            weight_decay=weight_decay, padding="valid",
            seed=seed)

    #non downsampling block
    x = block18(x, operation, 3, 1, n_channels,
            maxpool=False, batch_norm=batch_norm, 
            weight_decay=weight_decay, padding="same",
            seed=seed)

    #skip connection
    if(skip_connections):
        x = skip_block18(x, previous_block_x, operation, 
                        kernel_size, stride_size, n_channels,
                        maxpool=maxpool, batch_norm=batch_norm,
                        maxpool_k=maxpool_k, maxpool_s=maxpool_s,
                        weight_decay=weight_decay, padding="valid",
                        seed=seed)

    return x


class fMRI_AE(tf.keras.Model):
    """
        NA_specification - tuple - (list1, list2, bool, tuple1, tuple2)
                                    * list1 - kernel sizes
                                    * list2 - stride sizes
                                    * bool - maxpool
                                    * tuple1 - kernel size of maxpool
                                    * tuple2 - stride size of maxpool
                                    Example:
                                    na = ([(2,2,2), (2,2,2)], [(1,1,1), (1,1,1)], True, (2,2,2), (1,1,1))
                                    na is a neural architecture with 2 layers, kernel of size 2 for all 3 dimensions
                                    stride of size 1 for all dimensions, between each layer a max pooling operation 
                                    is applied with kernel size 2 for all dimensions and stride size 1 for all dimensions

    """
    def __init__(self, latent_shape, input_shape, na_spec, n_channels,
                        batch_norm=True, weight_decay=0.000, skip_connections=False,
                        local=True, local_attention=False, outfilter=0, dropout=False, seed=None):


        super(fMRI_AE, self).__init__()

        self.build_encoder(latent_shape, input_shape, na_spec, n_channels,
                        batch_norm=batch_norm, weight_decay=weight_decay, skip_connections=skip_connections,
                        local=local, local_attention=local_attention, dropout=dropout, seed=seed)

        self.build_decoder(outfilter=outfilter, seed=seed)

    def build_encoder(self, latent_shape, input_shape, na_spec, n_channels,
                        batch_norm=True, weight_decay=0.000, skip_connections=False,
                        local=True, local_attention=False, dropout=False, seed=None):

        self.latent_shape = latent_shape
        self.in_shape = input_shape


        input_shape = tf.keras.layers.Input(shape=input_shape)

        x = input_shape
        previous_block_x = input_shape

        for i in range(len(na_spec[0])):
            x = stack18(x, previous_block_x, tf.keras.layers.Conv3D, 
                        na_spec[0][i], na_spec[1][i], n_channels,
                        maxpool=na_spec[2], batch_norm=batch_norm, weight_decay=weight_decay, 
                        maxpool_k=na_spec[3], maxpool_s=na_spec[4],
                        skip_connections=skip_connections, seed=seed)
            previous_block_x=x
            print(x.shape)

        if(local):
            operation=tf.keras.layers.Conv3D
        else:
            operation=LocallyConnected3D

        #x = block(x, operation, (7,7,7), stride_size, n_channels,
        #x = block18(x, operation, (7,7,7), stride_size, n_channels,
        #        maxpool=maxpool, batch_norm=batch_norm, weight_decay=weight_decay, seed=seed)

        x = tf.keras.layers.Flatten()(x)
        x = tf.keras.layers.Dense(self.latent_shape[0]*self.latent_shape[1]*self.latent_shape[2], 
                                    kernel_initializer=tf.keras.initializers.GlorotUniform(seed=seed))(x)
        if(dropout):
            x = tf.keras.layers.Dropout(0.5)(x)
        x = tf.keras.layers.Reshape(self.latent_shape)(x)

        if(local_attention):
            #x = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2, attention_axes=(1, 2, 3))(x,x)
            x = tf.keras.layers.MultiHeadAttention(num_heads=n_channels, key_dim=x.shape[1]*x.shape[2]*x.shape[3], attention_axes=(1, 2, 3),
                                                kernel_initializer=tf.keras.initializers.GlorotUniform(seed=seed))(x,x)

        self.encoder = tf.keras.Model(input_shape, x)

    def build_decoder(self, outfilter=0, seed=None):
        input_shape = tf.keras.layers.Input(shape=self.latent_shape)

        x = tf.keras.layers.Flatten()(input_shape)

        #upsampling
        x = tf.keras.layers.Dense(self.in_shape[0]*self.in_shape[1]*self.in_shape[2],
                                    kernel_initializer=tf.keras.initializers.GlorotUniform(seed=seed))(x)
        x = tf.keras.layers.Reshape(self.in_shape)(x)

        #filter
        if(outfilter == 1):
            x = tf.keras.layers.Conv3D(filters=1, kernel_size=1, strides=1,
                                    kernel_initializer=tf.keras.initializers.GlorotUniform(seed=seed))(x)
        elif(outfilter == 2):
            x = LocallyConnected3D(filters=1, kernel_size=1, strides=1, implementation=3,
                                    kernel_initializer=tf.keras.initializers.GlorotUniform(seed=seed))(x)

        self.decoder = tf.keras.Model(input_shape, x)    

    def encode(self, X):
        return self.encoder(X)

    def decode(self, Z):
        return self.decoder(Z)

    def call(self, X):
        if(not self.encoder.built):
            self.encoder.build(X.shape)

        return self.decode(self.encode(X))


In [7]:
import importlib
importlib.reload(fmri_ae)


with tf.device('/CPU:0'):
    #build model
    model = fmri_ae.fMRI_AE(latent_dimension, fmri_train.shape[1:], kernel_size, stride_size, n_channels,
                            maxpool=True,
                        batch_norm=batch_norm, weight_decay=weight_decay, skip_connections=skip_connections,
                        local=True, local_attention=False, outfilter=outfilter, dropout=dropout)
    
    model.build(input_shape=fmri_train.shape)

    #train model
    optimizer = tf.keras.optimizers.Adam(learning_rate)
    loss_fn = tf.keras.losses.MSE#replace

    train_set = tf.data.Dataset.from_tensor_slices((fmri_train, fmri_train)).batch(batch_size)
    test_set = tf.data.Dataset.from_tensor_slices((fmri_test, fmri_test)).batch(1)

(None, 46, 46, 24, 4)
(None, 28, 28, 18, 4)
(None, 21, 21, 12, 4)


## Train model

In [7]:
loss_history = train.train(train_set, model, optimizer, 
            loss_fn, epochs=10, 
            val_set=None, verbose=True, verbose_batch=True)[0]

UnknownError: Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above. [Op:Conv3D]

## Evaluate model

In [None]:
train.evaluate(test_set, model, loss_fn)

## Train loss convergence

In [None]:
plt.figure()

plt.plot(np.arange(1,11,1), loss_history)
plt.xlabel("Epochs")
plt.ylabel("MSE")

plt.title("Train loss convergence")

plt.yscale("log")
plt.show()

## Visualize predicted slices

In [None]:
from utils import viz_utils

save_path = "/home/ist_davidcalhas/eeg_to_fmri/plots/fmri_ae/"

instance = 1
for instance_x, _ in test_set.repeat(1):
    fig = viz_utils.plot_3D_representation_projected_slices(instance_x.numpy()[0])
    #plt.savefig(save_path + str(instance) + "_ground_truth.pdf", format="pdf")
    
    fig = viz_utils.plot_3D_representation_projected_slices(model(instance_x).numpy()[0])
    #plt.savefig(save_path + str(instance) + "_predicted.pdf", format="pdf")
    fig
    break
    instance += 1