<a href="https://colab.research.google.com/github/Jatin-Khiyani/Visual-Situmlai-Reconstruction-Using-fMRI-and-Deep-Learning/blob/main/VQ-VAE%20for%20stimuli/VQ_VAE_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive


In [None]:
# ✅ 2. UNZIP FULL DATASET
import zipfile
import os

zip_path = '/content/drive/MyDrive/NSD_Dataset/prepared_nsd_data_subj01.zip'
extract_path = '/content/prepared_nsd_data_subj01'

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

print("✅ Dataset unzipped!")


✅ Dataset unzipped!


In [None]:
# ✅ 4. INSTALL LIBRARIES
!pip install -q diffusers[torch] transformers accelerate


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m108.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m75.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m46.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m38.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m17.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# ✅ 3. CLEANUP NON-STANDARD IMAGE FILES
import os

image_dir = '/content/prepared_nsd_data_subj01'
files = os.listdir(image_dir)

standard = set(f'image_{i:05d}.png' for i in range(195000))

for f in files:
    if f.startswith('image_') and f.endswith('.png'):
        if f not in standard:
            print(f"Deleting non-standard image: {f}")
            os.remove(os.path.join(image_dir, f))

print("✅ Cleaned invalid image files")


✅ Cleaned invalid image files


In [None]:
# ✅ 5. EXTRACT VQ-VAE LATENTS IN BATCHES (Optimized with skip logic)
import torch
from PIL import Image
from torchvision import transforms
from diffusers.models import AutoencoderKL
from tqdm import tqdm
import os

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

image_dir = '/content/prepared_nsd_data_subj01/prepared_nsd_data_subj01'
z_save_dir = '/content/drive/MyDrive/NSD_Dataset/z_latents'
os.makedirs(z_save_dir, exist_ok=True)

# === Load VAE
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
vae.eval().to(device)

# === Preprocess images
preprocess = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# === Load image filenames
all_image_files = sorted([
    f for f in os.listdir(image_dir)
    if f.startswith('image_') and f.endswith('.png') and f[6:-4].isdigit()
])

# === Skip already processed files
already_done = set(f.replace('.pt', '.png') for f in os.listdir(z_save_dir) if f.endswith('.pt'))
image_files = [f for f in all_image_files if f not in already_done]

batch_size = 36
print(f"Found {len(all_image_files)} total images.")
print(f"Skipping {len(already_done)} already processed.")
print(f"Processing {len(image_files)} remaining images in batches of {batch_size}...")

# === Encode in batches
with torch.no_grad():
    for i in tqdm(range(0, len(image_files), batch_size), desc="Encoding with VQ-VAE"):
        batch_files = image_files[i:i+batch_size]
        batch_imgs = []

        for fname in batch_files:
            try:
                img = Image.open(os.path.join(image_dir, fname)).convert("RGB")
                tensor = preprocess(img)
                batch_imgs.append(tensor)
            except Exception as e:
                print(f"⚠️ Skipping image {fname} due to error: {e}")
                continue

        if not batch_imgs:
            continue  # skip empty batch

        img_tensor = torch.stack(batch_imgs).to(device)
        z = vae.encode(img_tensor).latent_dist.sample() * 0.18215

        for j, fname in enumerate(batch_files):
            z_path = os.path.join(z_save_dir, fname.replace('.png', '.pt'))
            torch.save(z[j].cpu(), z_path)

print("✅ All latents saved to:", z_save_dir)
