In [None]:
from glob import glob
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
import seaborn as sns

data_fns = sorted(glob('brainiak-aperture-srm-data/sub-*_task-black_*bold.nii.gz'))
atlas_fn = 'brainiak-aperture-srm-data/Schaefer2018_400Parcels_17Networks_order_FSLMNI152_2.5mm.nii.gz'

# Load in the Schaefer 400-parcel atlas
atlas_nii = nib.load(atlas_fn)
atlas_img = atlas_nii.get_fdata()

# Left temporal parietal ROI labels
parcel_labels = [195, 196, 197, 198, 199, 200]

In [None]:
# Load in functional data and mask with "temporal parietal" ROI
data = []
for data_fn in data_fns:
    voxel_data = nib.load(data_fn).get_fdata()

    # Take union of all parcels (brain areas) comprising the full ROI
    roi_data = np.column_stack([voxel_data[atlas_img == parcel, :].T
                                for parcel in parcel_labels])
    data.append(roi_data)

In [None]:
from nilearn.plotting import plot_stat_map

# Visualize the left temporal parietal ROI
sns.set(palette='colorblind')
roi_img = np.zeros(atlas_img.shape)
for parcel in parcel_labels:
    roi_img[atlas_img == parcel] = 1

# Convert to a NIfTI image for visualization with Nilearn
roi_nii = nib.Nifti1Image(roi_img, atlas_nii.affine, atlas_nii.header)

# Plot plot left temporal parietal ROI
plot_stat_map(roi_nii, cmap='tab10_r', cut_coords=(-53, -46, 10),
              colorbar=False, title='left temporal parietal ROI');
#plt.show()

In [None]:
#participants
print(len(data))
#time, voxels
print(roi_data.shape)
#X, Y, Z, time
print(voxel_data.shape)

In [None]:
for i, subject_data in enumerate(data):
    print(f"Subject {i} training data shape: {subject_data.shape}")

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

class FMRIAutoencoder(nn.Module):
    def __init__(self, input_dim=935, hidden_dim=1024, latent_dim=768):
        super(FMRIAutoencoder, self).__init__()
        # Encoder: maps fmri input to latent space
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim)
        )
        # Decoder: reconstructs fmri input from latent representation
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )
        
    def forward(self, x):
        latent = self.encoder(x)
        reconstruction = self.decoder(latent)
        return latent, reconstruction

class FMRI_Dataset(Dataset):
    def __init__(self, fmri_list):
        # Concatenate along the first axis: all timepoints from all subjects
        self.data = np.concatenate(fmri_list, axis=0)  # Shape: (550 * num_subjects, 935)
        self.data = torch.tensor(self.data, dtype=torch.float32)
        
    def __len__(self):
        return self.data.shape[0]
    
    def __getitem__(self, index):
        return self.data[index]
        
dataset = FMRI_Dataset(data)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

model = FMRIAutoencoder(input_dim=935, hidden_dim=1024, latent_dim=768)

# Define the reconstruction loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 20  # Adjust
for epoch in range(num_epochs):
    running_loss = 0.0
    for fmri_sample in dataloader:
        optimizer.zero_grad()
        latent, recon = model(fmri_sample)
        loss = criterion(recon, fmri_sample)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        
    print(f"Epoch {epoch+1}, Loss: {running_loss / len(dataloader)}")

# Extract the encoder
fmri_encoder = model.encoder

# Save the encoder's weights
torch.save(fmri_encoder.state_dict(), "fmri_encoder_weights.pth")


In [None]:

# At the end of encoding_brain.ipynb
__all__ = ['FMRIAutoencoder']
