In [13]:
import os
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.linear_model import Ridge
from sklearn.metrics import mean_squared_error
from data_utils import get_lazy_dataloaders
from model_utils import evaluate_model, train_model
from eeg_clip_basic import EEGToCLIPNet
from scipy.spatial.distance import correlation

from diffusers import StableUnCLIPImg2ImgPipeline
import os
import torch
import numpy as np

In [2]:
sub = 1
recon_dir = f"results/thingseeg2_preproc/sub-{sub:02d}/unclip/" # Directory to save the reconstructed images
os.makedirs(recon_dir, exist_ok=True)

# Start the StableUnCLIP Image variations pipeline
pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variation="fp16"
)

device = "cuda"
pipe = pipe.to(device)

Keyword arguments {'variation': 'fp16'} are not expected by StableUnCLIPImg2ImgPipeline and will be ignored.
Loading pipeline components...: 100%|██████████| 9/9 [00:05<00:00,  1.77it/s]


In [3]:
clip_train_loader, clip_val_loader, clip_test_loader = get_lazy_dataloaders(sub_id=1, batch_size=32, shuffle=True, embedding_type='clip', flatten_eeg=False)

for data in clip_val_loader:
    print(data[0].shape, data[1].shape)
    break

vae_train_loader, vae_val_loader, vae_test_loader = get_lazy_dataloaders(sub_id=1, batch_size=32, shuffle=True, embedding_type='vae', flatten_eeg=False)

for data in vae_val_loader:
    print(data[0].shape, data[1].shape)
    break

torch.Size([32, 17, 80]) torch.Size([32, 1024])
torch.Size([32, 17, 80]) torch.Size([32, 36864])


In [4]:
class EEG_CNN1D(nn.Module):
    def __init__(self, num_channels=17, num_time_points=80, output_dim=1024): # Assuming CLIP embedding is 768
        super(EEG_CNN1D, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv1d(num_channels, 64, kernel_size=5, padding=2), # Output: (batch, 64, 80)
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.MaxPool1d(kernel_size=2, stride=2), # Output: (batch, 64, 40)

            nn.Conv1d(64, 128, kernel_size=5, padding=2), # Output: (batch, 128, 40)
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.MaxPool1d(kernel_size=2, stride=2), # Output: (batch, 128, 20)

            nn.Conv1d(128, 256, kernel_size=5, padding=2), # Output: (batch, 256, 20)
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.MaxPool1d(kernel_size=2, stride=2), # Output: (batch, 256, 10)
        )
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(256 * 10, output_dim) # Adjust 256*10 based on final Conv1d output shape

    def forward(self, x):
        # Input x shape: (batch_size, channels, time_points) -> (32, 17, 80)
        x = self.conv_layers(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

In [5]:
conv1d_model = EEG_CNN1D()
train_model(conv1d_model, clip_train_loader, clip_val_loader, device='cuda', model_name='conv1d_model')

Epoch 1/100:   0%|          | 0/414 [00:00<?, ?it/s]

Epoch 1/100: 100%|██████████| 414/414 [00:01<00:00, 403.87it/s]
Epoch 2/100: 100%|██████████| 414/414 [00:00<00:00, 536.15it/s]
Epoch 3/100: 100%|██████████| 414/414 [00:00<00:00, 523.86it/s]
Epoch 4/100: 100%|██████████| 414/414 [00:00<00:00, 546.84it/s]
Epoch 5/100: 100%|██████████| 414/414 [00:00<00:00, 515.30it/s]
Epoch 6/100: 100%|██████████| 414/414 [00:00<00:00, 500.12it/s]
Epoch 7/100: 100%|██████████| 414/414 [00:00<00:00, 547.57it/s]
Epoch 8/100: 100%|██████████| 414/414 [00:00<00:00, 546.31it/s]
Epoch 9/100: 100%|██████████| 414/414 [00:00<00:00, 497.15it/s]
Epoch 10/100: 100%|██████████| 414/414 [00:00<00:00, 515.60it/s]


Epoch 10/100, Train Loss: 0.291119, Val Loss: 0.294944


Epoch 11/100: 100%|██████████| 414/414 [00:00<00:00, 544.32it/s]
Epoch 12/100: 100%|██████████| 414/414 [00:00<00:00, 508.93it/s]
Epoch 13/100: 100%|██████████| 414/414 [00:00<00:00, 491.46it/s]
Epoch 14/100: 100%|██████████| 414/414 [00:00<00:00, 565.10it/s]
Epoch 15/100: 100%|██████████| 414/414 [00:00<00:00, 505.11it/s]
Epoch 16/100: 100%|██████████| 414/414 [00:00<00:00, 560.32it/s]
Epoch 17/100: 100%|██████████| 414/414 [00:00<00:00, 536.08it/s]
Epoch 18/100: 100%|██████████| 414/414 [00:00<00:00, 510.21it/s]
Epoch 19/100: 100%|██████████| 414/414 [00:00<00:00, 522.06it/s]
Epoch 20/100: 100%|██████████| 414/414 [00:00<00:00, 524.39it/s]


Epoch 20/100, Train Loss: 0.283914, Val Loss: 0.295039


Epoch 21/100: 100%|██████████| 414/414 [00:00<00:00, 500.17it/s]
Epoch 22/100: 100%|██████████| 414/414 [00:00<00:00, 530.25it/s]
Epoch 23/100: 100%|██████████| 414/414 [00:00<00:00, 512.19it/s]
Epoch 24/100: 100%|██████████| 414/414 [00:00<00:00, 471.62it/s]
Epoch 25/100: 100%|██████████| 414/414 [00:00<00:00, 495.96it/s]
Epoch 26/100: 100%|██████████| 414/414 [00:00<00:00, 505.41it/s]
Epoch 27/100: 100%|██████████| 414/414 [00:00<00:00, 506.52it/s]
Epoch 28/100: 100%|██████████| 414/414 [00:00<00:00, 514.82it/s]
Epoch 29/100: 100%|██████████| 414/414 [00:00<00:00, 515.14it/s]
Epoch 30/100: 100%|██████████| 414/414 [00:00<00:00, 548.67it/s]


Epoch 30/100, Train Loss: 0.255710, Val Loss: 0.313020


Epoch 31/100: 100%|██████████| 414/414 [00:00<00:00, 516.07it/s]
Epoch 32/100: 100%|██████████| 414/414 [00:00<00:00, 493.77it/s]

Early stopping at epoch 32





([1.1050289553288677,
  0.3844064027790862,
  0.3131975491985607,
  0.29855412129618697,
  0.29554848423326646,
  0.29407232503096264,
  0.2931280374383005,
  0.29236653447151184,
  0.2918951083784518,
  0.2911187402436123,
  0.2904218382017624,
  0.28986171399049715,
  0.28936839118095986,
  0.2888127211211384,
  0.2885899125521886,
  0.2883835282302709,
  0.28803187385561385,
  0.28760395044289927,
  0.2850473228716044,
  0.28391379132362954,
  0.28302320240488377,
  0.28188225929287897,
  0.2802225512582899,
  0.27869842665782874,
  0.2730510470947782,
  0.26965870791011387,
  0.26629475820899584,
  0.26271561068901117,
  0.25913644402067443,
  0.2557101748199854,
  0.24758549657276863,
  0.24382732113922276],
 [0.47420925159866995,
  0.3336997708448997,
  0.30385965968553835,
  0.2988210248832519,
  0.2970868091170604,
  0.29597899661614346,
  0.29546580005150574,
  0.2950788398201649,
  0.29495129447716933,
  0.29494399348130595,
  0.2946973262498012,
  0.2936868157524329,
  0.293

In [14]:
class PositionalEncoding(nn.Module):
    """Positional encoding for transformer"""
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        return x + self.pe[:x.size(0), :]


class EEGTransformerModel(nn.Module):
    """
    Transformer-based model for EEG to CLIP embedding conversion
    
    Input: (batch_size, num_channels, time_freq)
    Output: (batch_size, 1024) - CLIP embedding dimension
    """
    
    def __init__(self, 
                 num_channels, 
                 time_freq, 
                 d_model=512, 
                 nhead=8, 
                 num_layers=6, 
                 dim_feedforward=2048, 
                 dropout=0.1,
                 clip_embedding_dim=1024,
                 use_positional_encoding=True):
        super().__init__()
        
        self.num_channels = num_channels
        self.time_freq = time_freq
        self.d_model = d_model
        self.clip_embedding_dim = clip_embedding_dim
        
        # Input projection: project each channel to d_model dimensions
        self.input_projection = nn.Linear(num_channels, d_model)
        
        # Positional encoding
        self.use_positional_encoding = use_positional_encoding
        if use_positional_encoding:
            self.pos_encoder = PositionalEncoding(d_model, max_len=time_freq)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Global pooling and output projection
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.output_projection = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model // 2, clip_embedding_dim)
        )
        
        # Optional: Layer normalization for output
        self.output_norm = nn.LayerNorm(clip_embedding_dim)
        
    def forward(self, x):
        """
        Forward pass
        
        Args:
            x: Input tensor of shape (batch_size, num_channels, time_freq)
            
        Returns:
            Output tensor of shape (batch_size, clip_embedding_dim)
        """
        batch_size = x.size(0)
        
        # Transpose to (batch_size, time_freq, num_channels)
        x = x.transpose(1, 2)
        
        # Project channels to d_model dimensions: (batch_size, time_freq, d_model)
        x = self.input_projection(x)
        
        # Add positional encoding if enabled
        if self.use_positional_encoding:
            x = x.transpose(0, 1)  # (time_freq, batch_size, d_model)
            x = self.pos_encoder(x)
            x = x.transpose(0, 1)  # (batch_size, time_freq, d_model)
        
        # Apply transformer encoder
        x = self.transformer_encoder(x)  # (batch_size, time_freq, d_model)
        
        # Global pooling across time dimension
        x = x.transpose(1, 2)  # (batch_size, d_model, time_freq)
        x = self.global_pool(x)  # (batch_size, d_model, 1)
        x = x.squeeze(-1)  # (batch_size, d_model)
        
        # Project to CLIP embedding dimension
        x = self.output_projection(x)  # (batch_size, clip_embedding_dim)
        
        # Apply layer normalization
        x = self.output_norm(x)
        
        return x


class EEGTransformerModelV2(nn.Module):
    """
    Alternative transformer model with different architecture
    
    This version uses a more sophisticated approach with:
    - Multi-scale feature extraction
    - Residual connections
    - Channel attention mechanism
    """
    
    def __init__(self, 
                 num_channels, 
                 time_freq, 
                 d_model=512, 
                 nhead=8, 
                 num_layers=6, 
                 dim_feedforward=2048, 
                 dropout=0.1,
                 clip_embedding_dim=1024):
        super().__init__()
        
        self.num_channels = num_channels
        self.time_freq = time_freq
        self.d_model = d_model
        self.clip_embedding_dim = clip_embedding_dim
        
        # Multi-scale feature extraction
        self.channel_conv = nn.Sequential(
            nn.Conv1d(num_channels, d_model // 2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(d_model // 2, d_model, kernel_size=3, padding=1),
            nn.ReLU()
        )
        
        # Channel attention mechanism
        self.channel_attention = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Conv1d(d_model, d_model // 4, 1),
            nn.ReLU(),
            nn.Conv1d(d_model // 4, d_model, 1),
            nn.Sigmoid()
        )
        
        # Positional encoding
        self.pos_encoder = PositionalEncoding(d_model, max_len=time_freq)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Output projection with residual connection
        self.output_projection = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, clip_embedding_dim)
        )
        
        # Global pooling
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        
        # Output normalization
        self.output_norm = nn.LayerNorm(clip_embedding_dim)
        
    def forward(self, x):
        """
        Forward pass with multi-scale feature extraction and attention
        
        Args:
            x: Input tensor of shape (batch_size, num_channels, time_freq)
            
        Returns:
            Output tensor of shape (batch_size, clip_embedding_dim)
        """
        batch_size = x.size(0)
        
        # Multi-scale feature extraction
        conv_features = self.channel_conv(x)  # (batch_size, d_model, time_freq)
        
        # Channel attention
        attention_weights = self.channel_attention(conv_features)
        attended_features = conv_features * attention_weights
        
        # Prepare for transformer
        x = attended_features.transpose(1, 2)  # (batch_size, time_freq, d_model)
        
        # Add positional encoding
        x = x.transpose(0, 1)  # (time_freq, batch_size, d_model)
        x = self.pos_encoder(x)
        x = x.transpose(0, 1)  # (batch_size, time_freq, d_model)
        
        # Apply transformer encoder
        x = self.transformer_encoder(x)  # (batch_size, time_freq, d_model)
        
        # Global pooling
        x = x.transpose(1, 2)  # (batch_size, d_model, time_freq)
        x = self.global_pool(x)  # (batch_size, d_model, 1)
        x = x.squeeze(-1)  # (batch_size, d_model)
        
        # Output projection
        x = self.output_projection(x)  # (batch_size, clip_embedding_dim)
        
        # Normalize output
        x = self.output_norm(x)
        
        return x

In [15]:
transformer_model = EEGTransformerModel(num_channels=17, time_freq=80)
train_model(transformer_model, clip_train_loader, clip_val_loader, device='cuda', model_name='transformer_model')

Epoch 1/100: 100%|██████████| 414/414 [00:05<00:00, 74.25it/s]
Epoch 2/100: 100%|██████████| 414/414 [00:05<00:00, 75.34it/s]
Epoch 3/100: 100%|██████████| 414/414 [00:05<00:00, 75.92it/s]
Epoch 4/100: 100%|██████████| 414/414 [00:05<00:00, 75.93it/s]
Epoch 5/100: 100%|██████████| 414/414 [00:05<00:00, 74.27it/s]
Epoch 6/100: 100%|██████████| 414/414 [00:05<00:00, 75.01it/s]
Epoch 7/100: 100%|██████████| 414/414 [00:05<00:00, 75.08it/s]
Epoch 8/100: 100%|██████████| 414/414 [00:05<00:00, 74.87it/s]
Epoch 9/100: 100%|██████████| 414/414 [00:05<00:00, 75.26it/s]
Epoch 10/100: 100%|██████████| 414/414 [00:05<00:00, 75.52it/s]


Epoch 10/100, Train Loss: 0.300957, Val Loss: 0.301536


Epoch 11/100: 100%|██████████| 414/414 [00:05<00:00, 74.61it/s]
Epoch 12/100: 100%|██████████| 414/414 [00:05<00:00, 74.91it/s]
Epoch 13/100: 100%|██████████| 414/414 [00:05<00:00, 74.97it/s]
Epoch 14/100: 100%|██████████| 414/414 [00:05<00:00, 74.79it/s]
Epoch 15/100: 100%|██████████| 414/414 [00:05<00:00, 75.39it/s]
Epoch 16/100: 100%|██████████| 414/414 [00:05<00:00, 75.47it/s]
Epoch 17/100: 100%|██████████| 414/414 [00:05<00:00, 74.65it/s]
Epoch 18/100: 100%|██████████| 414/414 [00:05<00:00, 74.95it/s]
Epoch 19/100: 100%|██████████| 414/414 [00:05<00:00, 75.56it/s]
Epoch 20/100: 100%|██████████| 414/414 [00:05<00:00, 75.10it/s]


Epoch 20/100, Train Loss: 0.300616, Val Loss: 0.301212


Epoch 21/100: 100%|██████████| 414/414 [00:05<00:00, 74.90it/s]
Epoch 22/100: 100%|██████████| 414/414 [00:05<00:00, 75.17it/s]
Epoch 23/100: 100%|██████████| 414/414 [00:05<00:00, 75.24it/s]
Epoch 24/100: 100%|██████████| 414/414 [00:05<00:00, 74.19it/s]
Epoch 25/100: 100%|██████████| 414/414 [00:05<00:00, 74.74it/s]
Epoch 26/100: 100%|██████████| 414/414 [00:05<00:00, 74.87it/s]
Epoch 27/100: 100%|██████████| 414/414 [00:05<00:00, 74.76it/s]
Epoch 28/100: 100%|██████████| 414/414 [00:05<00:00, 74.01it/s]
Epoch 29/100: 100%|██████████| 414/414 [00:05<00:00, 74.07it/s]
Epoch 30/100: 100%|██████████| 414/414 [00:05<00:00, 75.07it/s]


Epoch 30/100, Train Loss: 0.300429, Val Loss: 0.301055


Epoch 31/100: 100%|██████████| 414/414 [00:05<00:00, 73.83it/s]
Epoch 32/100: 100%|██████████| 414/414 [00:05<00:00, 73.36it/s]
Epoch 33/100: 100%|██████████| 414/414 [00:05<00:00, 74.71it/s]
Epoch 34/100: 100%|██████████| 414/414 [00:05<00:00, 75.19it/s]
Epoch 35/100: 100%|██████████| 414/414 [00:05<00:00, 75.69it/s]
Epoch 36/100: 100%|██████████| 414/414 [00:05<00:00, 75.90it/s]
Epoch 37/100: 100%|██████████| 414/414 [00:05<00:00, 74.63it/s]
Epoch 38/100: 100%|██████████| 414/414 [00:05<00:00, 75.13it/s]
Epoch 39/100: 100%|██████████| 414/414 [00:05<00:00, 74.93it/s]
Epoch 40/100: 100%|██████████| 414/414 [00:05<00:00, 74.45it/s]


Epoch 40/100, Train Loss: 0.300340, Val Loss: 0.301035


Epoch 41/100: 100%|██████████| 414/414 [00:05<00:00, 74.75it/s]
Epoch 42/100: 100%|██████████| 414/414 [00:05<00:00, 74.62it/s]
Epoch 43/100: 100%|██████████| 414/414 [00:05<00:00, 74.52it/s]
Epoch 44/100: 100%|██████████| 414/414 [00:05<00:00, 74.43it/s]
Epoch 45/100: 100%|██████████| 414/414 [00:05<00:00, 73.24it/s]
Epoch 46/100: 100%|██████████| 414/414 [00:05<00:00, 74.49it/s]
Epoch 47/100: 100%|██████████| 414/414 [00:05<00:00, 74.59it/s]
Epoch 48/100: 100%|██████████| 414/414 [00:05<00:00, 74.19it/s]
Epoch 49/100: 100%|██████████| 414/414 [00:05<00:00, 74.45it/s]
Epoch 50/100: 100%|██████████| 414/414 [00:05<00:00, 75.35it/s]


Epoch 50/100, Train Loss: 0.300222, Val Loss: 0.300958


Epoch 51/100:  14%|█▍        | 58/414 [00:00<00:04, 75.61it/s]


KeyboardInterrupt: 