In [30]:
import numpy as np
import os
import glob
from natsort import natsorted
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

In [65]:
sub_id = "01"
image_dir = "data/images"
eeg_dir = f"data/things-eeg2/sub-{sub_id}"

eeg_train = np.load(f"{eeg_dir}/preprocessed_eeg_training.npy", allow_pickle=True).item()
eeg_test = np.load(f"{eeg_dir}/preprocessed_eeg_test.npy", allow_pickle=True).item()

In [67]:
eeg_train_new = np.load(f"{eeg_dir}/preprocessed_eeg_training_new.npy", allow_pickle=True)
eeg_test_new = np.load(f"{eeg_dir}/preprocessed_eeg_test_new.npy", allow_pickle=True)

In [70]:
eeg_train_new['preprocessed_eeg_data'].shape

(16540, 4, 63, 250)

In [71]:
eeg_train['preprocessed_eeg_data'].shape

(16540, 4, 17, 100)

In [3]:
def get_ordered_things_image_paths(images_root_dir):
    """
    Generates a list of image paths for the THINGS training set in order.

    Args:
        images_root_dir (str): Path to the root directory containing the
                                'things_images' (or similar) folder.

    Returns:
        list: A list of full paths to the image files, ordered to match
              the preprocessed EEG data's first dimension.
    """
    ordered_image_paths = []
    concept_dirs = natsorted(glob.glob(os.path.join(images_root_dir, '*')))

    # Iterate through sorted concept directories
    for concept_dir in concept_dirs:
        if not os.path.isdir(concept_dir):
            continue # Skip any non-directory files

        # Get and sort image files within each concept
        # Assumes images are .jpg or .png, adjust as needed
        image_files = natsorted(glob.glob(os.path.join(concept_dir, '*.jpg')))
        image_files.extend(natsorted(glob.glob(os.path.join(concept_dir, '*.png'))))
        # Remove duplicates if both .jpg and .png are present for the same name (unlikely for THINGS)
        image_files = list(set(image_files))
        image_files = natsorted(image_files) # Re-sort after set conversion

        # Add all images from this concept to the ordered list
        ordered_image_paths.extend(image_files)

    return ordered_image_paths

train_image_dir = os.path.join(image_dir, "training_images")
test_image_dir = os.path.join(image_dir, "test_images")
train_image_paths = get_ordered_things_image_paths(train_image_dir)
test_image_paths = get_ordered_things_image_paths(test_image_dir)

In [4]:
len(train_image_paths)

16540

In [15]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model.eval()
def extract_clip_embeddings(image_paths, batch_size=32):
    """
    Extracts CLIP image embeddings from a list of image paths.

    Args:
        image_paths (list): A list of paths to image files.
        batch_size (int): Number of images to process per batch.

    Returns:
        numpy.ndarray: A NumPy array of normalized CLIP embeddings.
                       Shape: (num_images, embedding_dim)
    """
    all_image_features = []

    # Process images in batches to manage memory and speed up
    for i in range(0, len(image_paths), batch_size):
        batch_paths = image_paths[i : i + batch_size]
        images = []
        for path in batch_paths:
            try:
                img = Image.open(path).convert("RGB") # Ensure RGB for consistency
                images.append(img)
            except Exception as e:
                print(f"Warning: Could not open image {path}: {e}. Skipping.")
                continue # Skip problematic images

        if not images:
            continue # Skip if the batch ended up empty due to errors

        # Preprocess images and move inputs to the selected device
        inputs = processor(images=images, return_tensors="pt", padding=True).to(device)

        # Forward pass (image encoder only)
        with torch.no_grad():
            image_features_batch = model.get_image_features(**inputs)

        # Normalize embeddings (critical for alignment)
        # Ensure normalization happens on the correct device if needed for fused operations
        image_features_batch = image_features_batch / image_features_batch.norm(dim=-1, keepdim=True)

        # Move to CPU and append as NumPy array
        all_image_features.append(image_features_batch.cpu().numpy())

    # Concatenate all batch results
    if all_image_features:
        return np.vstack(all_image_features)
    else:
        return np.array([])

Using device: mps


In [16]:
train_image_np = extract_clip_embeddings(train_image_paths)

In [17]:
train_image_np.shape

(16540, 512)

In [18]:
test_image_np = extract_clip_embeddings(test_image_paths)
test_image_np.shape

(200, 512)

In [19]:
np.save(f"{image_dir}/train_image_clip_embeddings.npy", train_image_np)
np.save(f"{image_dir}/test_image_clip_embeddings.npy", test_image_np)

In [26]:
class EEGImageDataset(Dataset):
    def __init__(self, eeg_data_array, clip_embeddings_array):
        if eeg_data_array.shape[0] != clip_embeddings_array.shape[0]:
            raise ValueError(
                f"Mismatch in number of conditions/images: "
                f"EEG has {eeg_data_array.shape[0]}, CLIP has {clip_embeddings_array.shape[0]}"
            )

        self.eeg_data = torch.from_numpy(eeg_data_array).float()
        self.clip_embeddings = torch.from_numpy(clip_embeddings_array).float()

        self.num_total_eeg_trials = self.eeg_data.shape[0] * self.eeg_data.shape[1]
        self.eeg_data = self.eeg_data.view(-1, self.eeg_data.shape[2], self.eeg_data.shape[3])

        self.clip_embeddings = self.clip_embeddings.repeat_interleave(eeg_data_array.shape[1], dim=0)

    def __len__(self):
        return self.num_total_eeg_trials

    def __getitem__(self, idx):
        return {
            'eeg': self.eeg_data[idx],
            'clip_embedding': self.clip_embeddings[idx]
        }
    
train_dataset = EEGImageDataset(eeg_train['preprocessed_eeg_data'], train_image_np)
test_dataset = EEGImageDataset(eeg_test['preprocessed_eeg_data'], test_image_np)

In [51]:
train_eeg_np = eeg_train['preprocessed_eeg_data'].copy()
test_eeg_np = eeg_test['preprocessed_eeg_data'].copy()

In [52]:
train_eeg_np.shape

(16540, 4, 17, 100)

In [54]:
train_image_np.shape

(16540, 512)

In [63]:
train_dataset[0]['eeg'].shape

torch.Size([17, 100])

In [33]:
batch_size = 64
num_workers = 4
train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    # num_workers=4, 
    # pin_memory=True
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=batch_size, 
    shuffle=False, 
    # num_workers=4, 
    # pin_memory=True
)

In [None]:
class BasicEEGEncoder(nn.Module):
    def __init__(self, in_channels=17, time_points=100, embedding_dim=512):
        super().__init__()

        self.conv_block = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(1, 5), padding=(0, 2)),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=(in_channels, 1), padding=0),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )

        self.mlp_projector = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(256, embedding_dim)
        )
    
    def forward(self, x):
        x = x.unsqueeze(1)

        x = self.conv_block(x)
        x = x.view(x.size(0), -1)

        eeg_embedding = self.mlp_projector(x)
        return eeg_embedding
    
class ImprovedEEGEncoder(nn.Module):
    def __init__(self, in_channels=17, time_points=100, embedding_dim=512):
        super().__init__()

        # --- Temporal Convolution Block ---
        # Captures temporal features for each channel
        self.temporal_conv_block = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(1, 5), padding=(0, 2)), # Example: 32 filters, temporal kernel
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=(1, 5), padding=(0, 2)), # More temporal depth
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )

        # --- Spatial Convolution Block ---
        # Captures spatial features across channels
        # Input to this block will be (Batch, 64, Channels, TimePoints)
        self.spatial_conv_block = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(in_channels, 1)), # Convolve across all channels
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=(1, time_points - 4)), # Flatten remaining temporal dimension
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
        # Note: kernel_size=(1, time_points - 4) is for (1, X) where X = (100 - (5-1)*2 - (5-1)*2) = 100 - 8 = 92
        # (100 is time_points, 5 is kernel_size, 2 is padding for first two convs)
        # Better to use AdaptiveAvgPool or calculate dynamically

        # Correct way to flatten spatial/temporal output to pass to MLP
        # After temporal_conv_block: (B, 64, 17, 100)
        # After spatial_conv_block (if it truly reduces to 1x1 or 1xDim): (B, 256, 1, 1) or (B, 256, 1, SomeDim)
        # Let's simplify and use AdaptiveAvgPool after each stage for clearer dimensionality.

        self.temporal_spatial_features = nn.Sequential(
            # Input: (B, 1, C, T)
            nn.Conv2d(1, 32, kernel_size=(1, 5), padding=(0, 2)), # Temporal filter 1
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=(1, 5), padding=(0, 2)), # Temporal filter 2
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=(in_channels, 1)), # Spatial filter
            nn.BatchNorm2d(128),
            nn.ReLU(),
            # Consider a mix of pooling and dense layers here instead of just AdaptiveAvgPool2d
            # e.g., nn.AdaptiveAvgPool2d((1, 1)) or nn.MaxPool2d((1, 2))
            nn.AdaptiveAvgPool2d((1,1)) # Reduces to (B, 128, 1, 1) -> (B, 128)
        )

        # Updated MLP Projector
        self.mlp_projector = nn.Sequential(
            nn.Linear(128, 512), # Increased hidden dim
            nn.ReLU(),
            nn.LayerNorm(512),
            nn.Linear(512, embedding_dim),
            nn.LayerNorm(embedding_dim) # Add LayerNorm at the final output too
        )

    def forward(self, x):
        x = x.unsqueeze(1) # Add channel dim for Conv2d: (B, 1, C, T)
        x = self.temporal_spatial_features(x)
        x = x.view(x.size(0), -1) # Flatten to (B, 128) for MLP
        eeg_embedding = self.mlp_projector(x)
        return eeg_embedding
    
class CombinedLoss(nn.Module):
    def __init__(self, lambda_param=0.5, temperature=0.07):
        super().__init__()
        self.lambda_param = lambda_param
        self.temperature = temperature
    
    def forward(self, eeg_embeddings, clip_embeddings):
        eeg_embeddings = F.normalize(eeg_embeddings, p=2, dim=-1)
        clip_embeddings = F.normalize(clip_embeddings, p=2, dim=-1)

        logits_per_eeg = torch.matmul(eeg_embeddings, clip_embeddings.T) / self.temperature
        logits_per_clip = logits_per_eeg.T

        labels = torch.arange(eeg_embeddings.shape[0], device=eeg_embeddings.device)

        loss_eeg =  F.cross_entropy(logits_per_eeg, labels)
        loss_clip = F.cross_entropy(logits_per_clip, labels)
        l_clip = (loss_eeg + loss_clip) / 2

        l_mse = F.mse_loss(eeg_embeddings, clip_embeddings)

        total_loss = self.lambda_param * l_clip + (1 - self.lambda_param) * l_mse
        return total_loss, l_clip.item(), l_mse.item()


In [None]:
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on device: {device}")

# eeg_encoder = BasicEEGEncoder(
#     in_channels=17,
#     time_points=100,
#     embedding_dim=512
# ).to(device)

eeg_encoder = ImprovedEEGEncoder(
    in_channels=17,
    time_points=100,
    embedding_dim=512
).to(device)
criterion = CombinedLoss(lambda_param=0.1, temperature=0.07).to(device)
optimizer = optim.AdamW(eeg_encoder.parameters(), lr=1e-3)

In [50]:
num_epochs = 10

print("\nStarting training...")
best_val_loss = float('inf')

for epoch in range(num_epochs):
    eeg_encoder.train()
    total_train_loss = 0
    total_l_clip_train = 0
    total_l_mse_train = 0

    for batch_idx, batch in enumerate(train_loader):
        eeg = batch['eeg'].to(device)
        clip_target = batch['clip_embedding'].to(device)

        optimizer.zero_grad()

        eeg_embeddings = eeg_encoder(eeg)

        loss, l_clip_val, l_mse_val = criterion(eeg_embeddings, clip_target)
        loss.backward() # Backward pass
        optimizer.step()

        total_train_loss += loss.item()
        total_l_clip_train += l_clip_val
        total_l_mse_train += l_mse_val

        if (batch_idx + 1) % 100 == 0:
            print(f"  Epoch {epoch+1}, Batch {batch_idx+1}/{len(train_loader)} | "
                  f"Loss: {loss.item():.4f} (CLIP: {l_clip_val:.4f}, MSE: {l_mse_val:.4f})")
            
    avg_train_loss = total_train_loss / len(train_loader)
    avg_l_clip_train = total_l_clip_train / len(train_loader)
    avg_l_mse_train = total_l_mse_train / len(train_loader)
    print(f"\nEpoch {epoch+1} Training Loss: {avg_train_loss:.4f} (CLIP: {avg_l_clip_train:.4f}, MSE: {avg_l_mse_train:.4f})")


    eeg_encoder.eval() # Set model to evaluation mode
    total_val_loss = 0
    total_l_clip_val = 0
    total_l_mse_val = 0

    with torch.no_grad(): # Disable gradient calculation for validation
        for batch_idx, batch in enumerate(test_loader):
            eeg = batch['eeg'].to(device)
            clip_target = batch['clip_embedding'].to(device)

            eeg_embeddings = eeg_encoder(eeg)
            loss, l_clip_val, l_mse_val = criterion(eeg_embeddings, clip_target)

            total_val_loss += loss.item()
            total_l_clip_val += l_clip_val
            total_l_mse_val += l_mse_val

    avg_val_loss = total_val_loss / len(test_loader)
    avg_l_clip_val = total_l_clip_val / len(test_loader)
    avg_l_mse_val = total_l_mse_val / len(test_loader)
    print(f"Epoch {epoch+1} Validation Loss: {avg_val_loss:.4f} (CLIP: {avg_l_clip_val:.4f}, MSE: {avg_l_mse_val:.4f})")

    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(eeg_encoder.state_dict(), 'best_eeg_encoder.pth')
        print(f"  Saved best model with validation loss: {best_val_loss:.4f}")

print("\nTraining complete!")


Starting training...
  Epoch 1, Batch 100/1034 | Loss: 0.3914 (CLIP: 3.8833, MSE: 0.0034)
  Epoch 1, Batch 200/1034 | Loss: 0.3817 (CLIP: 3.7865, MSE: 0.0033)
  Epoch 1, Batch 300/1034 | Loss: 0.3757 (CLIP: 3.7274, MSE: 0.0033)
  Epoch 1, Batch 400/1034 | Loss: 0.3811 (CLIP: 3.7810, MSE: 0.0033)
  Epoch 1, Batch 500/1034 | Loss: 0.3914 (CLIP: 3.8838, MSE: 0.0034)
  Epoch 1, Batch 600/1034 | Loss: 0.3904 (CLIP: 3.8735, MSE: 0.0034)
  Epoch 1, Batch 700/1034 | Loss: 0.3773 (CLIP: 3.7428, MSE: 0.0033)
  Epoch 1, Batch 800/1034 | Loss: 0.3749 (CLIP: 3.7187, MSE: 0.0033)
  Epoch 1, Batch 900/1034 | Loss: 0.3906 (CLIP: 3.8755, MSE: 0.0034)
  Epoch 1, Batch 1000/1034 | Loss: 0.3749 (CLIP: 3.7192, MSE: 0.0033)

Epoch 1 Training Loss: 0.3814 (CLIP: 3.7844, MSE: 0.0033)
Epoch 1 Validation Loss: 0.4260 (CLIP: 4.2286, MSE: 0.0035)
  Saved best model with validation loss: 0.4260
  Epoch 2, Batch 100/1034 | Loss: 0.3707 (CLIP: 3.6774, MSE: 0.0033)
  Epoch 2, Batch 200/1034 | Loss: 0.3770 (CLIP: 3.7

In [41]:
clip_embedding_dim = 512 # Or 1024, based on your model's config
num_channels = 17
num_timepoints = 100

device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
print(f"Evaluation on device: {device}")

eeg_encoder = BasicEEGEncoder(
    in_channels=num_channels,
    time_points=num_timepoints,
    embedding_dim=clip_embedding_dim 
).to(device)

# Load the best saved model weights
try:
    eeg_encoder.load_state_dict(torch.load('best_eeg_encoder.pth'))
    print("Loaded best EEG encoder for evaluation.")
except FileNotFoundError:
    print("Error: 'best_eeg_encoder.pth' not found. Please train the model first.")
    exit()

eeg_encoder.eval() # Set to evaluation mode

all_eeg_embeddings = []
all_clip_embeddings = []

with torch.no_grad():
    for batch_idx, batch in enumerate(test_loader): # Use test_loader here
        eeg = batch['eeg'].to(device)
        clip_target = batch['clip_embedding'].to(device)

        eeg_emb = eeg_encoder(eeg)

        eeg_emb = F.normalize(eeg_emb, p=2, dim=-1)
        clip_target = F.normalize(clip_target, p=2, dim=-1)

        all_eeg_embeddings.append(eeg_emb.cpu())
        all_clip_embeddings.append(clip_target.cpu())


all_eeg_embeddings = torch.cat(all_eeg_embeddings, dim=0).numpy()
all_clip_embeddings = torch.cat(all_clip_embeddings, dim=0).numpy()

print(f"Extracted {all_eeg_embeddings.shape[0]} EEG embeddings and CLIP embeddings from test set.")

Evaluation on device: mps
Loaded best EEG encoder for evaluation.
Extracted 16000 EEG embeddings and CLIP embeddings from test set.


In [45]:
def calculate_retrieval_metrics(eeg_embeddings, clip_embeddings, k_values=[1, 5, 10]):
    """
    Calculates Top-K retrieval accuracy and other metrics for EEG embeddings
    retrieving their corresponding CLIP embeddings.

    Args:
        eeg_embeddings (np.ndarray): NumPy array of EEG embeddings.
        clip_embeddings (np.ndarray): NumPy array of CLIP embeddings.
        k_values (list): A list of integers for which to calculate Top-K accuracy.

    Returns:
        tuple: A dictionary of accuracies, mean rank, and mean reciprocal rank (MRR).
    """
    num_samples = eeg_embeddings.shape[0]

    # Convert NumPy arrays to PyTorch tensors for matrix multiplication
    eeg_embeddings_t = torch.from_numpy(eeg_embeddings)
    clip_embeddings_t = torch.from_numpy(clip_embeddings)

    # Calculate the similarity matrix
    # S_ij = similarity(eeg_i, clip_j)
    # The result will be a (num_samples x num_samples) matrix where the diagonal
    # elements are the similarities between true pairs, and off-diagonals are negatives.
    similarity_matrix = torch.matmul(eeg_embeddings_t, clip_embeddings_t.T)

    # For each EEG embedding, find the rank of its true corresponding CLIP embedding
    # torch.argsort returns the indices that would sort the matrix in descending order.
    # So, the first index in each row of sorted_indices is the rank-1 match.
    sorted_indices = torch.argsort(similarity_matrix, dim=-1, descending=True)

    accuracies = {k: 0 for k in k_values}
    ranks = [] # To store the rank of the true positive for each query

    for i in range(num_samples):
        # The true corresponding CLIP embedding is at the same index 'i' in the clip_embeddings.
        # We find its position (rank) within the sorted_indices for the current EEG query.
        # `nonzero(as_tuple=True)[0].item()` gets the 0-based index (rank) of 'i' in the sorted list.
        rank_of_true_target = (sorted_indices[i] == i).nonzero(as_tuple=True)[0].item()

        # Check if the true target is within the Top-K results
        for k in k_values:
            if rank_of_true_target < k: # 0-based rank, so < k means within top k
                accuracies[k] += 1

        ranks.append(rank_of_true_target + 1) # Store 1-based rank for mean rank/MRR

    # Calculate final percentages
    final_accuracies = {}
    for k in k_values:
        final_accuracies[f"Top-{k} Accuracy"] = (accuracies[k] / num_samples) * 100

    # Calculate Mean Rank (MR) and Mean Reciprocal Rank (MRR)
    mean_rank = np.mean(ranks)
    mrr = np.mean([1.0 / rank for rank in ranks])

    return final_accuracies, mean_rank, mrr

In [46]:
print("\n--- Running Retrieval Performance Test ---")

# Define the Top-K values you want to evaluate
# These should ideally align with what's reported in papers (e.g., 1, 5, 10, 20, 50, 100)
evaluation_k_values = [1, 5, 10, 20, 50, 100]

retrieval_results, mean_rank_val, mrr_val = calculate_retrieval_metrics(
    all_eeg_embeddings,
    all_clip_embeddings,
    k_values=evaluation_k_values
)

print("\n--- Retrieval Performance (EEG to CLIP) ---")
for k, acc in retrieval_results.items():
    print(f"{k}: {acc:.2f}%")
print(f"Mean Rank: {mean_rank_val:.2f}")
print(f"Mean Reciprocal Rank (MRR): {mrr_val:.4f}")


--- Running Retrieval Performance Test ---

--- Retrieval Performance (EEG to CLIP) ---
Top-1 Accuracy: 0.01%
Top-5 Accuracy: 0.04%
Top-10 Accuracy: 0.11%
Top-20 Accuracy: 0.19%
Top-50 Accuracy: 0.44%
Top-100 Accuracy: 0.81%
Mean Rank: 7441.69
Mean Reciprocal Rank (MRR): 0.0008
