<a href="https://colab.research.google.com/github/Ajinkya-18/NeuroVision/blob/main/neurovision4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dataset preparation

In [None]:
import os
import torch
import numpy as np
from pathlib import Path
from tqdm.auto import tqdm
from scipy.signal import butter, sosfiltfilt, spectrogram, iirnotch, tf2sos
from datasets import load_dataset, concatenate_datasets
import pandas as pd
from sklearn.model_selection import train_test_split
import shutil
import requests
import zipfile

# ==============================================================================
# --- 1. CONFIGURATION ---
# ==============================================================================
class CONFIG_FINAL_PREP:
    HUGGINGFACE_DATASET = "Alljoined/05_125"
    COCO_TRAIN_URL = "http://images.cocodataset.org/zips/train2017.zip"
    COCO_VAL_URL = "http://images.cocodataset.org/zips/val2017.zip"
    TEMP_DIR = "/content/coco_temp"
    TRAIN_ZIP_PATH = os.path.join(TEMP_DIR, "train2017.zip")
    VAL_ZIP_PATH = os.path.join(TEMP_DIR, "val2017.zip")
    EXTRACT_DIR = os.path.join(TEMP_DIR, "coco_full_unzipped")

    # --- DEFINITIVE FIX: All processing will happen on the fast local disk first ---
    LOCAL_OUTPUT_ROOT = '/content/alljoined_lightweight_17k'

    # --- Final Destination on Google Drive ---
    DRIVE_OUTPUT_ROOT = '/content/drive/MyDrive/NeuroVision/alljoined_lightweight_17k'

    COCO_17K_DIR_NAME = 'coco_images_17k'
    SPECTROGRAMS_DIR_NAME = 'spectrograms'
    TOTAL_SAMPLES = 17000
    TRAIN_SIZE = 14000
    VAL_SIZE = 2800
    TEST_SIZE = 200
    FS = 250; LOW_CUT = 4; HIGH_CUT = 100; FILTER_ORDER = 5; NPERSEG = 128
    NOVERLAP = 64; EEG_CHANNELS = 64; TARGET_EEG_LEN = 334

# ==============================================================================
# --- 2. HELPER FUNCTIONS (Unchanged) ---
# ==============================================================================
def download_file(url, save_path):
    if os.path.exists(save_path):
        print(f"File {Path(save_path).name} already exists. Skipping download.")
        return True
    try:
        response = requests.get(url, stream=True); response.raise_for_status()
        total_size = int(response.headers.get('content-length', 0))
        with open(save_path, 'wb') as f, tqdm(desc=Path(save_path).name, total=total_size, unit='iB', unit_scale=True, unit_divisor=1024) as bar:
            for data in response.iter_content(chunk_size=1024*1024):
                f.write(data); bar.update(len(data))
        return True
    except requests.exceptions.RequestException: return False

def bandpass_filter(data, lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs; low = lowcut / nyq; high = highcut / nyq
    sos = butter(order, [low, high], analog=False, btype='band', output='sos')
    return sosfiltfilt(sos, data, axis=1)

def generate_stft_spectrogram(eeg_data, config):
    fs = config.FS; q = 30.0
    b50, a50 = iirnotch(50.0, q, fs); sos50 = tf2sos(b50, a50)
    notched = sosfiltfilt(sos50, eeg_data, axis=1)
    b60, a60 = iirnotch(60.0, q, fs); sos60 = tf2sos(b60, a60)
    notched = sosfiltfilt(sos60, notched, axis=1)
    filtered = bandpass_filter(notched, config.LOW_CUT, config.HIGH_CUT, fs, config.FILTER_ORDER)
    specs = [np.log1p(spectrogram(filtered[i,:], fs,nperseg=config.NPERSEG,noverlap=config.NOVERLAP)[2]) for i in range(filtered.shape[0])]
    return torch.tensor(np.array(specs), dtype=torch.float32)

# ==============================================================================
# --- 3. MAIN SCRIPT (Refactored to work locally) ---
# ==============================================================================
def create_final_dataset(config):
    # --- Setup LOCAL Paths ---
    output_root = Path(config.LOCAL_OUTPUT_ROOT)
    filtered_images_dest_dir = output_root / config.COCO_17K_DIR_NAME
    spectrograms_dest_dir = output_root / config.SPECTROGRAMS_DIR_NAME
    metadata_save_path = output_root / 'metadata.csv'
    output_root.mkdir(parents=True, exist_ok=True)

    # --- All fast, local operations ---
    # ... (The entire dataset creation process now happens on the fast local disk) ...

    ds = load_dataset(config.HUGGINGFACE_DATASET)
    full_dataset = concatenate_datasets([ds['train'], ds['test']])
    df = full_dataset.to_pandas()
    df_subset = df.sample(n=min(config.TOTAL_SAMPLES, len(df)), random_state=42)

    # This step is now much faster as it copies between two local folders
    required_coco_ids = set(df_subset['coco_id'].unique())
    print(f"Copying {len(required_coco_ids)} images locally...")
    source_dirs = [Path(config.EXTRACT_DIR) / 'train2017', Path(config.EXTRACT_DIR) / 'val2017']
    for coco_id in tqdm(required_coco_ids, desc="Copying images"):
        img_fn = f"{coco_id:012d}.jpg"; dest_p = filtered_images_dest_dir / img_fn
        if dest_p.exists(): continue
        for source_dir in source_dirs:
            src_p = source_dir / img_fn
            if src_p.exists():
                # Create parent dir for the image and then copy
                os.makedirs(dest_p.parent, exist_ok=True)
                shutil.copy(src_p, dest_p)
                break

    train_df, test_df = train_test_split(df_subset, test_size=config.TEST_SIZE, random_state=42)
    val_split_ratio = config.VAL_SIZE / (config.TRAIN_SIZE + config.VAL_SIZE)
    train_df, val_df = train_test_split(train_df, test_size=val_split_ratio, random_state=42)
    train_df.loc[:, 'split'] = 'train'; val_df.loc[:, 'split'] = 'val'; test_df.loc[:, 'split'] = 'test'
    final_df = pd.concat([train_df, val_df, test_df])

    all_metadata = []
    for split_name, split_df in {'train': train_df, 'val': val_df, 'test': test_df}.items():
        split_spectrogram_dir = spectrograms_dest_dir / split_name
        split_spectrogram_dir.mkdir(parents=True, exist_ok=True)
        for index, row in tqdm(split_df.iterrows(), total=len(split_df), desc=f"Processing '{split_name}' split"):
            eeg_raw_list = row['EEG']
            processed_channels = [(p:=[0.0]*(config.TARGET_EEG_LEN-len(c)))+c if len(c)<config.TARGET_EEG_LEN else c[-config.TARGET_EEG_LEN:] if len(c)>config.TARGET_EEG_LEN else c for c in eeg_raw_list]
            eeg_data = np.array(processed_channels, dtype=np.float32)
            spectrogram_tensor = generate_stft_spectrogram(eeg_data, config)
            image_filename = f"{row['coco_id']:012d}.jpg"
            spectrogram_filename = f"sample_{index}.pt"
            spectrogram_save_path = split_spectrogram_dir / spectrogram_filename
            torch.save(spectrogram_tensor, spectrogram_save_path)
            all_metadata.append({'image_path': os.path.join(config.COCO_17K_DIR_NAME, image_filename),
                                 'spectrogram_path': os.path.join(config.SPECTROGRAMS_DIR_NAME, split_name, spectrogram_filename),
                                 'split': split_name})

    metadata_df = pd.DataFrame(all_metadata)
    metadata_df.to_csv(metadata_save_path, index=False)

    # --- FINAL STEP: Copy the completed dataset to Google Drive ---
    print("\n--- ✅ Local Processing Complete! ---")
    print(f"Copying the final dataset from '{config.LOCAL_OUTPUT_ROOT}' to '{config.DRIVE_OUTPUT_ROOT}'...")

    # We use a system command for a robust copy operation.
    # The '!' runs this as a shell command in Colab.
    # It will overwrite the destination to ensure it's a clean copy.
    if os.path.exists(config.DRIVE_OUTPUT_ROOT):
        shutil.rmtree(config.DRIVE_OUTPUT_ROOT)
    shutil.copytree(config.LOCAL_OUTPUT_ROOT, config.DRIVE_OUTPUT_ROOT)

    print("\n--- 🎉 All Done! Your dataset is successfully saved to Google Drive. ---")

# --- EXECUTION ---
if __name__ == '__main__':
    # These setup steps are fast
    config = CONFIG_FINAL_PREP()
    os.makedirs(config.TEMP_DIR, exist_ok=True)
    if not download_file(config.COCO_TRAIN_URL, config.TRAIN_ZIP_PATH): exit()
    if not download_file(config.COCO_VAL_URL, config.VAL_ZIP_PATH): exit()
    for zip_path in [config.TRAIN_ZIP_PATH, config.VAL_ZIP_PATH]:
        split_name = Path(zip_path).stem
        if not (Path(config.EXTRACT_DIR) / split_name).exists():
            with zipfile.ZipFile(zip_path, 'r') as zf: zf.extractall(config.EXTRACT_DIR)
            os.remove(zip_path)

    # Main processing function
    create_final_dataset(config)

train2017.zip:   0%|          | 0.00/18.0G [00:00<?, ?iB/s]

In [None]:
# This command recursively (-r) zips the specified folder.
!zip -r /content/alljoined_lightweight_17k.zip /content/alljoined_lightweight_17k

In [None]:
from google.colab import files

# This command initiates the download of the zip file.
files.download('/content/alljoined_lightweight_17k.zip')

In [None]:
# This command creates a single, compressed archive of your entire dataset.
# The 'c' is for create, 'z' for gzip compression, 'f' for file.
!tar -czf /content/alljoined_lightweight_17k.tar.gz -C /content/alljoined_lightweight_17k .

In [None]:
# This copies the single .tar.gz file to your destination.
# This operation is much more stable than copying thousands of individual files.
!cp /content/alljoined_lightweight_17k.tar.gz /content/drive/MyDrive/NeuroVision/

In [None]:
len(os.listdir('/content/drive/MyDrive/NeuroVision/alljoined_lightweight_17k/coco_images_17k'))

In [None]:
# Create a compressed .tar.gz archive of your new lightweight dataset
!tar -czf /content/drive/MyDrive/NeuroVision/final_lightweight_17k.tar.gz -C /content/drive/MyDrive/NeuroVision/alljoined_lightweight_17k .

# Contrastive Encoder Model training

In [None]:
import os
os.makedirs('/content/final_lightweight_17k', exist_ok=True)

In [None]:
!cp /content/drive/MyDrive/NeuroVision/alljoined_lightweight_17k.tar.gz /content/

In [None]:
!tar -xzf /content/alljoined_lightweight_17k.tar.gz -C /content/final_lightweight_17k

In [None]:
# # 2. Extract the archive locally (very fast)
# !tar -xzf /content/drive/MyDrive/NeuroVision/alljoined_lightweight_17k.tar.gz -C /content/final_lightweight_17k

In [None]:
import torch
import random
from pathlib import Path
import matplotlib.pyplot as plt

# ==============================================================================
# --- 1. CONFIGURATION ---
# ==============================================================================

# IMPORTANT: Update this path to your main lightweight dataset folder.
DATASET_ROOT = Path('/content/final_lightweight_17k')

# Number of random samples to display from each data split.
N_SAMPLES_PER_SPLIT = 5

# ==============================================================================
# --- 2. DATA LOADING & VISUALIZATION ---
# ==============================================================================

def visualize_random_spectrograms(root_path, n_samples):
    """
    Loads and visualizes random raw spectrograms from the lightweight dataset.
    """
    splits = ['train', 'val', 'test']
    spectrograms_dir = root_path / 'spectrograms'

    # Check if the main directories exist
    if not root_path.exists() or not spectrograms_dir.exists():
        print(f"Error: Dataset root or spectrograms directory not found at '{root_path}'")
        print("Please update the DATASET_ROOT variable to the correct path.")
        return

    # --- Create the plot grid ---
    fig, axes = plt.subplots(
        nrows=len(splits),
        ncols=n_samples,
        figsize=(5 * n_samples, 4 * len(splits)),
        squeeze=False # Ensures axes is always a 2D array
    )

    fig.suptitle('Random Raw Spectrograms per Split', fontsize=20)

    for i, split_name in enumerate(splits):
        split_path = spectrograms_dir / split_name

        try:
            sample_files = list(split_path.glob('sample_*.pt'))
            if not sample_files:
                print(f"Warning: No sample files found in '{split_path}'. Skipping.")
                # Turn off unused axes in the grid for a cleaner look
                for j in range(n_samples):
                    axes[i, j].axis('off')
                continue

            random_files = random.sample(sample_files, min(n_samples, len(sample_files)))
        except Exception as e:
            print(f"Error accessing files in '{split_path}': {e}")
            continue

        # --- Plot each sample for the current split ---
        for j, file_path in enumerate(random_files):
            ax = axes[i, j]

            # UPDATED: Load the tensor directly, as it's not in a dictionary anymore.
            spectrogram = torch.load(file_path)

            # Use the first channel of the spectrogram for visualization
            # The spectrograms are now small (e.g., 65x5), not 224x224
            ax.imshow(spectrogram[0], cmap='viridis', origin='lower', aspect='auto')

            if j == 0:
                ax.set_ylabel(split_name.title(), fontsize=14, weight='bold')

            ax.set_title(f"{file_path.stem}")
            ax.set_xticks([])
            ax.set_yticks([])

    plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust for suptitle
    plt.show()

# --- EXECUTION ---
if __name__ == '__main__':
    visualize_random_spectrograms(DATASET_ROOT, N_SAMPLES_PER_SPLIT)

# Contrastive Encoder

In [None]:
import torch
from pathlib import Path
from tqdm.auto import tqdm

# Path to your lightweight dataset on the local Colab disk
DATA_ROOT = '/content/final_lightweight_17k'
SPECTROGRAM_TRAIN_DIR = Path(DATA_ROOT) / 'spectrograms' / 'train'

# Initialize accumulators
channel_sum = torch.zeros(64, dtype=torch.float64)
channel_sum_sq = torch.zeros(64, dtype=torch.float64)
pixel_count = 0

# Calculate stats only on the training set
files = list(SPECTROGRAM_TRAIN_DIR.glob('*.pt'))
for path in tqdm(files, desc="Calculating Spectrogram Stats"):
    data = torch.load(path)
    channel_sum += data.sum(dim=[1, 2]).to(torch.float64)
    channel_sum_sq += (data.to(torch.float64) ** 2).sum(dim=[1, 2])
    pixel_count += data.shape[1] * data.shape[2]

mean = (channel_sum / pixel_count).to(torch.float32)
std = torch.sqrt((channel_sum_sq / pixel_count) - mean.to(torch.float64)**2).to(torch.float32)

# Save the stats to the dataset folder
torch.save(mean, Path(DATA_ROOT) / 'spec_mean.pt')
torch.save(std, Path(DATA_ROOT) / 'spec_std.pt')

print(f"\nStats calculated and saved to {DATA_ROOT}")

Calculating Spectrogram Stats:   0%|          | 0/14000 [00:00<?, ?it/s]


Stats calculated and saved to /content/final_lightweight_17k


In [None]:
import os
import torch
import timm
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import v2 as transforms
from PIL import Image
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

# ==============================================================================
# --- 1. CONFIGURATION for Training from Scratch ---
# ==============================================================================
class CONFIG_CONTRASTIVE_TRAIN:
    PROCESSED_DATA_ROOT = '/content/final_lightweight_17k'
    RESUME_CHECKPOINT_PATH = '/content/drive/MyDrive/NeuroVision/contrastive_scratch_outputs/run_20250917_221600/contrastive_checkpoint_best.pth'
    OUTPUT_DIR = '/content/drive/MyDrive/NeuroVision/contrastive_scratch_outputs'
    LR = 5e-4
    WEIGHT_DECAY = 0.01
    BATCH_SIZE = 48
    NUM_EPOCHS = 200
    NUM_WORKERS = 2
    EEG_CHANNELS = 64
    ENCODER_DIM = 192
    PROJECTION_DIM = 256

# ==============================================================================
# --- 2. MODEL DEFINITIONS (Unchanged) ---
# ==============================================================================
class ProjectionHead(nn.Module):
    def __init__(self, embedding_dim, projection_dim):
        super().__init__(); self.projection = nn.Linear(embedding_dim, projection_dim); self.gelu = nn.GELU(); self.fc = nn.Linear(projection_dim, projection_dim); self.dropout = nn.Dropout(0.1); self.layer_norm = nn.LayerNorm(projection_dim)
    def forward(self, x):
        projected = self.projection(x); x = self.gelu(projected); x = self.fc(x); x = self.dropout(x); x = x + projected; x = self.layer_norm(x); return x

class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__(); self.temperature = nn.Parameter(torch.tensor(temperature))
    def forward(self, eeg_embeddings, image_embeddings):
        eeg_embeddings = F.normalize(eeg_embeddings, p=2, dim=1); image_embeddings = F.normalize(image_embeddings, p=2, dim=1); logits = (eeg_embeddings @ image_embeddings.T) / self.temperature; labels = torch.arange(len(logits)).to(logits.device); loss_eeg = F.cross_entropy(logits, labels); loss_image = F.cross_entropy(logits.T, labels); return (loss_eeg + loss_image) / 2

# ==============================================================================
# --- 3. DATASET CLASS (Unchanged) ---
# ==============================================================================
class LightweightEEGDataset(Dataset):
    def __init__(self, root_dir, metadata_csv, split, eeg_transform, image_transform):
        self.root_dir = Path(root_dir); self.eeg_transform = eeg_transform; self.image_transform = image_transform
        meta_df = pd.read_csv(metadata_csv)
        self.split_df = meta_df[meta_df['split'].str.strip() == split].reset_index(drop=True)
    def __len__(self): return len(self.split_df)
    def __getitem__(self, idx):
        sample_info = self.split_df.iloc[idx]; spectrogram_path = self.root_dir / sample_info['spectrogram_path']; image_path = self.root_dir / sample_info['image_path']
        spectrogram_tensor = torch.load(spectrogram_path); image = Image.open(image_path).convert("RGB")
        spectrogram_tensor = self.eeg_transform(spectrogram_tensor); image_tensor = self.image_transform(image)
        return spectrogram_tensor, image_tensor

# ==============================================================================
# --- 4. MAIN TRAINING SCRIPT (with Normalization) ---
# ==============================================================================
def run_contrastive_training(config):
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    output_dir = Path(config.OUTPUT_DIR) / f'run_{timestamp}'
    output_dir.mkdir(parents=True, exist_ok=True)
    checkpoint_path = output_dir / 'contrastive_checkpoint_best.pth'
    writer = SummaryWriter(log_dir=str(output_dir / 'logs'))

    # --- Prepare Data (Updated with Normalization) ---
    print("1. Preparing data and DataLoaders...")

    # Load the pre-calculated spectrogram stats
    spec_mean = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_mean.pt')
    spec_std = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_std.pt')

    eeg_transform = transforms.Compose([
        transforms.ToDtype(torch.float32, scale=False),
        transforms.Resize((224, 224), antialias=True),
        transforms.Normalize(mean=spec_mean.tolist(), std=spec_std.tolist()) # Add normalization
    ])
    image_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    metadata_csv_path = Path(config.PROCESSED_DATA_ROOT) / 'metadata.csv'

    train_dataset = LightweightEEGDataset(config.PROCESSED_DATA_ROOT, metadata_csv_path, 'train', eeg_transform, image_transform)
    train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True,
                              num_workers=config.NUM_WORKERS, drop_last=True, persistent_workers=True)

    val_dataset = LightweightEEGDataset(config.PROCESSED_DATA_ROOT, metadata_csv_path, 'val', eeg_transform, image_transform)
    val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False,
                            num_workers=config.NUM_WORKERS, persistent_workers=True)

    # --- Initialize Models ---
    print("2. Initializing models...")
    eeg_encoder = timm.create_model('vit_tiny_patch16_224', pretrained=True, in_chans=config.EEG_CHANNELS, num_classes=0).to(device)
    image_encoder = timm.create_model('vit_tiny_patch16_224', pretrained=True, num_classes=0).to(device)
    for param in image_encoder.parameters(): param.requires_grad = False
    eeg_projection = ProjectionHead(config.ENCODER_DIM, config.PROJECTION_DIM).to(device)
    image_projection = ProjectionHead(config.ENCODER_DIM, config.PROJECTION_DIM).to(device)

    # --- Training Setup ---
    criterion = ContrastiveLoss().to(device)
    trainable_params = list(eeg_encoder.parameters()) + list(eeg_projection.parameters()) + list(image_projection.parameters())
    optimizer = optim.AdamW(trainable_params, lr=config.LR, weight_decay=config.WEIGHT_DECAY)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
    best_val_loss = float('inf')

    # --- Training & Validation Loop ---
    print("\n--- Starting Contrastive Training From Scratch ---")
    for epoch in range(config.NUM_EPOCHS):
        eeg_encoder.train(); eeg_projection.train(); image_projection.train(); image_encoder.eval()

        total_train_loss = 0.0
        train_progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.NUM_EPOCHS} [Train]")
        for spectrogram_tensors, image_tensors in train_progress_bar:
            spectrogram_tensors, image_tensors = spectrogram_tensors.to(device), image_tensors.to(device); optimizer.zero_grad()
            eeg_features = eeg_encoder.forward_features(spectrogram_tensors)[:, 0]
            with torch.no_grad(): image_features = image_encoder.forward_features(image_tensors)[:, 0]
            eeg_embeddings = eeg_projection(eeg_features); image_embeddings = image_projection(image_features)
            loss = criterion(eeg_embeddings, image_embeddings); loss.backward(); optimizer.step()
            total_train_loss += loss.item(); train_progress_bar.set_postfix(Loss=f"{loss.item():.4f}")

        avg_train_loss = total_train_loss / len(train_loader)
        writer.add_scalar('Loss/train', avg_train_loss, epoch)

        eeg_encoder.eval(); eeg_projection.eval(); image_projection.eval()
        total_val_loss = 0.0
        val_progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.NUM_EPOCHS} [Val]")
        with torch.no_grad():
            for spectrogram_tensors, image_tensors in val_progress_bar:
                spectrogram_tensors, image_tensors = spectrogram_tensors.to(device), image_tensors.to(device)
                eeg_features = eeg_encoder.forward_features(spectrogram_tensors)[:, 0]; image_features = image_encoder.forward_features(image_tensors)[:, 0]
                eeg_embeddings = eeg_projection(eeg_features); image_embeddings = image_projection(image_features)
                loss = criterion(eeg_embeddings, image_embeddings); total_val_loss += loss.item()

        avg_val_loss = total_val_loss / len(val_loader)
        writer.add_scalar('Loss/validation', avg_val_loss, epoch)
        writer.add_scalar('LearningRate', optimizer.param_groups[0]['lr'], epoch)

        print(f"Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, LR: {optimizer.param_groups[0]['lr']:.2e}")
        scheduler.step(avg_val_loss)

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            print(f"New best validation loss ({best_val_loss:.4f}). Saving checkpoint to {checkpoint_path}")
            checkpoint = {'epoch': epoch, 'eeg_encoder_state_dict': eeg_encoder.state_dict(),
                          'eeg_projection_state_dict': eeg_projection.state_dict(),
                          'image_projection_state_dict': image_projection.state_dict(),
                          'optimizer_state_dict': optimizer.state_dict(), 'loss': best_val_loss}
            torch.save(checkpoint, checkpoint_path)

    writer.close()
    print(f"\n--- Training complete. Best model checkpoint saved to {checkpoint_path} ---")

# if __name__ == '__main__':
#     cl_train_config = CONFIG_CONTRASTIVE_TRAIN()
#     run_contrastive_training(cl_train_config)

In [None]:
import torch
import timm
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import v2 as transforms
from PIL import Image
import torch.nn as nn
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np

# ==============================================================================
# --- 1. CONFIGURATION ---
# ==============================================================================
class CONFIG_TSNE:
    # --- Source Data Path (on local Colab disk) ---
    PROCESSED_DATA_ROOT = '/content/final_lightweight_17k'

    # --- IMPORTANT: UPDATE THIS PATH to your best saved checkpoint file ---
    SAVED_CHECKPOINT_PATH = '/content/drive/MyDrive/NeuroVision/contrastive_scratch_outputs/run_20250917_221600/contrastive_checkpoint_best.pth'

    # --- Parameters ---
    BATCH_SIZE = 40
    NUM_WORKERS = 2

    # --- Model Dimensions (must match the trained model) ---
    EEG_CHANNELS = 64
    ENCODER_DIM = 192
    PROJECTION_DIM = 256

# ==============================================================================
# --- 2. MODEL AND DATASET DEFINITIONS (from training script) ---
# ==============================================================================
class ProjectionHead(nn.Module):
    def __init__(self, embedding_dim, projection_dim):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(0.1)
        self.layer_norm = nn.LayerNorm(projection_dim)
    def forward(self, x):
        projected = self.projection(x); x = self.gelu(projected); x = self.fc(x)
        x = self.dropout(x); x = x + projected; x = self.layer_norm(x)
        return x

class LightweightEEGDataset(Dataset):
    def __init__(self, root_dir, metadata_csv, split, eeg_transform):
        self.root_dir = Path(root_dir)
        self.eeg_transform = eeg_transform
        meta_df = pd.read_csv(metadata_csv)
        self.split_df = meta_df[meta_df['split'].str.strip() == split].reset_index(drop=True)
    def __len__(self): return len(self.split_df)
    def __getitem__(self, idx):
        sample_info = self.split_df.iloc[idx]
        spectrogram_path = self.root_dir / sample_info['spectrogram_path']
        spectrogram_tensor = torch.load(spectrogram_path)
        spectrogram_tensor = self.eeg_transform(spectrogram_tensor)
        return spectrogram_tensor

# ==============================================================================
# --- 3. MAIN VISUALIZATION SCRIPT ---
# ==============================================================================
def run_tsne_visualization(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # --- 1. Load the trained models from the checkpoint ---
    print("Loading trained EEG encoder and projection head...")
    eeg_encoder = timm.create_model('vit_tiny_patch16_224', pretrained=False, in_chans=config.EEG_CHANNELS, num_classes=0)
    eeg_projection = ProjectionHead(config.ENCODER_DIM, config.PROJECTION_DIM)

    try:
        checkpoint = torch.load(config.SAVED_CHECKPOINT_PATH, map_location=device)
        eeg_encoder.load_state_dict(checkpoint['eeg_encoder_state_dict'])
        eeg_projection.load_state_dict(checkpoint['eeg_projection_state_dict'])
        print("Model weights loaded successfully.")
    except FileNotFoundError:
        print(f"❌ ERROR: Checkpoint file not found at '{config.SAVED_CHECKPOINT_PATH}'")
        print("Please update the path in the configuration.")
        return
    except Exception as e:
        print(f"❌ An error occurred while loading the model: {e}")
        return

    eeg_encoder.to(device).eval()
    eeg_projection.to(device).eval()

    # --- 2. Prepare the Test Dataset ---
    print("Preparing test data loader...")
    # Normalization stats are needed for the transform
    spec_mean = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_mean.pt')
    spec_std = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_std.pt')

    eeg_transform = transforms.Compose([
        transforms.ToDtype(torch.float32, scale=False),
        transforms.Resize((224, 224), antialias=True),
        transforms.Normalize(mean=spec_mean.tolist(), std=spec_std.tolist())
    ])
    metadata_csv_path = Path(config.PROCESSED_DATA_ROOT) / 'metadata.csv'
    test_dataset = LightweightEEGDataset(config.PROCESSED_DATA_ROOT, metadata_csv_path, 'val', eeg_transform)
    test_loader = DataLoader(test_dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=config.NUM_WORKERS)

    # --- 3. Extract Features from the Test Set ---
    print("Extracting features from the test set...")
    all_features = []
    with torch.no_grad():
        for spectrogram_tensors in tqdm(test_loader, desc="Extracting Features"):
            spectrogram_tensors = spectrogram_tensors.to(device)
            features = eeg_encoder.forward_features(spectrogram_tensors)[:, 0]
            embeddings = eeg_projection(features)
            all_features.append(embeddings.cpu().numpy())

    features_array = np.concatenate(all_features, axis=0)
    print(f"Extracted {features_array.shape[0]} feature vectors.")

    # --- 4. Run t-SNE ---
    print("\nRunning t-SNE... (This may take a few minutes)")
    tsne = TSNE(n_components=2, perplexity=30, random_state=42, n_iter=1000, init='pca', learning_rate='auto')
    tsne_results = tsne.fit_transform(features_array)
    print("t-SNE complete.")

    # --- 5. Plot the Results ---
    plt.figure(figsize=(12, 10))
    plt.scatter(tsne_results[:, 0], tsne_results[:, 1], alpha=0.5)
    plt.title('t-SNE Visualization of Final EEG Features (Val Set)', fontsize=16)
    plt.xlabel('t-SNE Dimension 1')
    plt.ylabel('t-SNE Dimension 2')
    plt.grid(True)
    plt.show()

if __name__ == '__main__':
    tsne_config = CONFIG_TSNE()
    run_tsne_visualization(tsne_config)

# GAN

In [None]:
# ENCODER_CHECKPOINT_PATH = '/content/drive/MyDrive/NeuroVision/contrastive_scratch_outputs/run_20250917_221600/contrastive_checkpoint_best.pth'

In [None]:
print("⏳ Installing and upgrading libraries...")
# The -q flag makes the output cleaner
!pip install -q --upgrade diffusers transformers accelerate torchmetrics timm
!pip install ftfy regex lpips
# !pip install git+https://github.com/openai/CLIP.git

In [None]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import torchvision.models as models
import lpips
import pandas as pd
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision.transforms import v2 as transforms
from PIL import Image
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from tqdm.auto import tqdm
from datetime import datetime
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from torchvision.utils import save_image
import numpy as np

# ==============================================================================
# --- 1. CONFIGURATION ---
# ==============================================================================
class TRAIN_CONFIG:
    PROCESSED_DATA_ROOT = '/content/final_lightweight_17k'
    METADATA_CSV = Path(PROCESSED_DATA_ROOT) / 'metadata.csv'
    OUTPUT_DIR = '/content/drive/MyDrive/NeuroVision/parallel_fusion_encoder'

    # --- Training Parameters ---
    BATCH_SIZE = 2
    GRAD_ACCUMULATION_STEPS = 16
    NUM_EPOCHS = 100
    LR = 1e-5

    # --- Loss Weights ---
    MAE_LOSS_WEIGHT = 0.5
    LPIPS_LOSS_WEIGHT = 1.5

    # --- Encoder Dimensions ---
    CNN_FEATURES_DIM = 256  # Output of ResNet's layer3
    VIT_EMBED_DIM = 192     # Output of ViT-tiny
    FUSION_DIM = 256        # Common dimension for cross-attention
    EEG_EMBED_DIM = 768     # Final dimension for UNet

    # --- Script Control ---
    VISUALIZATION_INTERVAL = 1
    VAL_SAMPLES_PER_EPOCH = 280
    TRAIN_SAMPLES_PER_EPOCH = 1400

# ==============================================================================
# --- 2. CORE ARCHITECTURE MODULES ---
# ==============================================================================
class CrossAttentionBlock(nn.Module):
    """A simple cross-attention block."""
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5

        self.to_q = nn.Linear(dim, dim, bias=False)
        self.to_kv = nn.Linear(dim, dim * 2, bias=False)
        self.to_out = nn.Linear(dim, dim)

    def forward(self, query, context):
        q = self.to_q(query)
        k, v = self.to_kv(context).chunk(2, dim=-1)

        dots = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
        attn = dots.softmax(dim=-1)
        out = torch.einsum('b i j, b j d -> b i d', attn, v)

        return self.to_out(out)

class ParallelFusionEncoder(nn.Module):
    def __init__(self, cnn_out_dim, vit_embed_dim, fusion_dim, unet_cond_dim, in_chans=64):
        super().__init__()

        # --- 1. CNN Branch (Local Features) ---
        resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        resnet.conv1 = nn.Conv2d(in_chans, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.cnn_branch = nn.Sequential(*list(resnet.children())[:-3]) # Output of layer3

        # --- 2. ViT Branch (Global Features) ---
        self.vit_branch = timm.create_model('vit_tiny_patch16_224', pretrained=True, num_classes=0, in_chans=in_chans)

        # --- 3. Alignment Layers (Corrected) ---
        # These project the features from each branch to the common FUSION_DIM
        self.align_cnn_features = nn.Linear(cnn_out_dim, fusion_dim)
        self.align_vit_features = nn.Linear(vit_embed_dim, fusion_dim) # Expects 192, outputs FUSION_DIM

        # --- 4. Fusion Block ---
        self.fusion_block = CrossAttentionBlock(dim=fusion_dim)

        # --- 5. Final Projector ---
        # Takes concatenated [global_vec, local_vec]
        self.final_projector = nn.Sequential(
            nn.Linear(fusion_dim * 2, unet_cond_dim * 2),
            nn.GELU(),
            nn.Linear(unet_cond_dim * 2, unet_cond_dim)
        )

    def forward(self, x):
        # --- Branch 1: CNN ---
        cnn_map = self.cnn_branch(x)
        B, C, H, W = cnn_map.shape
        cnn_seq = cnn_map.flatten(2).permute(0, 2, 1) # -> (B, 196, 256)

        # --- Branch 2: ViT ---
        vit_seq = self.vit_branch.forward_features(x) # -> (B, 197, 192)

        # --- Align dimensions ---
        cnn_aligned = self.align_cnn_features(cnn_seq)       # -> (B, 196, FUSION_DIM)
        vit_aligned = self.align_vit_features(vit_seq) # -> (B, 197, FUSION_DIM)

        # --- Fuse features ---
        fused_cnn = self.fusion_block(
            query=cnn_aligned,
            context=vit_aligned
        )

        # --- Get final feature vectors ---
        global_features = vit_aligned[:, 0]       # Use the ViT's [CLS] token
        local_features = fused_cnn.mean(dim=1)  # Average the context-aware local features

        # --- Combine and project ---
        combined_features = torch.cat([global_features, local_features], dim=-1)
        final_embedding = self.final_projector(combined_features)

        return final_embedding

# ==============================================================================
# --- 3. DATASET (Unchanged) ---
# ==============================================================================
class LightweightEEGDataset(Dataset):
    def __init__(self, root_dir, metadata_csv, split, eeg_transform, image_transform):
        self.root_dir = Path(root_dir)
        self.eeg_transform = eeg_transform
        self.image_transform = image_transform
        df = pd.read_csv(metadata_csv)
        self.split_df = df[df['split'].str.strip() == split].reset_index(drop=True)

    def __len__(self): return len(self.split_df)
    def __getitem__(self, idx):
        info = self.split_df.iloc[idx]
        spec_p = self.root_dir / info['spectrogram_path']
        img_p = self.root_dir / info['image_path']
        spec = torch.load(spec_p)
        image = Image.open(img_p).convert("RGB")
        return self.eeg_transform(spec), self.image_transform(image)

# ==============================================================================
# --- 4. MAIN TRAINING SCRIPT ---
# ==============================================================================
def train_parallel_model(config):
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    output_dir = Path(config.OUTPUT_DIR) / f'run_{timestamp}'
    output_dir.mkdir(parents=True, exist_ok=True)
    checkpoint_path = output_dir / 'parallel_encoder_best.pth'

    print("1. Loading models...")
    vae = AutoencoderKL.from_pretrained("segmind/tiny-sd", subfolder="vae").to(device).eval()
    unet = UNet2DConditionModel.from_pretrained("segmind/tiny-sd", subfolder="unet").to(device).eval()
    for p in vae.parameters(): p.requires_grad = False
    for p in unet.parameters(): p.requires_grad = False

    eeg_encoder = ParallelFusionEncoder(
        cnn_out_dim=config.CNN_FEATURES_DIM,
        vit_embed_dim=config.VIT_EMBED_DIM,
        fusion_dim=config.FUSION_DIM,
        unet_cond_dim=config.EEG_EMBED_DIM
    ).to(device)

    print("2. Preparing data...")
    spec_mean = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_mean.pt')
    spec_std = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_std.pt')

    eeg_transform = transforms.Compose([transforms.ToDtype(torch.float32), transforms.Resize((224, 224)), transforms.Normalize(mean=spec_mean.tolist(), std=spec_std.tolist())])
    image_transform = transforms.Compose([transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

    train_dataset = LightweightEEGDataset(config.PROCESSED_DATA_ROOT, config.METADATA_CSV, 'train', eeg_transform, image_transform)
    val_dataset = LightweightEEGDataset(config.PROCESSED_DATA_ROOT, config.METADATA_CSV, 'val', eeg_transform, image_transform)

    val_indices = np.random.choice(len(val_dataset), min(len(val_dataset), config.VAL_SAMPLES_PER_EPOCH), replace=False)
    val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, sampler=SubsetRandomSampler(val_indices))
    fixed_spectrograms, fixed_real_images = next(iter(val_loader)); fixed_spectrograms, fixed_real_images = fixed_spectrograms.to(device), fixed_real_images.to(device)

    print("3. Setting up optimizer and loss...")
    optimizer = optim.AdamW(eeg_encoder.parameters(), lr=config.LR, weight_decay=1e-4)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2, eta_min=1e-7)
    scaler = torch.cuda.amp.GradScaler()
    noise_scheduler = DDPMScheduler.from_pretrained("segmind/tiny-sd", subfolder="scheduler")
    mae_loss_fn = nn.L1Loss()
    lpips_loss_fn = lpips.LPIPS(net='vgg').to(device)

    best_val_loss = float('inf')
    print("\n--- Starting Training with Parallel Fusion Encoder ---")

    for epoch in range(config.NUM_EPOCHS):
        eeg_encoder.train()
        train_indices = np.random.choice(len(train_dataset), min(len(train_dataset), config.TRAIN_SAMPLES_PER_EPOCH), replace=False)
        train_loader_subset = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, sampler=SubsetRandomSampler(train_indices), drop_last=True)
        train_bar = tqdm(train_loader_subset, desc=f"Epoch {epoch+1} [Train]")
        optimizer.zero_grad()

        for i, (spectrograms, real_images) in enumerate(train_bar):
            spectrograms, real_images = spectrograms.to(device), real_images.to(device)

            with torch.no_grad():
                latents = vae.encode(real_images).latent_dist.sample() * vae.config.scaling_factor
            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            with torch.cuda.amp.autocast():
                unet_embedding = eeg_encoder(spectrograms)
                eeg_cond = unet_embedding.unsqueeze(1)
                noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=eeg_cond).sample

                alpha_prod_t = noise_scheduler.alphas_cumprod.to(device)[timesteps].view(-1, 1, 1, 1)
                sqrt_one_minus_alpha_prod_t = (1 - alpha_prod_t).sqrt()
                pred_latents = (noisy_latents - sqrt_one_minus_alpha_prod_t * noise_pred) / alpha_prod_t.sqrt()
                generated_images = vae.decode(pred_latents / vae.config.scaling_factor, return_dict=False)[0]

                loss_mae = mae_loss_fn(noise_pred, noise)
                loss_lpips = lpips_loss_fn(generated_images, real_images).mean()
                loss = (loss_mae * config.MAE_LOSS_WEIGHT) + (loss_lpips * config.LPIPS_LOSS_WEIGHT)

                loss = loss / config.GRAD_ACCUMULATION_STEPS

            scaler.scale(loss).backward()

            if (i + 1) % config.GRAD_ACCUMULATION_STEPS == 0:
                torch.nn.utils.clip_grad_norm_(eeg_encoder.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            train_bar.set_postfix({'loss_mae': f'{loss_mae.item():.3f}', 'loss_lpips': f'{loss_lpips.item():.3f}', 'total_loss': f'{loss.item() * config.GRAD_ACCUMULATION_STEPS:.3f}'})

        # --- Validation & Visualization Loops (Unchanged) ---
        eeg_encoder.eval()
        total_val_loss = 0
        val_bar = tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]")
        with torch.no_grad():
            for spectrograms, real_images in val_bar:
                spectrograms, real_images = spectrograms.to(device), real_images.to(device)
                with torch.cuda.amp.autocast():
                    latents = vae.encode(real_images).latent_dist.sample() * vae.config.scaling_factor
                    noise = torch.randn_like(latents)
                    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()
                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
                    unet_embedding = eeg_encoder(spectrograms)
                    eeg_cond = unet_embedding.unsqueeze(1)
                    noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=eeg_cond).sample
                    alpha_prod_t = noise_scheduler.alphas_cumprod.to(device)[timesteps].view(-1, 1, 1, 1)
                    sqrt_one_minus_alpha_prod_t = (1 - alpha_prod_t).sqrt()
                    pred_latents = (noisy_latents - sqrt_one_minus_alpha_prod_t * noise_pred) / alpha_prod_t.sqrt()
                    generated_images = vae.decode(pred_latents / vae.config.scaling_factor, return_dict=False)[0]

                    loss_mae = mae_loss_fn(noise_pred, noise)
                    loss_lpips = lpips_loss_fn(generated_images, real_images).mean()
                    loss = (loss_mae * config.MAE_LOSS_WEIGHT) + (loss_lpips * config.LPIPS_LOSS_WEIGHT)
                    total_val_loss += loss.item()
                    val_bar.set_postfix({'val_loss': f'{loss.item():.4f}'})

        avg_val_loss = total_val_loss / len(val_loader) if len(val_loader) > 0 else 0
        scheduler.step()
        print(f"Epoch {epoch+1}/{config.NUM_EPOCHS} -> Val Loss: {avg_val_loss:.4f}, LR: {scheduler.get_last_lr()[0]:.2e}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            print(f"✨ New best validation loss. Saving encoder to {checkpoint_path}")
            torch.save({'eeg_encoder_state_dict': eeg_encoder.state_dict()}, checkpoint_path)

        if (epoch + 1) % config.VISUALIZATION_INTERVAL == 0:
            eeg_encoder.eval()
            with torch.no_grad(), torch.cuda.amp.autocast():
                unet_embedding = eeg_encoder(fixed_spectrograms)
                eeg_cond = unet_embedding.unsqueeze(1)
                latents = torch.randn((fixed_spectrograms.shape[0], 4, 64, 64), device=device, dtype=torch.float16)
                noise_scheduler.set_timesteps(50)
                for t in tqdm(noise_scheduler.timesteps, desc="Generating Images", leave=False):
                    noise_pred = unet(latents, t, encoder_hidden_states=eeg_cond.half()).sample
                    latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
                generated_images = vae.decode(latents / vae.config.scaling_factor, return_dict=False)[0]

            comp_grid = torch.cat([fixed_real_images, generated_images.float()])
            comp_grid = (comp_grid * 0.5 + 0.5).clamp(0, 1)
            save_path = output_dir / f'reconstructions_epoch_{epoch+1:03d}.png'
            save_image(comp_grid, save_path, nrow=config.BATCH_SIZE)
            print(f"Saved reconstructions to {save_path}")

    print("\n--- Training Complete ---")

if __name__ == '__main__':
    train_parallel_model(TRAIN_CONFIG())

1. Loading models...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
An error occurred while trying to fetch segmind/tiny-sd: segmind/tiny-sd does not appear to have a file named diffusion_pytorch_model.safetensors.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
An error occurred while trying to fetch segmind/tiny-sd: segmind/tiny-sd does not appear to have a file named diffusion_pytorch_model.safetensors.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 205MB/s]


model.safetensors:   0%|          | 0.00/22.9M [00:00<?, ?B/s]

2. Preparing data...
3. Setting up optimizer and loss...


  scaler = torch.cuda.amp.GradScaler()


Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/vgg.pth

--- Starting Training with Parallel Fusion Encoder ---


Epoch 1 [Train]:   0%|          | 0/700 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Epoch 1 [Val]:   0%|          | 0/140 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Epoch 1/100 -> Val Loss: 0.9177, LR: 9.05e-06
✨ New best validation loss. Saving encoder to /content/drive/MyDrive/NeuroVision/parallel_fusion_encoder/run_20250924_101049/parallel_encoder_best.pth


  with torch.no_grad(), torch.cuda.amp.autocast():


Generating Images:   0%|          | 0/50 [00:00<?, ?it/s]

Saved reconstructions to /content/drive/MyDrive/NeuroVision/parallel_fusion_encoder/run_20250924_101049/reconstructions_epoch_001.png


Epoch 2 [Train]:   0%|          | 0/700 [00:00<?, ?it/s]

Epoch 2 [Val]:   0%|          | 0/140 [00:00<?, ?it/s]

Epoch 2/100 -> Val Loss: 0.9030, LR: 6.58e-06
✨ New best validation loss. Saving encoder to /content/drive/MyDrive/NeuroVision/parallel_fusion_encoder/run_20250924_101049/parallel_encoder_best.pth


Generating Images:   0%|          | 0/50 [00:00<?, ?it/s]

Saved reconstructions to /content/drive/MyDrive/NeuroVision/parallel_fusion_encoder/run_20250924_101049/reconstructions_epoch_002.png


Epoch 3 [Train]:   0%|          | 0/700 [00:00<?, ?it/s]

Epoch 3 [Val]:   0%|          | 0/140 [00:00<?, ?it/s]

Epoch 3/100 -> Val Loss: 0.8792, LR: 3.52e-06
✨ New best validation loss. Saving encoder to /content/drive/MyDrive/NeuroVision/parallel_fusion_encoder/run_20250924_101049/parallel_encoder_best.pth


Generating Images:   0%|          | 0/50 [00:00<?, ?it/s]

Saved reconstructions to /content/drive/MyDrive/NeuroVision/parallel_fusion_encoder/run_20250924_101049/reconstructions_epoch_003.png


Epoch 4 [Train]:   0%|          | 0/700 [00:00<?, ?it/s]

Epoch 4 [Val]:   0%|          | 0/140 [00:00<?, ?it/s]

Epoch 4/100 -> Val Loss: 0.8701, LR: 1.05e-06
✨ New best validation loss. Saving encoder to /content/drive/MyDrive/NeuroVision/parallel_fusion_encoder/run_20250924_101049/parallel_encoder_best.pth


Generating Images:   0%|          | 0/50 [00:00<?, ?it/s]

Saved reconstructions to /content/drive/MyDrive/NeuroVision/parallel_fusion_encoder/run_20250924_101049/reconstructions_epoch_004.png


Epoch 5 [Train]:   0%|          | 0/700 [00:00<?, ?it/s]

Epoch 5 [Val]:   0%|          | 0/140 [00:00<?, ?it/s]

Epoch 5/100 -> Val Loss: 0.8735, LR: 1.00e-05


Generating Images:   0%|          | 0/50 [00:00<?, ?it/s]

Saved reconstructions to /content/drive/MyDrive/NeuroVision/parallel_fusion_encoder/run_20250924_101049/reconstructions_epoch_005.png


Epoch 6 [Train]:   0%|          | 0/700 [00:00<?, ?it/s]

# Test Scripts

In [None]:
import torch
import torch.nn as nn
import timm
import numpy as np
import pandas as pd
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import v2 as transforms
from PIL import Image
from tqdm.auto import tqdm
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
from torchvision.utils import save_image
from datetime import datetime

# ==============================================================================
# --- 1. CONFIGURATION ---
# ==============================================================================
class TEST_CONFIG:
    # --- Paths ---
    PROCESSED_DATA_ROOT = '/content/final_lightweight_17k'
    METADATA_CSV = Path(PROCESSED_DATA_ROOT) / 'metadata.csv'
    # IMPORTANT: Update this path to point to your saved model checkpoint
    EEG_ENCODER_CHECKPOINT_PATH = '/content/drive/MyDrive/NeuroVision/end_to_end_perceptual/run_20250922_220152/eeg_encoder_best.pth'
    OUTPUT_DIR = '/content/drive/MyDrive/NeuroVision/test_results'

    # --- Model Dimensions (must match training) ---
    BACKBONE_DIM = 192
    EEG_EMBED_DIM = 768 # For UNet
    CLIP_EMBED_DIM = 512 # For CLIP

    # --- Inference Parameters ---
    BATCH_SIZE = 5 # How many images to generate at once
    NUM_TEST_SAMPLES = 20 # Max number of images to generate from the test set
    INFERENCE_STEPS = 50 # Number of denoising steps
    GUIDANCE_SCALE = 10 # How strongly to adhere to the EEG signal (higher means stronger)

# ==============================================================================
# --- 2. EEG ENCODER (Copy-pasted from your training script) ---
# ==============================================================================
class EEGEncoder(nn.Module):
    def __init__(self, backbone_dim, projection_dim):
        super().__init__()
        self.backbone = timm.create_model('vit_tiny_patch16_224', pretrained=False, in_chans=64, num_classes=0)
        self.projector = nn.Sequential(
            nn.Linear(backbone_dim, projection_dim),
            nn.GELU(),
            nn.LayerNorm(projection_dim),
            nn.Linear(projection_dim, projection_dim)
        )
    def forward(self, x):
        features = self.backbone.forward_features(x)[:, 0]
        return self.projector(features)

# ==============================================================================
# --- 3. DATASET (Copy-pasted from your training script) ---
# ==============================================================================
class LightweightEEGDataset(Dataset):
    def __init__(self, root_dir, metadata_csv, split, eeg_transform, image_transform):
        self.root_dir=Path(root_dir)
        self.eeg_transform=eeg_transform
        self.image_transform=image_transform
        df=pd.read_csv(metadata_csv)
        self.split_df=df[df['split'].str.strip()==split].reset_index(drop=True)

    def __len__(self): return len(self.split_df)
    def __getitem__(self, idx):
        info=self.split_df.iloc[idx]
        spec_p=self.root_dir/info['spectrogram_path']
        img_p=self.root_dir/info['image_path']
        spec=torch.load(spec_p)
        image=Image.open(img_p).convert("RGB")
        return self.eeg_transform(spec), self.image_transform(image)

# ==============================================================================
# --- 4. MAIN TESTING SCRIPT ---
# ==============================================================================
def test_reconstruction(config):
    # --- Setup ---
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    output_dir = Path(config.OUTPUT_DIR) / f'test_{timestamp}'
    output_dir.mkdir(parents=True, exist_ok=True)
    print(f"Using device: {device}. Test results will be saved to {output_dir}")

    # --- 1. Load All Models ---
    print("1. Loading pre-trained models (VAE, UNet, and EEG Encoder)...")
    vae = AutoencoderKL.from_pretrained("segmind/tiny-sd", subfolder="vae", torch_dtype=torch.float16).to(device).eval()
    unet = UNet2DConditionModel.from_pretrained("segmind/tiny-sd", subfolder="unet", torch_dtype=torch.float16).to(device).eval()

    # Load your trained EEG Encoder
    eeg_encoder = EEGEncoder(config.BACKBONE_DIM, config.EEG_EMBED_DIM).to(device)
    try:
        checkpoint = torch.load(config.EEG_ENCODER_CHECKPOINT_PATH, map_location=device)
        eeg_encoder.load_state_dict(checkpoint['eeg_encoder_state_dict'])
        eeg_encoder.eval().to(torch.float16) # Use float16 for faster inference
        print("✅ Successfully loaded EEG Encoder checkpoint.")
    except Exception as e:
        print(f"🛑 Error loading EEG Encoder checkpoint: {e}")
        return

    # --- 2. Prepare Data ---
    print("2. Preparing test dataset...")
    spec_mean = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_mean.pt')
    spec_std = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_std.pt')

    eeg_transform = transforms.Compose([transforms.ToDtype(torch.float32), transforms.Resize((224,224)), transforms.Normalize(mean=spec_mean.tolist(),std=spec_std.tolist())])
    image_transform = transforms.Compose([transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

    test_dataset = LightweightEEGDataset(config.PROCESSED_DATA_ROOT, config.METADATA_CSV, 'test', eeg_transform, image_transform)
    test_loader = DataLoader(test_dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=0)
    print(f"Found {len(test_dataset)} samples in the test set.")

    # --- 3. Setup Inference Pipeline ---
    scheduler = DDIMScheduler.from_pretrained("segmind/tiny-sd", subfolder="scheduler")
    scheduler.set_timesteps(config.INFERENCE_STEPS)

    # --- 4. Run Inference Loop ---
    print("\n--- Starting Image Generation from Test Set ---")
    total_samples_generated = 0
    with torch.no_grad():
        for i, (spectrograms, real_images) in enumerate(tqdm(test_loader, desc="Generating Images")):
            if total_samples_generated >= config.NUM_TEST_SAMPLES:
                break

            spectrograms = spectrograms.to(device, dtype=torch.float16)
            batch_size = spectrograms.shape[0]
            real_images.to(device)

            # Get EEG embeddings from your trained encoder
            with torch.autocast('cuda'):
                unet_embedding = eeg_encoder(spectrograms)
            eeg_cond = unet_embedding.unsqueeze(1) # Add sequence dimension for UNet

            # Create unconditional embeddings for Classifier-Free Guidance
            uncond_embedding = torch.zeros_like(eeg_cond)
            # Concatenate for parallel processing
            embeddings = torch.cat([uncond_embedding, eeg_cond])

            # Prepare initial random noise
            latents = torch.randn((batch_size, unet.config.in_channels, 64, 64), device=device, dtype=torch.float16)
            latents = latents * scheduler.init_noise_sigma

            # Denoising loop
            for t in tqdm(scheduler.timesteps, leave=False, desc="Denoising"):
                # We need to process noise for both conditional and unconditional embeddings
                latent_model_input = torch.cat([latents] * 2)
                latent_model_input = scheduler.scale_model_input(latent_model_input, t)

                with torch.autocast('cuda'):
                    noise_pred = unet(latent_model_input, t, encoder_hidden_states=embeddings).sample

                # Perform guidance
                noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + config.GUIDANCE_SCALE * (noise_pred_cond - noise_pred_uncond)

                # Compute the previous noisy sample x_t -> x_{t-1}
                latents = scheduler.step(noise_pred, t, latents).prev_sample

            # Decode the final latents into images
            with torch.autocast('cuda'):
                generated_images = vae.decode(latents / vae.config.scaling_factor, return_dict=False)[0]

            # --- 5. Save Comparison Images ---
            for j in range(batch_size):
                # Clamp and denormalize images to [0, 1] range for saving
                real = (real_images[j].unsqueeze(0) * 0.5 + 0.5).clamp(0, 1).to(device)
                generated = (generated_images[j].unsqueeze(0) * 0.5 + 0.5).clamp(0, 1).to(device)

                # Create a side-by-side grid: [Original Image | Generated Image]
                comparison_grid = torch.cat([real, generated])
                save_path = output_dir / f'comparison_{total_samples_generated + j:04d}.png'
                save_image(comparison_grid, save_path, nrow=2)

            total_samples_generated += batch_size

    print(f"\n--- Testing Complete ---")
    print(f"✅ Generated {total_samples_generated} images. Results saved in: {output_dir}")


if __name__ == '__main__':
    # Make sure to update the checkpoint path in the config class!
    config = TEST_CONFIG()
    if 'YYYYMMDD' in config.EEG_ENCODER_CHECKPOINT_PATH:
        print("🛑 PLEASE UPDATE 'EEG_ENCODER_CHECKPOINT_PATH' in the TEST_CONFIG class before running.")
    else:
        test_reconstruction(config)

# dump

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import numpy as np
import pandas as pd
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision.transforms import v2 as transforms
from PIL import Image
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from tqdm.auto import tqdm
from datetime import datetime
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import save_image
import torchvision.models as models

# ==============================================================================
# --- 1. CONFIGURATION ---
# ==============================================================================
class TRAIN_CONFIG:
    PROCESSED_DATA_ROOT = '/content/final_lightweight_17k'
    METADATA_CSV = Path(PROCESSED_DATA_ROOT) / 'metadata.csv'
    OUTPUT_DIR = '/content/drive/MyDrive/NeuroVision/end_to_end_perceptual'
    EEG_ENCODER_PRETRAIN_PATH = None
    BATCH_SIZE = 2
    NUM_EPOCHS = 100
    LR = 1e-5

    # --- Loss Weights ---
    PERCEPTUAL_LOSS_WEIGHT = 1.0

    BACKBONE_DIM = 192
    EEG_EMBED_DIM = 768
    VISUALIZATION_INTERVAL = 1
    VAL_SAMPLES_PER_EPOCH = 400
    TRAIN_SAMPLES_PER_EPOCH = 2000

# ==============================================================================
# --- 2. PERCEPTUAL LOSS MODULE (USING RESNET-18) ---
# ==============================================================================
class ResNetPerceptualLoss(nn.Module):
    def __init__(self, resize=True):
        super(ResNetPerceptualLoss, self).__init__()
        # Load pre-trained ResNet-18 and use its feature extraction layers
        resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        # We take all layers except the final classification layer (fc)
        self.model = nn.Sequential(*list(resnet.children())[:-1]).eval()

        for param in self.model.parameters():
            param.requires_grad = False

        self.transform = nn.functional.interpolate
        self.resize = resize
        # Register buffers for ImageNet normalization
        self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
        self.loss_fn = nn.MSELoss()

    def forward(self, input, target):
        # Normalize images from your [-1, 1] range to ImageNet's expected format
        input = (input + 1) / 2.0
        target = (target + 1) / 2.0
        input = (input - self.mean) / self.std
        target = (target - self.mean) / self.std

        if self.resize:
            input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
            target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)

        input_features = self.model(input)
        target_features = self.model(target)

        return self.loss_fn(input_features, target_features)

# ==============================================================================
# --- 3. EEG ENCODER MODEL ---
# ==============================================================================
class EEGEncoder(nn.Module):
    def __init__(self, backbone_dim, projection_dim):
        super().__init__()
        self.backbone = timm.create_model('vit_tiny_patch16_224', pretrained=False, in_chans=64, num_classes=0)
        self.projector = nn.Sequential(
            nn.Linear(backbone_dim, projection_dim),
            nn.GELU(),
            nn.LayerNorm(projection_dim),
            nn.Linear(projection_dim, projection_dim)
        )
    def forward(self, x):
        features = self.backbone.forward_features(x)[:, 0]
        return self.projector(features)

# ==============================================================================
# --- 4. DATASET ---
# ==============================================================================
class LightweightEEGDataset(Dataset):
    def __init__(self, root_dir, metadata_csv, split, eeg_transform, image_transform):
        self.root_dir=Path(root_dir)
        self.eeg_transform=eeg_transform
        self.image_transform=image_transform
        df=pd.read_csv(metadata_csv)
        self.split_df=df[df['split'].str.strip()==split].reset_index(drop=True)

    def __len__(self): return len(self.split_df)
    def __getitem__(self, idx):
        info=self.split_df.iloc[idx]
        spec_p=self.root_dir/info['spectrogram_path']
        img_p=self.root_dir/info['image_path']
        spec=torch.load(spec_p)
        image=Image.open(img_p).convert("RGB")
        return self.eeg_transform(spec), self.image_transform(image)

# ==============================================================================
# --- 5. MAIN TRAINING SCRIPT ---
# ==============================================================================
def train_end_to_end(config):
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    output_dir = Path(config.OUTPUT_DIR) / f'run_{timestamp}'
    output_dir.mkdir(parents=True, exist_ok=True)
    checkpoint_path = output_dir / 'eeg_encoder_best.pth'
    print(f"Using device: {device}. Checkpoints will be saved to {output_dir}")

    # --- Load Models ---
    print("1. Loading models...")
    vae = AutoencoderKL.from_pretrained("segmind/tiny-sd", subfolder="vae").to(device).eval()
    unet = UNet2DConditionModel.from_pretrained("segmind/tiny-sd", subfolder="unet").to(device).eval()
    for p in vae.parameters(): p.requires_grad = False
    for p in unet.parameters(): p.requires_grad = False

    eeg_encoder = EEGEncoder(config.BACKBONE_DIM, config.EEG_EMBED_DIM).to(device)
    if config.EEG_ENCODER_PRETRAIN_PATH:
        try:
            print(f"   Loading pre-trained weights from {config.EEG_ENCODER_PRETRAIN_PATH}")
            pretrain_ckpt = torch.load(config.EEG_ENCODER_PRETRAIN_PATH, map_location=device)
            eeg_encoder.load_state_dict(pretrain_ckpt['eeg_encoder_state_dict'])
            print("   ✅ Successfully loaded pre-trained EEG Encoder weights.")
        except Exception as e:
            print(f"   ⚠️ Could not load pre-trained weights: {e}. Starting from scratch.")

    # --- Data Prep ---
    print("2. Preparing data...")
    spec_mean = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_mean.pt')
    spec_std = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_std.pt')

    eeg_transform = transforms.Compose([transforms.ToDtype(torch.float32), transforms.Resize((224,224)), transforms.Normalize(mean=spec_mean.tolist(),std=spec_std.tolist())])
    image_transform = transforms.Compose([transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

    train_dataset = LightweightEEGDataset(config.PROCESSED_DATA_ROOT, config.METADATA_CSV, 'train', eeg_transform, image_transform)
    val_dataset = LightweightEEGDataset(config.PROCESSED_DATA_ROOT, config.METADATA_CSV, 'val', eeg_transform, image_transform)

    val_indices = np.random.choice(len(val_dataset), min(len(val_dataset), config.VAL_SAMPLES_PER_EPOCH), replace=False)
    val_sampler = SubsetRandomSampler(val_indices)
    val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, sampler=val_sampler, num_workers=0)

    fixed_spectrograms, fixed_real_images = next(iter(val_loader)); fixed_spectrograms,fixed_real_images = fixed_spectrograms.to(device),fixed_real_images.to(device)
    save_image(fixed_real_images*0.5+0.5, output_dir/'real_samples.png', nrow=config.BATCH_SIZE)

    # --- Optimizer, Loss, etc. ---
    print("3. Setting up optimizer and loss...")
    optimizer = optim.AdamW(eeg_encoder.parameters(), lr=config.LR, weight_decay=5e-4)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=1, eta_min=5e-7)
    scaler = torch.amp.GradScaler('cuda')
    noise_scheduler = DDPMScheduler.from_pretrained("segmind/tiny-sd", subfolder="scheduler")

    mse_loss_fn = nn.MSELoss()
    perceptual_loss_fn = ResNetPerceptualLoss().to(device)

    best_val_loss = float('inf')
    print("\n--- Starting End-to-End Training ---")

    for epoch in range(config.NUM_EPOCHS):
        eeg_encoder.train()
        train_indices = np.random.choice(len(train_dataset), min(len(train_dataset), config.TRAIN_SAMPLES_PER_EPOCH), replace=False)
        train_sampler = SubsetRandomSampler(train_indices)
        train_loader_subset = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, sampler=train_sampler, num_workers=0, drop_last=True)

        train_pbar = tqdm(train_loader_subset, desc=f"Epoch {epoch+1} [Train]")
        for spectrograms, real_images in train_pbar:
            spectrograms, real_images = spectrograms.to(device), real_images.to(device)
            optimizer.zero_grad()

            with torch.no_grad():
                latents = vae.encode(real_images).latent_dist.sample() * vae.config.scaling_factor
            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            with torch.amp.autocast('cuda'):
                eeg_embeddings = eeg_encoder(spectrograms)
                eeg_cond = eeg_embeddings.unsqueeze(1)

                noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=eeg_cond).sample

                alpha_prod_t = noise_scheduler.alphas_cumprod.to(device)[timesteps]
                sqrt_alpha_prod_t = alpha_prod_t.sqrt().view(-1, 1, 1, 1)
                sqrt_one_minus_alpha_prod_t = (1 - alpha_prod_t).sqrt().view(-1, 1, 1, 1)
                pred_latents = (noisy_latents - sqrt_one_minus_alpha_prod_t * noise_pred) / sqrt_alpha_prod_t
                generated_images = vae.decode(pred_latents / vae.config.scaling_factor, return_dict=False)[0]

                loss_mse = mse_loss_fn(noise_pred, noise)
                loss_perceptual = perceptual_loss_fn(generated_images, real_images)

                loss = loss_mse + (loss_perceptual * config.PERCEPTUAL_LOSS_WEIGHT)

            train_pbar.set_postfix({'loss_mse': f'{loss_mse:.3f}', 'loss_percept': f'{loss_perceptual:.3f}', 'loss_total': f'{loss:.3f}'})

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        # --- Validation Loop ---
        eeg_encoder.eval()
        total_val_loss = 0
        with torch.no_grad():

            val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]")
            for spectrograms, real_images in val_pbar:
                spectrograms, real_images = spectrograms.to(device), real_images.to(device)
                latents = vae.encode(real_images).latent_dist.sample() * vae.config.scaling_factor
                noise = torch.randn_like(latents)
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                with torch.amp.autocast('cuda'):
                    eeg_embeddings = eeg_encoder(spectrograms)
                    eeg_cond = eeg_embeddings.unsqueeze(1)

                    noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=eeg_cond).sample

                    alpha_prod_t = noise_scheduler.alphas_cumprod.to(device)[timesteps]
                    sqrt_alpha_prod_t = alpha_prod_t.sqrt().view(-1, 1, 1, 1)
                    sqrt_one_minus_alpha_prod_t = (1 - alpha_prod_t).sqrt().view(-1, 1, 1, 1)
                    pred_latents = (noisy_latents - sqrt_one_minus_alpha_prod_t * noise_pred) / sqrt_alpha_prod_t
                    generated_images = vae.decode(pred_latents / vae.config.scaling_factor, return_dict=False)[0]

                    loss_mse = mse_loss_fn(noise_pred, noise)
                    loss_perceptual = perceptual_loss_fn(generated_images, real_images)

                    loss = loss_mse + (loss_perceptual * config.PERCEPTUAL_LOSS_WEIGHT)
                    total_val_loss += loss.item()

                val_pbar.set_postfix({'loss_mse': f'{loss_mse:.3f}', 'loss_percept': f'{loss_perceptual:.3f}', 'loss_total': f'{loss:.3f}'})

        avg_val_loss = total_val_loss / len(val_loader) if len(val_loader) > 0 else 0
        scheduler.step()
        print(f"Epoch {epoch+1}/{config.NUM_EPOCHS} -> Val Loss: {avg_val_loss:.4f}, LR: {scheduler.get_last_lr()[0]:.2e}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            print(f"✨ New best validation loss. Saving EEG Encoder to {checkpoint_path}")
            torch.save({'eeg_encoder_state_dict': eeg_encoder.state_dict()}, checkpoint_path)

        # --- Visualization ---
        if (epoch + 1) % config.VISUALIZATION_INTERVAL == 0:
            print(f"--- Generating reconstructions for epoch {epoch+1} ---")
            eeg_encoder.eval()
            with torch.no_grad():
                eeg_embeddings = eeg_encoder(fixed_spectrograms)
                eeg_cond = eeg_embeddings.unsqueeze(1)

                latents = torch.randn((fixed_spectrograms.shape[0], 4, 64, 64), device=device, dtype=torch.float16)
                noise_scheduler.set_timesteps(50)

                for t in tqdm(noise_scheduler.timesteps, desc="Generating Images"):
                    with torch.amp.autocast('cuda'):
                        noise_pred = unet(latents, t, encoder_hidden_states=eeg_cond.to(latents.dtype)).sample
                    latents = noise_scheduler.step(noise_pred, t, latents).prev_sample

                    with torch.amp.autocast('cuda'):
                        generated_images = vae.decode(latents / vae.config.scaling_factor, return_dict=False)[0]

                comp_grid = torch.cat([fixed_real_images, generated_images.to(torch.float32)])
                comp_grid = (comp_grid * 0.5 + 0.5).clamp(0, 1)
                save_path = output_dir / f'reconstructions_epoch_{epoch+1:03d}.png'
                save_image(comp_grid, save_path, nrow=config.BATCH_SIZE)
                print(f"Saved reconstructions to {save_path}")

    print("\n--- Training Complete ---")

if __name__ == '__main__':
    train_end_to_end(TRAIN_CONFIG())


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import numpy as np
import pandas as pd
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision.transforms import v2 as transforms
from PIL import Image
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from tqdm.auto import tqdm
from datetime import datetime
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import save_image
from torchmetrics.image import StructuralSimilarityIndexMeasure

# ==============================================================================
# --- 1. CONFIGURATION ---
# ==============================================================================
class TRAIN_CONFIG:
    # --- Paths ---
    PROCESSED_DATA_ROOT = '/content/final_lightweight_17k'
    METADATA_CSV = Path(PROCESSED_DATA_ROOT) / 'metadata.csv'
    OUTPUT_DIR = '/content/drive/MyDrive/NeuroVision/end_to_end_training'

    # --- Pre-trained Models (if any) ---
    EEG_ENCODER_PRETRAIN_PATH = '/content/drive/MyDrive/NeuroVision/end_to_end_training/run_20250922_132618/eeg_encoder_best.pth'

    # --- Training Hyperparameters ---
    BATCH_SIZE = 3
    NUM_EPOCHS = 100
    LR = 1e-5

    # --- Loss Weights ---
    SSIM_LOSS_WEIGHT = 0.5
    MSE_LOSS_WEIGHT = 1.5

    # --- Model Dimensions ---
    BACKBONE_DIM = 192
    EEG_EMBED_DIM = 768

    # --- Other ---
    VISUALIZATION_INTERVAL = 1
    VAL_SAMPLES_PER_EPOCH = 280
    TRAIN_SAMPLES_PER_EPOCH = 1400

# ==============================================================================
# --- 2. EEG ENCODER MODEL (WITH IMPROVED PROJECTOR) ---
# ==============================================================================
class EEGEncoder(nn.Module):
    def __init__(self, backbone_dim, projection_dim):
        super().__init__()
        self.backbone = timm.create_model('vit_tiny_patch16_224', pretrained=False, in_chans=64, num_classes=0)

        self.projector = nn.Sequential(
            nn.Linear(backbone_dim, projection_dim),
            nn.GELU(),
            nn.LayerNorm(projection_dim),
            nn.Linear(projection_dim, projection_dim)
        )

    def forward(self, x):
        features = self.backbone.forward_features(x)[:, 0]
        embeddings = self.projector(features)
        return embeddings

# ==============================================================================
# --- 3. DATASET ---
# ==============================================================================
class LightweightEEGDataset(Dataset):
    def __init__(self, root_dir, metadata_csv, split, eeg_transform, image_transform):
        self.root_dir=Path(root_dir)
        self.eeg_transform=eeg_transform
        self.image_transform=image_transform
        df=pd.read_csv(metadata_csv)
        self.split_df=df[df['split'].str.strip()==split].reset_index(drop=True)

    def __len__(self):
        return len(self.split_df)

    def __getitem__(self, idx):
        info=self.split_df.iloc[idx]
        spec_p=self.root_dir/info['spectrogram_path']
        img_p=self.root_dir/info['image_path']
        spec=torch.load(spec_p)
        image=Image.open(img_p).convert("RGB")
        return self.eeg_transform(spec), self.image_transform(image)

# ==============================================================================
# --- 4. MAIN TRAINING SCRIPT ---
# ==============================================================================
def train_end_to_end(config):
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    output_dir = Path(config.OUTPUT_DIR) / f'run_{timestamp}'
    output_dir.mkdir(parents=True, exist_ok=True)
    checkpoint_path = output_dir / 'eeg_encoder_best.pth'
    print(f"Using device: {device}. Checkpoints will be saved to {output_dir}")

    # --- Load Frozen Diffusion Models ---
    print("1. Loading frozen VAE and UNet...")
    vae = AutoencoderKL.from_pretrained("segmind/tiny-sd", subfolder="vae").to(device).eval()
    unet = UNet2DConditionModel.from_pretrained("segmind/tiny-sd", subfolder="unet").to(device).eval()
    for p in vae.parameters(): p.requires_grad = False
    for p in unet.parameters(): p.requires_grad = False

    # --- Load Trainable EEG Encoder ---
    print("2. Initializing EEG Encoder...")
    eeg_encoder = EEGEncoder(config.BACKBONE_DIM, config.EEG_EMBED_DIM).to(device)
    if config.EEG_ENCODER_PRETRAIN_PATH:
        try:
            print(f"Loading pre-trained weights from {config.EEG_ENCODER_PRETRAIN_PATH}")
            pretrain_ckpt = torch.load(config.EEG_ENCODER_PRETRAIN_PATH, map_location=device)
            eeg_encoder.load_state_dict(pretrain_ckpt['eeg_encoder_state_dict'])
            print("✅ Successfully loaded pre-trained EEG Encoder weights.")

        except Exception as e:
            print(f"⚠️ Could not load pre-trained weights: {e}. Starting from scratch.")

    # --- Data Prep ---
    print("3. Preparing data...")
    spec_mean = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_mean.pt')
    spec_std = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_std.pt')
    eeg_transform = transforms.Compose([transforms.ToDtype(torch.float32), transforms.Resize((224,224)), transforms.Normalize(mean=spec_mean.tolist(),std=spec_std.tolist())])
    image_transform = transforms.Compose([transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

    train_dataset = LightweightEEGDataset(config.PROCESSED_DATA_ROOT, config.METADATA_CSV, 'train', eeg_transform, image_transform)
    val_dataset = LightweightEEGDataset(config.PROCESSED_DATA_ROOT, config.METADATA_CSV, 'val', eeg_transform, image_transform)
    val_indices = np.random.choice(len(val_dataset), min(len(val_dataset), config.VAL_SAMPLES_PER_EPOCH), replace=False)
    val_sampler = SubsetRandomSampler(val_indices)
    val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, sampler=val_sampler, num_workers=0)
    fixed_spectrograms, fixed_real_images = next(iter(val_loader)); fixed_spectrograms,fixed_real_images = fixed_spectrograms.to(device),fixed_real_images.to(device)
    save_image(fixed_real_images*0.5+0.5, output_dir/'real_samples.png', nrow=config.BATCH_SIZE)


    # --- Optimizer, Loss, etc. ---
    print("4. Setting up optimizer and loss...")
    optimizer = optim.AdamW(eeg_encoder.parameters(), lr=config.LR, weight_decay=1e-3)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=1, eta_min=1e-7)
    scaler = torch.amp.GradScaler('cuda')
    noise_scheduler = DDPMScheduler.from_pretrained("segmind/tiny-sd", subfolder="scheduler")
    mse_loss_fn = nn.MSELoss()
    ssim_module = StructuralSimilarityIndexMeasure(data_range=2.0).to(device)
    best_val_loss = float('inf')

    print("\n--- Starting End-to-End Training ---")
    for epoch in range(config.NUM_EPOCHS):
        eeg_encoder.train()
        train_indices = np.random.choice(len(train_dataset), min(len(train_dataset), config.TRAIN_SAMPLES_PER_EPOCH), replace=False)
        train_sampler = SubsetRandomSampler(train_indices)
        train_loader_subset = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, sampler=train_sampler, num_workers=0, drop_last=True)

        train_pbar = tqdm(train_loader_subset, desc=f"Epoch {epoch+1} [Train]")

        for spectrograms, real_images in train_pbar:
            # --- FIX: Move both inputs to the correct device ---
            spectrograms, real_images = spectrograms.to(device), real_images.to(device)
            optimizer.zero_grad()

            with torch.no_grad():
                latents = vae.encode(real_images).latent_dist.sample() * vae.config.scaling_factor
            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            with torch.amp.autocast('cuda'):
                eeg_embeddings = eeg_encoder(spectrograms)
                eeg_cond = eeg_embeddings.unsqueeze(1)

                # with torch.no_grad():
                noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=eeg_cond).sample

                alpha_prod_t = noise_scheduler.alphas_cumprod.to(device)[timesteps]
                sqrt_alpha_prod_t = alpha_prod_t.sqrt().view(-1, 1, 1, 1)
                sqrt_one_minus_alpha_prod_t = (1 - alpha_prod_t).sqrt().view(-1, 1, 1, 1)
                pred_latents = (noisy_latents - sqrt_one_minus_alpha_prod_t * noise_pred) / sqrt_alpha_prod_t

                # with torch.no_grad():
                generated_images = vae.decode(pred_latents / vae.config.scaling_factor, return_dict=False)[0]

                loss_mse = mse_loss_fn(noise_pred, noise)
                ssim_val = ssim_module(generated_images, real_images)
                loss_ssim = 1.0 - ssim_val
                loss = (loss_mse * config.MSE_LOSS_WEIGHT) + (loss_ssim * config.SSIM_LOSS_WEIGHT)

            train_pbar.set_postfix({'loss_mse': f'{loss_mse:.3f}', 'loss_ssim': f'{loss_ssim:.3f}', 'tot_loss': f'{loss:.3f}'})

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        # --- Validation ---
        eeg_encoder.eval()
        total_val_loss = 0
        with torch.no_grad():

            val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]")

            for spectrograms, real_images in val_pbar:
                spectrograms, real_images = spectrograms.to(device), real_images.to(device)
                latents = vae.encode(real_images).latent_dist.sample() * vae.config.scaling_factor
                noise = torch.randn_like(latents)
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                eeg_embeddings = eeg_encoder(spectrograms)
                eeg_cond = eeg_embeddings.unsqueeze(1)

                noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=eeg_cond).sample

                alpha_prod_t = noise_scheduler.alphas_cumprod.to(device)[timesteps]
                sqrt_alpha_prod_t = alpha_prod_t.sqrt().view(-1, 1, 1, 1)
                sqrt_one_minus_alpha_prod_t = (1 - alpha_prod_t).sqrt().view(-1, 1, 1, 1)
                pred_latents = (noisy_latents - sqrt_one_minus_alpha_prod_t * noise_pred) / sqrt_alpha_prod_t

                generated_images = vae.decode(pred_latents / vae.config.scaling_factor, return_dict=False)[0]

                loss_mse = mse_loss_fn(noise_pred, noise)
                ssim_val = ssim_module(generated_images, real_images)
                loss_ssim = 1.0 - ssim_val
                loss = (loss_mse * config.MSE_LOSS_WEIGHT) + (loss_ssim * config.SSIM_LOSS_WEIGHT)
                total_val_loss += loss.item()

                val_pbar.set_postfix({'loss_mse': f'{loss_mse:.3f}', 'loss_ssim': f'{loss_ssim:.3f}', 'tot_loss': f'{loss:.3f}'})

        avg_val_loss = total_val_loss / len(val_loader) if len(val_loader) > 0 else 0
        scheduler.step() # Step the scheduler each epoch
        print(f"Epoch {epoch+1}/{config.NUM_EPOCHS} -> Val Loss: {avg_val_loss:.4f}, LR: {scheduler.get_last_lr()[0]:.2e}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            print(f"✨ New best validation loss. Saving EEG Encoder to {checkpoint_path}")
            torch.save({'eeg_encoder_state_dict': eeg_encoder.state_dict()}, checkpoint_path)

        # --- Visualization ---
        if (epoch + 1) % config.VISUALIZATION_INTERVAL == 0:
            print(f"--- Generating reconstructions for epoch {epoch+1} ---")
            eeg_encoder.eval()
            with torch.no_grad():
                eeg_embeddings = eeg_encoder(fixed_spectrograms)
                eeg_cond = eeg_embeddings.unsqueeze(1)

                latents = torch.randn((fixed_spectrograms.shape[0], 4, 64, 64), device=device, dtype=torch.float16)
                noise_scheduler.set_timesteps(50)
                print('begin image gen')

                for t in tqdm(noise_scheduler.timesteps, desc="Generating Images"):
                  with torch.amp.autocast('cuda'):
                    noise_pred = unet(latents, t, encoder_hidden_states=eeg_cond.to(latents.dtype)).sample
                  latents = noise_scheduler.step(noise_pred, t, latents).prev_sample

                print('pass thru vae for img gen')
                with torch.amp.autocast('cuda'):
                  generated_images = vae.decode(latents / vae.config.scaling_factor, return_dict=False)[0]
                print('create img grid')
                comp_grid = torch.cat([fixed_real_images, generated_images.to(torch.float32)])
                comp_grid = (comp_grid * 0.5 + 0.5).clamp(0, 1)
                save_path = output_dir / f'reconstructions_epoch_{epoch+1:03d}.png'
                save_image(comp_grid, save_path, nrow=config.BATCH_SIZE)
                print(f"Saved reconstructions to {save_path}")

    print("\n--- Training Complete ---")

if __name__ == '__main__':
    train_end_to_end(TRAIN_CONFIG())


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import numpy as np
import pandas as pd
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision.transforms import v2 as transforms
from PIL import Image
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from tqdm.auto import tqdm
from datetime import datetime
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import save_image
import torchvision.models as models
import clip

# ==============================================================================
# --- 1. CONFIGURATION ---
# ==============================================================================
class TRAIN_CONFIG:
    PROCESSED_DATA_ROOT = '/content/final_lightweight_17k'
    METADATA_CSV = Path(PROCESSED_DATA_ROOT) / 'metadata.csv'
    OUTPUT_DIR = '/content/drive/MyDrive/NeuroVision/end_to_end_clip'
    EEG_ENCODER_PRETRAIN_PATH = None #'/content/drive/MyDrive/NeuroVision/contrastive_scratch_outputs/run_20250917_221600/contrastive_checkpoint_best.pth' #'/content/drive/MyDrive/NeuroVision/end_to_end_clip/run_20250923_101956/eeg_encoder_best.pth' # Optional: Start with a previously trained encoder
    BATCH_SIZE = 2 # Lowered for increased memory usage
    NUM_EPOCHS = 100
    LR = 1e-5

    # --- Loss Weights ---
    PERCEPTUAL_LOSS_WEIGHT = 0.75
    CLIP_LOSS_WEIGHT = 0.2 # New, powerful semantic loss
    MSE_LOSS_WEIGHT = 2.5

    BACKBONE_DIM = 192
    EEG_EMBED_DIM = 768 # For UNet
    CLIP_EMBED_DIM = 512 # For CLIP

    VISUALIZATION_INTERVAL = 1
    VAL_SAMPLES_PER_EPOCH = 400
    TRAIN_SAMPLES_PER_EPOCH = 2000

# ==============================================================================
# --- 2. CLIP LOSS MODULE ---
# ==============================================================================
class CLIPLoss(nn.Module):
    def __init__(self, device):
        super(CLIPLoss, self).__init__()
        self.model, self.preprocess = clip.load("ViT-B/32", device=device)
        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False

        self.clip_transform = transforms.Compose([
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
            transforms.CenterCrop(224),
            transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
        ])

    def forward(self, eeg_embedding, image):
        # Image is in [-1, 1], preprocess expects [0, 1]
        image_for_clip = (image + 1) / 2.0

        preprocessed_image = self.clip_transform(image_for_clip)
        # CLIP's preprocess includes normalization
        image_features = self.model.encode_image(preprocessed_image)

        # L2 normalize both embeddings
        eeg_embedding = F.normalize(eeg_embedding, p=2, dim=-1)
        image_features = F.normalize(image_features, p=2, dim=-1)

        # Calculate cosine similarity loss
        return (1 - F.cosine_similarity(eeg_embedding, image_features)).mean()

# ==============================================================================
# --- 3. PERCEPTUAL LOSS (ResNet) ---
# ==============================================================================
class ResNetPerceptualLoss(nn.Module):
    def __init__(self, resize=True):
        super(ResNetPerceptualLoss, self).__init__()
        resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.model = nn.Sequential(*list(resnet.children())[:-1]).eval()
        for param in self.model.parameters(): param.requires_grad = False
        self.transform = nn.functional.interpolate
        self.resize = resize
        self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
        self.loss_fn = nn.MSELoss()

    def forward(self, input, target):
        input = (input + 1) / 2.0; target = (target + 1) / 2.0
        input = (input - self.mean) / self.std; target = (target - self.mean) / self.std
        if self.resize:
            input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
            target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
        input_features = self.model(input); target_features = self.model(target)
        return self.loss_fn(input_features, target_features)

# ==============================================================================
# --- 4. EEG ENCODER (with two heads) ---
# ==============================================================================
class EEGEncoder(nn.Module):
    def __init__(self, backbone_dim, unet_dim, clip_dim):
        super().__init__()
        self.backbone = timm.create_model('vit_tiny_patch16_224', pretrained=False, in_chans=64, num_classes=0)

        # Head A: For guiding the UNet
        self.unet_projector = nn.Sequential(
            nn.Linear(backbone_dim, unet_dim),
            nn.GELU(),
            nn.LayerNorm(unet_dim),
            nn.Linear(unet_dim, unet_dim)
        )

        # Head B: For matching CLIP's semantic space
        self.clip_projector = nn.Sequential(
            nn.Linear(backbone_dim, clip_dim * 2),
            nn.ReLU(),
            nn.Linear(clip_dim * 2, clip_dim)
        )

    def forward(self, x):
        features = self.backbone.forward_features(x)[:, 0]
        unet_embedding = self.unet_projector(features)
        clip_embedding = self.clip_projector(features)
        return unet_embedding, clip_embedding

# ==============================================================================
# --- 5. DATASET ---
# ==============================================================================
class LightweightEEGDataset(Dataset):
    def __init__(self, root_dir, metadata_csv, split, eeg_transform, image_transform):
        self.root_dir=Path(root_dir)
        self.eeg_transform=eeg_transform
        self.image_transform=image_transform
        df=pd.read_csv(metadata_csv)
        self.split_df=df[df['split'].str.strip()==split].reset_index(drop=True)

    def __len__(self): return len(self.split_df)
    def __getitem__(self, idx):
        info=self.split_df.iloc[idx]
        spec_p=self.root_dir/info['spectrogram_path']
        img_p=self.root_dir/info['image_path']
        spec=torch.load(spec_p)
        image=Image.open(img_p).convert("RGB")
        return self.eeg_transform(spec), self.image_transform(image)

# ==============================================================================
# --- 6. MAIN TRAINING SCRIPT ---
# ==============================================================================
def train_end_to_end_clip(config):
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    output_dir = Path(config.OUTPUT_DIR) / f'run_{timestamp}'
    output_dir.mkdir(parents=True, exist_ok=True)
    checkpoint_path = output_dir / 'eeg_encoder_best.pth'
    writer = SummaryWriter(log_dir=str(output_dir / 'logs'))
    print(f"Using device: {device}. Checkpoints will be saved to {output_dir}")

    print("1. Loading models...")
    vae = AutoencoderKL.from_pretrained("segmind/tiny-sd", subfolder="vae").to(device).eval()
    unet = UNet2DConditionModel.from_pretrained("segmind/tiny-sd", subfolder="unet").to(device).eval()
    for p in vae.parameters(): p.requires_grad = False
    for p in unet.parameters(): p.requires_grad = False

    eeg_encoder = EEGEncoder(config.BACKBONE_DIM, config.EEG_EMBED_DIM, config.CLIP_EMBED_DIM).to(device)
    if config.EEG_ENCODER_PRETRAIN_PATH:
        try:
            print(f"Loading pre-trained weights from {config.EEG_ENCODER_PRETRAIN_PATH}")
            pretrain_ckpt = torch.load(config.EEG_ENCODER_PRETRAIN_PATH, map_location=device)
            eeg_encoder.backbone.load_state_dict(pretrain_ckpt['eeg_encoder_state_dict'])
            print("✅ Successfully loaded pre-trained EEG Encoder weights.")
        except Exception as e:
            print(f"⚠️ Could not load pre-trained weights: {e}. Starting from scratch.")

    print("2. Preparing data...")
    spec_mean = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_mean.pt')
    spec_std = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_std.pt')

    eeg_transform = transforms.Compose([transforms.ToDtype(torch.float32), transforms.Resize((224,224)), transforms.Normalize(mean=spec_mean.tolist(),std=spec_std.tolist())])
    image_transform = transforms.Compose([transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

    train_dataset = LightweightEEGDataset(config.PROCESSED_DATA_ROOT, config.METADATA_CSV, 'train', eeg_transform, image_transform)
    val_dataset = LightweightEEGDataset(config.PROCESSED_DATA_ROOT, config.METADATA_CSV, 'val', eeg_transform, image_transform)

    val_indices = np.random.choice(len(val_dataset), min(len(val_dataset), config.VAL_SAMPLES_PER_EPOCH), replace=False)
    val_sampler = SubsetRandomSampler(val_indices)
    val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, sampler=val_sampler, num_workers=0)

    fixed_spectrograms, fixed_real_images = next(iter(val_loader)); fixed_spectrograms,fixed_real_images = fixed_spectrograms.to(device),fixed_real_images.to(device)
    save_image(fixed_real_images*0.5+0.5, output_dir/'real_samples.png', nrow=config.BATCH_SIZE)

    print("3. Setting up optimizer and loss...")
    optimizer = optim.AdamW(eeg_encoder.parameters(), lr=config.LR, weight_decay=1e-5)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=1e-7)
    scaler = torch.amp.GradScaler('cuda')
    noise_scheduler = DDPMScheduler.from_pretrained("segmind/tiny-sd", subfolder="scheduler")

    mse_loss_fn = nn.MSELoss()
    perceptual_loss_fn = ResNetPerceptualLoss().to(device)
    # clip_loss_fn = CLIPLoss(device)

    best_val_loss = float('inf')
    print("\n--- Starting End-to-End Training with CLIP Guidance ---")

    for epoch in range(config.NUM_EPOCHS):
        eeg_encoder.train()
        train_indices = np.random.choice(len(train_dataset), min(len(train_dataset), config.TRAIN_SAMPLES_PER_EPOCH), replace=False)
        train_sampler = SubsetRandomSampler(train_indices)
        train_loader_subset = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, sampler=train_sampler, num_workers=0, drop_last=True)

        train_bar = tqdm(train_loader_subset, desc=f"Epoch {epoch+1} [Train]")

        for spectrograms, real_images in train_bar:
            spectrograms, real_images = spectrograms.to(device), real_images.to(device)
            optimizer.zero_grad()

            with torch.no_grad():
                latents = vae.encode(real_images).latent_dist.sample() * vae.config.scaling_factor
            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            with torch.amp.autocast('cuda'):
                unet_embedding, clip_embedding = eeg_encoder(spectrograms)
                eeg_cond = unet_embedding.unsqueeze(1)

                noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=eeg_cond).sample

                alpha_prod_t = noise_scheduler.alphas_cumprod.to(device)[timesteps]
                sqrt_alpha_prod_t = alpha_prod_t.sqrt().view(-1, 1, 1, 1)
                sqrt_one_minus_alpha_prod_t = (1 - alpha_prod_t).sqrt().view(-1, 1, 1, 1)
                pred_latents = (noisy_latents - sqrt_one_minus_alpha_prod_t * noise_pred) / sqrt_alpha_prod_t
                generated_images = vae.decode(pred_latents / vae.config.scaling_factor, return_dict=False)[0]

                # --- NEW COMBINED LOSS CALCULATION ---
                loss_mse = mse_loss_fn(noise_pred, noise)
                loss_perceptual = perceptual_loss_fn(generated_images, real_images)
                loss_clip = clip_loss_fn(clip_embedding, real_images)

                # Combine the losses
                loss = (loss_mse * config.MSE_LOSS_WEIGHT) + (loss_perceptual * config.PERCEPTUAL_LOSS_WEIGHT) + (loss_clip * config.CLIP_LOSS_WEIGHT)

            train_bar.set_postfix({'loss_mse': f'{loss_mse:.3f}', 'loss_percep': f'{loss_perceptual:.3f}', 'loss_clip': f'{loss_clip:.3f}', 'loss_total': f'{loss:.3f}'})

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        # --- Validation Loop ---
        eeg_encoder.eval()
        total_val_loss = 0
        with torch.no_grad():
            val_bar = tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]")

            for spectrograms, real_images in val_bar:
                spectrograms, real_images = spectrograms.to(device), real_images.to(device)

                with torch.amp.autocast('cuda'):
                    latents = vae.encode(real_images).latent_dist.sample() * vae.config.scaling_factor
                    noise = torch.randn_like(latents)
                    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()
                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                    unet_embedding, clip_embedding = eeg_encoder(spectrograms)
                    eeg_cond = unet_embedding.unsqueeze(1)

                    noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=eeg_cond).sample

                    alpha_prod_t = noise_scheduler.alphas_cumprod.to(device)[timesteps]
                    sqrt_alpha_prod_t = alpha_prod_t.sqrt().view(-1, 1, 1, 1)
                    sqrt_one_minus_alpha_prod_t = (1 - alpha_prod_t).sqrt().view(-1, 1, 1, 1)
                    pred_latents = (noisy_latents - sqrt_one_minus_alpha_prod_t * noise_pred) / sqrt_alpha_prod_t
                    generated_images = vae.decode(pred_latents / vae.config.scaling_factor, return_dict=False)[0]

                    loss_mse = mse_loss_fn(noise_pred, noise)
                    loss_perceptual = perceptual_loss_fn(generated_images, real_images)
                    loss_clip = clip_loss_fn(clip_embedding, real_images)

                    loss = (loss_mse * config.MSE_LOSS_WEIGHT) + (loss_perceptual * config.PERCEPTUAL_LOSS_WEIGHT) + (loss_clip * config.CLIP_LOSS_WEIGHT)
                    total_val_loss += loss.item()

                val_bar.set_postfix({'loss_mse': f'{loss_mse:.3f}', 'loss_percep': f'{loss_perceptual:.3f}', 'loss_clip': f'{loss_clip:.3f}', 'loss_total': f'{loss:.3f}'})

        avg_val_loss = total_val_loss / len(val_loader) if len(val_loader) > 0 else 0
        scheduler.step()
        print(f"Epoch {epoch+1}/{config.NUM_EPOCHS} -> Val Loss: {avg_val_loss:.4f}, LR: {scheduler.get_last_lr()[0]:.2e}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            print(f"✨ New best validation loss. Saving EEG Encoder to {checkpoint_path}")
            torch.save({'eeg_encoder_state_dict': eeg_encoder.state_dict()}, checkpoint_path)

        # --- Visualization ---
        if (epoch + 1) % config.VISUALIZATION_INTERVAL == 0:
            print(f"--- Generating reconstructions for epoch {epoch+1} ---")
            eeg_encoder.eval()
            with torch.no_grad():
                # For visualization, we only need the UNet embedding
                unet_embedding, _ = eeg_encoder(fixed_spectrograms)
                eeg_cond = unet_embedding.unsqueeze(1)

                latents = torch.randn((fixed_spectrograms.shape[0], 4, 64, 64), device=device, dtype=torch.float16)
                noise_scheduler.set_timesteps(50)

                for t in tqdm(noise_scheduler.timesteps, desc="Generating Images"):
                    with torch.amp.autocast('cuda'):
                        noise_pred = unet(latents, t, encoder_hidden_states=eeg_cond.to(latents.dtype)).sample
                    latents = noise_scheduler.step(noise_pred, t, latents).prev_sample

                    with torch.amp.autocast('cuda'):
                        generated_images = vae.decode(latents / vae.config.scaling_factor, return_dict=False)[0]

                comp_grid = torch.cat([fixed_real_images, generated_images.to(torch.float32)])
                comp_grid = (comp_grid * 0.5 + 0.5).clamp(0, 1)
                save_path = output_dir / f'reconstructions_epoch_{epoch+1:03d}.png'
                save_image(comp_grid, save_path, nrow=config.BATCH_SIZE)
                print(f"Saved reconstructions to {save_path}")

    print("\n--- Training Complete ---")

if __name__ == '__main__':
    train_end_to_end_clip(TRAIN_CONFIG())
