In [None]:
import os
import torch
from IPython.display import display
from models.core.diffusion.pipe import Pipe
from models.core.diffusion.custom_pipeline import Generator4Embeds
from models.core.diffusion.diffusion_prior import DiffusionPriorUNet
from utils.data_modules.diffusion_embedding import DiffusionEmbeddingDataModule

In [None]:
data_module = DiffusionEmbeddingDataModule(
    eeg_embeddings_file="path/to/eeg_embeddings.npy",
    subject=1,
    session=1,
    batch_size=1024,
    num_workers=4,
    val_split=0.1,
    test='default'
)

data_module.setup()
train_loader = data_module.train_dataloader()
test_loader = data_module.test_dataloader()

In [None]:
device = 'gpu' if torch.cuda.is_available() else 'cpu'

In [None]:
diffusion_prior = DiffusionPriorUNet(cond_dim=1024, dropout=0.1)
pipe = Pipe(diffusion_prior, device=device)

In [None]:
save_path = f"../models/check_points/diffusion_prior/subj0{1}_session{1}.pt"

In [None]:
# load pretrained model
model_name = 'diffusion_prior' # 'diffusion_prior_vice_pre_imagenet' or 'diffusion_prior_vice_pre'
pipe.diffusion_prior.load_state_dict(torch.load(save_path, map_location=device))

In [None]:
# Initialize the image generator
print("Initializing image generator...")
generator = Generator4Embeds(num_inference_steps=4, device=device)

In [None]:
# Generate images from test EEG signals
print("Generating images from test EEG...")
output_dir = "generated_images"
os.makedirs(output_dir, exist_ok=True)

num_samples_to_generate = min(100, len(test_loader.dataset))
num_inference_steps = 50
guidance_scale = 5.0

for i in range(num_samples_to_generate):
    if i % 10 == 0:
        print(f"Generating image {i+1}/{num_samples_to_generate}...")
    
    # Get EEG embedding for this sample
    eeg_embed, _ = test_loader.dataset[i].to(device)
    
    # Generate image embedding using diffusion prior
    generated_img_embed = pipe.generate(
        c_embeds=eeg_embed,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale
    )
    
    for j in range(10):
        # Generate actual image
        image = generator.generate(generated_img_embed.to(dtype=torch.float16))
        
        # Save image
        image_path = os.path.join(output_dir, f"generated_image_{i:03d}.png")
        image.save(image_path)
        
        # Display first 5 images
        if i < 5:
            print(f"Generated image {i+1}:")
            display(image)

print(f"All generated images saved to {output_dir}")

In [None]:
# Compare with ground truth: Generate image from actual image embeddings
print("Generating reference images from ground truth image embeddings...")
reference_dir = "reference_images"
os.makedirs(reference_dir, exist_ok=True)

for i in range(min(5, len(test_loader.dataset))):
    # Use ground truth image embedding
    _, gt_img_embed = test_loader[i:i+1].to(device)
    
    # Generate image directly from ground truth embedding
    reference_image = generator.generate(gt_img_embed.to(dtype=torch.float16))
    
    # Save reference image
    ref_path = os.path.join(reference_dir, f"reference_image_{i:03d}.png")
    reference_image.save(ref_path)
    
    print(f"Reference image {i+1} (ground truth):")
    display(reference_image)