In [None]:
import pandas as pd
import numpy as np
import pickle
import matplotlib.pyplot as plt

# --- Environment Setup ---y

# Option 1: Using Anaconda
# conda create --name tf python=3.8 anaconda 
# conda activate tf
# pip install -r ../requirements.txt
# Note: This installs an older version of TensorFlow since Anaconda 
# no longer maintains recent GPU packages. It still works reliably.

# Option 2: Using Python venv
# python3 -m venv .venv
# source .venv/bin/activate
# pip install -r ../requirements.txt

from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from tensorflow.keras import utils as np_utils
from fMRINet import fmriNet8, fmriNet16, fmriNet32

import sklearn
import tensorflow as tf

# TensorFlow optimizers

# AdamW lives in different places depending on TF version.
try:
    # TF ≥ 2.13
    from tensorflow.keras.optimizers import AdamW
except ImportError:
    try:
        # TF 2.11–2.12
        from tensorflow.keras.optimizers.experimental import AdamW
    except ImportError:
        # TF 2.10.x (requires tensorflow-addons)
        from tensorflow_addons.optimizers import AdamW

# For compatibility with libraries such as iNNvestigate, 
# you may need to disable eager execution:
# tf.compat.v1.disable_eager_execution()

In [None]:
# Load the DataFrame from the pickle file /-/ this is the toy dataframe ; for the actual data; please consider dataframe.pkl and get in touch with the author.


# #{'PVT': 0, 'VWM': 1, 'DOT': 2, 'MOD': 3, 'DYN': 4, 'rest': 5}
# df = pd.read_pickle('dataframe.pkl')

# df.head()


#{'PVT': 0, 'VWM': 1, 'DOT': 2, 'MOD': 3, 'DYN': 4, 'rest': 5}
df = pd.read_pickle('dataframe.pkl')

df.head()

In [None]:
print("TF version:", tf.__version__)
print("Built with CUDA?", tf.test.is_built_with_cuda())
print("Visible GPUs:", tf.config.list_physical_devices("GPU"))

In [None]:
# subjs = df["subject"].unique()
# np.random.shuffle(subjs) # do in-place shuffle

# to work with the same train/validation splits while doing model development
with open('subjs.pickle', 'rb') as f:
    subjs = pickle.load(f)

# pull train/valid data by taking subjects from shuffled list
train_df = df[df['subject'].isin(subjs[0:45])]
valid_df = df[df['subject'].isin(subjs[45:,])]

# convert to numpy arrays and do reordering of data dimensions to feed into network
train_label = np.array(train_df['Task'])

train_data  = np.dstack(train_df['Time_Series_Data'])
train_data  = np.expand_dims(train_data, axis=0)
train_data  = np.transpose(train_data, axes=[3, 2, 1, 0]) # (batch, row, time, region)

valid_label = np.array(valid_df['Task'])

valid_data  = np.dstack(valid_df['Time_Series_Data'])
valid_data  = np.expand_dims(valid_data, axis=0)
valid_data  = np.transpose(valid_data, axes=[3, 2, 1, 0]) # (batch, row, time, region)

In [None]:
train_label = np_utils.to_categorical(train_label)
valid_label = np_utils.to_categorical(valid_label)


In [None]:
# calculate class weights for training data to use at training time
train_label_v2 = np.argmax(train_label, axis=1)
a, b           = np.unique(train_label_v2, return_counts=True)
weights        = sklearn.utils.class_weight.compute_class_weight(class_weight='balanced', classes=a, y=train_label_v2)
class_weights  = {0:weights[0], 1:weights[1], 2:weights[2], 3:weights[3], 4:weights[4], 5:weights[5]}
class_weights

In [None]:
model = fmriNet8(num_classes=6, input_shape=(214, 277, 1), temporal_kernel_sec=60, fs=1.0)  # or fmriNet16(), fmriNet32()

model.summary()

In [None]:
model.compile(loss='categorical_crossentropy', optimizer=AdamW(weight_decay=0.0005), 
              metrics = ['accuracy'])

checkpointer = ModelCheckpoint(filepath='/tmp/checkpoint.h5', verbose=1,
                               save_best_only=True)

# simple learning rate schedule, half learning rate every 200 epochs
# which seems to do ok for this data
def lr_schedule(epoch):
         return (0.001 * np.power(0.5, np.floor(epoch/200)))

scheduler    = LearningRateScheduler(lr_schedule, verbose=1)

In [None]:
# without eager execution this takes much longer to train..
fittedModel = model.fit(train_data, train_label, batch_size = 64, epochs = 200, 
                        verbose = 2, validation_data=(valid_data, valid_label),
                        callbacks=[checkpointer], class_weight = class_weights)

In [None]:
# Plot training history
plt.figure(figsize=(12, 5))

# Accuracy
plt.subplot(1, 2, 1)
plt.plot(fittedModel.history['accuracy'], label='Train Accuracy')
plt.plot(fittedModel.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Model Accuracy')
plt.legend()
plt.grid(True)

# Loss
plt.subplot(1, 2, 2)
plt.plot(fittedModel.history['loss'], label='Train Loss')
plt.plot(fittedModel.history['val_loss'], label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Model Loss')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

In [None]:
model.load_weights('/tmp/checkpoint.h5')
preds = model.predict(valid_data)

In [None]:
from sklearn.metrics import balanced_accuracy_score
balanced_accuracy_score(np.argmax(valid_label, axis=1), np.argmax(preds, axis=1))

In [None]:
filters = np.squeeze(model.layers[2].get_weights())

fig = plt.subplots(2, 4, figsize=(12, 4))

for i in range(1, 9):
    plt.subplot(2, 4, i)
    plt.plot(filters[:, i-1])
    plt.title(f'Temporal Filter {i}')
    
plt.tight_layout()

In [None]:
filters = np.squeeze(model.layers[4].get_weights())

fig = plt.subplots(8, 4, figsize=(8, 12))

i = 1
for j in range(8):
    for k in range(4):
        plt.subplot(8, 4, i)
        plt.plot(filters[:, j, k])
        plt.title(f'T. Filter  {j+1}, S. Filter {k+1}')
        i = i + 1
    
plt.tight_layout()