### DataLoader for FMRI

In [1]:
import numpy as np
import torch
from torch.utils.data import Dataset
import os
from torchvision import transforms

class CustomTrainDataset(Dataset):
    def __init__(self, fmri_folder, images_folder, captions_file, transform=None):
        self.fmri_files = [os.path.join(fmri_folder, file) for file in os.listdir(fmri_folder) if 'nsd_train_fmriavg_nsdgeneral' in file and 'batch' in file]
        self.fmri_files = sorted(self.fmri_files, key=lambda x: int(x.split('_batch')[-1].split('.')[0]))[:-1]

        self.images_files = [os.path.join(images_folder, file) for file in os.listdir(images_folder) if 'nsd_train_stim' in file and 'batch' in file]
        self.images_files = sorted(self.images_files, key=lambda x: int(x.split('_batch')[-1].split('.')[0]))[:-1]
        # self.captions = np.load(captions_file)
        
        # Assume all batch files have the same number of samples
        self.samples_per_file = len(np.load(self.fmri_files[0]))
        self.total_samples = self.samples_per_file * len(self.fmri_files)
        self.transform = transform

    def __len__(self):
        return self.total_samples

    def __getitem__(self, idx):
        file_idx = idx // self.samples_per_file
        sample_idx = idx % self.samples_per_file

        fmri = np.load(self.fmri_files[file_idx], mmap_mode='r')[sample_idx]
        image = np.load(self.images_files[file_idx], mmap_mode='r')[sample_idx]
        # caption = self.captions[idx % len(self.captions)]  # Cycle through captions if they are less than fmri and images

        fmri = torch.from_numpy(fmri).float()
        image = torch.from_numpy(image).float()

        if self.transform:
            image = transforms.Resize((256,256))(image.permute(2,0,1))
        return image, fmri
    
class CustomTestDataset(Dataset):
    def __init__(self, fmri_file, images_file, captions_file):
        self.fmri = np.load(fmri_file, mmap_mode='r')
        self.images = np.load(images_file, mmap_mode='r')
        # self.captions = np.load(captions_file, mmap_mode='r')

    def __len__(self):
        return len(self.fmri)

    def __getitem__(self, idx):
        fmri = torch.from_numpy(self.fmri[idx]).float()
        image = torch.from_numpy(self.images[idx]).float()
        # caption = self.captions[idx]

        return image, fmri

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, random_split
import numpy as np

# Paths
train_fmri_folder = 'data/processed_data/subj07'
train_images_folder = 'data/processed_data/subj07'
train_captions_file = 'data/processed_data/subj07/nsd_train_cap_sub7.npy'

test_fmri_file = 'data/processed_data/subj07/nsd_test_fmriavg_nsdgeneral_sub7.npy'
test_images_file = 'data/processed_data/subj07/nsd_test_stim_sub7.npy'
test_captions_file = 'data/processed_data/subj07/nsd_test_cap_sub7.npy'

def custom_collate(batch):
    images, fmri_data = zip(*batch)

    # Convert images and fMRI data to tensors
    images = torch.stack([torch.from_numpy(np.array(img)).float() for img in images])
    fmri_data = torch.stack([torch.from_numpy(np.array(fmri)).float() for fmri in fmri_data])

    # Handle captions as a list of strings
    return images, fmri_data


# Datasets
train_dataset = CustomTrainDataset(train_fmri_folder, train_images_folder, train_captions_file, transform=True)
test_dataset = CustomTestDataset(test_fmri_file, test_images_file, test_captions_file)

# DataLoaders
batch_size = 10  # Adjust as needed
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False,  collate_fn=custom_collate)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate)

### Network

In [3]:
import torch
import torch.nn as nn

class EncoderDecoder(nn.Module):
    def __init__(self):
        super(EncoderDecoder, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(12682, 1024),
            nn.ReLU(),
            nn.Linear(1024, 256),
            nn.ReLU(),
        )

        # Intermediate layers
        self.intermediate = nn.Sequential(
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU()
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(256, 1024),
            nn.ReLU(),
            nn.Linear(1024, 256 * 256 * 3),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.intermediate(x)  # Pass through the intermediate layers
        x = self.decoder(x)
        x = x.view(-1, 3, 256, 256)  # Reshape to image dimensions
        return x


In [4]:
print(np.load(train_dataset.images_files[176]).shape)
print(np.load(train_dataset.fmri_files[176]).shape)

(50, 425, 425, 3)
(50, 12682)


### Training

In [5]:
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm

# Device configuration
# if no GPU available, raise error
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    raise Exception('GPU not available')

# Initialize the model, loss function, and optimizer
model = EncoderDecoder().to(device)
criterion = nn.MSELoss()  # Mean Squared Error Loss
optimizer = Adam(model.parameters(), lr=0.001)

# DataLoader setup (assuming you have already created train_loader)
# train_loader = ...

# Train for 3 epochs
num_epochs = 30
for epoch in range(num_epochs):
    torch.cuda.empty_cache()
    model.train()
    running_loss = 0.0
    with tqdm(train_loader, unit="batch") as tepoch:
        for images, fmri in tepoch:
            images, fmri = images.to(device), fmri.to(device)
            tepoch.set_description(f"Epoch {epoch+1}")

            # Forward pass
            optimizer.zero_grad()
            outputs = model(fmri)
            loss = criterion(outputs, images)

            # Backward and optimize
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            tepoch.set_postfix(loss=loss.item())
            
            # Clear memory
            del images, fmri, outputs, loss
            torch.cuda.empty_cache()

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")
    torch.cuda.empty_cache()



  fmri = torch.from_numpy(fmri).float()
Epoch 1: 100%|██████████| 885/885 [01:31<00:00,  9.69batch/s, loss=1.72e+4]


Epoch [1/30], Average Loss: 17657.7108


Epoch 2: 100%|██████████| 885/885 [01:31<00:00,  9.72batch/s, loss=1.72e+4]


Epoch [2/30], Average Loss: 17657.5143


Epoch 3: 100%|██████████| 885/885 [01:30<00:00,  9.74batch/s, loss=1.72e+4]


Epoch [3/30], Average Loss: 17657.5139


Epoch 4: 100%|██████████| 885/885 [01:30<00:00,  9.75batch/s, loss=1.72e+4]


Epoch [4/30], Average Loss: 17657.5139


Epoch 5: 100%|██████████| 885/885 [01:31<00:00,  9.72batch/s, loss=1.72e+4]


Epoch [5/30], Average Loss: 17657.5139


Epoch 6: 100%|██████████| 885/885 [01:30<00:00,  9.73batch/s, loss=1.72e+4]


Epoch [6/30], Average Loss: 17657.5139


Epoch 7: 100%|██████████| 885/885 [01:30<00:00,  9.74batch/s, loss=1.72e+4]


Epoch [7/30], Average Loss: 17657.5139


Epoch 8: 100%|██████████| 885/885 [01:31<00:00,  9.72batch/s, loss=1.72e+4]


Epoch [8/30], Average Loss: 17657.5139


Epoch 9: 100%|██████████| 885/885 [01:30<00:00,  9.74batch/s, loss=1.72e+4]


Epoch [9/30], Average Loss: 17657.5139


Epoch 10: 100%|██████████| 885/885 [01:30<00:00,  9.75batch/s, loss=1.72e+4]


Epoch [10/30], Average Loss: 17657.5139


Epoch 11: 100%|██████████| 885/885 [01:30<00:00,  9.73batch/s, loss=1.72e+4]


Epoch [11/30], Average Loss: 17657.5139


Epoch 12: 100%|██████████| 885/885 [01:30<00:00,  9.74batch/s, loss=1.72e+4]


Epoch [12/30], Average Loss: 17657.5139


Epoch 13: 100%|██████████| 885/885 [01:30<00:00,  9.73batch/s, loss=1.72e+4]


Epoch [13/30], Average Loss: 17657.5139


Epoch 14: 100%|██████████| 885/885 [01:30<00:00,  9.73batch/s, loss=1.72e+4]


Epoch [14/30], Average Loss: 17657.5139


Epoch 15: 100%|██████████| 885/885 [01:30<00:00,  9.74batch/s, loss=1.72e+4]


Epoch [15/30], Average Loss: 17657.5139


Epoch 16: 100%|██████████| 885/885 [01:30<00:00,  9.74batch/s, loss=1.72e+4]


Epoch [16/30], Average Loss: 17657.5139


Epoch 17: 100%|██████████| 885/885 [01:30<00:00,  9.73batch/s, loss=1.72e+4]


Epoch [17/30], Average Loss: 17657.5139


Epoch 18: 100%|██████████| 885/885 [01:30<00:00,  9.73batch/s, loss=1.72e+4]


Epoch [18/30], Average Loss: 17657.5139


Epoch 19: 100%|██████████| 885/885 [01:31<00:00,  9.73batch/s, loss=1.72e+4]


Epoch [19/30], Average Loss: 17657.5139


Epoch 20: 100%|██████████| 885/885 [01:30<00:00,  9.74batch/s, loss=1.72e+4]


Epoch [20/30], Average Loss: 17657.5139


Epoch 21: 100%|██████████| 885/885 [01:30<00:00,  9.75batch/s, loss=1.72e+4]


Epoch [21/30], Average Loss: 17657.5139


Epoch 22: 100%|██████████| 885/885 [01:30<00:00,  9.74batch/s, loss=1.72e+4]


Epoch [22/30], Average Loss: 17657.5139


Epoch 23: 100%|██████████| 885/885 [01:31<00:00,  9.72batch/s, loss=1.72e+4]


Epoch [23/30], Average Loss: 17657.5139


Epoch 24: 100%|██████████| 885/885 [01:30<00:00,  9.74batch/s, loss=1.72e+4]


Epoch [24/30], Average Loss: 17657.5139


Epoch 25: 100%|██████████| 885/885 [01:31<00:00,  9.71batch/s, loss=1.72e+4]


Epoch [25/30], Average Loss: 17657.5139


Epoch 26: 100%|██████████| 885/885 [01:30<00:00,  9.73batch/s, loss=1.72e+4]


Epoch [26/30], Average Loss: 17657.5139


Epoch 27: 100%|██████████| 885/885 [01:30<00:00,  9.74batch/s, loss=1.72e+4]


Epoch [27/30], Average Loss: 17657.5139


Epoch 28: 100%|██████████| 885/885 [01:30<00:00,  9.74batch/s, loss=1.72e+4]


Epoch [28/30], Average Loss: 17657.5139


Epoch 29: 100%|██████████| 885/885 [01:31<00:00,  9.72batch/s, loss=1.72e+4]


Epoch [29/30], Average Loss: 17657.5139


Epoch 30: 100%|██████████| 885/885 [01:30<00:00,  9.74batch/s, loss=1.72e+4]

Epoch [30/30], Average Loss: 17657.5139





In [6]:
# Saving the model and state dict
torch.save(model, 'entire_model.pth')
