In [None]:
import os
import re
import numpy as np
import nibabel as nib
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.utils import plot_model
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
import gzip
from skimage.transform import resize
import pickle

# Constants
dataset_path = os.path.normpath("D:/PROJECTS/emotion_detect/archive")
label_path = os.path.join(dataset_path, "onsetime")
target_shape = (32, 32, 32)
sequence_length = 20
future_steps = 5
batch_size = 5
VALID_EMOTIONS = ["happy", "sad", "angry", "neutral", "scrambled", "blank"]

# Optimize GPU usage
physical_devices = tf.config.list_physical_devices('GPU')
if physical_devices:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

# Load emotion labels
def load_labels_from_tsv(labels_dir):
    label_mapping = {}
    if not os.path.exists(labels_dir):
        print(f"Label directory '{labels_dir}' not found!")
        return label_mapping

    for tsv_file in os.listdir(labels_dir):
        if tsv_file.endswith(".tsv"):
            subject_run = re.search(r"task-emotionalfaces_run-(\d+)_events", tsv_file)
            if subject_run:
                df = pd.read_csv(os.path.join(labels_dir, tsv_file), sep='\t', usecols=['trial_type'])
                valid_trial_types = df['trial_type'].dropna()[df['trial_type'].isin(VALID_EMOTIONS)].tolist()
                if valid_trial_types:
                    label_mapping[subject_run.group(1)] = valid_trial_types
    return label_mapping

# Process fMRI sequences
def process_fmri_file_sequence(fmri_file_path, label_mapping):
    subject_run = re.search(r"wrsub-(\d+)_task-emotionalfaces_run-(\d+)", os.path.basename(fmri_file_path))
    if not subject_run:
        return None, None

    run_id = subject_run.group(2)
    labels = label_mapping.get(run_id)
    if not labels:
        return None, None

    if fmri_file_path.endswith(".nii.gz"):
        with gzip.open(fmri_file_path, 'rb') as f_in:
            fmri_img = nib.FileHolder(fileobj=f_in)
            fmri_data = nib.Nifti1Image.from_file_map({'image': fmri_img}).get_fdata(dtype=np.float32)
    else:
        fmri_data = nib.load(fmri_file_path).get_fdata(dtype=np.float32)

    fmri_data = (fmri_data - fmri_data.min()) / (fmri_data.max() - fmri_data.min() + 1e-10)
    total_frames = fmri_data.shape[-1]

    num_sequences = (total_frames - sequence_length - future_steps) // batch_size + 1
    sequences = np.zeros((num_sequences * batch_size, sequence_length, *target_shape), dtype=np.float32)
    sequence_labels = []

    idx = 0
    for i in range(0, total_frames - sequence_length - future_steps, batch_size):
        batch_sequences = fmri_data[..., i:i + batch_size + sequence_length + future_steps]
        for j in range(min(batch_size, total_frames - i - sequence_length - future_steps)):
            seq = batch_sequences[..., j:j + sequence_length]
            seq = resize(seq, (*target_shape, sequence_length), anti_aliasing=True, preserve_range=True)
            sequences[idx] = np.transpose(seq, (3, 0, 1, 2))  # (sequence_length, 32, 32, 32)
            sequence_labels.append(labels[min(i + j + sequence_length, len(labels) - 1)])
            idx += 1

    if idx > 0:
        return sequences[:idx], sequence_labels
    return None, None

# Build Transformer model
def build_transformer_model(sequence_length, target_shape, num_classes, future_steps, d_model=128, num_heads=4, ff_dim=256, dropout_rate=0.1):
    inputs = layers.Input(shape=(sequence_length, *target_shape, 1))

    x = layers.TimeDistributed(layers.Conv3D(16, (3, 3, 3), activation='relu', padding='same', strides=2))(inputs)
    x = layers.TimeDistributed(layers.GlobalAveragePooling3D())(x)
    x = layers.Dense(d_model)(x)

    positions = tf.range(start=0, limit=sequence_length, delta=1)
    pos_encoding = layers.Embedding(input_dim=sequence_length, output_dim=d_model)(positions)
    x = x + pos_encoding

    for _ in range(2):  # Two Transformer blocks
        attn_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)(x, x)
        x = layers.LayerNormalization(epsilon=1e-6)(x + attn_output)
        ffn_output = layers.Dense(ff_dim, activation='relu')(x)
        ffn_output = layers.Dropout(dropout_rate)(ffn_output)
        ffn_output = layers.Dense(d_model)(ffn_output)
        x = layers.LayerNormalization(epsilon=1e-6)(x + ffn_output)

    x = layers.Flatten()(x)
    x = layers.Dense(future_steps * num_classes, activation='linear')(x)
    outputs = layers.Reshape((future_steps, num_classes))(x)

    model = models.Model(inputs, outputs)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4), loss='mse', metrics=['mae'])
    return model

# Main Execution
label_mapping = load_labels_from_tsv(label_path)

fmri_sequences, label_sequences = [], []
for root, _, files in os.walk(dataset_path):
    for file in files:
        if file.endswith(('.nii', '.nii.gz')):
            file_path = os.path.join(root, file)
            seqs, labels = process_fmri_file_sequence(file_path, label_mapping)
            if seqs is not None and labels:
                fmri_sequences.append(seqs)
                label_sequences.extend(labels)

if not fmri_sequences:
    raise ValueError("No fMRI sequences processed!")

fmri_sequences = np.concatenate(fmri_sequences, axis=0)

label_encoder = LabelEncoder()
encoded_labels = label_encoder.fit_transform(label_sequences)
one_hot_encoder = OneHotEncoder(sparse_output=False)
labels_one_hot = one_hot_encoder.fit_transform(encoded_labels.reshape(-1, 1))

# Prepare future label sequences
y_future = np.zeros((len(labels_one_hot), future_steps, labels_one_hot.shape[1]))
for i in range(len(labels_one_hot) - future_steps):
    y_future[i] = labels_one_hot[i:i + future_steps]

X_train, X_test, y_train, y_test = train_test_split(
    fmri_sequences, y_future, test_size=0.2, random_state=42
)

X_train = np.expand_dims(X_train, axis=-1)
X_test = np.expand_dims(X_test, axis=-1)

# Build and train model
model = build_transformer_model(sequence_length, target_shape, labels_one_hot.shape[1], future_steps)

# Optional: Visualize the model
plot_model(model, show_shapes=True, to_file="transformer_model.png")

callbacks = [
    EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=1e-6)
]

model.fit(X_train, y_train, epochs=20, batch_size=32, validation_split=0.2, callbacks=callbacks, verbose=1)

# Evaluation
test_loss, test_mae = model.evaluate(X_test, y_test, verbose=0)
print(f"Test Loss: {test_loss:.4f}, Test MAE: {test_mae:.4f}")

# Prediction samples
predictions = model.predict(X_test, verbose=0)
for i in range(min(5, len(X_test))):
    predicted_labels = label_encoder.inverse_transform(np.argmax(predictions[i], axis=1))
    actual_labels = label_encoder.inverse_transform(np.argmax(y_test[i], axis=1))
    print(f"Sample {i + 1} - Actual: {actual_labels}, Predicted: {predicted_labels}")

# Save the model and encoders
model.save("fmri_emotion_model_transformer.keras")
with open("label_encoder.pkl", "wb") as f:
    pickle.dump(label_encoder, f)

print("Transformer model and label encoder saved successfully!")


Epoch 1/20
[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m205s[0m 2s/step - loss: 0.1952 - mae: 0.3539 - val_loss: 0.1803 - val_mae: 0.3570 - learning_rate: 1.0000e-04
Epoch 2/20
[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m163s[0m 2s/step - loss: 0.1813 - mae: 0.3585 - val_loss: 0.1799 - val_mae: 0.3599 - learning_rate: 1.0000e-04
Epoch 3/20
[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m325s[0m 4s/step - loss: 0.1805 - mae: 0.3599 - val_loss: 0.1795 - val_mae: 0.3569 - learning_rate: 1.0000e-04
Epoch 4/20
[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m222s[0m 3s/step - loss: 0.1791 - mae: 0.3587 - val_loss: 0.1782 - val_mae: 0.3578 - learning_rate: 1.0000e-04
Epoch 5/20
[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m165s[0m 2s/step - loss: 0.1791 - mae: 0.3584 - val_loss: 0.1771 - val_mae: 0.3540 - learning_rate: 1.0000e-04
Epoch 6/20
[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m154s[0m 2s/step - loss: 0.1749 - mae