<a href="https://colab.research.google.com/github/Ajinkya-18/NeuroVision/blob/main/eeg_text_image.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [1]:
import os
os.makedirs('/content/final_lightweight_17k', exist_ok=True)

In [2]:
!cp /content/drive/MyDrive/NeuroVision/alljoined_lightweight_17k.tar.gz /content/

In [3]:
!tar -xzf /content/alljoined_lightweight_17k.tar.gz -C /content/final_lightweight_17k

In [4]:
import torch
from pathlib import Path
from tqdm.auto import tqdm

# Path to your lightweight dataset on the local Colab disk
DATA_ROOT = '/content/final_lightweight_17k'
SPECTROGRAM_TRAIN_DIR = Path(DATA_ROOT) / 'spectrograms' / 'train'

# Initialize accumulators
channel_sum = torch.zeros(64, dtype=torch.float64)
channel_sum_sq = torch.zeros(64, dtype=torch.float64)
pixel_count = 0

# Calculate stats only on the training set
files = list(SPECTROGRAM_TRAIN_DIR.glob('*.pt'))
for path in tqdm(files, desc="Calculating Spectrogram Stats"):
    data = torch.load(path)
    channel_sum += data.sum(dim=[1, 2]).to(torch.float64)
    channel_sum_sq += (data.to(torch.float64) ** 2).sum(dim=[1, 2])
    pixel_count += data.shape[1] * data.shape[2]

mean = (channel_sum / pixel_count).to(torch.float32)
std = torch.sqrt((channel_sum_sq / pixel_count) - mean.to(torch.float64)**2).to(torch.float32)

# Save the stats to the dataset folder
torch.save(mean, Path(DATA_ROOT) / 'spec_mean.pt')
torch.save(std, Path(DATA_ROOT) / 'spec_std.pt')

print(f"\nStats calculated and saved to {DATA_ROOT}")

Calculating Spectrogram Stats:   0%|          | 0/14000 [00:00<?, ?it/s]


Stats calculated and saved to /content/final_lightweight_17k


In [5]:
print("⏳ Installing and upgrading libraries...")
# The -q flag makes the output cleaner
!pip install torch torchvision timm lpips transformers accelerate diffusers open-clip-torch
!pip install peft ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git

⏳ Installing and upgrading libraries...
Collecting lpips
  Downloading lpips-0.1.4-py3-none-any.whl.metadata (10 kB)
Collecting open-clip-torch
  Downloading open_clip_torch-3.2.0-py3-none-any.whl.metadata (32 kB)
Collecting ftfy (from open-clip-torch)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading lpips-0.1.4-py3-none-any.whl (53 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.8/53.8 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading open_clip_torch-3.2.0-py3-none-any.whl (1.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m40.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ftfy, lpips, open-clip-torch
Successfully installed ftfy-6.3.1 lpips-0.1.4 open-clip-torch-3.2.0
Collecting git+https://github.c

In [6]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Training

In [None]:
import torch
import pandas as pd
from pathlib import Path
from PIL import Image
from tqdm.auto import tqdm
import json
from transformers import Blip2Processor, Blip2ForConditionalGeneration

# --- CONFIGURATION ---
class CONFIG:
    PROCESSED_DATA_ROOT = '/content/final_lightweight_17k'
    METADATA_CSV = Path(PROCESSED_DATA_ROOT) / 'metadata.csv'
    BLIP_MODEL_NAME = "Salesforce/blip2-opt-2.7b"

def run_caption_generation():
    """A self-contained script to run only once."""
    config = CONFIG()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print("Loading BLIP model for caption generation...")
    processor = Blip2Processor.from_pretrained(config.BLIP_MODEL_NAME)
    caption_blip = Blip2ForConditionalGeneration.from_pretrained(config.BLIP_MODEL_NAME).to(device)

    for split in ['train', 'val']:
        captions_path = Path(config.PROCESSED_DATA_ROOT) / f'{split}_gt_captions.json'
        if captions_path.exists():
            print(f"Captions for '{split}' split already exist. Skipping.")
            continue

        print(f"Generating captions for '{split}' split...")
        df = pd.read_csv(config.METADATA_CSV)
        split_df = df[df['split'].str.strip() == split].reset_index(drop=True)

        captions = {}
        batch_size = 16
        caption_blip.eval()

        with torch.no_grad():
            for i in tqdm(range(0, len(split_df), batch_size), desc=f"Generating {split} captions"):
                batch_info = split_df.iloc[i:i+batch_size]
                batch_images = [
                    Image.open(Path(config.PROCESSED_DATA_ROOT) / info['image_path']).convert("RGB")
                    for _, info in batch_info.iterrows()
                ]

                inputs = processor(images=batch_images, return_tensors="pt").to(device)
                generated_ids = caption_blip.generate(**inputs, max_length=50)
                batch_captions = processor.batch_decode(generated_ids, skip_special_tokens=True)

                for j, caption in enumerate(batch_captions):
                    captions[str(i + j)] = caption.strip()

        with open(captions_path, 'w') as f:
            json.dump(captions, f, indent=2)
        print(f"✅ Saved {len(captions)} captions for '{split}' split.")

if __name__ == '__main__':
    run_caption_generation()

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Loading BLIP model for caption generation...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

preprocessor_config.json:   0%|          | 0.00/432 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/68.0 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/882 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/23.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/548 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/10.0G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/141 [00:00<?, ?B/s]

Generating captions for 'train' split...


Generating train captions:   0%|          | 0/875 [00:00<?, ?it/s]

✅ Saved 14000 captions for 'train' split.
Generating captions for 'val' split...


Generating val captions:   0%|          | 0/175 [00:00<?, ?it/s]

✅ Saved 2800 captions for 'val' split.


In [None]:
!apt-get install -y fonts-liberation

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
fonts-liberation is already the newest version (1:1.07.4-11).
0 upgraded, 0 newly installed, 0 to remove and 38 not upgraded.


In [None]:
import torch
import torch.nn as nn
import pandas as pd
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision.transforms import v2 as transforms
from PIL import Image
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.auto import tqdm
from datetime import datetime
import numpy as np
from transformers import Blip2Processor, Blip2ForConditionalGeneration, Blip2VisionModel
import warnings
from sentence_transformers import SentenceTransformer
from diffusers import StableDiffusionPipeline
import torch.nn.functional as F
from PIL import Image, ImageDraw, ImageFont
import json
from transformers import get_linear_schedule_with_warmup

# ==============================================================================
# --- 1. CONFIGURATION ---
# ==============================================================================
class TRAIN_CONFIG:
    PROCESSED_DATA_ROOT = '/content/final_lightweight_17k'
    METADATA_CSV = Path(PROCESSED_DATA_ROOT) / 'metadata.csv'
    OUTPUT_DIR = '/content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment'

    BLIP_MODEL_NAME = "Salesforce/blip2-opt-2.7b"

    BATCH_SIZE = 14
    GRAD_ACCUMULATION_STEPS = 2
    NUM_EPOCHS = 30
    ADAPTER_LR = 1e-5  # Higher LR for the new adapter
    BLIP_LR = 1e-6     # Lower LR for fine-tuning the pre-trained layers

    WEIGHT_DECAY = 1e-3
    GRADIENT_CLIP_NORM = 1.0

    VALIDATION_INTERVAL = 1
    VAL_SAMPLES_PER_EPOCH = 400
    TRAIN_SAMPLES_PER_EPOCH = 2000
    VIS_GRID_SIZE = 5

    DIFFUSION_MODEL_ID ="runwayml/stable-diffusion-v1-5"
    EVAL_IMAGE_GENERATION_INTERVAL = 1
    WARMUP_STEPS = 500  # Number of steps to gradually increase the LR
    TOTAL_TRAIN_STEPS = (TRAIN_SAMPLES_PER_EPOCH/BATCH_SIZE) * NUM_EPOCHS
    STAGE1_EPOCHS = 12
    # STAGE2_EPOCHS = 18
    ALIGN_WEIGHT = 1.0
    SEMANTIC_WEIGHT = 2.0
    CE_WEIGHT = 3.0
    BEST_MODEL_DIR = '/content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251001_103040/best_model_epoch_27_loss_5.0528.pth'


# ==============================================================================
# --- 2. EEG ENCODER MODULE ---
# ==============================================================================

class EEGTransformerEncoder(nn.Module):
    """
    A powerful hybrid CNN-Transformer encoder to extract rich features
    directly from 64-channel EEG spectrograms.
    """
    def __init__(self, in_chans=64, embed_dim=768, nhead=8, num_layers=4, dropout=0.3):
        super().__init__()

        # 1. Convolutional Stem: Extracts local spatio-temporal features
        self.conv_stem = nn.Sequential(
            nn.Conv2d(in_chans, embed_dim // 4, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(embed_dim // 4),
            nn.ELU(),
            nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(embed_dim // 2),
            nn.ELU(),
            nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(embed_dim),
        )

        # 2. Transformer Encoder: Learns global relationships between features
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=nhead,
            dim_feedforward=embed_dim * 4,
            dropout=dropout,
            activation='gelu',
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.output_layernorm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # Input x: [Batch, 64, Freq, Time]

        # Pass through convolutional stem
        x = self.conv_stem(x)
        # Output x: [Batch, embed_dim, F/4, T/4]

        # Prepare for transformer: flatten spatial dims and permute
        b, c, h, w = x.shape
        x = x.flatten(2).permute(0, 2, 1) # -> [Batch, H*W, embed_dim]

        # Pass through transformer encoder
        x = self.transformer_encoder(x)

        # Apply final layer normalization
        x = self.output_layernorm(x)

        return x

# ==============================================================================
# --- 3. THE HYBRID EEG-BLIP2 MODEL ---
# ==============================================================================
class EEG_BLIP2_Model(nn.Module):
    """
    The final, refactored model that directly connects a powerful EEG encoder
    to the BLIP-2 Q-Former via a projection layer.
    """
    def __init__(self, config, device, eeg_embed_dim=768):
        super().__init__()
        self.device = device

        # 1. Instantiate the powerful, from-scratch EEG encoder
        self.eeg_encoder = EEGTransformerEncoder(embed_dim=eeg_embed_dim)

        # 2. Load the pre-trained BLIP model
        self.blip = Blip2ForConditionalGeneration.from_pretrained(config.BLIP_MODEL_NAME)

        # 3. Create the crucial projection layer
        # This layer maps the EEG encoder's output dimension (e.g., 768) to the
        # dimension the BLIP vision encoder's features, which the Q-Former expects (1408).
        blip_vision_hidden_size = self.blip.vision_model.config.hidden_size
        self.eeg_projection = nn.Linear(eeg_embed_dim, blip_vision_hidden_size)

        # The __init__ method no longer handles freezing. This is now controlled by the training loop.

        self.eeg_encoder.to(device)
        self.eeg_projection.to(device)
        self.blip.to(device)

    def get_eeg_embedding(self, eeg_spectrograms):
        """ The direct EEG-to-Q-Former forward pass. """
        # 1. Get rich features from our custom EEG encoder
        eeg_features = self.eeg_encoder(eeg_spectrograms)
        # Output shape: [Batch, SeqLen, 768]

        # 2. Project EEG features to match the Q-Former's expected input dimension
        projected_eeg_features = self.eeg_projection(eeg_features)
        # Output shape: [Batch, SeqLen, 1408]

        # 3. Pass projected features directly to the Q-Former
        query_tokens = self.blip.query_tokens.expand(projected_eeg_features.shape[0], -1, -1)
        query_outputs = self.blip.qformer(
            query_embeds=query_tokens,
            encoder_hidden_states=projected_eeg_features,
            return_dict=True,
        )
        return query_outputs.last_hidden_state.mean(dim=1)

    def get_image_embedding(self, pil_images, processor):
        # This function remains the same, using the standard image pathway
        inputs = processor(images=pil_images, return_tensors="pt").to(self.device)
        pixel_values = inputs.pixel_values
        image_features = self.blip.vision_model(pixel_values).last_hidden_state
        query_tokens = self.blip.query_tokens.expand(image_features.shape[0], -1, -1)
        query_outputs = self.blip.qformer(
            query_embeds=query_tokens,
            encoder_hidden_states=image_features,
            return_dict=True,
        )
        return query_outputs.last_hidden_state.mean(dim=1)

    def generate(self, eeg_spectrograms, **kwargs):
        with torch.no_grad():
            # Follow the same direct pathway for generation
            eeg_features = self.eeg_encoder(eeg_spectrograms)
            projected_eeg_features = self.eeg_projection(eeg_features)

            query_tokens = self.blip.query_tokens.expand(projected_eeg_features.shape[0], -1, -1)
            query_outputs = self.blip.qformer(
                query_embeds=query_tokens,
                encoder_hidden_states=projected_eeg_features,
                return_dict=True,
            )
            language_model_inputs = self.blip.language_projection(query_outputs.last_hidden_state)
            generated_ids = self.blip.language_model.generate(
                inputs_embeds=language_model_inputs,
                **kwargs
            )
            return generated_ids

    def forward(self, eeg_spectrograms):
        """
        Returns the EEG embedding and the raw logits from the language model.
        """
        # 1. Get EEG features and project them
        eeg_features = self.eeg_encoder(eeg_spectrograms)
        projected_eeg_features = self.eeg_projection(eeg_features)

        # 2. Get the 32 query embeddings from the Q-Former
        query_tokens = self.blip.query_tokens.expand(projected_eeg_features.shape[0], -1, -1)
        query_outputs = self.blip.qformer(
            query_embeds=query_tokens,
            encoder_hidden_states=projected_eeg_features,
            return_dict=True,
        )

        # 3. Project queries for the language model
        language_model_inputs = self.blip.language_projection(query_outputs.last_hidden_state)

        # 4. Get the logits from the language model
        outputs = self.blip.language_model(inputs_embeds=language_model_inputs, return_dict=True)
        logits = outputs.logits

        # Get the final EEG embedding for our alignment loss
        eeg_embedding = query_outputs.last_hidden_state.mean(dim=1)

        return eeg_embedding, logits

# ==============================================================================
# --- 4. DATASET & COLLATING ---
# ==============================================================================
def collate_fn(batch):
    spectrograms, pil_images, text_captions = zip(*[(item[0], item[1], item[2]) for item in batch])
    labels = [item[3] for item in batch]
    padded_labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)

    return torch.stack(spectrograms), list(pil_images), list(text_captions), padded_labels


class EEGDatasetWithCaptions(Dataset):
    def __init__(self, root_dir, metadata_csv, split, processor, transform, augment=False):
        self.root_dir = Path(root_dir)
        self.transform = transform
        df = pd.read_csv(metadata_csv)
        self.split_df = df[df['split'].str.strip() == split].reset_index(drop=True)
        self.augment = augment

        self.processor = processor

        # --- NEW: Define a dedicated augmentation pipeline ---
        if self.augment:
            self.augmentation_transform = transforms.Compose([
                transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.5),
                # Randomly mask out frequency bands
                transforms.RandomApply([transforms.RandomErasing(p=1.0, scale=(0.02, 0.1), ratio=(0.1, 0.5))], p=0.5),
                # Randomly mask out time steps
                transforms.RandomApply([transforms.RandomErasing(p=1.0, scale=(0.02, 0.1), ratio=(2.0, 5.0))], p=0.5),
            ])

        # This is your existing normalization transform
        self.normalization_transform = transform

        # Load pre-generated captions
        captions_path = Path(root_dir) / f'{split}_gt_captions.json'
        print(f"Loading pre-generated captions from {captions_path}")
        with open(captions_path, 'r') as f:
            self.captions = json.load(f)

    def __len__(self):
        return len(self.split_df)

    def __getitem__(self, idx):
        info = self.split_df.iloc[idx]

        spectrogram = torch.load(self.root_dir / info['spectrogram_path'])

        # --- APPLY AUGMENTATIONS (for training set only) ---
        if self.augment:
            spectrogram = self.augmentation_transform(spectrogram)

        # Apply normalization after augmentation
        spectrogram = self.normalization_transform(spectrogram)

        image = Image.open(self.root_dir / info['image_path']).convert("RGB")
        gt_caption = self.captions.get(str(idx), "an image")
        labels = self.processor(text=gt_caption, return_tensors="pt", padding=True).input_ids.squeeze()

        return spectrogram, image, gt_caption, labels

# ==============================================================================
# --- 5. MAIN TRAINING FUNCTION ---
# ==============================================================================
# InfoNCE loss function

def info_nce_loss(query, positive_key, temperature=0.07):
    # Ensure inputs are normalized
    query = F.normalize(query, dim=-1)
    positive_key = F.normalize(positive_key, dim=-1)

    # Calculate the similarity matrix of every query with every key
    # The diagonal of this matrix contains the positive pairs
    logits = query @ positive_key.T

    # The labels are the indices of the positive pairs (the diagonal)
    labels = torch.arange(len(query), device=query.device)

    # Calculate the cross-entropy loss
    return F.cross_entropy(logits / temperature, labels)


class SemanticLoss(nn.Module):
    """
    A loss function that measures the semantic similarity between two sets of captions
    using a pre-trained SentenceTransformer model.
    """
    def __init__(self, device):
        super().__init__()
        self.device = device
        # Load a pre-trained model optimized for semantic similarity
        self.model = SentenceTransformer('all-MiniLM-L6-v2').to(device)
        # Freeze the model's weights as we only use it for inference
        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, predicted_captions, ground_truth_captions):
        # Convert the text captions into semantic embedding vectors
        pred_embeddings = self.model.encode(predicted_captions, convert_to_tensor=True)
        gt_embeddings = self.model.encode(ground_truth_captions, convert_to_tensor=True)

        # Calculate the cosine similarity. The loss is 1.0 minus the similarity.
        # A higher similarity (closer to 1.0) results in a lower loss (closer to 0.0).
        cosine_sim = F.cosine_similarity(pred_embeddings, gt_embeddings, dim=-1)
        loss = 1.0 - cosine_sim.mean()
        return loss

#------------------------------------------------------------------------------------------

def create_and_save_reconstruction_grid(eval_samples, diffusion_pipe, epoch, val_loss, output_dir):
    """
    Generates a 2x2 image grid comparing original and reconstructed images.
    Top row: Original images with ground truth captions.
    Bottom row: Reconstructed images with predicted captions.
    """
    if len(eval_samples) < 2:
        print("Not enough samples for a 2x2 grid, skipping.")
        return

    print(f"\n🖼️  Generating reconstruction grid for epoch {epoch}...")

    # Setup canvas and font
    w, h = 512, 512  # Standard Stable Diffusion size
    title_h = 60    # Space for titles
    grid = Image.new('RGB', (w * 2, h * 2 + title_h * 2), 'black')
    draw = ImageDraw.Draw(grid)
    try:
        font = ImageFont.truetype("LiberationSans-Regular.ttf", 20)
    except IOError:
        print("Default font not found, using fallback.")
        font = ImageFont.load_default()

    # Process two samples
    for i in range(2):
        predicted_cap, gt_cap, original_image = eval_samples[i]

        # --- Generate the reconstructed image from the predicted caption ---
        generator = torch.Generator(device=diffusion_pipe.device).manual_seed(42 + i)
        reconstructed_image = diffusion_pipe(
            prompt=predicted_cap, generator=generator, num_inference_steps=20
        ).images[0]

        # --- Paste images and draw titles ---
        col_offset = i * w

        # Top Row (Original)
        draw.text((col_offset + 5, 5), f"GT: {gt_cap[:50]}...", font=font, fill="white")
        grid.paste(original_image.resize((w, h)), (col_offset, title_h))

        # Bottom Row (Reconstructed)
        draw.text((col_offset + 5, h + title_h + 5), f"Pred: {predicted_cap[:50]}...", font=font, fill="white")
        grid.paste(reconstructed_image.resize((w, h)), (col_offset, h + title_h * 2))

    # --- Save the final grid ---
    # Sanitize loss value for filename
    val_loss_str = f"{val_loss:.4f}".replace('.', '_')
    save_path = Path(output_dir) / 'reconstructions'
    save_path.mkdir(exist_ok=True)
    filename = save_path / f"reconstruction_epoch_{epoch}_loss_{val_loss_str}.png"
    grid.save(filename)
    print(f"  Saved reconstruction grid to {filename}")

#----------------------------------------------------------------------------------------------
def train_model(config):
    # --- Setup (Device, Output Dir, etc.) ---
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    output_dir = Path(config.OUTPUT_DIR) / f'run_{timestamp}'
    output_dir.mkdir(parents=True, exist_ok=True)

    # --- Models and Data ---
    processor = Blip2Processor.from_pretrained(config.BLIP_MODEL_NAME)
    model = EEG_BLIP2_Model(config, device) # Assumes the refactored, cleaner version

    for p in model.blip.parameters():
        p.requires_grad=False

    print("Starting training from scratch.")
    # try:
    #   print('Loading best model weights')
    #   model.load_state_dict(torch.load(config.BEST_MODEL_DIR, map_location=device))
    #   print('Successfully loaded best model weights!')
    # except:
    #   print('Could not load the best model weights.')

    print("Loading Stable Diffusion pipeline for visualizations...")
    diffusion_pipe = StableDiffusionPipeline.from_pretrained(
        config.DIFFUSION_MODEL_ID,
        torch_dtype=torch.float16 if device.type == "cuda" else torch.float32,
        safety_checker=None, requires_safety_checker=False
    ).to(device)
    diffusion_pipe.set_progress_bar_config(disable=True)

    semantic_loss_fn = SemanticLoss(device)
    ce_loss_fn = nn.CrossEntropyLoss(ignore_index=-100)

    # --- Data Loading ---
    spec_mean = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_mean.pt')
    spec_std = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_std.pt')
    spec_std += 1e-6

    transform = transforms.Compose([
        transforms.ToDtype(torch.float32),
        transforms.Resize((224, 224), antialias=True),
        transforms.Normalize(mean=spec_mean.tolist(), std=spec_std.tolist()),
    ])

    train_dataset = EEGDatasetWithCaptions(config.PROCESSED_DATA_ROOT, config.METADATA_CSV, 'train', processor, transform, augment=True)
    val_dataset = EEGDatasetWithCaptions(config.PROCESSED_DATA_ROOT, config.METADATA_CSV, 'val', processor, transform, augment=False)

    # --- Training Components ---
    scaler = torch.cuda.amp.GradScaler()
    best_val_loss = float('inf')

    generation_args = {
        "max_new_tokens": 15, "do_sample": True, "min_new_tokens": 5,
        "repetition_penalty": 1.4, "top_p": 0.9, "temperature": 0.85
    }

    # --- Main Training Loop ---
    print("🎯 Starting Decoupled 3-Stage Training Loop")
    for epoch in range(config.NUM_EPOCHS):

        # --- STAGE 1 & 2 Freezing Logic ---
        if epoch < config.STAGE1_EPOCHS:
            if epoch == 0:
                print(f"\n--- STAGE 1: Training EEG Encoder & Projection ONLY (Epochs 0-{config.STAGE1_EPOCHS-1}) ---")

            # Set requires_grad for Stage 1
            for name, param in model.named_parameters():
                if 'eeg_encoder' in name or 'eeg_projection' in name:
                    param.requires_grad = True
                    # print('Unfreeing Encoder and eeg-projection layers')

                else:
                    param.requires_grad = False

        elif epoch == config.STAGE1_EPOCHS:
            print(f"\n--- STAGE 2: Fine-Tuning Q-Former and language-projection (Epochs {config.STAGE1_EPOCHS}+) ---")

            # Set requires_grad for Stage 2
            for name, param in model.named_parameters():
                if 'qformer' in name or 'language_projection' in name:
                    param.requires_grad = True
                    # print('Unfreezing Q-Former layer')

                else:
                    param.requires_grad = False

        # elif epoch == config.STAGE2_EPOCHS:
        #     print(f"\n---STAGE 3: Fine-tuning Language Projection (Epochs {config.STAGE2_EPOCHS}+)---")

        #     for name, param in model.named_parameters():
        #         if 'qformer' in name or 'language_projection' in name:
        #             param.requires_grad = True
        #             # print('Unfreezing language-projection layer')

        #         else:
        #             param.requires_grad = False


        # --- Create/Re-create Optimizer & Scheduler at stage transitions ---
        if epoch == 0 or epoch == config.STAGE1_EPOCHS:
            print(f"Creating optimizer for new stage...")

            # Use differential learning rates for Stage 2 and 3
            if epoch < config.STAGE1_EPOCHS:
                trainable_params = filter(lambda p: p.requires_grad, model.parameters())
                optimizer = optim.AdamW(trainable_params, lr=config.ADAPTER_LR, weight_decay=config.WEIGHT_DECAY)
            else: # For both Stage 2 and Stage 3
                qformer_params = [p for n, p in model.named_parameters() if p.requires_grad and 'qformer' in n]
                lang_proj_params = [p for n, p in model.named_parameters() if p.requires_grad and 'language_projection' in n]

                optimizer_grouped_parameters = [
                    {"params": qformer_params, "lr": config.BLIP_LR},
                    {"params": lang_proj_params, "lr": config.BLIP_LR * 0.5}, # Even smaller LR for the final layer
                ]
                optimizer = optim.AdamW(optimizer_grouped_parameters, weight_decay=config.WEIGHT_DECAY)

            num_training_steps = (len(train_dataset) // config.BATCH_SIZE) * config.NUM_EPOCHS
            scheduler = get_linear_schedule_with_warmup(
                optimizer, num_warmup_steps=config.WARMUP_STEPS, num_training_steps=num_training_steps
            )
            print("Optimizer and scheduler created.")

        # --- Training Batch Loop ---
        model.train()
        train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE,
            sampler=SubsetRandomSampler(np.random.choice(len(train_dataset), min(len(train_dataset), config.TRAIN_SAMPLES_PER_EPOCH), replace=False)),
            drop_last=True, collate_fn=collate_fn)
        train_bar = tqdm(train_loader, desc=f"🔄 Epoch {epoch+1}/{config.NUM_EPOCHS}")

        for batch_idx, (spectrograms, pil_images, gt_text_captions, labels) in enumerate(train_bar):
            spectrograms, labels = spectrograms.to(device), labels.to(device)

            with torch.amp.autocast('cuda'):
                eeg_embeddings, logits = model(spectrograms)

                # For semantic loss, we need to generate captions (this is slower but necessary)
                with torch.no_grad():
                    image_embeddings = model.get_image_embedding(pil_images, processor)
                    generated_ids = model.generate(spectrograms, **generation_args)
                    eeg_captions = processor.batch_decode(generated_ids, skip_special_tokens=True)

                loss_semantic = semantic_loss_fn(eeg_captions, gt_text_captions)
                loss_align = info_nce_loss(eeg_embeddings, image_embeddings)

                logits_len = logits.size(1)
                labels_len = labels.size(1)

                if labels_len < logits_len:
                    padded_labels = F.pad(labels, (0, logits_len - labels_len), 'constant', -100)
                else:
                    # In case labels are longer (unlikely), truncate them
                    padded_labels = labels[:, :logits_len]

                loss_ce = ce_loss_fn(logits.view(-1, logits.size(-1)), padded_labels.view(-1))

                # Use consistent loss weights defined in config
                loss = (config.ALIGN_WEIGHT * loss_align) + (config.SEMANTIC_WEIGHT * loss_semantic) + (config.CE_WEIGHT * loss_ce)
                loss = loss / config.GRAD_ACCUMULATION_STEPS

            scaler.scale(loss).backward()

            if (batch_idx + 1) % config.GRAD_ACCUMULATION_STEPS == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, model.parameters()), config.GRADIENT_CLIP_NORM)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()

            train_bar.set_postfix({
                "ce_loss": f"{loss_ce.item():.4f}",
                "align_loss": f"{loss_align.item():.4f}",
                "sem_loss": f"{loss_semantic.item():.4f}"
            })

        # --- Validation Loop ---
        print(f"\n🔍 Running Validation for epoch {epoch+1}")
        model.eval()

        val_align_losses, val_semantic_losses, val_ce_losses, eval_samples = [], [], [], []

        val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE,
            sampler=SubsetRandomSampler(np.random.choice(len(val_dataset), min(len(val_dataset), config.VAL_SAMPLES_PER_EPOCH), replace=False)),
            collate_fn=collate_fn)

        with torch.no_grad(), torch.amp.autocast('cuda'):
            for spectrograms, pil_images, gt_text_captions, labels in tqdm(val_loader, desc="Validation Loop"):
                spectrograms, labels = spectrograms.to(device), labels.to(device)

                eeg_embeddings, logits = model(spectrograms)
                image_embeddings = model.get_image_embedding(pil_images, processor)

                generated_ids = model.generate(spectrograms, **generation_args)
                eeg_captions = processor.batch_decode(generated_ids, skip_special_tokens=True)

                loss_align = info_nce_loss(eeg_embeddings, image_embeddings)
                loss_semantic = semantic_loss_fn(eeg_captions, gt_text_captions)

                logits_len = logits.size(1)
                labels_len = labels.size(1)
                if labels_len < logits_len:
                    padded_labels = F.pad(labels, (0, logits_len - labels_len), 'constant', -100)
                else:
                    padded_labels = labels[:, :logits_len]

                loss_ce = ce_loss_fn(logits.view(-1, logits.size(-1)), padded_labels.view(-1))

                val_align_losses.append(loss_align.item())
                val_semantic_losses.append(loss_semantic.item())
                val_ce_losses.append(loss_ce.item())

                if len(eval_samples) < config.VIS_GRID_SIZE:
                    for i in range(min(config.VIS_GRID_SIZE - len(eval_samples), len(eeg_captions))):
                        eval_samples.append((eeg_captions[i], gt_text_captions[i], pil_images[i]))

        avg_val_align_loss = np.mean(val_align_losses)
        avg_val_semantic_loss = np.mean(val_semantic_losses)
        avg_val_ce_loss = np.mean(val_ce_losses)

        total_val_loss = (config.ALIGN_WEIGHT * avg_val_align_loss) + \
                         (config.SEMANTIC_WEIGHT * avg_val_semantic_loss) + \
                         (config.CE_WEIGHT * avg_val_ce_loss)

        print(f"  Validation Results:")
        print(f"    - Avg Alignment Loss: {avg_val_align_loss:.4f}")
        print(f"    - Avg Semantic Loss:  {avg_val_semantic_loss:.4f}")
        print(f"    - Avg CE Loss:        {avg_val_ce_loss:.4f}")
        print(f"    - Total Weighted Loss: {total_val_loss:.4f}")

        if total_val_loss < best_val_loss:
            best_val_loss = total_val_loss
            save_path = output_dir / f'best_model_epoch_{epoch+1}_loss_{total_val_loss:.4f}.pth'
            torch.save(model.state_dict(), save_path)
            print(f"🏆 New best model saved to {save_path}")

        # --- Print caption comparisons ---
        if eval_samples:
            print("\n--- ✍️  Caption Generation Samples ---")
            for i, (predicted_cap, gt_cap, _) in enumerate(eval_samples):
                print(f"Sample {i+1}:")
                print(f"  - Ground Truth: {gt_cap.strip()}")
                print(f"  - Predicted:    {predicted_cap.strip()}")
            print("------------------------------------")

        if (epoch + 1) % config.EVAL_IMAGE_GENERATION_INTERVAL == 0 and eval_samples:
            create_and_save_reconstruction_grid(
                eval_samples,
                diffusion_pipe,
                epoch + 1,
                total_val_loss,
                output_dir
            )

    print("\n🎉 Training Complete!")
    return output_dir

if __name__ == '__main__':
    config = TRAIN_CONFIG()
    try:
        output_dir = train_model(config)
        if output_dir:
            print(f"✅ Training completed successfully! Results saved to: {output_dir}")
    except Exception as e:
        print(f"❌ Training failed with error: {str(e)}")
        import traceback
        traceback.print_exc()

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Starting training from scratch.
Loading Stable Diffusion pipeline for visualizations...


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

`torch_dtype` is deprecated! Use `dtype` instead!


Loading pre-generated captions from /content/final_lightweight_17k/train_gt_captions.json
Loading pre-generated captions from /content/final_lightweight_17k/val_gt_captions.json
🎯 Starting Decoupled 3-Stage Training Loop

--- STAGE 1: Training EEG Encoder & Projection ONLY (Epochs 0-11) ---
Creating optimizer for new stage...
Optimizer and scheduler created.


  scaler = torch.cuda.amp.GradScaler()


🔄 Epoch 1/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 1


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.7038
    - Avg Semantic Loss:  0.9214
    - Avg CE Loss:        8.2764
    - Total Weighted Loss: 29.3757
🏆 New best model saved to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251002_171100/best_model_epoch_1_loss_29.3757.pth

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a baseball player throwing a ball on a field
  - Predicted:    a closeup of a black and white image
Sample 2:
  - Ground Truth: a man is putting toppings on a pizza
  - Predicted:    a photograph of a black and white picture on top
Sample 3:
  - Ground Truth: two birds perched on a table
  - Predicted:    a group of people sitting on the porch in front a house
Sample 4:
  - Ground Truth: a herd of cows in a forest
  - Predicted:    a black and white image of a person standing next to an old man
Sample 5:
  - Ground Truth: a kitchen with a stove, refrigerator, and microwave
  - Predicted:    the vertical strip of a picture is on top

🔄 Epoch 2/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 2


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.7129
    - Avg Semantic Loss:  0.9182
    - Avg CE Loss:        6.7723
    - Total Weighted Loss: 24.8661
🏆 New best model saved to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251002_171100/best_model_epoch_2_loss_24.8661.pth

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a group of horses and a dog in a field
  - Predicted:    two people walking past a tree and two birds flying
Sample 2:
  - Ground Truth: a bus is parked on the side of the road
  - Predicted:    a pair of shoes is in front with a black background and the words
Sample 3:
  - Ground Truth: a man swinging a baseball bat on a baseball field
  - Predicted:    the new black - 1 x book
Sample 4:
  - Ground Truth: a boat that is in the water
  - Predicted:    a flying eagle hanging from a tree
Sample 5:
  - Ground Truth: a pan of vegetables on the stove
  - Predicted:    a man standing on a chair next to an open window in the background
--

🔄 Epoch 3/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 3


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6814
    - Avg Semantic Loss:  0.9120
    - Avg CE Loss:        5.5582
    - Total Weighted Loss: 21.1800
🏆 New best model saved to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251002_171100/best_model_epoch_3_loss_21.1800.pth

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a man riding a snowboard down a snowy slope
  - Predicted:    a man on a scooter holding up his hand while looking down at
Sample 2:
  - Ground Truth: a pizza on a metal pan
  - Predicted:    boulevard by lincoln
Sample 3:
  - Ground Truth: a woman holding a yellow umbrella
  - Predicted:    two people sitting in chairs and looking at a wall with blue painted stripes
Sample 4:
  - Ground Truth: a person is falling off of a surfboard
  - Predicted:    a couple of cats sitting on a bench and an open bowl
Sample 5:
  - Ground Truth: a man in an orange jacket skiing down a snow covered slope
  - Predicted:    this is what you need for 

🔄 Epoch 4/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 4


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.7150
    - Avg Semantic Loss:  0.9235
    - Avg CE Loss:        5.0531
    - Total Weighted Loss: 19.7211
🏆 New best model saved to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251002_171100/best_model_epoch_4_loss_19.7211.pth

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a group of horses and a dog in a field
  - Predicted:    top 10 reasons to visit vietnam 1,000s of natural
Sample 2:
  - Ground Truth: a man playing tennis
  - Predicted:    high tech hong kong
Sample 3:
  - Ground Truth: a man riding a motorcycle down a highway
  - Predicted:    a pair of two-year old twins on a swing
Sample 4:
  - Ground Truth: three baseball players are walking on a field
  - Predicted:    a few of my favorite black and white shots from a great year
Sample 5:
  - Ground Truth: a man riding a horse in a field
  - Predicted:    waterfall and the ocean
------------------------------------

🖼️  Generating reconstruc

🔄 Epoch 5/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 5


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6812
    - Avg Semantic Loss:  0.9231
    - Avg CE Loss:        4.7603
    - Total Weighted Loss: 18.8083
🏆 New best model saved to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251002_171100/best_model_epoch_5_loss_18.8083.pth

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a woman playing tennis
  - Predicted:    snowboarders on the wall
Sample 2:
  - Ground Truth: a group of giraffes eating from a feeder
  - Predicted:    amended title: how to write on a picture you have taken
Sample 3:
  - Ground Truth: a woman sitting in a chair under an umbrella on the beach
  - Predicted:    tony barbour on the boardwalk nj
Sample 4:
  - Ground Truth: a cow standing in the grass
  - Predicted:    kawaii kimono danses sous la chale
Sample 5:
  - Ground Truth: a man holding a tennis racket on a tennis court
  - Predicted:    halo, the end of this is a pretty good ending for us
------------------------------------


🔄 Epoch 6/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 6


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6537
    - Avg Semantic Loss:  0.9184
    - Avg CE Loss:        4.6128
    - Total Weighted Loss: 18.3288
🏆 New best model saved to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251002_171100/best_model_epoch_6_loss_18.3288.pth

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: two giraffes in the wild
  - Predicted:    a house in a street of trees
Sample 2:
  - Ground Truth: a bird flying over the ocean and beach
  - Predicted:    abandoned house in a town background, city of shanghai
Sample 3:
  - Ground Truth: a couple of people sitting at a table outside of a building
  - Predicted:    and a background of the same color
Sample 4:
  - Ground Truth: a herd of cows grazing in a field
  - Predicted:    another of these is a little smaller
Sample 5:
  - Ground Truth: a large building with a clock on the top
  - Predicted:    lots of photos like that are in stock
------------------------------------

🖼️  Ge

🔄 Epoch 7/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 7


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6647
    - Avg Semantic Loss:  0.9123
    - Avg CE Loss:        4.5136
    - Total Weighted Loss: 18.0299
🏆 New best model saved to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251002_171100/best_model_epoch_7_loss_18.0299.pth

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a cat is laying under a blanket
  - Predicted:    on an outdoor scene, and a second group of
Sample 2:
  - Ground Truth: a computer monitor sitting on a desk next to a keyboard
  - Predicted:    how about a little bit of background?
Sample 3:
  - Ground Truth: a desk with a computer and a basket on it
  - Predicted:    on the right side of a page or on another image
Sample 4:
  - Ground Truth: a large airplane on the tarmac
  - Predicted:    a picture of a house, and this is one that has been very
Sample 5:
  - Ground Truth: a zebra is eating grass in a field
  - Predicted:    on the beach with a group of people behind
------------

🔄 Epoch 8/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 8


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6601
    - Avg Semantic Loss:  0.9112
    - Avg CE Loss:        4.4783
    - Total Weighted Loss: 17.9172
🏆 New best model saved to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251002_171100/best_model_epoch_8_loss_17.9172.pth

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a bus driving down a street with a few buildings in the background
  - Predicted:    the water is clear and the sky blue, so there isn't any
Sample 2:
  - Ground Truth: a dog laying in the sand under a chair
  - Predicted:    a field of daisies and flowers
Sample 3:
  - Ground Truth: a couple of people sitting at a table outside of a building
  - Predicted:    a close up of a man eating soup while sitting in the sun and
Sample 4:
  - Ground Truth: a man playing tennis
  - Predicted:    a table of contents for a book
Sample 5:
  - Ground Truth: a table with two containers of food and a bowl of salad
  - Predicted:    snowboarders on

🔄 Epoch 9/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 9


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6573
    - Avg Semantic Loss:  0.9290
    - Avg CE Loss:        4.2963
    - Total Weighted Loss: 17.4041
🏆 New best model saved to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251002_171100/best_model_epoch_9_loss_17.4041.pth

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: two yellow trains are parked at a train station
  - Predicted:    toddler a boy
Sample 2:
  - Ground Truth: a person riding skis down a snowy slope
  - Predicted:    a man on a boat with the sun behind him and his dog in
Sample 3:
  - Ground Truth: a woman in a pink shirt is about to hit a tennis ball
  - Predicted:    r/wtf at people who don't understand that this is an
Sample 4:
  - Ground Truth: a woman playing tennis
  - Predicted:    sitting on the floor, wearing a jacket and jeans
Sample 5:
  - Ground Truth: a sheep standing in front of a fence
  - Predicted:    I'll take the second part of that bet, good sir.
---------------

🔄 Epoch 10/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 10


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6507
    - Avg Semantic Loss:  0.9293
    - Avg CE Loss:        4.4246
    - Total Weighted Loss: 17.7829

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: three elephants walking in the grass
  - Predicted:    a house, a mountain
Sample 2:
  - Ground Truth: a mother elephant and baby elephant
  - Predicted:    a plane in the sky above a forest and buildings
Sample 3:
  - Ground Truth: a living room with a couch, a tv, and a window
  - Predicted:    he said that, not me. I don't know where he got
Sample 4:
  - Ground Truth: two zebras walking in a zoo
  - Predicted:    a small boy sitting on a red chair and then two children in the
Sample 5:
  - Ground Truth: a woman holding a tennis racket on a tennis court
  - Predicted:    it's a long time ago but i remember how much my mum used
------------------------------------

🖼️  Generating reconstruction grid for epoch 10...
  Saved reconstruction grid to /content/drive/MyDr

🔄 Epoch 11/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 11


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6604
    - Avg Semantic Loss:  0.9096
    - Avg CE Loss:        4.3732
    - Total Weighted Loss: 17.5993

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a group of birds standing on the beach
  - Predicted:    a green and white photo of a group
Sample 2:
  - Ground Truth: a coffee pot and a plate with a cake on it
  - Predicted:    The sun is shining in the sky, and people are walking on a
Sample 3:
  - Ground Truth: an elephant with tusks walking through a forest
  - Predicted:    I've been reading your posts for a while now and have learned so
Sample 4:
  - Ground Truth: a train traveling over a bridge over water
  - Predicted:    a car parked in a parking lot is shown with two men at the
Sample 5:
  - Ground Truth: a bed with a white sheet
  - Predicted:    a green lawn with a beach view, some water and the sand on
------------------------------------

🖼️  Generating reconstruction grid for epoch 11...
  Saved re

🔄 Epoch 12/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 12


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6506
    - Avg Semantic Loss:  0.9032
    - Avg CE Loss:        4.2794
    - Total Weighted Loss: 17.2952
🏆 New best model saved to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251002_171100/best_model_epoch_12_loss_17.2952.pth

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a microwave oven sitting on top of a counter
  - Predicted:    a woman sitting a beach, with the sky in her hand
Sample 2:
  - Ground Truth: a man on a skateboard doing a trick on a ledge
  - Predicted:    a large tree in front of a man sitting on it, with the
Sample 3:
  - Ground Truth: a herd of sheep standing under a tree
  - Predicted:    a large, a small or an individual object
Sample 4:
  - Ground Truth: a person is skiing down a snowy slope
  - Predicted:    a large number of people are walking in the background, and they have
Sample 5:
  - Ground Truth: a group of planes flying in the sky
  - Predicted:    a bus a boat and

🔄 Epoch 13/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 13


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6590
    - Avg Semantic Loss:  0.9115
    - Avg CE Loss:        4.2799
    - Total Weighted Loss: 17.3217

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a dog looking out the side mirror of a car
  - Predicted:    a beach is in the background a large group of people are gathered around
Sample 2:
  - Ground Truth: a man standing in front of a fruit stand with a dog
  - Predicted:    snow and sand are on the beach of a black sea
Sample 3:
  - Ground Truth: a group of people playing frisbee in a field
  - Predicted:    a car is parked at a stop sign in front of an open field
Sample 4:
  - Ground Truth: a living room with a table and chairs
  - Predicted:    r/foodporn maybe?
Sample 5:
  - Ground Truth: two pizzas on a table
  - Predicted:    a house in the middle of a field is on its side with trees
------------------------------------

🖼️  Generating reconstruction grid for epoch 13...
  Saved reconstruction grid to /

🔄 Epoch 14/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 14


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6569
    - Avg Semantic Loss:  0.9024
    - Avg CE Loss:        4.1764
    - Total Weighted Loss: 16.9911
🏆 New best model saved to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251002_171100/best_model_epoch_14_loss_16.9911.pth

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a desk with a computer and a monitor
  - Predicted:    a chair, a table and some books are seen in the background of
Sample 2:
  - Ground Truth: a baby elephant is standing next to an adult elephant
  - Predicted:    a man and a woman sitting on the ground looking at someone in black
Sample 3:
  - Ground Truth: a zebra is eating some grass
  - Predicted:    lgbt-related photo on the side of a building in downtown
Sample 4:
  - Ground Truth: a train traveling down the tracks with a bridge over it
  - Predicted:    a field of wheat with a cat in the middle.
Sample 5:
  - Ground Truth: a person riding skis down a snowy slope
  - Pred

🔄 Epoch 15/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 15


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6597
    - Avg Semantic Loss:  0.9038
    - Avg CE Loss:        4.1952
    - Total Weighted Loss: 17.0530

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a woman playing tennis on a court
  - Predicted:    a green, red and white pattern is shown on the background of a
Sample 2:
  - Ground Truth: two people sitting on a bench overlooking a body of water
  - Predicted:    a tree is the foreground of a landscape with people and cars in background
Sample 3:
  - Ground Truth: a group of giraffes in a zoo enclosure
  - Predicted:    an open air beach, the sea is blue and white.
Sample 4:
  - Ground Truth: a small airplane on the ground
  - Predicted:    a man with a backpack sitting next to the train, riding on it
Sample 5:
  - Ground Truth: a man and a woman playing a video game
  - Predicted:    one of my favorite things about living in new york is that it
------------------------------------

🖼️  Generating reconstructi

🔄 Epoch 16/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 16


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6561
    - Avg Semantic Loss:  0.9076
    - Avg CE Loss:        4.3117
    - Total Weighted Loss: 17.4064

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: two giraffes walking in a zoo enclosure
  - Predicted:    heater cover over a glass of water in the shape or pattern
Sample 2:
  - Ground Truth: a train station with people waiting for their train
  - Predicted:    a closeup of the face, nose and eyes (2)
Sample 3:
  - Ground Truth: a desk with two computer monitors and a keyboard
  - Predicted:    a house is a very common object that we see in our day to
Sample 4:
  - Ground Truth: a person riding a dirt bike on a field
  - Predicted:    he is a man, and has some interesting tattoos on his body
Sample 5:
  - Ground Truth: a large airplane flying in the sky
  - Predicted:    a woman sitting in a bath tub, and two men walking on the
------------------------------------

🖼️  Generating reconstruction grid for epoch 16

🔄 Epoch 17/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 17


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6518
    - Avg Semantic Loss:  0.9161
    - Avg CE Loss:        4.2370
    - Total Weighted Loss: 17.1949

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a large airplane on the tarmac
  - Predicted:    or the other way around
Sample 2:
  - Ground Truth: a group of cows grazing in a field
  - Predicted:    a couple of girls standing on the table, they look at each other
Sample 3:
  - Ground Truth: a group of people in a boat on a lake
  - Predicted:    snowing mountains, a black cat in the foreground is surrounded by
Sample 4:
  - Ground Truth: a wooden boardwalk
  - Predicted:    a table of food a dog and cat standing side by side in the
Sample 5:
  - Ground Truth: a pizza on a metal pan
  - Predicted:    a large awning, the window is covered with thick black paint
------------------------------------

🖼️  Generating reconstruction grid for epoch 17...
  Saved reconstruction grid to /content/drive/MyDrive/NeuroVisio

🔄 Epoch 18/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 18


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6478
    - Avg Semantic Loss:  0.9137
    - Avg CE Loss:        4.1373
    - Total Weighted Loss: 16.8871
🏆 New best model saved to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251002_171100/best_model_epoch_18_loss_16.8871.pth

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a tall building with a clock on it
  - Predicted:    snowing in the background
Sample 2:
  - Ground Truth: a group of people skiing down a snowy slope
  - Predicted:    a large picture of the two men on a bench in front
Sample 3:
  - Ground Truth: a living room with a couch, a table, and a window
  - Predicted:    it's pretty obvious which one is your child...
Sample 4:
  - Ground Truth: a bird flying over the ocean and beach
  - Predicted:    a group of people playing a game, and there is one sitting on
Sample 5:
  - Ground Truth: a dog sitting in a car
  - Predicted:    a lot of people with a smile on their face. The sun was
---

🔄 Epoch 19/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 19


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6553
    - Avg Semantic Loss:  0.9144
    - Avg CE Loss:        4.1886
    - Total Weighted Loss: 17.0500

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a zebra standing in the dirt
  - Predicted:    a view of the island in a blue sky is on display at an
Sample 2:
  - Ground Truth: a man in a wheelchair is being helped by a man on a skateboard
  - Predicted:    a parkour, a longboard and snow boarder on the mountain
Sample 3:
  - Ground Truth: a herd of zebras walking across a field
  - Predicted:    here is my collection of photos from the amazing event i attended today at
Sample 4:
  - Ground Truth: a group of planes flying in the sky with smoke coming out of them
  - Predicted:    a woman on the ground, a man with hat and two dogs behind
Sample 5:
  - Ground Truth: a young girl laying in bed with a blanket
  - Predicted:    a table on which a lot of things are stacked and piled up.
-------------------------------

🔄 Epoch 20/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 20


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6504
    - Avg Semantic Loss:  0.9203
    - Avg CE Loss:        4.1806
    - Total Weighted Loss: 17.0328

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a man riding a wave on a surfboard
  - Predicted:    anaconda in a bath of water, the snake is showing its
Sample 2:
  - Ground Truth: two elephants fighting in the grass
  - Predicted:    a man is eating a sandwich in the center of this picture and on
Sample 3:
  - Ground Truth: a street with a lot of lights and a crosswalk
  - Predicted:    our world of sports is the best.
Sample 4:
  - Ground Truth: a group of people playing frisbee in a field
  - Predicted:    rainbow beach on a rainy day.
Sample 5:
  - Ground Truth: a dog standing on a boat in the water
  - Predicted:    it's mexican food
------------------------------------

🖼️  Generating reconstruction grid for epoch 20...
  Saved reconstruction grid to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/

🔄 Epoch 21/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 21


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6466
    - Avg Semantic Loss:  0.9154
    - Avg CE Loss:        4.1934
    - Total Weighted Loss: 17.0577

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a snowboarder is doing a trick on a ramp
  - Predicted:    a big tree on a sunny day in the forest of northern calif
Sample 2:
  - Ground Truth: a group of boats on a beach
  - Predicted:    a kitchen a bathroom
Sample 3:
  - Ground Truth: a bench sitting on a dirt path in a park
  - Predicted:    snowboarder is taking his first ever slopestyle competition and
Sample 4:
  - Ground Truth: a man riding a wave on a surfboard
  - Predicted:    a snowboarder is a skier or roller board
Sample 5:
  - Ground Truth: a snowboarder in the air
  - Predicted:    a tree a mountain and the sea, all in one frame.
------------------------------------

🖼️  Generating reconstruction grid for epoch 21...
  Saved reconstruction grid to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignm

🔄 Epoch 22/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 22


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6481
    - Avg Semantic Loss:  0.9183
    - Avg CE Loss:        4.1480
    - Total Weighted Loss: 16.9286

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a man and a girl standing at a table with a sign
  - Predicted:    on the left, a man wearing blue jeans and grey shirt sitting on
Sample 2:
  - Ground Truth: two yellow trains are parked at a train station
  - Predicted:    a small boy is sitting on the bottom right corner, with a woman
Sample 3:
  - Ground Truth: a man on a skateboard doing a trick on a rail
  - Predicted:    a tree is a very popular image on the Internet. It shows trees
Sample 4:
  - Ground Truth: a bowl of mixed vegetables
  - Predicted:    a group of people playing a game on the beach in calabas
Sample 5:
  - Ground Truth: an elephant with tusks walking through a forest
  - Predicted:    a lot of people were walking past and they weren't even looking at
------------------------------------

🖼️ 

🔄 Epoch 23/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 23


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6507
    - Avg Semantic Loss:  0.9042
    - Avg CE Loss:        4.2442
    - Total Weighted Loss: 17.1917

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a group of people sitting at a table
  - Predicted:    an egg on a tray of meat with food color in the background.
Sample 2:
  - Ground Truth: a bird walking on the beach
  - Predicted:    a small green plant on a red background
Sample 3:
  - Ground Truth: a houseboat docked on a dock next to a forest
  - Predicted:    a green and yellow wall, a red sofa in the living room ,
Sample 4:
  - Ground Truth: a living room with a couch, a table, and a tv
  - Predicted:    a woman is standing on a stool holding her hand in front of the
Sample 5:
  - Ground Truth: a couple of people sitting at a table outside of a building
  - Predicted:    the beach is very famous and popular. Its name comes from the Greek
------------------------------------

🖼️  Generating reconstruction 

🔄 Epoch 24/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 24


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6478
    - Avg Semantic Loss:  0.9184
    - Avg CE Loss:        4.0629
    - Total Weighted Loss: 16.6732
🏆 New best model saved to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251002_171100/best_model_epoch_24_loss_16.6732.pth

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a clock hanging from a wire in a city
  - Predicted:    it seems like it's not even on the shelf.
Sample 2:
  - Ground Truth: a man riding a wave on a surfboard
  - Predicted:    a lot of people in a line at an event.
Sample 3:
  - Ground Truth: two sheep standing on a road
  - Predicted:    in the sky above water on a beach
Sample 4:
  - Ground Truth: a stop sign with a 4-way sign on it
  - Predicted:    a green carpet with a yellow wall behind it.
Sample 5:
  - Ground Truth: a zebra standing in front of a pile of wood
  - Predicted:    a beach full of people.
------------------------------------

🖼️  Generating reconstruction gr

🔄 Epoch 25/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 25


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6497
    - Avg Semantic Loss:  0.9208
    - Avg CE Loss:        4.0917
    - Total Weighted Loss: 16.7664

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a couple of people sitting at a table outside of a building
  - Predicted:    as it is a little dark to see what we are looking at
Sample 2:
  - Ground Truth: a young boy eating a donut
  - Predicted:    the thing is, you could have said "I went to the store
Sample 3:
  - Ground Truth: three children sitting on a couch
  - Predicted:    here is the best view of these two pictures:
Sample 4:
  - Ground Truth: a bird is perched on a post
  - Predicted:    its a dog, and not really that funny
Sample 5:
  - Ground Truth: a bear walking down a road with mountains in the background
  - Predicted:    it looks like it's on the edge of a cliff
------------------------------------

🖼️  Generating reconstruction grid for epoch 25...
  Saved reconstruction grid to /content/driv

🔄 Epoch 26/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 26


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6471
    - Avg Semantic Loss:  0.9199
    - Avg CE Loss:        4.1065
    - Total Weighted Loss: 16.8066

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a bear is standing on the grass
  - Predicted:    a black and white photo of a person is next to the red light
Sample 2:
  - Ground Truth: a traffic light hanging from a power line
  - Predicted:    a woman in a white shirt and black pants, the back of her
Sample 3:
  - Ground Truth: a pizza on a metal pan
  - Predicted:    a large green sign on the right side of a photo.
Sample 4:
  - Ground Truth: a group of birds sitting on a tree branch
  - Predicted:    a man sitting on a train car with luggage and an empty bag of
Sample 5:
  - Ground Truth: a box of donuts with two different flavors
  - Predicted:    a chair with a picture of the author above it
------------------------------------

🖼️  Generating reconstruction grid for epoch 26...
  Saved reconstruction grid

🔄 Epoch 27/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 27


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6510
    - Avg Semantic Loss:  0.9158
    - Avg CE Loss:        4.1018
    - Total Weighted Loss: 16.7881

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a herd of elephants in a field
  - Predicted:    , , .- -,-.-.--. ----.'---
Sample 2:
  - Ground Truth: a cat laying on a bed
  - Predicted:    it's a bad day to be an anime fan, eh?
Sample 3:
  - Ground Truth: a woman playing tennis on a court
  - Predicted:    a black cat sitting on a tree
Sample 4:
  - Ground Truth: a man holding a baby and a laptop
  - Predicted:    a car is a nice sight. I also like the blue sky and
Sample 5:
  - Ground Truth: a person riding skis down a snow covered slope
  - Predicted:    on a mountain of snow in the background is an image that looks like
------------------------------------

🖼️  Generating reconstruction grid for epoch 27...
  Saved reconstruction grid to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251002_17

🔄 Epoch 28/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 28


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6484
    - Avg Semantic Loss:  0.9155
    - Avg CE Loss:        3.9811
    - Total Weighted Loss: 16.4225
🏆 New best model saved to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251002_171100/best_model_epoch_28_loss_16.4225.pth

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: two women playing frisbee on a field
  - Predicted:    he just got in the car and drove it on his way to a
Sample 2:
  - Ground Truth: a desk with a laptop and a desktop computer
  - Predicted:    a large tree in the background
Sample 3:
  - Ground Truth: a train is on the tracks with a few cars
  - Predicted:    a close up of a woman and girl riding bicycles in front
Sample 4:
  - Ground Truth: a train is coming down the tracks
  - Predicted:    a large screen is covered with a black cloth.
Sample 5:
  - Ground Truth: a woman holding a tennis racket on a tennis court
  - Predicted:    a woman at the computer with a cat on her lap

🔄 Epoch 29/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 29


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6434
    - Avg Semantic Loss:  0.9199
    - Avg CE Loss:        4.0519
    - Total Weighted Loss: 16.6389

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a vase with flowers in it
  - Predicted:    on a dark background, two women are walking in the park. One
Sample 2:
  - Ground Truth: a man riding a wave on a surfboard
  - Predicted:    a table covered with plates, a plate in the middle of it and
Sample 3:
  - Ground Truth: man holding a baseball bat
  - Predicted:    it's not the same girl.  it is a different blonde,
Sample 4:
  - Ground Truth: a man riding a wave on a surfboard
  - Predicted:    a tree with a man dressed as the pope and other people on it
Sample 5:
  - Ground Truth: a plate of food with meat and broccoli
  - Predicted:    A car is standing on a street in the background. The text says
------------------------------------

🖼️  Generating reconstruction grid for epoch 29...
  Saved reconstruction gri

🔄 Epoch 30/30:   0%|          | 0/142 [00:00<?, ?it/s]


🔍 Running Validation for epoch 30


Validation Loop:   0%|          | 0/29 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.6430
    - Avg Semantic Loss:  0.9190
    - Avg CE Loss:        4.1190
    - Total Weighted Loss: 16.8381

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a living room with a table and chairs
  - Predicted:    a forest of the trees
Sample 2:
  - Ground Truth: a couple of people sitting at a table outside of a building
  - Predicted:    a long time ago a man and his dog got into an argument.
Sample 3:
  - Ground Truth: a bed with a blanket and a lamp in a room
  - Predicted:    a man playing a guitar, sitting in front of an abstract image on
Sample 4:
  - Ground Truth: two people laying on snowboards
  - Predicted:    on your knees with a towel over them.
Sample 5:
  - Ground Truth: two men playing tennis on a court
  - Predicted:    a table with a chair, some plates and glasses is in the foreground
------------------------------------

🖼️  Generating reconstruction grid for epoch 30...
  Saved reconstruction grid to 

# Correct Scripts for documentation

In [None]:
import torch
import torch.nn as nn
import pandas as pd
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision.transforms import v2 as transforms
from PIL import Image
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.auto import tqdm
from datetime import datetime
import numpy as np
from transformers import Blip2Processor, Blip2ForConditionalGeneration, Blip2VisionModel
import warnings
from sentence_transformers import SentenceTransformer
from diffusers import StableDiffusionPipeline
import torch.nn.functional as F
from PIL import Image, ImageDraw, ImageFont
import json
from transformers import get_linear_schedule_with_warmup

# ==============================================================================
# --- 1. CONFIGURATION ---
# ==============================================================================
class TRAIN_CONFIG:
    PROCESSED_DATA_ROOT = '/content/final_lightweight_17k'
    METADATA_CSV = Path(PROCESSED_DATA_ROOT) / 'metadata.csv'
    OUTPUT_DIR = '/content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment'

    BLIP_MODEL_NAME = "Salesforce/blip2-opt-2.7b"

    BATCH_SIZE = 8
    GRAD_ACCUMULATION_STEPS = 4
    NUM_EPOCHS = 50
    ADAPTER_LR = 1e-5  # Higher LR for the new adapter
    BLIP_LR = 1e-6     # Lower LR for fine-tuning the pre-trained layers

    WEIGHT_DECAY = 1e-3
    GRADIENT_CLIP_NORM = 2.0

    VALIDATION_INTERVAL = 1
    VAL_SAMPLES_PER_EPOCH = 400
    TRAIN_SAMPLES_PER_EPOCH = 2000
    VIS_GRID_SIZE = 5

    DIFFUSION_MODEL_ID ="runwayml/stable-diffusion-v1-5"
    EVAL_IMAGE_GENERATION_INTERVAL = 1
    WARMUP_STEPS = 400  # Number of steps to gradually increase the LR
    TOTAL_TRAIN_STEPS = 250 * NUM_EPOCHS
    STAGE1_EPOCHS = 15
    ALIGN_WEIGHT = 1.0
    SEMANTIC_WEIGHT = 2.5
    BEST_MODEL_DIR = '/content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251001_103040/best_model_epoch_27_loss_5.0528.pth'


# ==============================================================================
# --- 2. EEG ENCODER MODULE ---
# ==============================================================================

class DeeperEEGNetAdapter(nn.Module):
    """A deeper, more powerful EEGNet-inspired adapter."""
    def __init__(self, in_chans=64, out_chans=3, F1=16, D=2, F2=32, F3=64, dropout=0.5):
        super().__init__()
        # Block 1 (same as before)
        self.block1 = nn.Sequential(
            nn.Conv2d(1, F1, (1, 64), padding='same', bias=False),
            nn.BatchNorm2d(F1),
            nn.Conv2d(F1, F1 * D, (in_chans, 1), groups=F1, bias=False),
            nn.BatchNorm2d(F1 * D),
            nn.ELU(),
            nn.AvgPool2d((1, 4)),
            nn.Dropout(dropout)
        )
        # Block 2 (same as before)
        self.block2 = nn.Sequential(
            nn.Conv2d(F1 * D, F2, (1, 16), padding='same', groups=F1 * D, bias=False),
            nn.Conv2d(F2, F2, (1, 1), bias=False),
            nn.BatchNorm2d(F2),
            nn.ELU(),
            nn.AvgPool2d((1, 8)),
            nn.Dropout(dropout)
        )
        # --- NEW: Third block for more capacity ---
        self.block3 = nn.Sequential(
            nn.Conv2d(F2, F3, (1, 8), padding='same', groups=F2, bias=False),
            nn.Conv2d(F3, F3, (1, 1), bias=False),
            nn.BatchNorm2d(F3),
            nn.ELU(),
            nn.AvgPool2d((1, 2)),
            nn.Dropout(dropout)
        )
        self.final_projection = nn.Conv2d(F3, out_chans, (1, 1))
        self.adaptive_pool = nn.AdaptiveAvgPool2d((224, 224))

    def forward(self, x):
        x = x.view(x.size(0), 1, x.size(1), -1)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x) # Pass through the new block
        x = self.final_projection(x)
        x = self.adaptive_pool(x)
        return x

# ==============================================================================
# --- 3. THE HYBRID EEG-BLIP2 MODEL ---
# ==============================================================================
class EEG_BLIP2_Model(nn.Module):
    def __init__(self, config, device):
        super().__init__()
        self.device = device

        # --- REFACTOR: Instantiate the adapter directly ---
        self.adapter = DeeperEEGNetAdapter(in_chans=64)

        self.blip = Blip2ForConditionalGeneration.from_pretrained(config.BLIP_MODEL_NAME)

        # Freeze the entire BLIP model initially
        for param in self.blip.parameters():
            param.requires_grad = False

        self.adapter.to(device)
        self.blip.to(device)

    def get_eeg_embedding(self, eeg_spectrograms):
        # --- REFACTOR: Forward pass is now cleaner ---
        pseudo_image = self.adapter(eeg_spectrograms)
        vision_outputs = self.blip.vision_model(pseudo_image, return_dict=True)
        eeg_features = vision_outputs.last_hidden_state

        query_tokens = self.blip.query_tokens.expand(eeg_features.shape[0], -1, -1)
        query_outputs = self.blip.qformer(
            query_embeds=query_tokens,
            encoder_hidden_states=eeg_features,
            return_dict=True,
        )
        return query_outputs.last_hidden_state.mean(dim=1)

    def get_image_embedding(self, pil_images, processor):
        inputs = processor(images=pil_images, return_tensors="pt").to(self.device)
        pixel_values = inputs.pixel_values

        image_features = self.blip.vision_model(pixel_values).last_hidden_state
        query_tokens = self.blip.query_tokens.expand(image_features.shape[0], -1, -1)
        query_outputs = self.blip.qformer(
            query_embeds=query_tokens,
            encoder_hidden_states=image_features,
            return_dict=True,
        )
        return query_outputs.last_hidden_state.mean(dim=1)

    def generate(self, eeg_spectrograms, **kwargs):
        with torch.no_grad():
            # --- REFACTOR: Generation pass is also cleaner ---
            pseudo_image = self.adapter(eeg_spectrograms)
            vision_outputs = self.blip.vision_model(pseudo_image, return_dict=True)
            eeg_features = vision_outputs.last_hidden_state

            query_tokens = self.blip.query_tokens.expand(eeg_features.shape[0], -1, -1)
            query_outputs = self.blip.qformer(
                query_embeds=query_tokens,
                encoder_hidden_states=eeg_features,
                return_dict=True,
            )
            language_model_inputs = self.blip.language_projection(query_outputs.last_hidden_state)
            generated_ids = self.blip.language_model.generate(
                inputs_embeds=language_model_inputs,
                **kwargs
            )
            return generated_ids

# ==============================================================================
# --- 4. DATASET & COLLATING ---
# ==============================================================================
def collate_fn(batch):
    spectrograms = torch.stack([item[0] for item in batch])
    pil_images = [item[1] for item in batch]
    captions = [item[2] for item in batch]
    return spectrograms, pil_images, captions

class EEGDatasetWithCaptions(Dataset):
    def __init__(self, root_dir, metadata_csv, split, transform, augment=False):
        self.root_dir = Path(root_dir)
        self.transform = transform
        df = pd.read_csv(metadata_csv)
        self.split_df = df[df['split'].str.strip() == split].reset_index(drop=True)
        self.augment = augment

        # --- NEW: Define a dedicated augmentation pipeline ---
        if self.augment:
            self.augmentation_transform = transforms.Compose([
                transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.5),
                # Randomly mask out frequency bands
                transforms.RandomApply([transforms.RandomErasing(p=1.0, scale=(0.02, 0.1), ratio=(0.1, 0.5))], p=0.5),
                # Randomly mask out time steps
                transforms.RandomApply([transforms.RandomErasing(p=1.0, scale=(0.02, 0.1), ratio=(2.0, 5.0))], p=0.5),
            ])

        # This is your existing normalization transform
        self.normalization_transform = transform

        # Load pre-generated captions
        captions_path = Path(root_dir) / f'{split}_gt_captions.json'
        print(f"Loading pre-generated captions from {captions_path}")
        with open(captions_path, 'r') as f:
            self.captions = json.load(f)

    def __len__(self):
        return len(self.split_df)

    def __getitem__(self, idx):
        info = self.split_df.iloc[idx]

        spectrogram = torch.load(self.root_dir / info['spectrogram_path'])

        # --- APPLY AUGMENTATIONS (for training set only) ---
        if self.augment:
            spectrogram = self.augmentation_transform(spectrogram)

        # Apply normalization after augmentation
        spectrogram = self.normalization_transform(spectrogram)

        image = Image.open(self.root_dir / info['image_path']).convert("RGB")
        gt_caption = self.captions.get(str(idx), "an image")

        return spectrogram, image, gt_caption

# ==============================================================================
# --- 5. MAIN TRAINING FUNCTION ---
# ==============================================================================
# InfoNCE loss function

def info_nce_loss(query, positive_key, temperature=0.07):
    # Ensure inputs are normalized
    query = F.normalize(query, dim=-1)
    positive_key = F.normalize(positive_key, dim=-1)

    # Calculate the similarity matrix of every query with every key
    # The diagonal of this matrix contains the positive pairs
    logits = query @ positive_key.T

    # The labels are the indices of the positive pairs (the diagonal)
    labels = torch.arange(len(query), device=query.device)

    # Calculate the cross-entropy loss
    return F.cross_entropy(logits / temperature, labels)


class SemanticLoss(nn.Module):
    """
    A loss function that measures the semantic similarity between two sets of captions
    using a pre-trained SentenceTransformer model.
    """
    def __init__(self, device):
        super().__init__()
        self.device = device
        # Load a pre-trained model optimized for semantic similarity
        self.model = SentenceTransformer('all-MiniLM-L6-v2').to(device)
        # Freeze the model's weights as we only use it for inference
        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, predicted_captions, ground_truth_captions):
        # Convert the text captions into semantic embedding vectors
        pred_embeddings = self.model.encode(predicted_captions, convert_to_tensor=True)
        gt_embeddings = self.model.encode(ground_truth_captions, convert_to_tensor=True)

        # Calculate the cosine similarity. The loss is 1.0 minus the similarity.
        # A higher similarity (closer to 1.0) results in a lower loss (closer to 0.0).
        cosine_sim = F.cosine_similarity(pred_embeddings, gt_embeddings, dim=-1)
        loss = 1.0 - cosine_sim.mean()
        return loss

#------------------------------------------------------------------------------------------

def create_and_save_reconstruction_grid(eval_samples, diffusion_pipe, epoch, val_loss, output_dir):
    """
    Generates a 2x2 image grid comparing original and reconstructed images.
    Top row: Original images with ground truth captions.
    Bottom row: Reconstructed images with predicted captions.
    """
    if len(eval_samples) < 2:
        print("Not enough samples for a 2x2 grid, skipping.")
        return

    print(f"\n🖼️  Generating reconstruction grid for epoch {epoch}...")

    # Setup canvas and font
    w, h = 512, 512  # Standard Stable Diffusion size
    title_h = 60    # Space for titles
    grid = Image.new('RGB', (w * 2, h * 2 + title_h * 2), 'black')
    draw = ImageDraw.Draw(grid)
    try:
        font = ImageFont.truetype("LiberationSans-Regular.ttf", 20)
    except IOError:
        print("Default font not found, using fallback.")
        font = ImageFont.load_default()

    # Process two samples
    for i in range(5):
        predicted_cap, gt_cap, original_image = eval_samples[i]

        # --- Generate the reconstructed image from the predicted caption ---
        generator = torch.Generator(device=diffusion_pipe.device).manual_seed(42 + i)
        reconstructed_image = diffusion_pipe(
            prompt=predicted_cap, generator=generator, num_inference_steps=20
        ).images[0]

        # --- Paste images and draw titles ---
        col_offset = i * w

        # Top Row (Original)
        draw.text((col_offset + 5, 5), f"GT: {gt_cap[:50]}...", font=font, fill="white")
        grid.paste(original_image.resize((w, h)), (col_offset, title_h))

        # Bottom Row (Reconstructed)
        draw.text((col_offset + 5, h + title_h + 5), f"Pred: {predicted_cap[:50]}...", font=font, fill="white")
        grid.paste(reconstructed_image.resize((w, h)), (col_offset, h + title_h * 2))

    # --- Save the final grid ---
    # Sanitize loss value for filename
    val_loss_str = f"{val_loss:.4f}".replace('.', '_')
    save_path = Path(output_dir) / 'reconstructions'
    save_path.mkdir(exist_ok=True)
    filename = save_path / f"reconstruction_epoch_{epoch}_loss_{val_loss_str}.png"
    grid.save(filename)
    print(f"  Saved reconstruction grid to {filename}")

#----------------------------------------------------------------------------------------------
def train_model(config):
    # --- Setup (Device, Output Dir, etc.) ---
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    output_dir = Path(config.OUTPUT_DIR) / f'run_{timestamp}'
    output_dir.mkdir(parents=True, exist_ok=True)

    # --- Models and Data ---
    processor = Blip2Processor.from_pretrained(config.BLIP_MODEL_NAME)
    model = EEG_BLIP2_Model(config, device) # Assumes the refactored, cleaner version

    try:
      print('Loading best model weights')
      model.load_state_dict(torch.load(config.BEST_MODEL_DIR, map_location=device))
      print('Successfully loaded best model weights!')
    except:
      print('Could not load the best model weights.')

    print("Loading Stable Diffusion pipeline for visualizations...")
    diffusion_pipe = StableDiffusionPipeline.from_pretrained(
        config.DIFFUSION_MODEL_ID,
        torch_dtype=torch.float16 if device.type == "cuda" else torch.float32,
        safety_checker=None, requires_safety_checker=False
    ).to(device)
    diffusion_pipe.set_progress_bar_config(disable=True)

    semantic_loss_fn = SemanticLoss(device)

    # --- Data Loading ---
    spec_mean = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_mean.pt')
    spec_std = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_std.pt')
    spec_std += 1e-6

    transform = transforms.Compose([
        transforms.ToDtype(torch.float32),
        transforms.Resize((224, 224), antialias=True),
        transforms.Normalize(mean=spec_mean.tolist(), std=spec_std.tolist()),
    ])

    train_dataset = EEGDatasetWithCaptions(config.PROCESSED_DATA_ROOT, config.METADATA_CSV, 'train', transform, augment=True)
    val_dataset = EEGDatasetWithCaptions(config.PROCESSED_DATA_ROOT, config.METADATA_CSV, 'val', transform, augment=False)

    # --- Training Components ---
    scaler = torch.cuda.amp.GradScaler()
    best_val_loss = float('inf')

    generation_args = {
        "max_new_tokens": 30, "do_sample": True, "min_new_tokens": 8,
        "repetition_penalty": 1.25, "top_p": 0.9, "temperature": 0.81
    }

    # --- Main Training Loop ---
    print("🎯 Starting Decoupled 2-Stage Training Loop")
    for epoch in range(15, config.NUM_EPOCHS):
        # --- Stage 1 & 2 Freezing Logic ---
        if epoch < config.STAGE1_EPOCHS:
            if epoch == 0:
                print(f"\n--- STAGE 1: Training Adapter ONLY (Epochs 0-{config.STAGE1_EPOCHS-1}) ---")
            for param in model.adapter.parameters():
                param.requires_grad = True
            for param in model.blip.parameters():
                param.requires_grad = False
        elif epoch == config.STAGE1_EPOCHS:
            print(f"\n--- STAGE 2: Fine-Tuning Q-Former ONLY (Epochs {config.STAGE1_EPOCHS}+) ---")
            for param in model.adapter.parameters():
                param.requires_grad = False
            for param in model.blip.qformer.parameters():
                param.requires_grad = True
            for param in model.blip.vision_model.parameters():
                param.requires_grad = True
            if hasattr(model.blip, "language_projection"):
                for param in model.blip.language_projection.parameters():
                    param.requires_grad = True

        # --- Create/Re-create Optimizer & Scheduler at stage transitions ---
        if epoch == 0 or epoch == config.STAGE1_EPOCHS:
            print(f"Creating optimizer for Stage {1 if epoch < config.STAGE1_EPOCHS else 2}...")
            if epoch < config.STAGE1_EPOCHS:
                # Stage 1: Only adapter is trained
                optimizer = optim.AdamW(model.adapter.parameters(), lr=config.ADAPTER_LR, weight_decay=config.WEIGHT_DECAY)
            else:
                # Stage 2: Only BLIP layers are trained
                trainable_params = filter(lambda p: p.requires_grad, model.parameters())
                optimizer = optim.AdamW(trainable_params, lr=config.BLIP_LR, weight_decay=config.WEIGHT_DECAY)

            num_training_steps = (len(train_dataset) // config.BATCH_SIZE) * config.NUM_EPOCHS
            scheduler = get_linear_schedule_with_warmup(
                optimizer, num_warmup_steps=config.WARMUP_STEPS, num_training_steps=num_training_steps
            )
            print("Optimizer and scheduler created.")

        # --- Training Batch Loop ---
        model.train()
        train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE,
            sampler=SubsetRandomSampler(np.random.choice(len(train_dataset), min(len(train_dataset), config.TRAIN_SAMPLES_PER_EPOCH), replace=False)),
            drop_last=True, collate_fn=collate_fn)
        train_bar = tqdm(train_loader, desc=f"🔄 Epoch {epoch+1}/{config.NUM_EPOCHS}")

        for batch_idx, (spectrograms, pil_images, gt_captions) in enumerate(train_bar):
            spectrograms = spectrograms.to(device)
            with torch.amp.autocast('cuda'):
                eeg_embeddings = model.get_eeg_embedding(spectrograms)
                with torch.no_grad():
                    generated_ids = model.generate(spectrograms, **generation_args)
                    eeg_captions = processor.batch_decode(generated_ids, skip_special_tokens=True)
                    image_embeddings = model.get_image_embedding(pil_images, processor)

                loss_align = info_nce_loss(eeg_embeddings, image_embeddings)
                loss_semantic = semantic_loss_fn(eeg_captions, gt_captions)

                # Use consistent loss weights defined in config
                loss = (config.ALIGN_WEIGHT * loss_align) + (config.SEMANTIC_WEIGHT * loss_semantic)
                loss = loss / config.GRAD_ACCUMULATION_STEPS

            scaler.scale(loss).backward()

            if (batch_idx + 1) % config.GRAD_ACCUMULATION_STEPS == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, model.parameters()), config.GRADIENT_CLIP_NORM)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()

            train_bar.set_postfix({"align_loss": f"{loss_align.item():.4f}", "sem_loss": f"{loss_semantic.item():.4f}"})

        # --- Validation Loop ---
        print(f"\n🔍 Running Validation for epoch {epoch+1}")
        model.eval()

        # Lists to store metrics for this epoch
        val_align_losses = []
        val_semantic_losses = []
        eval_samples = [] # To store samples for printing and image generation

        # Create the validation data loader
        val_indices = np.random.choice(len(val_dataset), min(len(val_dataset), config.VAL_SAMPLES_PER_EPOCH), replace=False)
        val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE,
            sampler=SubsetRandomSampler(val_indices), collate_fn=collate_fn)

        # Disable gradients for the validation loop
        with torch.no_grad(), torch.amp.autocast('cuda'):
            for spectrograms, pil_images, gt_captions in tqdm(val_loader, desc="Validation Loop"):
                spectrograms = spectrograms.to(device)

                # --- Perform the forward pass ---
                eeg_embeddings = model.get_eeg_embedding(spectrograms)
                image_embeddings = model.get_image_embedding(pil_images, processor)

                # Generate predicted captions from the EEG
                generated_ids = model.generate(spectrograms, **generation_args)
                eeg_captions = processor.batch_decode(generated_ids, skip_special_tokens=True)

                # --- Calculate the two primary losses ---
                loss_align = info_nce_loss(eeg_embeddings, image_embeddings)
                loss_semantic = semantic_loss_fn(eeg_captions, gt_captions)

                val_align_losses.append(loss_align.item())
                val_semantic_losses.append(loss_semantic.item())

                # --- Collect a few samples for visualization ---
                if len(eval_samples) < config.VIS_GRID_SIZE:
                    num_to_add = config.VIS_GRID_SIZE - len(eval_samples)
                    for i in range(min(num_to_add, len(eeg_captions))):
                        eval_samples.append((eeg_captions[i], gt_captions[i], pil_images[i]))

        avg_val_align_loss = np.mean(val_align_losses)
        avg_val_semantic_loss = np.mean(val_semantic_losses)

        # Use consistent loss weights
        total_val_loss = (config.ALIGN_WEIGHT * avg_val_align_loss) + (config.SEMANTIC_WEIGHT * avg_val_semantic_loss)

        print(f"  Validation Results:")
        print(f"    - Avg Alignment Loss: {avg_val_align_loss:.4f}")
        print(f"    - Avg Semantic Loss:  {avg_val_semantic_loss:.4f}")
        print(f"    - Total Weighted Loss: {total_val_loss:.4f}")

        # Update the learning rate scheduler based on the total validation loss
        # scheduler.step(total_val_loss)

        # --- Save the best model based on the total validation loss ---
        if total_val_loss < best_val_loss:
            best_val_loss = total_val_loss
            save_path = output_dir / f'best_model_epoch_{epoch+1}_loss_{total_val_loss:.4f}.pth'
            torch.save(model.state_dict(), save_path)
            print(f"🏆 New best model saved to {save_path}")

        # --- Print caption comparisons ---
        if eval_samples:
            print("\n--- ✍️  Caption Generation Samples ---")
            for i, (predicted_cap, gt_cap, _) in enumerate(eval_samples):
                print(f"Sample {i+1}:")
                print(f"  - Ground Truth: {gt_cap.strip()}")
                print(f"  - Predicted:    {predicted_cap.strip()}")
            print("------------------------------------")

        if (epoch + 1) % config.EVAL_IMAGE_GENERATION_INTERVAL == 0 and eval_samples:
            create_and_save_reconstruction_grid(
                eval_samples,
                diffusion_pipe,
                epoch + 1,
                total_val_loss,
                output_dir
            )

    print("\n🎉 Training Complete!")
    return output_dir

if __name__ == '__main__':
    config = TRAIN_CONFIG()
    try:
        output_dir = train_model(config)
        if output_dir:
            print(f"✅ Training completed successfully! Results saved to: {output_dir}")
    except Exception as e:
        print(f"❌ Training failed with error: {str(e)}")
        import traceback
        traceback.print_exc()

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading best model weights
Successfully loaded best model weights!
Loading Stable Diffusion pipeline for visualizations...


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

`torch_dtype` is deprecated! Use `dtype` instead!


Loading pre-generated captions from /content/final_lightweight_17k/train_gt_captions.json
Loading pre-generated captions from /content/final_lightweight_17k/val_gt_captions.json
🎯 Starting Decoupled 2-Stage Training Loop

--- STAGE 2: Fine-Tuning Q-Former ONLY (Epochs 15+) ---
Creating optimizer for Stage 2...
Optimizer and scheduler created.


  scaler = torch.cuda.amp.GradScaler()


🔄 Epoch 16/50:   0%|          | 0/250 [00:00<?, ?it/s]

  return F.conv2d(



🔍 Running Validation for epoch 16


Validation Loop:   0%|          | 0/50 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.1269
    - Avg Semantic Loss:  0.9073
    - Total Weighted Loss: 4.3951
🏆 New best model saved to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251001_161150/best_model_epoch_16_loss_4.3951.pth

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a tall building with a clock on the top
  - Predicted:    a picture of the book, a paperback edition
Sample 2:
  - Ground Truth: a bus with a lot of luggage inside
  - Predicted:    the purple and maroon lines are on a white background
Sample 3:
  - Ground Truth: a lunch box with a container of food
  - Predicted:    a red and gray striped background with a black border
Sample 4:
  - Ground Truth: a man riding a wave on a surfboard
  - Predicted:    the art of personalizing your resume
Sample 5:
  - Ground Truth: a dog sitting in a car
  - Predicted:    a single frame of the video is displayed in a gray background
------------------------------------

🖼️  Generatin

🔄 Epoch 17/50:   0%|          | 0/250 [00:00<?, ?it/s]


🔍 Running Validation for epoch 17


Validation Loop:   0%|          | 0/50 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.1389
    - Avg Semantic Loss:  0.9047
    - Total Weighted Loss: 4.4007

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a skateboarder is doing a trick in a skate park
  - Predicted:    a large pink and purple striped wall
Sample 2:
  - Ground Truth: a man riding a surfboard on a wave
  - Predicted:    a gray square with a red and white border
Sample 3:
  - Ground Truth: a snowboarder is doing a trick on a ramp
  - Predicted:    an image of the bible with a large text
Sample 4:
  - Ground Truth: a group of boats on a beach
  - Predicted:    a green and purple line with a red border
Sample 5:
  - Ground Truth: a woman riding a horse in a dirt arena
  - Predicted:    the first person to do it has a purple background
------------------------------------

🖼️  Generating reconstruction grid for epoch 17...
  Saved reconstruction grid to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251001_161150/reconstruct

🔄 Epoch 18/50:   0%|          | 0/250 [00:00<?, ?it/s]


🔍 Running Validation for epoch 18


Validation Loop:   0%|          | 0/50 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.1426
    - Avg Semantic Loss:  0.9047
    - Total Weighted Loss: 4.4045

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a green and white fire hydrant
  - Predicted:    an image of a square with red and blue lines
Sample 2:
  - Ground Truth: a young boy eating a donut
  - Predicted:    the vertical axis of a bar chart with several colors
Sample 3:
  - Ground Truth: a car driving on a highway under a sign
  - Predicted:    a collection of vertical lines in a dark grey
Sample 4:
  - Ground Truth: a bathroom with a sink and a mirror
  - Predicted:    person, one of the best known faces in a popular and highly successful animated tv series
Sample 5:
  - Ground Truth: a kitchen with a stove, refrigerator, and microwave
  - Predicted:    the screen is pink, purple and green
------------------------------------

🖼️  Generating reconstruction grid for epoch 18...
  Saved reconstruction grid to /content/drive/MyDrive/NeuroVis

🔄 Epoch 19/50:   0%|          | 0/250 [00:00<?, ?it/s]


🔍 Running Validation for epoch 19


Validation Loop:   0%|          | 0/50 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.1280
    - Avg Semantic Loss:  0.9117
    - Total Weighted Loss: 4.4072

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a grill with food on it
  - Predicted:    the red line is a graph showing that the temperature in australia has increased over time
Sample 2:
  - Ground Truth: a bathroom with a sink and a mirror
  - Predicted:    a screenshot of a line graph showing the number of people in china
Sample 3:
  - Ground Truth: a monkey sitting on a ledge eating a banana
  - Predicted:    a man with a red and black stripe on his shirt
Sample 4:
  - Ground Truth: a bowl of soup on a table with two plates
  - Predicted:    a rainbow colored rectangle that shows two different colors
Sample 5:
  - Ground Truth: a baby is sitting on a couch with a teddy bear
  - Predicted:    a grey background with blue and green stripes
------------------------------------

🖼️  Generating reconstruction grid for epoch 19...
  Saved reconstr

🔄 Epoch 20/50:   0%|          | 0/250 [00:00<?, ?it/s]


🔍 Running Validation for epoch 20


Validation Loop:   0%|          | 0/50 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.1051
    - Avg Semantic Loss:  0.9068
    - Total Weighted Loss: 4.3722
🏆 New best model saved to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251001_161150/best_model_epoch_20_loss_4.3722.pth

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: two giraffes in the wild
  - Predicted:    a line of text that says "text"
Sample 2:
  - Ground Truth: a person riding skis down a snowy mountain
  - Predicted:    the map shows a large area with several lines
Sample 3:
  - Ground Truth: a street sign on a pole in front of tall buildings
  - Predicted:    a long line of pink and gray lines on a white background
Sample 4:
  - Ground Truth: a box of fruit
  - Predicted:    the background is shown in green, red and purple
Sample 5:
  - Ground Truth: a baby elephant is standing next to an adult elephant
  - Predicted:    the person on the phone with no one around
------------------------------------

🖼️  Generating rec

🔄 Epoch 21/50:   0%|          | 0/250 [00:00<?, ?it/s]


🔍 Running Validation for epoch 21


Validation Loop:   0%|          | 0/50 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.1013
    - Avg Semantic Loss:  0.9085
    - Total Weighted Loss: 4.3725

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a train traveling down the tracks with a few cars
  - Predicted:    a line drawing of a red and green tree
Sample 2:
  - Ground Truth: a bathroom with a sink, toilet and bathtub
  - Predicted:    a graphic background for a horizontal line
Sample 3:
  - Ground Truth: a group of motorcycles parked on the side of a mountain road
  - Predicted:    a screenshot of the dark theme on an android phone
Sample 4:
  - Ground Truth: a man in a suit and tie standing in a doorway
  - Predicted:    a vertical line of text on a black background
Sample 5:
  - Ground Truth: a train on the tracks
  - Predicted:    an image of a computer screen with an active text field
------------------------------------

🖼️  Generating reconstruction grid for epoch 21...
  Saved reconstruction grid to /content/drive/MyDrive/NeuroVis

🔄 Epoch 22/50:   0%|          | 0/250 [00:00<?, ?it/s]


🔍 Running Validation for epoch 22


Validation Loop:   0%|          | 0/50 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.1035
    - Avg Semantic Loss:  0.9011
    - Total Weighted Loss: 4.3562
🏆 New best model saved to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251001_161150/best_model_epoch_22_loss_4.3562.pth

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a couple of people sitting on a bench near a boat
  - Predicted:    a gray background with a red and purple line
Sample 2:
  - Ground Truth: a man looking at bananas
  - Predicted:    a text and graphic page for a line
Sample 3:
  - Ground Truth: a white and black vase sitting on a table
  - Predicted:    a red and black line background with a circle
Sample 4:
  - Ground Truth: a man sitting at a desk with a laptop computer
  - Predicted:    a line drawing of a person with the sun in front
Sample 5:
  - Ground Truth: a man in a tuxedo
  - Predicted:    the web page with an image of a bird
------------------------------------

🖼️  Generating reconstruction grid for 

🔄 Epoch 23/50:   0%|          | 0/250 [00:00<?, ?it/s]


🔍 Running Validation for epoch 23


Validation Loop:   0%|          | 0/50 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.1007
    - Avg Semantic Loss:  0.8875
    - Total Weighted Loss: 4.3195
🏆 New best model saved to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251001_161150/best_model_epoch_23_loss_4.3195.pth

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a baseball player swinging at a ball
  - Predicted:    a vertical stripe, a pink and grey line
Sample 2:
  - Ground Truth: a person walking down a street with a red umbrella
  - Predicted:    a pink background with some lines and curves
Sample 3:
  - Ground Truth: two black swans swimming in the water
  - Predicted:    an image of a small grey and red line
Sample 4:
  - Ground Truth: a bench sitting on a dirt path in a park
  - Predicted:    a vertical red and purple gradient background
Sample 5:
  - Ground Truth: a living room with a table and chairs
  - Predicted:    a dark purple and pink border on the screen
------------------------------------

🖼️  Generating 

🔄 Epoch 24/50:   0%|          | 0/250 [00:00<?, ?it/s]


🔍 Running Validation for epoch 24


Validation Loop:   0%|          | 0/50 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.1074
    - Avg Semantic Loss:  0.8923
    - Total Weighted Loss: 4.3382

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a young boy eating a donut
  - Predicted:    an image of a purple and pink line
Sample 2:
  - Ground Truth: a zebra standing in a dirt field
  - Predicted:    a purple and yellow stripe on a dark background
Sample 3:
  - Ground Truth: a man riding a motorcycle down a highway
  - Predicted:    a green and blue striped pattern is shown in the background
Sample 4:
  - Ground Truth: a living room with a couch, a table, and a window
  - Predicted:    a beautiful image of a flower in pink and gray
Sample 5:
  - Ground Truth: a cat sleeping on a bench
  - Predicted:    a red and blue abstract graphic with a gradient effect
------------------------------------

🖼️  Generating reconstruction grid for epoch 24...
  Saved reconstruction grid to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_2025100

🔄 Epoch 25/50:   0%|          | 0/250 [00:00<?, ?it/s]


🔍 Running Validation for epoch 25


Validation Loop:   0%|          | 0/50 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.1008
    - Avg Semantic Loss:  0.8932
    - Total Weighted Loss: 4.3339

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a train traveling over a bridge over a body of water
  - Predicted:    a purple, yellow and green machine
Sample 2:
  - Ground Truth: a couple of people sitting at a table outside of a building
  - Predicted:    an animated black and white graphite line drawing
Sample 3:
  - Ground Truth: a train is traveling down the tracks
  - Predicted:    a gray and pink graphite slide
Sample 4:
  - Ground Truth: a person on skis
  - Predicted:    the background image shows a red and pink line
Sample 5:
  - Ground Truth: a baby elephant is standing next to an adult elephant
  - Predicted:    a white circle with an arrow pointing to it
------------------------------------

🖼️  Generating reconstruction grid for epoch 25...
  Saved reconstruction grid to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_2

🔄 Epoch 26/50:   0%|          | 0/250 [00:00<?, ?it/s]


🔍 Running Validation for epoch 26


Validation Loop:   0%|          | 0/50 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.0935
    - Avg Semantic Loss:  0.8894
    - Total Weighted Loss: 4.3170
🏆 New best model saved to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251001_161150/best_model_epoch_26_loss_4.3170.pth

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a group of people in wet suits on surfboards
  - Predicted:    3d text in 3D with a purple background
Sample 2:
  - Ground Truth: a woman holding a tennis racket on a tennis court
  - Predicted:    the purple and blue lines are shown in a rectangular shape
Sample 3:
  - Ground Truth: a woman holding a rainbow colored umbrella
  - Predicted:    a gray and pink colored background, with the words "the great wall of china"
Sample 4:
  - Ground Truth: a lunch box with two containers of food
  - Predicted:    a large image of a red, green and yellow striped line
Sample 5:
  - Ground Truth: two zebras in a fenced in area
  - Predicted:    pink and purple lines artboard
--

🔄 Epoch 27/50:   0%|          | 0/250 [00:00<?, ?it/s]


🔍 Running Validation for epoch 27


Validation Loop:   0%|          | 0/50 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.0988
    - Avg Semantic Loss:  0.9052
    - Total Weighted Loss: 4.3618

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a man is skiing down a snow covered slope
  - Predicted:    an airplane in the sky and an arrow on it
Sample 2:
  - Ground Truth: a desk with a laptop and a desktop computer
  - Predicted:    the background is a purple and blue color
Sample 3:
  - Ground Truth: a man riding a wave on a surfboard
  - Predicted:    a pink and purple vertical map background
Sample 4:
  - Ground Truth: a man riding a bike down a street
  - Predicted:    a pink and red background with an image of a boat
Sample 5:
  - Ground Truth: a bear is standing in the water
  - Predicted:    the same image as the one above
------------------------------------

🖼️  Generating reconstruction grid for epoch 27...
  Saved reconstruction grid to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251001_161150/reconstructions/re

🔄 Epoch 28/50:   0%|          | 0/250 [00:00<?, ?it/s]


🔍 Running Validation for epoch 28


Validation Loop:   0%|          | 0/50 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.0909
    - Avg Semantic Loss:  0.8927
    - Total Weighted Loss: 4.3227

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a dog sitting in the grass next to a frisbee
  - Predicted:    the black and white map of a diagonal vertical line
Sample 2:
  - Ground Truth: a table set for a dinner party
  - Predicted:    a long black and white diagonal line
Sample 3:
  - Ground Truth: a person riding skis down a snow covered slope
  - Predicted:    an abstract white line texture with pink and purple stripes
Sample 4:
  - Ground Truth: a man riding a wave on a surfboard
  - Predicted:    a gray and black abstract image with a pink border
Sample 5:
  - Ground Truth: a man riding a wave on a surfboard
  - Predicted:    the abstract pink, blue and purple line art
------------------------------------

🖼️  Generating reconstruction grid for epoch 28...
  Saved reconstruction grid to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignm

🔄 Epoch 29/50:   0%|          | 0/250 [00:00<?, ?it/s]


🔍 Running Validation for epoch 29


Validation Loop:   0%|          | 0/50 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.0937
    - Avg Semantic Loss:  0.8916
    - Total Weighted Loss: 4.3228

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a cutting board with a bunch of donuts on it
  - Predicted:    a line of purple and pink lines on a grey background
Sample 2:
  - Ground Truth: two giraffes standing in a field with trees in the background
  - Predicted:    a long, white line that is a beautiful color
Sample 3:
  - Ground Truth: two sheep standing on a road
  - Predicted:    a rainbow of shapes in a dark background
Sample 4:
  - Ground Truth: a plate of pasta and broccoli
  - Predicted:    a long, pink and red striped line
Sample 5:
  - Ground Truth: a train is on the tracks at a station
  - Predicted:    a grayish blue and purple smokey, grey cloud background
------------------------------------

🖼️  Generating reconstruction grid for epoch 29...
  Saved reconstruction grid to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/

🔄 Epoch 30/50:   0%|          | 0/250 [00:00<?, ?it/s]


🔍 Running Validation for epoch 30


Validation Loop:   0%|          | 0/50 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.0917
    - Avg Semantic Loss:  0.8938
    - Total Weighted Loss: 4.3261

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a couple of red double decker buses driving down a street
  - Predicted:    a red and black stripey background
Sample 2:
  - Ground Truth: a plate of food with meat and broccoli
  - Predicted:    a red and purple vertical line drawing
Sample 3:
  - Ground Truth: a group of people sitting on a park bench
  - Predicted:    an image of a computer screen displaying a graph
Sample 4:
  - Ground Truth: a baby is laying in a basket with a teddy bear
  - Predicted:    the purple and blue line story map
Sample 5:
  - Ground Truth: a group of boats on a beach
  - Predicted:    a pink and purple graphite line art
------------------------------------

🖼️  Generating reconstruction grid for epoch 30...
  Saved reconstruction grid to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251001_161150/recon

🔄 Epoch 31/50:   0%|          | 0/250 [00:00<?, ?it/s]


🔍 Running Validation for epoch 31


Validation Loop:   0%|          | 0/50 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.0905
    - Avg Semantic Loss:  0.9001
    - Total Weighted Loss: 4.3408

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a yellow train on the tracks
  - Predicted:    a line drawing of a tree with leaves
Sample 2:
  - Ground Truth: a man riding a surfboard on a wave
  - Predicted:    a red, pink and gray line graph
Sample 3:
  - Ground Truth: a group of horses standing on a hill
  - Predicted:    pink and blue stripe background for web and mobile
Sample 4:
  - Ground Truth: three giraffes in a pen
  - Predicted:    a long pink and purple line on a white background
Sample 5:
  - Ground Truth: a bunk bed in a room with a door
  - Predicted:    a pink and purple line of lines on a wall
------------------------------------

🖼️  Generating reconstruction grid for epoch 31...
  Saved reconstruction grid to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251001_161150/reconstructions/reconstruction_epoch_31_los

🔄 Epoch 32/50:   0%|          | 0/250 [00:00<?, ?it/s]


🔍 Running Validation for epoch 32


Validation Loop:   0%|          | 0/50 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.0899
    - Avg Semantic Loss:  0.9041
    - Total Weighted Loss: 4.3502

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a blue motorcycle parked in a parking lot
  - Predicted:    a purple and pink line on a red background
Sample 2:
  - Ground Truth: a man in a wet suit is riding a wave
  - Predicted:    a dark, long vertical stripe with pink and brown lines
Sample 3:
  - Ground Truth: a train station with several trains on the tracks
  - Predicted:    a long red and pink diagonal stripe
Sample 4:
  - Ground Truth: a man swinging a baseball bat on a baseball field
  - Predicted:    a pink and gray background is shown on a large computer screen
Sample 5:
  - Ground Truth: a group of dogs running in a field
  - Predicted:    a ball in the air with a long and thin line
------------------------------------

🖼️  Generating reconstruction grid for epoch 32...
  Saved reconstruction grid to /content/drive/MyDrive/NeuroVisio

🔄 Epoch 33/50:   0%|          | 0/250 [00:00<?, ?it/s]


🔍 Running Validation for epoch 33


Validation Loop:   0%|          | 0/50 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.0897
    - Avg Semantic Loss:  0.8946
    - Total Weighted Loss: 4.3261

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a table with two containers of food and a bowl of salad
  - Predicted:    a female and male line of panthers
Sample 2:
  - Ground Truth: a dog laying in the sand under a chair
  - Predicted:    a long black and grey line with the word
Sample 3:
  - Ground Truth: a kitchen with pots hanging from the ceiling
  - Predicted:    an abstract background with pink and red lines
Sample 4:
  - Ground Truth: a white bowl with food
  - Predicted:    a dark red and blue line on the background
Sample 5:
  - Ground Truth: a tall stone building with a clock on the side
  - Predicted:    a red and pink abstract line art
------------------------------------

🖼️  Generating reconstruction grid for epoch 33...
  Saved reconstruction grid to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20251001_161150/reco

🔄 Epoch 34/50:   0%|          | 0/250 [00:00<?, ?it/s]


🔍 Running Validation for epoch 34


Validation Loop:   0%|          | 0/50 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.0910
    - Avg Semantic Loss:  0.8927
    - Total Weighted Loss: 4.3227

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a bed with white sheets
  - Predicted:    a black and grey background with the word "breathtaking"
Sample 2:
  - Ground Truth: three zebras grazing in a field near a body of water
  - Predicted:    a large pink, blue and purple line art
Sample 3:
  - Ground Truth: a man sitting at a table eating a meal
  - Predicted:    a red and gray pattern of a large floor
Sample 4:
  - Ground Truth: a large air france airplane flying through the air
  - Predicted:    a pink and brown line on a grey background
Sample 5:
  - Ground Truth: a bathroom with a toilet, sink and mirror
  - Predicted:    a collection of a pink and purple
------------------------------------

🖼️  Generating reconstruction grid for epoch 34...
  Saved reconstruction grid to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_2025100

🔄 Epoch 35/50:   0%|          | 0/250 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
import torch
import torch.nn as nn
import pandas as pd
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision.transforms import v2 as transforms
from PIL import Image
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.auto import tqdm
from datetime import datetime
import numpy as np
from transformers import Blip2Processor, Blip2ForConditionalGeneration, Blip2VisionModel
import warnings
from sentence_transformers import SentenceTransformer
from diffusers import StableDiffusionPipeline
import torch.nn.functional as F
from PIL import Image, ImageDraw, ImageFont
import json
from transformers import get_linear_schedule_with_warmup

# ==============================================================================
# --- 1. CONFIGURATION ---
# ==============================================================================
class TRAIN_CONFIG:
    PROCESSED_DATA_ROOT = '/content/final_lightweight_17k'
    METADATA_CSV = Path(PROCESSED_DATA_ROOT) / 'metadata.csv'
    OUTPUT_DIR = '/content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment'

    BLIP_MODEL_NAME = "Salesforce/blip2-opt-2.7b"

    BATCH_SIZE = 16
    GRAD_ACCUMULATION_STEPS = 2
    NUM_EPOCHS = 40
    ADAPTER_LR = 1e-5  # Higher LR for the new adapter
    BLIP_LR = 3e-6     # Lower LR for fine-tuning the pre-trained layers

    WEIGHT_DECAY = 1e-3
    GRADIENT_CLIP_NORM = 1.0

    VALIDATION_INTERVAL = 1
    VAL_SAMPLES_PER_EPOCH = 700
    TRAIN_SAMPLES_PER_EPOCH = 3500
    VIS_GRID_SIZE = 5

    DIFFUSION_MODEL_ID ="runwayml/stable-diffusion-v1-5"
    EVAL_IMAGE_GENERATION_INTERVAL = 1
    WARMUP_STEPS = 150  # Number of steps to gradually increase the LR
    TOTAL_TRAIN_STEPS = 218 * 40
    STAGE1_EPOCHS = 2

# ==============================================================================
# --- 2. EEG ENCODER MODULE ---
# ==============================================================================

class ChannelAdapter(nn.Module):
    """
    A powerful CNN adapter inspired by EEGNet principles to extract robust
    features from EEG spectrograms. It processes the 64 channels to find
    temporal and spatial patterns, then creates a 3-channel pseudo-image.
    """
    def __init__(self, in_chans=64, out_chans=3, F1=16, D=2, F2=32, dropout=0.4):
        super().__init__()

        # Block 1: Temporal and Depthwise Spatial Convolutions
        self.block1 = nn.Sequential(
            # Input is [B, 1, Chans, Samples]
            # Temporal Conv to learn filters across time
            nn.Conv2d(1, F1, (1, 64), padding='same', bias=False),
            nn.BatchNorm2d(F1),
            # Depthwise Spatial Conv to learn spatial filters for each feature map
            nn.Conv2d(F1, F1 * D, (in_chans, 1), groups=F1, bias=False),
            nn.BatchNorm2d(F1 * D),
            nn.ELU(),
            # nn.AvgPool2d((1, 4)),
            nn.Dropout(dropout)
        )

        # Block 2: Separable Convolutions for more abstract features
        self.block2 = nn.Sequential(
            # Separable Conv (Depthwise followed by Pointwise)
            nn.Conv2d(F1 * D, F2, (1, 16), padding='same', groups=F1 * D, bias=False),
            nn.Conv2d(F2, F2, (1, 1), bias=False),
            nn.BatchNorm2d(F2),
            nn.ELU(),
            # nn.AvgPool2d((1, 8)),
            nn.Dropout(dropout)
        )

        # Final projection layer to create the 3-channel "pseudo-image"
        self.final_projection = nn.Conv2d(F2, out_chans, (1, 1))
        self.adaptive_pool = nn.AdaptiveAvgPool2d((224, 224))

    def forward(self, x):
        # Input x is [Batch, 64 Channels, Freq Bins, Time Steps]
        # We reshape it to fit the EEGNet structure: [Batch, 1, Channels, Samples]
        # This treats the frequency and time dimensions as one long sequence.
        x = x.view(x.size(0), 1, x.size(1), -1)

        x = self.block1(x)
        x = self.block2(x)

        # Project to 3 channels
        x = self.final_projection(x)
        x = self.adaptive_pool(x)

        # The output is a feature map, which will be resized by the vision model's patch embedding
        return x
#-------------------------------------------------------------------------------------------

class EEGEncoder(nn.Module):
    def __init__(self, config, in_chans=64):
        super().__init__()
        # self.vision_model = Blip2VisionModel.from_pretrained(config.BLIP_MODEL_NAME)

        # Freeze the original vision model
        # for param in self.vision_model.parameters():
        #     param.requires_grad = False

        self.adapter = ChannelAdapter(in_chans=in_chans)

        for p in self.adapter.parameters():
            p.requires_grad = True


    def forward(self, x, vision_model):
      adapted_x = self.adapter(x)
      vision_outputs = vision_model(adapted_x, return_dict=True)
      return vision_outputs.last_hidden_state

# ==============================================================================
# --- 3. THE HYBRID EEG-BLIP2 MODEL ---
# ==============================================================================
class EEG_BLIP2_Model(nn.Module):
    def __init__(self, config, device):
        super().__init__()
        self.device = device
        self.eeg_encoder = EEGEncoder(config)
        self.blip = Blip2ForConditionalGeneration.from_pretrained(
            config.BLIP_MODEL_NAME
        )

        # Freeze the entire BLIP model initially
        for param in self.blip.parameters():
            param.requires_grad = False

        # print("Unfreezing specified layers for fine-tuning...")

        # --- THIS IS THE KEY CHANGE ---
        # Unfreeze the ENTIRE Q-Former
        # for param in self.blip.qformer.parameters():
        #     param.requires_grad = True
        # print(" - Unfroze the entire Q-Former.")

        # # Also unfreeze the language projection layer (connects Q-Former to LLM)
        # if hasattr(self.blip, "language_projection"):
        #     for param in self.blip.language_projection.parameters():
        #         param.requires_grad = True
        #     print("  - Unfroze Language Projection layer.")


        self.eeg_encoder.to(device)
        self.blip.to(device)

    def get_eeg_embedding(self, eeg_spectrograms):
        eeg_features = self.eeg_encoder(eeg_spectrograms, self.blip.vision_model)
        query_tokens = self.blip.query_tokens.expand(eeg_features.shape[0], -1, -1)
        query_outputs = self.blip.qformer(
            query_embeds=query_tokens,
            encoder_hidden_states=eeg_features,
            return_dict=True,
        )
        return query_outputs.last_hidden_state.mean(dim=1)

    def get_image_embedding(self, pil_images, processor):
        inputs = processor(images=pil_images, return_tensors="pt").to(self.device)
        pixel_values = inputs.pixel_values

        image_features = self.blip.vision_model(pixel_values).last_hidden_state
        query_tokens = self.blip.query_tokens.expand(image_features.shape[0], -1, -1)
        query_outputs = self.blip.qformer(
            query_embeds=query_tokens,
            encoder_hidden_states=image_features,
            return_dict=True,
        )
        return query_outputs.last_hidden_state.mean(dim=1)

    def generate(self, eeg_spectrograms, **kwargs):
        with torch.no_grad():
            eeg_features = self.eeg_encoder(eeg_spectrograms, self.blip.vision_model)
            query_tokens = self.blip.query_tokens.expand(eeg_features.shape[0], -1, -1)
            query_outputs = self.blip.qformer(
                query_embeds=query_tokens,
                encoder_hidden_states=eeg_features,
                return_dict=True,
            )
            language_model_inputs = self.blip.language_projection(query_outputs.last_hidden_state)
            generated_ids = self.blip.language_model.generate(
                inputs_embeds=language_model_inputs,
                **kwargs
            )
            return generated_ids

# ==============================================================================
# --- 4. DATASET & COLLATING ---
# ==============================================================================
def collate_fn(batch):
    spectrograms = torch.stack([item[0] for item in batch])
    pil_images = [item[1] for item in batch]
    captions = [item[2] for item in batch]
    return spectrograms, pil_images, captions

class EEGDatasetWithCaptions(Dataset):
    def __init__(self, root_dir, metadata_csv, split, transform, augment=False):
        self.root_dir = Path(root_dir)
        self.transform = transform
        df = pd.read_csv(metadata_csv)
        self.split_df = df[df['split'].str.strip() == split].reset_index(drop=True)
        self.augment = augment

        # --- NEW: Define a dedicated augmentation pipeline ---
        if self.augment:
            self.augmentation_transform = transforms.Compose([
                transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.5),
                # Randomly mask out frequency bands
                transforms.RandomApply([transforms.RandomErasing(p=1.0, scale=(0.02, 0.1), ratio=(0.1, 0.5))], p=0.5),
                # Randomly mask out time steps
                transforms.RandomApply([transforms.RandomErasing(p=1.0, scale=(0.02, 0.1), ratio=(2.0, 5.0))], p=0.5),
            ])

        # This is your existing normalization transform
        self.normalization_transform = transform

        # Load pre-generated captions
        captions_path = Path(root_dir) / f'{split}_gt_captions.json'
        print(f"Loading pre-generated captions from {captions_path}")
        with open(captions_path, 'r') as f:
            self.captions = json.load(f)

    def __len__(self):
        return len(self.split_df)

    def __getitem__(self, idx):
        info = self.split_df.iloc[idx]

        spectrogram = torch.load(self.root_dir / info['spectrogram_path'])

        # --- APPLY AUGMENTATIONS (for training set only) ---
        if self.augment:
            spectrogram = self.augmentation_transform(spectrogram)

        # Apply normalization after augmentation
        spectrogram = self.normalization_transform(spectrogram)

        image = Image.open(self.root_dir / info['image_path']).convert("RGB")
        gt_caption = self.captions.get(str(idx), "an image")

        return spectrogram, image, gt_caption

# ==============================================================================
# --- 5. MAIN TRAINING FUNCTION ---
# ==============================================================================
# InfoNCE loss function

def info_nce_loss(query, positive_key, temperature=0.07):
    # Ensure inputs are normalized
    query = F.normalize(query, dim=-1)
    positive_key = F.normalize(positive_key, dim=-1)

    # Calculate the similarity matrix of every query with every key
    # The diagonal of this matrix contains the positive pairs
    logits = query @ positive_key.T

    # The labels are the indices of the positive pairs (the diagonal)
    labels = torch.arange(len(query), device=query.device)

    # Calculate the cross-entropy loss
    return F.cross_entropy(logits / temperature, labels)


class SemanticLoss(nn.Module):
    """
    A loss function that measures the semantic similarity between two sets of captions
    using a pre-trained SentenceTransformer model.
    """
    def __init__(self, device):
        super().__init__()
        self.device = device
        # Load a pre-trained model optimized for semantic similarity
        self.model = SentenceTransformer('all-MiniLM-L6-v2').to(device)
        # Freeze the model's weights as we only use it for inference
        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, predicted_captions, ground_truth_captions):
        # Convert the text captions into semantic embedding vectors
        pred_embeddings = self.model.encode(predicted_captions, convert_to_tensor=True)
        gt_embeddings = self.model.encode(ground_truth_captions, convert_to_tensor=True)

        # Calculate the cosine similarity. The loss is 1.0 minus the similarity.
        # A higher similarity (closer to 1.0) results in a lower loss (closer to 0.0).
        cosine_sim = F.cosine_similarity(pred_embeddings, gt_embeddings, dim=-1)
        loss = 1.0 - cosine_sim.mean()
        return loss

#------------------------------------------------------------------------------------------

def create_and_save_reconstruction_grid(eval_samples, diffusion_pipe, epoch, val_loss, output_dir):
    """
    Generates a 2x2 image grid comparing original and reconstructed images.
    Top row: Original images with ground truth captions.
    Bottom row: Reconstructed images with predicted captions.
    """
    if len(eval_samples) < 2:
        print("Not enough samples for a 2x2 grid, skipping.")
        return

    print(f"\n🖼️  Generating reconstruction grid for epoch {epoch}...")

    # Setup canvas and font
    w, h = 512, 512  # Standard Stable Diffusion size
    title_h = 60    # Space for titles
    grid = Image.new('RGB', (w * 2, h * 2 + title_h * 2), 'black')
    draw = ImageDraw.Draw(grid)
    try:
        font = ImageFont.truetype("LiberationSans-Regular.ttf", 20)
    except IOError:
        print("Default font not found, using fallback.")
        font = ImageFont.load_default()

    # Process two samples
    for i in range(5):
        predicted_cap, gt_cap, original_image = eval_samples[i]

        # --- Generate the reconstructed image from the predicted caption ---
        generator = torch.Generator(device=diffusion_pipe.device).manual_seed(42 + i)
        reconstructed_image = diffusion_pipe(
            prompt=predicted_cap, generator=generator, num_inference_steps=20
        ).images[0]

        # --- Paste images and draw titles ---
        col_offset = i * w

        # Top Row (Original)
        draw.text((col_offset + 5, 5), f"GT: {gt_cap[:50]}...", font=font, fill="white")
        grid.paste(original_image.resize((w, h)), (col_offset, title_h))

        # Bottom Row (Reconstructed)
        draw.text((col_offset + 5, h + title_h + 5), f"Pred: {predicted_cap[:50]}...", font=font, fill="white")
        grid.paste(reconstructed_image.resize((w, h)), (col_offset, h + title_h * 2))

    # --- Save the final grid ---
    # Sanitize loss value for filename
    val_loss_str = f"{val_loss:.4f}".replace('.', '_')
    save_path = Path(output_dir) / 'reconstructions'
    save_path.mkdir(exist_ok=True)
    filename = save_path / f"reconstruction_epoch_{epoch}_loss_{val_loss_str}.png"
    grid.save(filename)
    print(f"  Saved reconstruction grid to {filename}")

#----------------------------------------------------------------------------------------------
def train_model(config):
    # --- Setup (Device, Output Dir, etc.) ---
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    output_dir = Path(config.OUTPUT_DIR) / f'run_{timestamp}'
    output_dir.mkdir(parents=True, exist_ok=True)

    # --- Models and Data ---
    processor = Blip2Processor.from_pretrained(config.BLIP_MODEL_NAME)
    model = EEG_BLIP2_Model(config, device)
    try:
        model.load_state_dict(torch.load('/content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20250930_190044/best_model_epoch_10_loss_5.9745.pth',
                                         map_location=device))
        print('Loaded saved EEG_BLIP2_Model weights')
    except:
        print('Could not load eeg_blip2_model weights')

    print("Loading Stable Diffusion pipeline for visualizations...")
    diffusion_pipe = StableDiffusionPipeline.from_pretrained(
        config.DIFFUSION_MODEL_ID,
        torch_dtype=torch.float16 if device.type == "cuda" else torch.float32,
        safety_checker=None, requires_safety_checker=False
    ).to(device)
    diffusion_pipe.set_progress_bar_config(disable=True)

    semantic_loss_fn = SemanticLoss(device)

    # --- Data Loading (with Epsilon fix) ---
    spec_mean = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_mean.pt')
    spec_std = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_std.pt')
    spec_std += 1e-6

    transform = transforms.Compose([
        transforms.ToDtype(torch.float32),
        transforms.Resize((224, 224), antialias=True),
        transforms.Normalize(mean=spec_mean.tolist(), std=spec_std.tolist()),
    ])

    train_dataset = EEGDatasetWithCaptions(config.PROCESSED_DATA_ROOT, config.METADATA_CSV, 'train', transform, augment=True)
    val_dataset = EEGDatasetWithCaptions(config.PROCESSED_DATA_ROOT, config.METADATA_CSV, 'val', transform, augment=False)

    # --- Optimizer, Scheduler, and Training Components ---
    # NOTE: Moved outside the loop for correct initialization
    scaler = torch.cuda.amp.GradScaler()
    best_val_loss = float('inf')

    generation_args = {
        "max_new_tokens": 30, "do_sample": True, "min_new_tokens": 6,
        "repetition_penalty": 1.4, "top_p": 0.9, "temperature": 0.85
    }

    # --- Main Training Loop ---
    print("🎯 Starting 2-Stage Training Loop")
    for epoch in range(config.NUM_EPOCHS):
        # --- STAGE 1: Train Adapter Only (Epochs 0-19) ---
        if epoch < config.STAGE1_EPOCHS:
            if epoch == 0:
                print("\n--- STAGE 1: Training EEGNetAdapter Only ---")
            for param in model.eeg_encoder.adapter.parameters():
                param.requires_grad = True
            for param in model.blip.parameters():
                param.requires_grad = False

        # --- STAGE 2: Fine-Tune Q-Former (Epochs 20+) ---
        elif epoch == config.STAGE1_EPOCHS:
            print("\n--- STAGE 2: Fine-Tuning Q-Former ---")
            for param in model.blip.qformer.parameters():
                param.requires_grad = True
            if hasattr(model.blip, "language_projection"):
                for param in model.blip.language_projection.parameters():
                    param.requires_grad = True
                print("  - Unfroze Language Projection layer.")

        # --- Create Optimizer & Scheduler at the start of each stage ---
        if epoch == 0 or epoch == config.STAGE1_EPOCHS:
            print(f"Creating optimizer for Stage {1 if epoch==0 else 2}...")
            adapter_params = [p for n, p in model.named_parameters() if p.requires_grad and 'adapter' in n]
            blip_params = [p for n, p in model.named_parameters() if p.requires_grad and 'adapter' not in n]

            optimizer_grouped_parameters = [
                {"params": adapter_params, "lr": config.ADAPTER_LR},
                {"params": blip_params, "lr": config.BLIP_LR},
            ]
            optimizer = optim.AdamW(optimizer_grouped_parameters, weight_decay=config.WEIGHT_DECAY)

            # Calculate total steps for the whole training run for the scheduler
            num_training_steps = (len(train_dataset) // config.BATCH_SIZE) * config.NUM_EPOCHS
            scheduler = get_linear_schedule_with_warmup(
                optimizer, num_warmup_steps=300, num_training_steps=num_training_steps
            )
            print("Optimizer and scheduler created.")

        model.train()
        train_loader = DataLoader(
            train_dataset, batch_size=config.BATCH_SIZE,
            sampler=SubsetRandomSampler(np.random.choice(len(train_dataset), min(len(train_dataset), config.TRAIN_SAMPLES_PER_EPOCH), replace=False)),
            drop_last=True, collate_fn=collate_fn
        )
        train_bar = tqdm(train_loader, desc=f"🔄 Epoch {epoch+1}/{config.NUM_EPOCHS}")

        for batch_idx, (spectrograms, pil_images, gt_captions) in enumerate(train_bar):
            spectrograms = spectrograms.to(device)

            with torch.amp.autocast('cuda'):
                eeg_embeddings = model.get_eeg_embedding(spectrograms)
                with torch.no_grad():
                    generated_ids = model.generate(spectrograms, **generation_args)
                    eeg_captions = processor.batch_decode(generated_ids, skip_special_tokens=True)
                    image_embeddings = model.get_image_embedding(pil_images, processor)

                loss_align = info_nce_loss(eeg_embeddings, image_embeddings)
                loss_semantic = semantic_loss_fn(eeg_captions, gt_captions)
                loss = (1.0 * loss_align) + (2.5 * loss_semantic)
                loss = loss / config.GRAD_ACCUMULATION_STEPS

            # scaler.scale() must be outside the autocast block
            scaler.scale(loss).backward()

            # Optimizer step logic must also be outside the autocast block
            if (batch_idx + 1) % config.GRAD_ACCUMULATION_STEPS == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, model.parameters()), config.GRADIENT_CLIP_NORM)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()

            train_bar.set_postfix({"align_loss": f"{loss_align.item():.4f}", "sem_loss": f"{loss_semantic.item():.4f}"})


        print(f"\n🔍 Running Validation for epoch {epoch+1}")
        model.eval()

        # Lists to store metrics for this epoch
        val_align_losses = []
        val_semantic_losses = []
        eval_samples = [] # To store samples for printing and image generation

        # Create the validation data loader
        val_indices = np.random.choice(len(val_dataset), min(len(val_dataset), config.VAL_SAMPLES_PER_EPOCH), replace=False)
        val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, sampler=SubsetRandomSampler(val_indices), collate_fn=collate_fn)

        # Disable gradients for the validation loop
        with torch.no_grad(), torch.amp.autocast('cuda'):
            for spectrograms, pil_images, gt_captions in tqdm(val_loader, desc="Validation Loop"):
                spectrograms = spectrograms.to(device)

                # --- Perform the forward pass ---
                eeg_embeddings = model.get_eeg_embedding(spectrograms)
                image_embeddings = model.get_image_embedding(pil_images, processor)

                # Generate predicted captions from the EEG
                generated_ids = model.generate(spectrograms, **generation_args)
                eeg_captions = processor.batch_decode(generated_ids, skip_special_tokens=True)

                # --- Calculate the two primary losses ---
                loss_align = info_nce_loss(eeg_embeddings, image_embeddings)
                loss_semantic = semantic_loss_fn(eeg_captions, gt_captions)

                val_align_losses.append(loss_align.item())
                val_semantic_losses.append(loss_semantic.item())

                # --- Collect a few samples for visualization ---
                if len(eval_samples) < config.VIS_GRID_SIZE:
                    num_to_add = config.VIS_GRID_SIZE - len(eval_samples)
                    for i in range(min(num_to_add, len(eeg_captions))):
                        eval_samples.append((eeg_captions[i], gt_captions[i], pil_images[i]))

        # --- Calculate and print average losses ---
        avg_val_align_loss = np.mean(val_align_losses)
        avg_val_semantic_loss = np.mean(val_semantic_losses)
        # The total loss is a weighted sum, same as in the training loop
        total_val_loss = (1.0 * avg_val_align_loss) + (2.5 * avg_val_semantic_loss)

        print(f"  Validation Results:")
        print(f"    - Avg Alignment Loss: {avg_val_align_loss:.4f}")
        print(f"    - Avg Semantic Loss:  {avg_val_semantic_loss:.4f}")
        print(f"    - Total Weighted Loss: {total_val_loss:.4f}")

        # Update the learning rate scheduler based on the total validation loss
        # scheduler.step(total_val_loss)

        # --- Save the best model based on the total validation loss ---
        if total_val_loss < best_val_loss:
            best_val_loss = total_val_loss
            save_path = output_dir / f'best_model_epoch_{epoch+1}_loss_{total_val_loss:.4f}.pth'
            torch.save(model.state_dict(), save_path)
            print(f"🏆 New best model saved to {save_path}")

        # --- Print caption comparisons ---
        if eval_samples:
            print("\n--- ✍️  Caption Generation Samples ---")
            for i, (predicted_cap, gt_cap, _) in enumerate(eval_samples):
                print(f"Sample {i+1}:")
                print(f"  - Ground Truth: {gt_cap.strip()}")
                print(f"  - Predicted:    {predicted_cap.strip()}")
            print("------------------------------------")

        if (epoch + 1) % config.EVAL_IMAGE_GENERATION_INTERVAL == 0 and eval_samples:
            create_and_save_reconstruction_grid(
                eval_samples,
                diffusion_pipe,
                epoch + 1,
                total_val_loss,
                output_dir
            )

    print("\n🎉 Training Complete!")
    return output_dir

if __name__ == '__main__':
    config = TRAIN_CONFIG()
    try:
        output_dir = train_model(config)
        if output_dir:
            print(f"✅ Training completed successfully! Results saved to: {output_dir}")
    except Exception as e:
        print(f"❌ Training failed with error: {str(e)}")
        import traceback
        traceback.print_exc()

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded saved EEG_BLIP2_Model weights
Loading Stable Diffusion pipeline for visualizations...


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

`torch_dtype` is deprecated! Use `dtype` instead!


Loading pre-generated captions from /content/final_lightweight_17k/train_gt_captions.json
Loading pre-generated captions from /content/final_lightweight_17k/val_gt_captions.json
🎯 Starting 2-Stage Training Loop

--- STAGE 1: Training EEGNetAdapter Only ---
Creating optimizer for Stage 1...
Optimizer and scheduler created.


  scaler = torch.cuda.amp.GradScaler()


🔄 Epoch 1/40:   0%|          | 0/218 [00:00<?, ?it/s]

  return F.conv2d(



🔍 Running Validation for epoch 1


Validation Loop:   0%|          | 0/44 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.7884
    - Avg Semantic Loss:  0.8919
    - Total Weighted Loss: 5.0182
🏆 New best model saved to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20250930_234041/best_model_epoch_1_loss_5.0182.pth

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: two seagulls are swimming in the water
  - Predicted:    a purple striped background with a thin line in the middle
Sample 2:
  - Ground Truth: an elephant standing in a field with bushes in the background
  - Predicted:    the dark blue and red vertical lines on a gray background
Sample 3:
  - Ground Truth: a sheep is grazing on grass in a field
  - Predicted:    a large square with a long gray line
Sample 4:
  - Ground Truth: a tall building with a clock on the top
  - Predicted:    a person with long hair is shown in this picture
Sample 5:
  - Ground Truth: a baseball player swinging a bat at a ball
  - Predicted:    a grey and red striped wall with a dark backgr

🔄 Epoch 2/40:   0%|          | 0/218 [00:00<?, ?it/s]


🔍 Running Validation for epoch 2


Validation Loop:   0%|          | 0/44 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.7897
    - Avg Semantic Loss:  0.8939
    - Total Weighted Loss: 5.0244

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a plane on the runway
  - Predicted:    a pink and gray striped background
Sample 2:
  - Ground Truth: a boat that is traveling on the water
  - Predicted:    a gray and pink striped background
Sample 3:
  - Ground Truth: a dog sitting in a car
  - Predicted:    a woman with her hands on the keyboard and some text
Sample 4:
  - Ground Truth: a herd of cows grazing in a field
  - Predicted:    a man is wearing a shirt with the words "you are beautiful"
Sample 5:
  - Ground Truth: a car with a surfboard on top of it
  - Predicted:    a video of the screen with pink and purple lines
------------------------------------

🖼️  Generating reconstruction grid for epoch 2...
  Saved reconstruction grid to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20250930_234041/reconstructions/reconstructio

🔄 Epoch 3/40:   0%|          | 0/218 [00:00<?, ?it/s]


🔍 Running Validation for epoch 3


Validation Loop:   0%|          | 0/44 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.7895
    - Avg Semantic Loss:  0.8914
    - Total Weighted Loss: 5.0179
🏆 New best model saved to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20250930_234041/best_model_epoch_3_loss_5.0179.pth

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a kitchen with a refrigerator, stove, and sink
  - Predicted:    a vertical stripe with red and black lines
Sample 2:
  - Ground Truth: a vase with flowers in it
  - Predicted:    a screenshot of the image on a mobile device
Sample 3:
  - Ground Truth: three giraffes in a zoo enclosure
  - Predicted:    a gray, red and black screen is shown in the image
Sample 4:
  - Ground Truth: a plate of broccoli
  - Predicted:    the wedding video is shown with two red and pink bars
Sample 5:
  - Ground Truth: a desk with a computer, a laptop, and a monitor
  - Predicted:    a large horizontal line with red stripes and a small black box
------------------------------------

🖼️ 

🔄 Epoch 4/40:   0%|          | 0/218 [00:00<?, ?it/s]


🔍 Running Validation for epoch 4


Validation Loop:   0%|          | 0/44 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.7895
    - Avg Semantic Loss:  0.8869
    - Total Weighted Loss: 5.0068
🏆 New best model saved to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20250930_234041/best_model_epoch_4_loss_5.0068.pth

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a bowl of soup on a table with two plates
  - Predicted:    the background of a dark gray screen with pink lines
Sample 2:
  - Ground Truth: a stop light and a sign that says no parking
  - Predicted:    the background of a grayish, purple and black color
Sample 3:
  - Ground Truth: a laptop computer and a mouse on a table
  - Predicted:    a grey and purple stripe on a screen
Sample 4:
  - Ground Truth: a dog sitting in a car
  - Predicted:    a logo that says, 'The people's court'
Sample 5:
  - Ground Truth: a view of a highway with a sign that says "exit"
  - Predicted:    the text is "get a new life" and then the image changes
-----------------------------------

🔄 Epoch 5/40:   0%|          | 0/218 [00:00<?, ?it/s]


🔍 Running Validation for epoch 5


Validation Loop:   0%|          | 0/44 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.7874
    - Avg Semantic Loss:  0.8836
    - Total Weighted Loss: 4.9964
🏆 New best model saved to /content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment/run_20250930_234041/best_model_epoch_5_loss_4.9964.pth

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a group of people sitting around a table eating pizza
  - Predicted:    a red rectangle with a black border
Sample 2:
  - Ground Truth: a man standing in a group of sheep
  - Predicted:    a video of the logo for person
Sample 3:
  - Ground Truth: a small airplane on the ground
  - Predicted:    the logo of company on the bottom
Sample 4:
  - Ground Truth: a church with a clock tower and a path
  - Predicted:    a logo with a purple and blue background
Sample 5:
  - Ground Truth: a bear is standing in the water
  - Predicted:    a red and white stripe is visible on the side of a video
------------------------------------

🖼️  Generating reconstruction grid for epoch

🔄 Epoch 6/40:   0%|          | 0/218 [00:00<?, ?it/s]


🔍 Running Validation for epoch 6


Validation Loop:   0%|          | 0/44 [00:00<?, ?it/s]

  Validation Results:
    - Avg Alignment Loss: 2.7855
    - Avg Semantic Loss:  0.8981
    - Total Weighted Loss: 5.0306

--- ✍️  Caption Generation Samples ---
Sample 1:
  - Ground Truth: a pair of green scissors
  - Predicted:    the background of a black and white screen, with the image showing an animated red diagonal line
Sample 2:
  - Ground Truth: a green and white fire hydrant
  - Predicted:    a screenshot of a computer screen with an animated background
Sample 3:
  - Ground Truth: a bed with a white sheet
  - Predicted:    the video is shot with a red and white background
Sample 4:
  - Ground Truth: a double decker bus driving down a street
  - Predicted:    a pink and grey line on the background
Sample 5:
  - Ground Truth: a bear is walking through a field of green bushes
  - Predicted:    a purple and orange striped background
------------------------------------

🖼️  Generating reconstruction grid for epoch 6...
  Saved reconstruction grid to /content/drive/MyDrive/NeuroV

🔄 Epoch 7/40:   0%|          | 0/218 [00:00<?, ?it/s]

# Evaluations

In [None]:
import torch
import torch.nn.functional as F
from pathlib import Path
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.transforms import v2 as transforms
from torch.utils.data import DataLoader, SubsetRandomSampler
import numpy as np
from tqdm.auto import tqdm
import json
from transformers import Blip2Processor
from diffusers import StableDiffusionPipeline
import clip
from sentence_transformers import SentenceTransformer

class ModelEvaluator:
    """Comprehensive evaluation of trained EEG-to-Caption model"""

    def __init__(self, model_path, config, device):
        self.device = device
        self.config = config
        self.model_path = model_path

        # Load processor
        self.processor = Blip2Processor.from_pretrained(config.BLIP_MODEL_NAME)

        # Load trained model
        self.model = self._load_model()

        # Initialize evaluation tools
        self.sentence_model = SentenceTransformer('all-MiniLM-L6-v2').to(device)
        self.clip_model, self.clip_preprocess = clip.load(config.CLIP_MODEL_NAME, device=device)

        # Initialize Stable Diffusion
        print("Loading Stable Diffusion pipeline...")
        self.diffusion_pipe = StableDiffusionPipeline.from_pretrained(
            config.DIFFUSION_MODEL_ID,
            torch_dtype=torch.float16 if device == "cuda" else torch.float32,
            safety_checker=None,
            requires_safety_checker=False
        )
        self.diffusion_pipe = self.diffusion_pipe.to(device)
        self.diffusion_pipe.set_progress_bar_config(disable=True)

        # Initialize ground truth caption generator
        self.caption_blip = Blip2ForConditionalGeneration.from_pretrained(config.BLIP_MODEL_NAME)
        self.caption_blip.to(device)
        self.caption_blip.eval()

        print("✅ Model evaluator initialized successfully!")

    def _load_model(self):
        """Load the trained model"""
        # Create model
        model = EnhancedEEG_BLIP2_Model(self.config, self.device)

        # Load state dict
        checkpoint = torch.load(self.model_path, map_location=self.device)
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
            print(f"✅ Loaded model from epoch {checkpoint.get('epoch', 'unknown')}")
        else:
            model.load_state_dict(checkpoint)

        model.eval()
        return model

    def generate_captions_from_eeg(self, spectrograms, generation_params=None):
        """Generate captions from EEG spectrograms"""
        if generation_params is None:
            generation_params = self.config.GENERATION_PARAMS

        with torch.no_grad():
            generated_ids = self.model.generate(spectrograms, **generation_params)
            captions = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
            return [cap.strip() for cap in captions]

    def generate_ground_truth_captions(self, images):
        """Generate ground truth captions from images"""
        with torch.no_grad():
            inputs = self.processor(images=images, return_tensors="pt").to(self.device)
            generated_ids = self.caption_blip.generate(
                **inputs,
                max_length=50,
                num_beams=3,
                temperature=0.8,
                do_sample=True
            )
            captions = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
            return [cap.strip() for cap in captions]

    def compute_semantic_similarity(self, captions1, captions2):
        """Compute semantic similarity between two sets of captions"""
        embeddings1 = self.sentence_model.encode(captions1, convert_to_tensor=True)
        embeddings2 = self.sentence_model.encode(captions2, convert_to_tensor=True)

        similarities = F.cosine_similarity(embeddings1, embeddings2, dim=-1)
        return similarities.cpu().numpy()

    def compute_clip_score(self, images, captions):
        """Compute CLIP score between images and captions"""
        # Preprocess images
        if isinstance(images[0], Image.Image):
            image_inputs = torch.stack([self.clip_preprocess(img) for img in images]).to(self.device)
        else:
            image_inputs = images

        # Tokenize texts
        text_inputs = clip.tokenize(captions, truncate=True).to(self.device)

        with torch.no_grad():
            image_features = self.clip_model.encode_image(image_inputs)
            text_features = self.clip_model.encode_text(text_inputs)

            # Normalize features
            image_features = F.normalize(image_features, dim=-1)
            text_features = F.normalize(text_features, dim=-1)

            # Compute similarities
            similarities = torch.sum(image_features * text_features, dim=-1)
            return similarities.cpu().numpy()

    def generate_image_from_caption(self, caption, seed=42):
        """Generate image from caption using Stable Diffusion"""
        try:
            generator = torch.Generator(device=self.device).manual_seed(seed)
            with torch.no_grad():
                result = self.diffusion_pipe(
                    prompt=caption,
                    generator=generator,
                    num_inference_steps=25,
                    guidance_scale=7.5,
                    height=512,
                    width=512
                )
                return result.images[0]
        except Exception as e:
            print(f"Error generating image: {e}")
            return Image.new('RGB', (512, 512), color='black')

    def compute_image_reconstruction_score(self, original_images, eeg_captions):
        """Compute reconstruction quality by generating images from EEG captions"""
        reconstruction_scores = []

        print("🖼️  Evaluating image reconstruction quality...")
        for i, (original_image, caption) in enumerate(zip(original_images, eeg_captions)):
            # Generate image from EEG caption
            generated_image = self.generate_image_from_caption(caption, seed=42+i)

            # Compute CLIP similarity between original and generated images
            orig_tensor = self.clip_preprocess(original_image).unsqueeze(0).to(self.device)
            gen_tensor = self.clip_preprocess(generated_image).unsqueeze(0).to(self.device)

            with torch.no_grad():
                orig_features = self.clip_model.encode_image(orig_tensor)
                gen_features = self.clip_model.encode_image(gen_tensor)

                orig_features = F.normalize(orig_features, dim=-1)
                gen_features = F.normalize(gen_features, dim=-1)

                similarity = F.cosine_similarity(orig_features, gen_features, dim=-1)
                reconstruction_scores.append(similarity.item())

        return reconstruction_scores

    def comprehensive_evaluation(self, test_loader, num_samples=100, save_dir=None):
        """Run comprehensive evaluation on test data"""
        self.model.eval()

        if save_dir:
            save_dir = Path(save_dir)
            save_dir.mkdir(parents=True, exist_ok=True)

        all_results = {
            'semantic_similarities': [],
            'clip_scores_eeg': [],
            'clip_scores_gt': [],
            'reconstruction_scores': [],
            'sample_comparisons': []
        }

        sample_count = 0

        print(f"🔬 Running comprehensive evaluation on {num_samples} samples...")

        with torch.no_grad():
            for batch_idx, (spectrograms, pil_images, gt_captions) in enumerate(tqdm(test_loader)):
                if sample_count >= num_samples:
                    break

                spectrograms = spectrograms.to(self.device)
                batch_size = len(spectrograms)

                # Generate EEG captions
                eeg_captions = self.generate_captions_from_eeg(spectrograms)

                # Generate ground truth captions (if not provided)
                if not gt_captions or all(cap == "" for cap in gt_captions):
                    gt_captions = self.generate_ground_truth_captions(pil_images)

                # Compute semantic similarities
                semantic_sims = self.compute_semantic_similarity(eeg_captions, gt_captions)
                all_results['semantic_similarities'].extend(semantic_sims)

                # Compute CLIP scores
                clip_scores_eeg = self.compute_clip_score(pil_images, eeg_captions)
                clip_scores_gt = self.compute_clip_score(pil_images, gt_captions)
                all_results['clip_scores_eeg'].extend(clip_scores_eeg)
                all_results['clip_scores_gt'].extend(clip_scores_gt)

                # Compute reconstruction scores (only for first few samples due to time)
                if sample_count < 20:
                    reconstruction_scores = self.compute_image_reconstruction_score(
                        pil_images[:min(4, batch_size)],
                        eeg_captions[:min(4, batch_size)]
                    )
                    all_results['reconstruction_scores'].extend(reconstruction_scores)

                    # Save visual examples
                    if save_dir and sample_count < 8:
                        self._save_evaluation_examples(
                            pil_images[:min(2, batch_size)],
                            eeg_captions[:min(2, batch_size)],
                            gt_captions[:min(2, batch_size)],
                            semantic_sims[:min(2, batch_size)],
                            save_dir,
                            sample_count
                        )

                # Store sample comparisons
                for i in range(min(batch_size, num_samples - sample_count)):
                    if len(all_results['sample_comparisons']) < 50:  # Store up to 50 examples
                        all_results['sample_comparisons'].append({
                            'eeg_caption': eeg_captions[i],
                            'gt_caption': gt_captions[i],
                            'semantic_similarity': semantic_sims[i],
                            'clip_score_eeg': clip_scores_eeg[i],
                            'clip_score_gt': clip_scores_gt[i]
                        })

                sample_count += batch_size

        # Compute final statistics
        results_summary = self._compute_evaluation_statistics(all_results)

        # Save detailed results
        if save_dir:
            with open(save_dir / 'evaluation_results.json', 'w') as f:
                json.dump({
                    'summary': results_summary,
                    'detailed_results': all_results
                }, f, indent=2, default=lambda x: x.tolist() if isinstance(x, np.ndarray) else x)

        return results_summary, all_results

    def _compute_evaluation_statistics(self, results):
        """Compute evaluation statistics"""
        semantic_sims = np.array(results['semantic_similarities'])
        clip_eeg = np.array(results['clip_scores_eeg'])
        clip_gt = np.array(results['clip_scores_gt'])
        reconstruction = np.array(results['reconstruction_scores']) if results['reconstruction_scores'] else np.array([])

        stats = {
            'semantic_similarity': {
                'mean': float(np.mean(semantic_sims)),
                'std': float(np.std(semantic_sims)),
                'median': float(np.median(semantic_sims)),
                'min': float(np.min(semantic_sims)),
                'max': float(np.max(semantic_sims)),
                'high_quality_ratio': float(np.mean(semantic_sims > 0.7))
            },
            'clip_score_eeg_captions': {
                'mean': float(np.mean(clip_eeg)),
                'std': float(np.std(clip_eeg)),
                'median': float(np.median(clip_eeg)),
                'min': float(np.min(clip_eeg)),
                'max': float(np.max(clip_eeg))
            },
            'clip_score_gt_captions': {
                'mean': float(np.mean(clip_gt)),
                'std': float(np.std(clip_gt)),
                'median': float(np.median(clip_gt)),
                'min': float(np.min(clip_gt)),
                'max': float(np.max(clip_gt))
            }
        }

        if len(reconstruction) > 0:
            stats['reconstruction_quality'] = {
                'mean': float(np.mean(reconstruction)),
                'std': float(np.std(reconstruction)),
                'median': float(np.median(reconstruction)),
                'min': float(np.min(reconstruction)),
                'max': float(np.max(reconstruction))
            }

        # Compute relative performance
        clip_ratio = stats['clip_score_eeg_captions']['mean'] / stats['clip_score_gt_captions']['mean']
        stats['clip_performance_ratio'] = float(clip_ratio)

        return stats

    def _save_evaluation_examples(self, images, eeg_captions, gt_captions, similarities, save_dir, start_idx):
        """Save visual evaluation examples"""
        for i, (image, eeg_cap, gt_cap, sim) in enumerate(zip(images, eeg_captions, gt_captions, similarities)):
            fig, axes = plt.subplots(1, 3, figsize=(15, 5))

            # Original image
            axes[0].imshow(image)
            axes[0].set_title("Original Image")
            axes[0].axis('off')

            # Generated image from EEG caption
            generated_image = self.generate_image_from_caption(eeg_cap, seed=42 + start_idx + i)
            axes[1].imshow(generated_image)
            axes[1].set_title(f"Generated from EEG\n(Sim: {sim:.3f})")
            axes[1].axis('off')

            # Generated image from GT caption
            gt_generated_image = self.generate_image_from_caption(gt_cap, seed=100 + start_idx + i)
            axes[2].imshow(gt_generated_image)
            axes[2].set_title("Generated from GT")
            axes[2].axis('off')

            # Add captions as text
            fig.suptitle(f"EEG: {eeg_cap}\nGT: {gt_cap}", fontsize=10, y=0.02, wrap=True)

            plt.tight_layout()
            plt.savefig(save_dir / f'example_{start_idx + i}.png', dpi=150, bbox_inches='tight')
            plt.close()

    def print_evaluation_report(self, results_summary):
        """Print a comprehensive evaluation report"""
        print("\n" + "="*80)
        print("🔬 COMPREHENSIVE EVALUATION REPORT")
        print("="*80)

        sem_sim = results_summary['semantic_similarity']
        print(f"\n📝 SEMANTIC SIMILARITY (EEG vs Ground Truth Captions)")
        print(f"   Mean: {sem_sim['mean']:.4f} ± {sem_sim['std']:.4f}")
        print(f"   Median: {sem_sim['median']:.4f}")
        print(f"   Range: [{sem_sim['min']:.4f}, {sem_sim['max']:.4f}]")
        print(f"   High Quality Ratio (>0.7): {sem_sim['high_quality_ratio']:.2%}")

        clip_eeg = results_summary['clip_score_eeg_captions']
        clip_gt = results_summary['clip_score_gt_captions']
        print(f"\n🖼️  CLIP SCORES (Image-Text Alignment)")
        print(f"   EEG Captions:  {clip_eeg['mean']:.4f} ± {clip_eeg['std']:.4f}")
        print(f"   GT Captions:   {clip_gt['mean']:.4f} ± {clip_gt['std']:.4f}")
        print(f"   Relative Performance: {results_summary['clip_performance_ratio']:.2%}")

        if 'reconstruction_quality' in results_summary:
            recon = results_summary['reconstruction_quality']
            print(f"\n🎨 IMAGE RECONSTRUCTION QUALITY")
            print(f"   Mean: {recon['mean']:.4f} ± {recon['std']:.4f}")
            print(f"   Median: {recon['median']:.4f}")
            print(f"   Range: [{recon['min']:.4f}, {recon['max']:.4f}]")

        # Performance assessment
        print(f"\n🏆 OVERALL ASSESSMENT")
        if sem_sim['mean'] > 0.7:
            print("   Semantic Alignment: EXCELLENT ✅")
        elif sem_sim['mean'] > 0.5:
            print("   Semantic Alignment: GOOD 👍")
        else:
            print("   Semantic Alignment: NEEDS IMPROVEMENT 🔄")

        if results_summary['clip_performance_ratio'] > 0.8:
            print("   Image-Text Alignment: EXCELLENT ✅")
        elif results_summary['clip_performance_ratio'] > 0.6:
            print("   Image-Text Alignment: GOOD 👍")
        else:
            print("   Image-Text Alignment: NEEDS IMPROVEMENT 🔄")

        print("="*80)

# ==============================================================================
# --- EVALUATION SCRIPT USAGE ---
# ==============================================================================
def run_evaluation(model_path, config, test_data_path, output_dir=None, num_samples=200):
    """Main evaluation function"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize evaluator
    evaluator = ModelEvaluator(model_path, config, device)

    # Prepare test data
    spec_mean = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_mean.pt')
    spec_std = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_std.pt')

    transform = transforms.Compose([
        transforms.ToDtype(torch.float32),
        transforms.Resize((224, 224), antialias=True),
        transforms.Normalize(mean=spec_mean.tolist(), std=spec_std.tolist()),
    ])

    # Create test dataset
    test_dataset = EEGDatasetWithCaptions(
        config.PROCESSED_DATA_ROOT,
        config.METADATA_CSV,
        'test',  # or 'val' if no test split
        transform,
        evaluator.processor,
        evaluator.caption_blip,
        device
    )

    # Create test loader
    test_indices = np.random.choice(len(test_dataset), min(len(test_dataset), num_samples), replace=False)
    test_loader = DataLoader(
        test_dataset,
        batch_size=8,
        sampler=SubsetRandomSampler(test_indices),
        num_workers=2
    )

    # Run evaluation
    print(f"🚀 Starting evaluation on {len(test_indices)} samples...")
    results_summary, detailed_results = evaluator.comprehensive_evaluation(
        test_loader,
        num_samples=num_samples,
        save_dir=output_dir
    )

    # Print report
    evaluator.print_evaluation_report(results_summary)

    return evaluator, results_summary, detailed_results

# Example usage
if __name__ == "__main__":
    from complete_training_pipeline import ENHANCED_TRAIN_CONFIG

    # Configuration
    config = ENHANCED_TRAIN_CONFIG()
    model_path = "/path/to/your/best_model.pth"  # Update this path
    output_dir = "/path/to/evaluation/results"   # Update this path

    # Run evaluation
    evaluator, results, detailed = run_evaluation(
        model_path=model_path,
        config=config,
        test_data_path=config.PROCESSED_DATA_ROOT,
        output_dir=output_dir,
        num_samples=100
    )

ModuleNotFoundError: No module named 'complete_training_pipeline'

In [None]:
import torch
from diffusers import StableDiffusionPipeline
from PIL import Image

# --- 1. Set up the model and pipeline ---
# This will download the model from the Hugging Face Hub the first time you run it.
model_id = "runwayml/stable-diffusion-v1-5"
# Use torch.float16 for faster inference and less memory on NVIDIA GPUs.
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)

# Move the pipeline to the GPU for much faster generation.
# If you don't have a GPU, you can remove this line, but it will be very slow.
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = pipe.to(device)
print(f"Using device: {device}")

# --- 2. Define your text prompt ---
# This is where you'll put the caption you generated.
prompt = "the small brown dog next to the chair is a dog house a small dog kenn"

# --- 3. (Optional) Set a seed for reproducibility ---
# Using a seed ensures you get the same image every time for the same prompt.
generator = torch.Generator(device=device).manual_seed(42)

# --- 4. Generate the image ---
# The pipeline returns an object with the generated image(s).
print("Generating image...")
output = pipe(prompt=prompt, generator=generator)
image = output.images[0] # The generated image is a PIL Image object

# --- 5. Save the image ---
output_path = "generated_image.png"
image.save(output_path)
print(f"Image saved successfully to {output_path}")

# You can also display the image if you are in a Jupyter Notebook or similar environment.
# display(image)

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Using device: cuda
Generating image...


  0%|          | 0/50 [00:00<?, ?it/s]

Image saved successfully to generated_image.png


# Old scripts

In [None]:
import torch
import torch.nn as nn
import pandas as pd
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision.transforms import v2 as transforms
from PIL import Image
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.auto import tqdm
from datetime import datetime
import numpy as np
from transformers import Blip2Processor, Blip2ForConditionalGeneration, Blip2VisionModel
import warnings

# ==============================================================================
# --- 1. CONFIGURATION ---
# ==============================================================================
class TRAIN_CONFIG:
    PROCESSED_DATA_ROOT = '/content/final_lightweight_17k'
    METADATA_CSV = Path(PROCESSED_DATA_ROOT) / 'metadata.csv'
    OUTPUT_DIR = '/content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment'

    BLIP_MODEL_NAME = "Salesforce/blip2-opt-2.7b"

    BATCH_SIZE = 2
    GRAD_ACCUMULATION_STEPS = 16
    NUM_EPOCHS = 50
    LR = 5e-5

    WEIGHT_DECAY = 5e-3
    GRADIENT_CLIP_NORM = 1.0

    VALIDATION_INTERVAL = 1
    VISUALIZATION_INTERVAL = 1 # Note: Visualization happens with validation
    VAL_SAMPLES_PER_EPOCH = 700
    TRAIN_SAMPLES_PER_EPOCH = 3500
    VIS_GRID_SIZE = 4

    ALPHA_ALIGN = 1.0      # weight for cosine alignment
    BETA_CE     = 5.0      # weight for cross-entropy (strong)
    GAMMA_EOS   = 10.0

# ==============================================================================
# --- 2. EEG ENCODER MODULE ---
# ==============================================================================

class ChannelAdapter(nn.Module):
    """Lightweight adapter to turn 64-channel spectrogram -> 3-channel-like input for BLIP vision."""
    def __init__(self, in_chans=64, mid_ch=64, out_ch=3):
        super().__init__()
        self.conv1 = nn.Conv2d(in_chans, mid_ch, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(mid_ch)
        self.act = nn.GELU()
        self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1, bias=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act(x)
        x = self.conv2(x)
        return x

class EEGEncoder(nn.Module):
    def __init__(self, config, in_chans=64):
        super().__init__()
        self.vision_model = Blip2VisionModel.from_pretrained(config.BLIP_MODEL_NAME)

        # Freeze the original vision model
        for param in self.vision_model.parameters():
            param.requires_grad = False

        self.adapter = ChannelAdapter(in_chans=in_chans, mid_ch=64, out_ch=3)

        for p in self.adapter.parameters():
            p.requires_grad = True


    def forward(self, x):
      adapted_x = self.adapter(x)
      vision_outputs = self.vision_model(adapted_x, return_dict=True)
      return vision_outputs.last_hidden_state

# ==============================================================================
# --- 3. THE HYBRID EEG-BLIP2 MODEL ---
# ==============================================================================
class EEG_BLIP2_Model(nn.Module):
    def __init__(self, config, device):
        super().__init__()
        self.device = device
        self.eeg_encoder = EEGEncoder(config)
        self.blip = Blip2ForConditionalGeneration.from_pretrained(
            config.BLIP_MODEL_NAME
        )

        # Freeze the entire BLIP model initially
        for param in self.blip.parameters():
            param.requires_grad = False

        print("Unfreezing specified layers for fine-tuning...")
        # Unfreeze the language projection layer (connects Q-Former to LLM)
        if hasattr(self.blip, "language_projection"):
            for p in self.blip.language_projection.parameters():
                p.requires_grad = True

            print("  - Unfroze Language Projection layer.")

        # Unfreeze the last layer(s) of the Q-Former for fine-tuning
        try:
            qenc_layers = self.blip.qformer.encoder.layer
            n_unfreeze_qformer = 2
            for idx in range(len(qenc_layers) - n_unfreeze_qformer, len(qenc_layers)):
                for p in qenc_layers[idx].parameters():
                    p.requires_grad = True
        except Exception:
            pass

        try:
            vision_layers = self.blip.vision_model.encoder.layers
            for p in vision_layers[-2].parameters():
                p.requires_grad = True
            print("  - Unfroze the last Vision Model layer.")
        except Exception as e:
            print(f"Could not unfreeze Vision Model layer: {e}")

        # 4. Unfreeze the last Language Model layer
        try:
            # Note: The exact path might differ slightly between transformers versions
            llm_layers = self.blip.language_model.model.decoder.layers
            for p in llm_layers[-2].parameters():
                p.requires_grad = True
            print("  - Unfroze the last Language Model layer.")
        except Exception as e:
            print(f"Could not unfreeze Language Model layer: {e}")


        self.eeg_encoder.to(device)
        self.blip.to(device)

    def get_eeg_embedding(self, eeg_spectrograms):
        eeg_features = self.eeg_encoder(eeg_spectrograms)
        query_tokens = self.blip.query_tokens.expand(eeg_features.shape[0], -1, -1)
        query_outputs = self.blip.qformer(
            query_embeds=query_tokens,
            encoder_hidden_states=eeg_features,
            return_dict=True,
        )
        return query_outputs.last_hidden_state.mean(dim=1)

    def get_image_embedding(self, pil_images, processor):
        inputs = processor(images=pil_images, return_tensors="pt").to(self.device)
        pixel_values = inputs.pixel_values

        image_features = self.blip.vision_model(pixel_values).last_hidden_state
        query_tokens = self.blip.query_tokens.expand(image_features.shape[0], -1, -1)
        query_outputs = self.blip.qformer(
            query_embeds=query_tokens,
            encoder_hidden_states=image_features,
            return_dict=True,
        )
        return query_outputs.last_hidden_state.mean(dim=1)

    def generate(self, eeg_spectrograms, **kwargs):
        with torch.no_grad():
            eeg_features = self.eeg_encoder(eeg_spectrograms)
            query_tokens = self.blip.query_tokens.expand(eeg_features.shape[0], -1, -1)
            query_outputs = self.blip.qformer(
                query_embeds=query_tokens,
                encoder_hidden_states=eeg_features,
                return_dict=True,
            )
            language_model_inputs = self.blip.language_projection(query_outputs.last_hidden_state)
            generated_ids = self.blip.language_model.generate(
                inputs_embeds=language_model_inputs,
                **kwargs
            )
            return generated_ids

# ==============================================================================
# --- 4. DATASET & COLLATING ---
# ==============================================================================
def collate_fn(batch):
    spectrograms = torch.stack([item[0] for item in batch])
    pil_images = [item[1] for item in batch]
    return spectrograms, pil_images

class EEGDataset(Dataset):
    def __init__(self, root_dir, metadata_csv, split, transform):
        self.root_dir = Path(root_dir)
        self.transform = transform
        df = pd.read_csv(metadata_csv)
        self.split_df = df[df['split'].str.strip() == split].reset_index(drop=True)

    def __len__(self): return len(self.split_df)

    def __getitem__(self, idx):
        info = self.split_df.iloc[idx]
        spectrogram = self.transform(torch.load(self.root_dir / info['spectrogram_path']))
        image = Image.open(self.root_dir / info['image_path']).convert("RGB")
        return spectrogram, image

# ==============================================================================
# --- 5. MAIN TRAINING FUNCTION ---
# ==============================================================================
# InfoNCE loss function
def info_nce_loss(query, positive_key, negative_keys=None, temperature=0.07):
    # Ensure inputs are normalized
    query = nn.functional.normalize(query, dim=-1)
    positive_key = nn.functional.normalize(positive_key, dim=-1)

    # Calculate positive logits
    positive_logits = torch.sum(query * positive_key, dim=-1, keepdim=True)

    # Concatenate all other items in the batch as negative keys
    if negative_keys is None:
        negative_keys = positive_key

    negative_logits = query @ negative_keys.T

    # Mask out self-comparison for the InfoNCE formula
    logits = torch.cat([positive_logits, negative_logits], dim=1)

    # Create labels for cross-entropy: first element is the positive pair
    labels = torch.zeros(len(query), dtype=torch.long, device=query.device)

    return nn.functional.cross_entropy(logits / temperature, labels)

#------------------------------------------------------------------------------------------
def train_model(config):
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    output_dir = Path(config.OUTPUT_DIR) / f'run_{timestamp}'
    output_dir.mkdir(parents=True, exist_ok=True)

    print("🚀 Starting EEG-to-Text Training with Embedding Alignment Loss")

    processor = Blip2Processor.from_pretrained(config.BLIP_MODEL_NAME)
    student_model = EEG_BLIP2_Model(config, device)

    spec_mean = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_mean.pt')
    spec_std = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_std.pt')
    transform = transforms.Compose([
        transforms.ToDtype(torch.float32),
        transforms.Resize((224, 224), antialias=True),
        transforms.Normalize(mean=spec_mean.tolist(), std=spec_std.tolist()),
    ])
    train_dataset = EEGDataset(config.PROCESSED_DATA_ROOT, config.METADATA_CSV, 'train', transform)
    val_dataset = EEGDataset(config.PROCESSED_DATA_ROOT, config.METADATA_CSV, 'val', transform)

    val_indices = np.random.choice(len(val_dataset), min(len(val_dataset), config.VAL_SAMPLES_PER_EPOCH), replace=False)
    val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, sampler=SubsetRandomSampler(val_indices), collate_fn=collate_fn)

    fixed_spectrograms, fixed_pil_images = next(iter(val_loader))
    fixed_spectrograms = fixed_spectrograms.to(device)

    # OPTIMIZED: Only pass trainable parameters to the optimizer
    trainable_params = filter(lambda p: p.requires_grad, student_model.parameters())
    optimizer = optim.AdamW(trainable_params, lr=config.LR, weight_decay=config.WEIGHT_DECAY)

    # scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=1, eta_min=1e-7)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=2, min_lr=1e-7)
    scaler = torch.cuda.amp.GradScaler()
    loss_fn = nn.CosineEmbeddingLoss()
    best_val_loss = float('inf')

    print("🎯 Starting Training Loop")
    for epoch in range(config.NUM_EPOCHS):
        student_model.train()
        train_indices = np.random.choice(len(train_dataset), min(len(train_dataset), config.TRAIN_SAMPLES_PER_EPOCH), replace=False)
        train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, sampler=SubsetRandomSampler(train_indices), drop_last=True, collate_fn=collate_fn)
        train_bar = tqdm(train_loader, desc=f"🔄 Epoch {epoch+1}/{config.NUM_EPOCHS} [TRAIN]")

        optimizer.zero_grad()
        for batch_idx, (spectrograms, pil_images) in enumerate(train_bar):
            spectrograms = spectrograms.to(device)

            with torch.amp.autocast('cuda'):
                eeg_embedding = student_model.get_eeg_embedding(spectrograms)

                with torch.no_grad():
                    image_embedding = student_model.get_image_embedding(pil_images, processor)

                target = torch.ones(eeg_embedding.size(0)).to(device)
                loss_align = info_nce_loss(eeg_embedding, image_embedding, temperature=0.07)
                cse_loss = loss_fn(eeg_embedding, image_embedding, target)

                # CORRECTED: Was `loss = loss / ...` but `loss` wasn't defined. Use `loss_align`.
                loss = (2.0 * loss_align + cse_loss) / config.GRAD_ACCUMULATION_STEPS

            scaler.scale(loss).backward()

            if (batch_idx + 1) % config.GRAD_ACCUMULATION_STEPS == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(student_model.parameters(), config.GRADIENT_CLIP_NORM)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            # IMPROVED: Show the actual alignment loss, not the scaled version
            train_bar.set_postfix(align_loss=f"{loss_align.item():.4f}")

        if (epoch + 1) % config.VALIDATION_INTERVAL == 0:
            print(f"\n🔍 Running Validation & Visualization for epoch {epoch+1}")
            student_model.eval()
            val_loss = 0

            with torch.no_grad(), torch.amp.autocast('cuda'):
                for spectrograms, pil_images in val_loader:
                    spectrograms = spectrograms.to(device)
                    eeg_embedding = student_model.get_eeg_embedding(spectrograms)
                    image_embedding = student_model.get_image_embedding(pil_images, processor)
                    target = torch.ones(eeg_embedding.size(0)).to(device)
                    cse_loss = loss_fn(eeg_embedding, image_embedding, target).item()
                    loss_align = info_nce_loss(eeg_embedding, image_embedding, temperature=0.07)
                    loss = 2.0 * loss_align + cse_loss
                    val_loss += loss

            avg_val_loss = val_loss / len(val_loader)
            print(f"  Validation Loss: {avg_val_loss:.4f}")
            scheduler.step(avg_val_loss)

            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                # Save only the trainable part of the model: the EEG encoder and fine-tuned layers
                torch.save(student_model.state_dict(), output_dir / 'best_model.pth')
                print(f"🏆 New best model saved with validation loss: {best_val_loss:.4f}")

            with torch.no_grad(), torch.amp.autocast('cuda'):
                generated_ids = student_model.generate(
                    fixed_spectrograms[:config.VIS_GRID_SIZE],
                    do_sample=True, temperature=1.0, top_p=0.9,
                    min_new_tokens=8, max_new_tokens=32, repetition_penalty=1.5
                    # num_beams=4
                )
                predicted_captions = processor.batch_decode(generated_ids, skip_special_tokens=True)

                # Get ground truth captions
                # CORRECTED: Removed explicit .to(dtype=torch.float16) for robustness
                gt_pixel_values = processor(images=fixed_pil_images, return_tensors="pt").to(device).pixel_values
                gt_captions_ids = student_model.blip.generate(gt_pixel_values)
                gt_captions = processor.batch_decode(gt_captions_ids, skip_special_tokens=True)


                print("\n--- Caption Generation Samples ---")
                for i in range(len(predicted_captions)):
                    print(f"Sample {i+1}:")
                    print(f"  - Ground Truth: {gt_captions[i].strip()}")
                    print(f"  - Predicted:    {predicted_captions[i].strip()}")
                print("------------------------------------")

    print("\n🎉 Training Complete!")
    return output_dir

if __name__ == '__main__':
    config = TRAIN_CONFIG()
    try:
        output_dir = train_model(config)
        if output_dir:
            print(f"✅ Training completed successfully! Results saved to: {output_dir}")
    except Exception as e:
        print(f"❌ Training failed with error: {str(e)}")
        import traceback
        traceback.print_exc()

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


🚀 Starting EEG-to-Text Training with Embedding Alignment Loss


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of Blip2VisionModel were not initialized from the model checkpoint at Salesforce/blip2-opt-2.7b and are newly initialized: ['embeddings.class_embedding', 'embeddings.patch_embedding.bias', 'embeddings.patch_embedding.weight', 'embeddings.position_embedding', 'encoder.layers.0.layer_norm1.bias', 'encoder.layers.0.layer_norm1.weight', 'encoder.layers.0.layer_norm2.bias', 'encoder.layers.0.layer_norm2.weight', 'encoder.layers.0.mlp.fc1.bias', 'encoder.layers.0.mlp.fc1.weight', 'encoder.layers.0.mlp.fc2.bias', 'encoder.layers.0.mlp.fc2.weight', 'encoder.layers.0.self_attn.projection.bias', 'encoder.layers.0.self_attn.projection.weight', 'encoder.layers.0.self_attn.qkv.bias', 'encoder.layers.0.self_attn.qkv.weight', 'encoder.layers.1.layer_norm1.bias', 'encoder.layers.1.layer_norm1.weight', 'encoder.layers.1.layer_norm2.bias', 'encoder.layers.1.layer_norm2.weight', 'encoder.layers.1.mlp.fc1.bias', 'encoder.layers.1.mlp.fc1.weight', 'encoder.layers.1.mlp.fc2.bias', 'encoder.laye

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Unfreezing specified layers for fine-tuning...
  - Unfroze Language Projection layer.
  - Unfroze the last Vision Model layer.
  - Unfroze the last Language Model layer.
🎯 Starting Training Loop


  scaler = torch.cuda.amp.GradScaler()


🔄 Epoch 1/50 [TRAIN]:   0%|          | 0/1750 [00:00<?, ?it/s]


🔍 Running Validation & Visualization for epoch 1
  Validation Loss: 2.2821
🏆 New best model saved with validation loss: 2.2821

--- Caption Generation Samples ---
Sample 1:
  - Ground Truth: a man with a donut in his mouth
  - Predicted:    david taylor in the desert. all rights reserved 2012 dalton brown
Sample 2:
  - Ground Truth: a pair of scissors
  - Predicted:    color of the day with a blue door
------------------------------------


🔄 Epoch 2/50 [TRAIN]:   0%|          | 0/1750 [00:00<?, ?it/s]

KeyboardInterrupt: 

# Dump

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision.transforms import v2 as transforms
from PIL import Image
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.auto import tqdm
from datetime import datetime
import numpy as np
import json
import clip
from sentence_transformers import SentenceTransformer
from diffusers import StableDiffusionPipeline
from transformers import Blip2Processor, Blip2ForConditionalGeneration, Blip2VisionModel

# ==============================================================================
# --- 1. CONFIGURATION (MERGED) ---
# ==============================================================================
class ENHANCED_CONFIG:
    # --- Paths and Models ---
    PROCESSED_DATA_ROOT = '/content/final_lightweight_17k'
    METADATA_CSV = Path(PROCESSED_DATA_ROOT) / 'metadata.csv'
    OUTPUT_DIR = '/content/drive/MyDrive/NeuroVision/EEG_BLIP2_Alignment_Stable'
    BLIP_MODEL_NAME = "Salesforce/blip2-opt-2.7b"
    CLIP_MODEL_NAME = "ViT-B/32"
    DIFFUSION_MODEL_ID = "runwayml/stable-diffusion-v1-5"

    # --- Stable Hyperparameters ---
    BATCH_SIZE = 32
    LR = 1e-5
    WEIGHT_DECAY = 5e-3
    GRADIENT_CLIP_NORM = 1.0
    GRAD_ACCUMULATION_STEPS = 1

    # --- Training Loop Settings ---
    NUM_EPOCHS = 50
    TRAIN_SAMPLES_PER_EPOCH = 3500
    VAL_SAMPLES_PER_EPOCH = 700

    # --- New Loss Weights ---
    EMBEDDING_WEIGHT = 1.0
    CAPTION_SIM_WEIGHT = 3.0
    CLIP_ALIGNMENT_WEIGHT = 0.25

    # --- Evaluation Settings ---
    EVAL_IMAGE_GENERATION_INTERVAL = 1
    EVAL_SAMPLES_TO_GENERATE = 3 # How many images to generate each evaluation

# ==============================================================================
# --- 2. STABLE EEG ENCODER MODULE (FROM OLD SCRIPT) ---
# ==============================================================================
class ChannelAdapter(nn.Module):
    def __init__(self, in_chans=64, mid_ch=64, out_ch=3):
        super().__init__()
        self.conv1 = nn.Conv2d(in_chans, mid_ch, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(mid_ch)
        self.act = nn.GELU()
        self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1, bias=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act(x)
        x = self.conv2(x)
        return x

class EEGEncoder(nn.Module):
    """
    A simplified, stable encoder that ONLY contains the adapter.
    It receives the vision model during its forward pass.
    """
    def __init__(self, in_chans=64):
        super().__init__()
        self.adapter = ChannelAdapter(in_chans=in_chans)

    def forward(self, x, vision_model):
      # Takes the main vision model as an argument
      adapted_x = self.adapter(x)
      vision_outputs = vision_model(adapted_x, return_dict=True)
      return vision_outputs.last_hidden_state

# ==============================================================================
# --- 3. STABLE HYBRID EEG-BLIP2 MODEL (FROM OLD SCRIPT) ---
# ==============================================================================
class EEG_BLIP2_Model(nn.Module):
    def __init__(self, config, device):
        super().__init__()
        self.device = device
        self.eeg_encoder = EEGEncoder() # Now much simpler
        self.blip = Blip2ForConditionalGeneration.from_pretrained(config.BLIP_MODEL_NAME)

        # Freeze the entire BLIP model initially
        for param in self.blip.parameters():
            param.requires_grad = False

        print("Unfreezing specified layers for fine-tuning...")
        # Unfreeze the language projection layer (connects Q-Former to LLM)
        if hasattr(self.blip, "language_projection"):
            for p in self.blip.language_projection.parameters():
                p.requires_grad = True

            print("  - Unfroze Language Projection layer.")

        # Unfreeze the last layer(s) of the Q-Former for fine-tuning
        try:
            qenc_layers = self.blip.qformer.encoder.layer
            n_unfreeze_qformer = 4
            for idx in range(len(qenc_layers) - n_unfreeze_qformer, len(qenc_layers)):
                for p in qenc_layers[idx].parameters():
                    p.requires_grad = True
        except Exception:
            pass

        try:
            vision_layers = self.blip.vision_model.encoder.layers
            for p in vision_layers[-4].parameters():
                p.requires_grad = True
            print("  - Unfroze the last Vision Model layer.")
        except Exception as e:
            print(f"Could not unfreeze Vision Model layer: {e}")

        # 4. Unfreeze the last Language Model layer
        try:
            # Note: The exact path might differ slightly between transformers versions
            llm_layers = self.blip.language_model.model.decoder.layers
            for p in llm_layers[-4].parameters():
                p.requires_grad = True
            print("  - Unfroze the last Language Model layer.")
        except Exception as e:
            print(f"Could not unfreeze Language Model layer: {e}")

        self.eeg_encoder.to(device)
        self.blip.to(device)

    def get_eeg_embedding(self, eeg_spectrograms):
        # Pass the single, authoritative vision model to the encoder
        eeg_features = self.eeg_encoder(eeg_spectrograms, self.blip.vision_model)

        query_tokens = self.blip.query_tokens.expand(eeg_features.shape[0], -1, -1)
        query_outputs = self.blip.qformer(
            query_embeds=query_tokens, encoder_hidden_states=eeg_features, return_dict=True)
        return query_outputs.last_hidden_state.mean(dim=1)

    def get_image_embedding(self, pil_images, processor):
        inputs = processor(images=pil_images, return_tensors="pt").to(self.device)
        pixel_values = inputs.pixel_values

        # Use the single, authoritative vision model
        image_features = self.blip.vision_model(pixel_values).last_hidden_state
        query_tokens = self.blip.query_tokens.expand(image_features.shape[0], -1, -1)
        query_outputs = self.blip.qformer(
            query_embeds=query_tokens, encoder_hidden_states=image_features, return_dict=True)
        return query_outputs.last_hidden_state.mean(dim=1)

    def generate(self, eeg_spectrograms, **kwargs):
        with torch.no_grad():
            # Pass the single, authoritative vision model to the encoder
            eeg_features = self.eeg_encoder(eeg_spectrograms, self.blip.vision_model)

            query_tokens = self.blip.query_tokens.expand(eeg_features.shape[0], -1, -1)
            query_outputs = self.blip.qformer(
                query_embeds=query_tokens, encoder_hidden_states=eeg_features, return_dict=True)
            language_model_inputs = self.blip.language_projection(query_outputs.last_hidden_state)
            generated_ids = self.blip.language_model.generate(
                inputs_embeds=language_model_inputs, **kwargs)
            return generated_ids

# ==============================================================================
# --- 4. DATASET & HELPERS (MERGED) ---
# ==============================================================================
class EEGDatasetWithCaptions(Dataset):
    def __init__(self, root_dir, metadata_csv, split, transform):
        self.root_dir = Path(root_dir)
        self.transform = transform
        df = pd.read_csv(metadata_csv)
        self.split_df = df[df['split'].str.strip() == split].reset_index(drop=True)

        captions_path = Path(root_dir) / f'{split}_gt_captions.json'
        if not captions_path.exists():
            raise FileNotFoundError(f"Caption file not found: {captions_path}. Please run the caption generation script first.")
        print(f"Loading pre-generated captions from {captions_path}")
        with open(captions_path, 'r') as f:
            self.captions = json.load(f)

    def __len__(self):
        return len(self.split_df)

    def __getitem__(self, idx):
        info = self.split_df.iloc[idx]
        spectrogram = self.transform(torch.load(self.root_dir / info['spectrogram_path']))
        image = Image.open(self.root_dir / info['image_path']).convert("RGB")
        gt_caption = self.captions.get(str(idx), "an image")
        return spectrogram, image, gt_caption

def collate_fn(batch):
    spectrograms = torch.stack([item[0] for item in batch])
    pil_images = [item[1] for item in batch]
    captions = [item[2] for item in batch]
    return spectrograms, pil_images, captions

# ==============================================================================
# --- 5. ENHANCED LOSS FUNCTIONS (FROM NEW SCRIPT) ---
# ==============================================================================
class CLIPAlignmentLoss(nn.Module):
    def __init__(self, device, model_name="ViT-B/32"):
        super().__init__()
        self.device = device
        self.clip_model, self.clip_preprocess = clip.load(model_name, device=device)
        for param in self.clip_model.parameters():
            param.requires_grad = False

    def forward(self, images, texts):
        if isinstance(images[0], Image.Image):
            image_inputs = torch.stack([self.clip_preprocess(img) for img in images]).to(self.device)
        else:
            image_inputs = images.to(self.device)
        text_inputs = clip.tokenize(texts, truncate=True).to(self.device)
        with torch.no_grad():
            image_features = self.clip_model.encode_image(image_inputs)
            text_features = self.clip_model.encode_text(text_inputs)
        image_features = F.normalize(image_features, dim=-1)
        text_features = F.normalize(text_features, dim=-1)
        similarity = torch.sum(image_features * text_features, dim=-1)
        return -similarity.mean()

class ComprehensiveLoss(nn.Module):
    def __init__(self, device, config):
        super().__init__()
        self.device = device
        self.config = config
        self.sentence_model = SentenceTransformer('all-MiniLM-L6-v2').to(device)
        self.clip_loss = CLIPAlignmentLoss(device, config.CLIP_MODEL_NAME)
        self.cosine_loss = nn.CosineEmbeddingLoss()
        for param in self.sentence_model.parameters():
            param.requires_grad = False

    def info_nce_loss(self, query, positive_key, temperature=0.07):
        query = F.normalize(query, dim=-1)
        positive_key = F.normalize(positive_key, dim=-1)
        positive_logits = torch.sum(query * positive_key, dim=-1, keepdim=True)
        negative_logits = query @ positive_key.T
        logits = torch.cat([positive_logits, negative_logits], dim=1) / temperature
        labels = torch.zeros(len(query), dtype=torch.long, device=self.device)
        return F.cross_entropy(logits, labels)

    def forward(self, eeg_embeddings, image_embeddings, eeg_captions, gt_captions, pil_images):
        losses = {}
        eeg_embeddings = F.normalize(eeg_embeddings, dim=-1)
        image_embeddings = F.normalize(image_embeddings, dim=-1)

        losses['embedding_contrastive'] = self.info_nce_loss(eeg_embeddings, image_embeddings)

        with torch.no_grad():
            eeg_caption_embeddings = self.sentence_model.encode(eeg_captions, convert_to_tensor=True)
            gt_caption_embeddings = self.sentence_model.encode(gt_captions, convert_to_tensor=True)

        caption_sim = F.cosine_similarity(eeg_caption_embeddings, gt_caption_embeddings, dim=-1)
        losses['caption_similarity'] = 1.0 - caption_sim.mean()

        losses['clip_alignment'] = self.clip_loss(pil_images, eeg_captions)

        total_loss = (
            self.config.EMBEDDING_WEIGHT * losses['embedding_contrastive'] +
            self.config.CAPTION_SIM_WEIGHT * losses['caption_similarity'] +
            self.config.CLIP_ALIGNMENT_WEIGHT * losses['clip_alignment']
        )
        losses['total'] = total_loss
        return losses

# ==============================================================================
# --- 6. IMAGE EVALUATION (FROM NEW SCRIPT) ---
# ==============================================================================
class ImageGenerationEvaluator:
    def __init__(self, device, config):
        self.device = device
        self.config = config
        print("Loading Stable Diffusion pipeline...")
        self.diffusion_pipe = StableDiffusionPipeline.from_pretrained(
            config.DIFFUSION_MODEL_ID,
            torch_dtype=torch.float16 if device.type == "cuda" else torch.float32,
            safety_checker=None,
            requires_safety_checker=False
        ).to(device)
        self.diffusion_pipe.set_progress_bar_config(disable=True)
        self.eval_dir = Path(config.OUTPUT_DIR) / 'generated_images'
        self.eval_dir.mkdir(parents=True, exist_ok=True)

    def generate_and_save_samples(self, eeg_captions, original_images, epoch):
        print(f"\n🖼️  Generating {len(eeg_captions)} images for evaluation...")
        for i, (caption, original_image) in enumerate(zip(eeg_captions, original_images)):
            generator = torch.Generator(device=self.device).manual_seed(42 + i)
            generated_image = self.diffusion_pipe(
                prompt=caption, generator=generator, num_inference_steps=20
            ).images[0]

            comparison = Image.new('RGB', (1024, 512))
            comparison.paste(original_image.resize((512, 512)), (0, 0))
            comparison.paste(generated_image.resize((512, 512)), (512, 0))

            comparison_path = self.eval_dir / f'epoch_{epoch}_sample_{i}.png'
            comparison.save(comparison_path)
            print(f"  Saved sample {i+1} to {comparison_path}")

# ==============================================================================
# --- 7. MAIN TRAINING FUNCTION (MERGED AND CORRECTED) ---
# ==============================================================================
def train_model(config):
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    output_dir = Path(config.OUTPUT_DIR) / f'run_{timestamp}'
    output_dir.mkdir(parents=True, exist_ok=True)

    print("🚀 Starting Merged EEG-to-Text Training")

    processor = Blip2Processor.from_pretrained(config.BLIP_MODEL_NAME)
    model = EEG_BLIP2_Model(config, device)

    generation_args = {
    "max_new_tokens": 30,  # Generate up to 30 new words
    "min_length": 8,       # Generate at least 8 words
    "num_beams": 3,        # Use 3 beams for higher quality search
    "do_sample": False     # Use deterministic generation for loss calculation
    }

    comprehensive_loss = ComprehensiveLoss(device, config)
    image_evaluator = ImageGenerationEvaluator(device, config)

    spec_mean = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_mean.pt')
    spec_std = torch.load(Path(config.PROCESSED_DATA_ROOT) / 'spec_std.pt')
    spec_std += 1e-6

    transform = transforms.Compose([
        transforms.ToDtype(torch.float32),
        transforms.Resize((224, 224), antialias=True),
        transforms.Normalize(mean=spec_mean.tolist(), std=spec_std.tolist()),
    ])

    train_dataset = EEGDatasetWithCaptions(config.PROCESSED_DATA_ROOT, config.METADATA_CSV, 'train', transform)
    val_dataset = EEGDatasetWithCaptions(config.PROCESSED_DATA_ROOT, config.METADATA_CSV, 'val', transform)

    trainable_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.AdamW(trainable_params, lr=config.LR, weight_decay=config.WEIGHT_DECAY)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=2, min_lr=1e-7)
    scaler = torch.cuda.amp.GradScaler()
    best_val_loss = float('inf')
    simple_alignment_loss = nn.CosineEmbeddingLoss()

    print("🎯 Starting Training Loop")
    for epoch in range(config.NUM_EPOCHS):
        model.train()
        train_indices = np.random.choice(len(train_dataset), min(len(train_dataset), config.TRAIN_SAMPLES_PER_EPOCH), replace=False)
        train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, sampler=SubsetRandomSampler(train_indices), drop_last=True, collate_fn=collate_fn)
        train_bar = tqdm(train_loader, desc=f"🔄 Epoch {epoch+1}/{config.NUM_EPOCHS} [TRAIN]")

        for batch_idx, (spectrograms, pil_images, gt_captions) in enumerate(train_bar):
            spectrograms = spectrograms.to(device)

            with torch.amp.autocast('cuda'):
                eeg_embeddings = model.get_eeg_embedding(spectrograms)

                with torch.no_grad():
                    image_embeddings = model.get_image_embedding(pil_images, processor)

                if epoch < 10: # STAGE 1: Foundational Alignment
                    target = torch.ones(eeg_embeddings.size(0)).to(device)
                    loss = simple_alignment_loss(eeg_embeddings, image_embeddings, target)
                else: # STAGE 2: Full Comprehensive Loss
                    with torch.no_grad():
                        generated_ids = model.generate(spectrograms, **generation_args)
                        eeg_captions = processor.batch_decode(generated_ids, skip_special_tokens=True)

                    loss_dict = comprehensive_loss(eeg_embeddings, image_embeddings, eeg_captions, gt_captions, pil_images)
                    loss = loss_dict['total'] / config.GRAD_ACCUMULATION_STEPS

            scaler.scale(loss).backward()

            if (batch_idx + 1) % config.GRAD_ACCUMULATION_STEPS == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(
                    filter(lambda p: p.requires_grad, model.parameters()),
                    config.GRADIENT_CLIP_NORM
                )
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            train_bar.set_postfix(loss=loss.item() * config.GRAD_ACCUMULATION_STEPS)

        # --- Validation Loop ---
        print(f"\n🔍 Running Validation for epoch {epoch+1}")
        model.eval()
        val_losses = []
        eval_samples = [] # To store samples for image generation

        val_indices = np.random.choice(len(val_dataset), min(len(val_dataset), config.VAL_SAMPLES_PER_EPOCH), replace=False)
        val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, sampler=SubsetRandomSampler(val_indices), collate_fn=collate_fn)

        with torch.no_grad(), torch.amp.autocast('cuda'):
            for spectrograms, pil_images, gt_captions in val_loader:
                spectrograms = spectrograms.to(device)
                eeg_embeddings = model.get_eeg_embedding(spectrograms)
                image_embeddings = model.get_image_embedding(pil_images, processor)
                generated_ids = model.generate(spectrograms, **generation_args)
                eeg_captions = processor.batch_decode(generated_ids, skip_special_tokens=True)

                # loss_dict = comprehensive_loss(
                #     eeg_embeddings, image_embeddings, eeg_captions, gt_captions, pil_images)
                # val_losses.append(loss_dict['total'].item())
                target = torch.ones(eeg_embeddings.size(0)).to(device)
                val_losses.append(simple_alignment_loss(eeg_embeddings, image_embeddings, target).item())

                # Collect samples for image generation
                if len(eval_samples) < config.EVAL_SAMPLES_TO_GENERATE:
                    for i in range(len(eeg_captions)):
                        if len(eval_samples) < config.EVAL_SAMPLES_TO_GENERATE:
                            eval_samples.append((eeg_captions[i], gt_captions[i], pil_images[i]))

        avg_val_loss = np.mean(val_losses)
        print(f"  Validation Loss: {avg_val_loss:.4f}")
        scheduler.step(avg_val_loss)

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), output_dir / 'best_model.pth')
            print(f"🏆 New best model saved with validation loss: {best_val_loss:.4f}")

        if eval_samples:
            print("\n--- ✍️  Caption Generation Samples ---")
            for i, (predicted_cap, gt_cap, _) in enumerate(eval_samples):
              print(f"Sample {i+1}:")
              print(f"  - Ground Truth: {gt_cap.strip()}")
              print(f"  - Predicted:    {predicted_cap.strip()}")
              print("------------------------------------")

        # --- Image Generation Evaluation ---
        if (epoch + 1) % config.EVAL_IMAGE_GENERATION_INTERVAL == 0 and eval_samples:
            eval_captions,  _, eval_images = zip(*eval_samples)
            image_evaluator.generate_and_save_samples(list(eval_captions), list(eval_images), epoch + 1)

    print("\n🎉 Training Complete!")
    return output_dir

if __name__ == '__main__':
    # It's recommended to run caption generation separately first.
    # For now, we assume the JSON files exist.
    config = ENHANCED_CONFIG()
    try:
        output_dir = train_model(config)
        if output_dir:
            print(f"✅ Training completed successfully! Results saved to: {output_dir}")
    except Exception as e:
        print(f"❌ Training failed with error: {str(e)}")
        import traceback
        traceback.print_exc()

In [None]:
def generate_and_cache_captions(config, processor, blip_model, device):
    """
    A one-time function to generate and save captions for train and val splits.
    """
    for split in ['train', 'val']:
        captions_path = Path(config.PROCESSED_DATA_ROOT) / f'{split}_gt_captions.json'
        if captions_path.exists():
            print(f"Captions for '{split}' split already exist at {captions_path}. Skipping.")
            continue

        print(f"Generating ground truth captions for '{split}' split...")
        df = pd.read_csv(config.METADATA_CSV)
        split_df = df[df['split'].str.strip() == split].reset_index(drop=True)

        captions = {}
        batch_size = 16
        blip_model.eval()

        with torch.no_grad():
            for i in tqdm(range(0, len(split_df), batch_size), desc=f"Generating {split} captions"):
                batch_info = split_df.iloc[i:i+batch_size]
                batch_images = [
                    Image.open(Path(config.PROCESSED_DATA_ROOT) / info['image_path']).convert("RGB")
                    for _, info in batch_info.iterrows()
                ]

                inputs = processor(images=batch_images, return_tensors="pt").to(device)
                generated_ids = blip_model.generate(**inputs, max_length=50)
                batch_captions = processor.batch_decode(generated_ids, skip_special_tokens=True)

                for j, caption in enumerate(batch_captions):
                    captions[str(i + j)] = caption.strip()

        with open(captions_path, 'w') as f:
            json.dump(captions, f, indent=2)
        print(f"✅ Saved {len(captions)} captions for '{split}' split to {captions_path}")
