# Neuro-Hybrid Latent Diffusion Model (NH-LDM) Demo

This notebook demonstrates the image reconstruction pipeline from fMRI data.

In [None]:
import yaml
import torch
import matplotlib.pyplot as plt
from project.data.preprocessing import DataProcessor
from project.models.semantic_stream import SemanticDecoder
from project.models.structural_stream import StructuralDecoder
from project.models.diffusion_pipeline import ReconstructionPipeline
from project.evaluation.visualization import visualize_results

# Load Config
with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("Config loaded.")

## 1. Load Test Data
We load the held-out test set fMRI data.

In [None]:
processor = DataProcessor(config)
test_dataset = processor.load_data(split='test')
print(f"Test data loaded: {len(test_dataset)} samples")

## 2. Load Models
Load the trained Ridge Regressions and Stable Diffusion pipeline.

In [None]:
# Initialize Models
semantic_model = SemanticDecoder(config)
structural_model = StructuralDecoder(config)
diffusion_pipe = ReconstructionPipeline(config)

# Note: In a real run, you would load the weights here using pickle
# with open('project/models/semantic_decoder.pkl', 'rb') as f: semantic_model.model = pickle.load(f)
# with open('project/models/structural_decoder.pkl', 'rb') as f: structural_model.model = pickle.load(f)

## 3. Run Reconstruction
Select a random sample and reconstruct it.

In [None]:
idx = 0
data_sample = test_dataset[idx]
fmri_vec = data_sample['fmri'].unsqueeze(0) # (1, Voxels)

# Extract ROIs (Mocking the ROI extraction call for single sample)
hvc_indices = processor.roi_extractor.get_indices(config['data']['rois']['hvc'])
evc_indices = processor.roi_extractor.get_indices(config['data']['rois']['evc'])

hvc_data = processor.roi_extractor.select_voxels(fmri_vec.numpy(), hvc_indices)
evc_data = processor.roi_extractor.select_voxels(fmri_vec.numpy(), evc_indices)

# Predict
c_hat = semantic_model.predict(hvc_data)
z_hat = structural_model.predict(evc_data)

# Convert to Tensor
c_hat = torch.from_numpy(c_hat).to(diffusion_pipe.device).float()
z_hat = torch.from_numpy(z_hat).to(diffusion_pipe.device).float()

# Generate
images = diffusion_pipe.reconstruct(z_hat, c_hat)

plt.imshow(images[0])
plt.title("Reconstruction")
plt.axis('off')
plt.show()