In [None]:
import os
import torch
import random
from torch import nn
from IPython.display import display
from torch.utils.data import DataLoader
from models.core.diffusion.pipe import Pipe
from models.trainers.contrastive import ContrastiveTrainerModel
from models.core.diffusion.custom_pipeline import Generator4Embeds
from utils.data_modules.contrastive import EEGContrastiveDataModule
from models.core.diffusion.diffusion_prior import DiffusionPriorUNet
from utils.datasets.diffusion_embedding import DiffusionEmbeddingDataset

In [None]:
dm = EEGContrastiveDataModule(
    input_channels=['Fp1', 'AF7', 'AF3', 'F1', 'F3', 'F5', 'F7', 'FT7', 'FC5', 'FC3', 'FC1', 'C1', 'C3', 'C5', 'T7', 'TP7', 'CP5', 'CP3', 'CP1', 'P1', 'P3', 'P5', 'P7', 'P9', 'PO7', 'PO3', 'O1', 'Iz', 'Oz', 'POz', 'Pz', 'CPz', 'Fpz', 'Fp2', 'AF8', 'AF4', 'AFz', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT8', 'FC6', 'FC4', 'FC2', 'FCz', 'Cz', 'C2', 'C4', 'C6', 'T8', 'TP8', 'CP6', 'CP4', 'CP2', 'P2', 'P4', 'P6', 'P8', 'P10', 'PO8', 'PO4', 'O2'],
    sfreq=250,
    montage='standard_1020',
    window_before_event_ms=50,
    window_after_event_ms=600,
    subject=1, 
    session=1, 
    batch_size=1024, 
    num_workers=4
)

In [None]:
sample_data = dm.get_sample_info()
train_loader = dm.train_dataloader()
test_loader = dm.test_dataloader()

In [None]:
epochs = 150
subject = 1
session = 1
num_channels = sample_data['input']['num_channels']
timesteps = sample_data['input']['num_timesteps']
num_fine_labels = sample_data['output']['fine_labels_shape']

In [None]:
encoder_checkpoint_path = "../models/check_points/contrastive_encoder/subj1_session1_epoch=199.ckpt"
train_features_path = '../../data/all-joined-1/coco/features/train_features.pt'
test_features_path = '../../data/all-joined-1/coco/features/test_features.pt'
save_path = f"../models/check_points/diffusion_prior/subj{subject}_session{session}.pt"

In [None]:
checkpoint_path = encoder_checkpoint_path
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(checkpoint_path, map_location=device)

In [None]:
lightning_model = ContrastiveTrainerModel(num_channels, timesteps, num_fine_labels)

if 'state_dict' in checkpoint:
    lightning_model.load_state_dict(checkpoint['state_dict'])
else:
    lightning_model.load_state_dict(checkpoint)

In [None]:
contrastive_model = lightning_model.to(device)
contrastive_model.eval()

In [None]:
# Extract EEG features using your trained contrastive model
print("Extracting EEG features from training data...")

train_features_list = []
train_labels_list = []

with torch.no_grad():
    for batch_idx, (eeg_data, output) in enumerate(train_loader):
        if batch_idx % 10 == 0:
            print(f"Processing batch {batch_idx}/{len(train_loader)}")
        
        img, img_features, text_features, super_labels, fine_labels = output
        eeg_data = eeg_data.to(device)
        
        # Extract EEG features using your contrastive encoder
        eeg_features = contrastive_model.encoder(eeg_data)  # This gives you the aligned embeddings
        
        train_features_list.append(eeg_features.cpu())
        train_labels_list.append(fine_labels.cpu())

# Concatenate all features
train_eeg_features = torch.cat(train_features_list, dim=0)
train_labels = torch.cat(train_labels_list, dim=0)

print(f"Training EEG features shape: {train_eeg_features.shape}")
print(f"Training labels shape: {train_labels.shape}")

In [None]:
# Extract EEG features from test data
print("Extracting EEG features from test data...")
test_features_list = []
test_labels_list = []

with torch.no_grad():
    for batch_idx, (eeg_data, output) in enumerate(test_loader):
        if batch_idx % 10 == 0:
            print(f"Processing batch {batch_idx}/{len(test_loader)}")
        
        img, img_features, text_features, super_labels, fine_labels = output
        eeg_data = eeg_data.to(device)
        
        # Extract EEG features using your contrastive encoder
        eeg_features = contrastive_model.encoder(eeg_data)
        
        test_features_list.append(eeg_features.cpu())
        test_labels_list.append(fine_labels.cpu())

# Concatenate all features
test_eeg_features = torch.cat(test_features_list, dim=0)
test_labels = torch.cat(test_labels_list, dim=0)

print(f"Test EEG features shape: {test_eeg_features.shape}")
print(f"Test labels shape: {test_labels.shape}")

In [None]:
# Load pre-computed image embeddings (ViT-H-14 features)
print("Loading image embeddings...")
img_embeddings_train = torch.load(train_features_path)
img_embeddings_test = torch.load(test_features_path)

print(f"Original train image embeddings shape: {img_embeddings_train.shape}")
print(f"Original test image embeddings shape: {img_embeddings_test.shape}")

In [None]:
# Reshape if needed (based on your original code structure)
if len(img_embeddings_train.shape) == 4:  # (num_images, num_repetitions, num_views, embed_dim)
    img_embeddings_train_reshaped = img_embeddings_train.view(-1, img_embeddings_train.shape[-1])
else:
    img_embeddings_train_reshaped = img_embeddings_train

print(f"Reshaped train image embeddings shape: {img_embeddings_train_reshaped.shape}")

In [None]:
# Create dataset for diffusion training
print("Creating diffusion training dataset...")
diffusion_dataset = DiffusionEmbeddingDataset(
    c_embeddings=train_eeg_features, 
    h_embeddings=img_embeddings_train_reshaped
)

diffusion_dataloader = DataLoader(
    diffusion_dataset, 
    batch_size=1024, 
    shuffle=True, 
    num_workers=64
)

print(f"Diffusion dataset size: {len(diffusion_dataset)}")

In [None]:
# Initialize diffusion prior
# The cond_dim should match the output dimension of your contrastive encoder
encoder_output_dim = train_eeg_features.shape[1]
print(f"Encoder output dimension: {encoder_output_dim}")

diffusion_prior = DiffusionPriorUNet(
    cond_dim=encoder_output_dim,  # This should match your encoder output
    dropout=0.1
)

print(f"Diffusion prior parameters: {sum(p.numel() for p in diffusion_prior.parameters() if p.requires_grad)}")

In [None]:
# Train the diffusion prior
print("Training diffusion prior...")
pipe = Pipe(diffusion_prior, device=device)
pipe.train(diffusion_dataloader, num_epochs=150, learning_rate=1e-3)

In [None]:
# Initialize the image generator
print("Initializing image generator...")
generator = Generator4Embeds(num_inference_steps=4, device=device)

In [None]:
# Save the trained diffusion prior
os.makedirs(os.path.dirname(save_path), exist_ok=True)
torch.save(pipe.diffusion_prior.state_dict(), save_path)
print(f"Diffusion prior saved to {save_path}")

In [None]:
# Generate images from test EEG signals
print("Generating images from test EEG...")
output_dir = "generated_images_contrastive"
os.makedirs(output_dir, exist_ok=True)

num_samples_to_generate = min(100, len(test_eeg_features))
num_inference_steps = 50
guidance_scale = 5.0

for i in range(num_samples_to_generate):
    if i % 10 == 0:
        print(f"Generating image {i+1}/{num_samples_to_generate}...")
    
    # Get EEG embedding for this sample
    eeg_embed = test_eeg_features[i:i+1].to(device)
    
    # Generate image embedding using diffusion prior
    generated_img_embed = pipe.generate(
        c_embeds=eeg_embed,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale
    )
    
    # Generate actual image
    image = generator.generate(generated_img_embed.to(dtype=torch.float16))
    
    # Save image
    image_path = os.path.join(output_dir, f"generated_image_{i:03d}.png")
    image.save(image_path)
    
    # Display first 5 images
    if i < 5:
        print(f"Generated image {i+1}:")
        display(image)

print(f"All generated images saved to {output_dir}")

In [None]:
# Compare with ground truth: Generate image from actual image embeddings
print("Generating reference images from ground truth image embeddings...")
reference_dir = "reference_images"
os.makedirs(reference_dir, exist_ok=True)

for i in range(min(5, len(img_embeddings_test))):
    # Use ground truth image embedding
    gt_img_embed = img_embeddings_test[i:i+1].to(device)
    
    # Generate image directly from ground truth embedding
    reference_image = generator.generate(gt_img_embed.to(dtype=torch.float16))
    
    # Save reference image
    ref_path = os.path.join(reference_dir, f"reference_image_{i:03d}.png")
    reference_image.save(ref_path)
    
    print(f"Reference image {i+1} (ground truth):")
    display(reference_image)

In [None]:
# Evaluate reconstruction quality (optional)
print("Evaluating reconstruction quality...")

# Calculate similarity between generated and ground truth embeddings
similarities = []
mse_losses = []

for i in range(min(50, len(test_eeg_features))):
    # Generate embedding from EEG
    eeg_embed = test_eeg_features[i:i+1].to(device)
    generated_embed = pipe.generate(
        c_embeds=eeg_embed,
        num_inference_steps=50,
        guidance_scale=5.0
    )
    
    # Get ground truth embedding
    gt_embed = img_embeddings_test[i:i+1].to(device)
    
    # Calculate cosine similarity
    cos_sim = torch.nn.functional.cosine_similarity(
        generated_embed, gt_embed, dim=1
    ).item()
    
    # Calculate MSE
    mse = torch.nn.functional.mse_loss(
        generated_embed, gt_embed
    ).item()
    
    similarities.append(cos_sim)
    mse_losses.append(mse)

avg_similarity = sum(similarities) / len(similarities)
avg_mse = sum(mse_losses) / len(mse_losses)

print(f"Average cosine similarity: {avg_similarity:.4f}")
print(f"Average MSE: {avg_mse:.4f}")

In [None]:
# Generate images with different guidance scales (experiment)
print("Experimenting with different guidance scales...")
experiment_dir = "guidance_scale_experiment"
os.makedirs(experiment_dir, exist_ok=True)

sample_idx = 0  # Use first test sample
eeg_embed = test_eeg_features[sample_idx:sample_idx+1].to(device)

guidance_scales = [0.0, 2.5, 5.0, 7.5, 10.0]

for guidance_scale in guidance_scales:
    print(f"Generating with guidance scale {guidance_scale}...")
    
    # Generate image embedding
    generated_embed = pipe.generate(
        c_embeds=eeg_embed,
        num_inference_steps=50,
        guidance_scale=guidance_scale
    )
    
    # Generate image
    image = generator.generate(generated_embed.to(dtype=torch.float16))
    
    # Save image
    image_path = os.path.join(experiment_dir, f"guidance_{guidance_scale}_sample_{sample_idx}.png")
    image.save(image_path)
    
    print(f"Guidance scale {guidance_scale}:")
    display(image)

print("Guidance scale experiment completed!")

In [None]:
# %%
# Optional: Load pre-trained diffusion prior for inference only
# Uncomment if you want to load a pre-trained model instead of training

# print("Loading pre-trained diffusion prior...")
# diffusion_prior_pretrained = DiffusionPriorUNet(cond_dim=encoder_output_dim, dropout=0.1)
# diffusion_prior_pretrained.load_state_dict(torch.load(save_path))
# pipe_pretrained = Pipe(diffusion_prior_pretrained, device=device)

# # Use pipe_pretrained for inference
# # ...

# %%