# Notebook to create image embeddings using Gemma 3 
### (To run this notebook you need T4 GPU in google colab)

### Installing dependancies

In [None]:
!pip install torch PIL transformers
!pip install --upgrade transformers

### Processing all spectrogram images and converting them into image embeddings
### Results are saved in spectrogram_features.csv with the path of the file and the corresponding embeddings

In [None]:
import torch
import numpy as np
import pandas as pd
import os
import glob
from PIL import Image
from transformers import (
    AutoModelForCausalLM,
    SiglipImageProcessor
)
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

def extract_features_from_spectrograms(spectrograms_root, output_csv='spectrogram_features.csv'):
    """
    Extract global features from all spectrogram images and save to CSV
    """

    # Model configuration
    MODEL_ID = "google/gemma-3-4b-it"
    dtype = torch.float16

    print("Loading model and vision tower...")

    # Load model + frozen SigLIP vision tower
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=dtype,
        device_map="cuda",
        trust_remote_code=True
    )

    print(f"Model loaded on device: {next(model.parameters()).device}")

    # Get vision encoder
    vision_enc = model.vision_tower
    device = next(vision_enc.parameters()).device

    # Load image processor
    processor = SiglipImageProcessor.from_pretrained(MODEL_ID)

    print("Searching for spectrogram images...")

    # Find all PNG files in spectrograms directory and its subdirectories
    # The '**' pattern with recursive=True is correct for finding files in subdirectories
    png_files = glob.glob(os.path.join(spectrograms_root, '**', '*.png'), recursive=True)

    if not png_files:
        print(f"No PNG files found in {spectrograms_root} or its subdirectories.")
        return

    print(f"Found {len(png_files)} spectrogram images to process")

    # Initialize lists to store results
    image_paths = []
    global_features_list = []
    failed_images = []

    # Process images with progress bar
    for img_path in tqdm(png_files, desc="Extracting features"):
        try:
            # Load and process the image
            img = Image.open(img_path).convert("RGB")
            pixel_values = processor(images=img, return_tensors="pt").pixel_values

            # Move to device
            pixel_values = pixel_values.to(device)

            # Get vision embeddings
            with torch.no_grad():  # Save memory
                vision_outputs = vision_enc(pixel_values, output_hidden_states=True)
                hidden_state = vision_outputs.last_hidden_state  # shape: [1, seq_len, hidden_dim]

                # Average over tokens (dim=1) to get global feature vector
                global_feature = hidden_state.mean(dim=1)  # shape: [1, hidden_dim]

                # Convert to numpy and flatten
                global_feature_np = global_feature.detach().cpu().float().numpy().flatten()

            # Store results
            image_paths.append(img_path)
            global_features_list.append(global_feature_np)

            # Clear cache to prevent memory issues
            torch.cuda.empty_cache()

        except Exception as e:
            print(f"Error processing {img_path}: {str(e)}")
            failed_images.append((img_path, str(e)))
            continue

    if not global_features_list:
        print("No features were successfully extracted!")
        return

    print(f"Successfully processed {len(global_features_list)} images")

    # Convert features to numpy array
    features_array = np.array(global_features_list)
    print(f"Feature array shape: {features_array.shape}")

    # Create DataFrame
    print("Creating DataFrame...")

    # Create column names for features
    feature_dim = features_array.shape[1]
    feature_columns = [f'feature_{i}' for i in range(feature_dim)]

    # Create the main DataFrame
    df_data = {'image_path': image_paths}

    # Add feature columns
    for i, col_name in enumerate(feature_columns):
        df_data[col_name] = features_array[:, i]

    df = pd.DataFrame(df_data)

    # Save to CSV
    print(f"Saving features to {output_csv}...")
    df.to_csv(output_csv, index=False)

    print(f"Features saved successfully!")
    print(f"CSV shape: {df.shape}")
    print(f"Feature vector dimension: {feature_dim}")

    if failed_images:
        print(f"\nFailed to process {len(failed_images)} images:")
        for img_path, error in failed_images:
            print(f"  {img_path}: {error}")

    # Print sample of the DataFrame
    print(f"\nSample of saved data:")
    print(df.head())

    return df

def extract_features_batch(spectrograms_root, output_csv='spectrogram_features.csv', batch_size=8):
    """
    Extract features in batches to handle memory more efficiently
    """

    # Model configuration
    MODEL_ID = "google/gemma-3-4b-it"
    dtype = torch.float16

    print("Loading model and vision tower...")

    # Load model + frozen SigLIP vision tower
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=dtype,
        device_map="cuda",
        trust_remote_code=True
    )

    print(f"Model loaded on device: {next(model.parameters()).device}")

    # Get vision encoder
    vision_enc = model.vision_tower
    device = next(vision_enc.parameters()).device

    # Load image processor
    processor = SiglipImageProcessor.from_pretrained(MODEL_ID)

    print("Searching for spectrogram images...")

    # Find all PNG files in spectrograms directory and its subdirectories
    png_files = glob.glob(os.path.join(spectrograms_root, '**', '*.png'), recursive=True)

    if not png_files:
        print(f"No PNG files found in {spectrograms_root} or its subdirectories.")
        return

    print(f"Found {len(png_files)} spectrogram images to process")
    print(f"Processing in batches of {batch_size}")

    # Initialize lists to store results
    all_image_paths = []
    all_global_features = []
    failed_images = []

    # Process images in batches
    for i in tqdm(range(0, len(png_files), batch_size), desc="Processing batches"):
        batch_files = png_files[i:i+batch_size]
        batch_images = []
        batch_paths = []

        # Load batch of images
        for img_path in batch_files:
            try:
                img = Image.open(img_path).convert("RGB")
                batch_images.append(img)
                batch_paths.append(img_path)
            except Exception as e:
                print(f"Error loading {img_path}: {str(e)}")
                failed_images.append((img_path, str(e)))
                continue

        if not batch_images:
            continue

        try:
            # Process batch
            pixel_values = processor(images=batch_images, return_tensors="pt").pixel_values
            pixel_values = pixel_values.to(device)

            # Get vision embeddings for batch
            with torch.no_grad():
                vision_outputs = vision_enc(pixel_values, output_hidden_states=True)
                hidden_states = vision_outputs.last_hidden_state  # shape: [batch_size, seq_len, hidden_dim]

                # Average over tokens (dim=1) to get global feature vectors
                global_features = hidden_states.mean(dim=1)  # shape: [batch_size, hidden_dim]

                # Convert to numpy
                global_features_np = global_features.detach().cpu().float().numpy()

            # Store results
            all_image_paths.extend(batch_paths)
            all_global_features.extend(global_features_np)

            # Clear cache
            torch.cuda.empty_cache()

        except Exception as e:
            print(f"Error processing batch {i//batch_size + 1}: {str(e)}")
            for path in batch_paths:
                failed_images.append((path, str(e)))
            continue

    if not all_global_features:
        print("No features were successfully extracted!")
        return

    print(f"Successfully processed {len(all_global_features)} images")

    # Convert features to numpy array
    features_array = np.array(all_global_features)
    print(f"Feature array shape: {features_array.shape}")

    # Create DataFrame
    print("Creating DataFrame...")

    # Create column names for features
    feature_dim = features_array.shape[1]
    feature_columns = [f'feature_{i}' for i in range(feature_dim)]

    # Create the main DataFrame
    df_data = {'image_path': all_image_paths}

    # Add feature columns
    for i, col_name in enumerate(feature_columns):
        df_data[col_name] = features_array[:, i]

    df = pd.DataFrame(df_data)

    # Save to CSV
    print(f"Saving features to {output_csv}...")
    df.to_csv(output_csv, index=False)

    print(f"Features saved successfully!")
    print(f"CSV shape: {df.shape}")
    print(f"Feature vector dimension: {feature_dim}")

    if failed_images:
        print(f"\nFailed to process {len(failed_images)} images:")
        for img_path, error in failed_images:
            print(f"  {img_path}: {error}")

    # Print sample of the DataFrame
    print(f"\nSample of saved data:")
    print(df.head())

    return df

# Main execution
if __name__ == "__main__":
    # Define the root directory for spectrograms
    SPECTROGRAMS_ROOT_PATH = '/content/drive/Shareddrives/Radar/spectrograms'

    # Check if spectrograms directory exists
    if not os.path.exists(SPECTROGRAMS_ROOT_PATH):
        print(f"Error: The specified spectrograms directory does not exist: {SPECTROGRAMS_ROOT_PATH}")
        print(f"Current working directory: {os.getcwd()}")
        print("Please ensure the path is correct and accessible.")
    else:
        # Choose processing method based on available memory
        print("Choose processing method:")
        print("1. Single image processing (lower memory usage)")
        print("2. Batch processing (faster but uses more memory)")

        choice = input("Enter choice (1 or 2, default=1): ").strip()

        if choice == "2":
            # Batch processing
            batch_size_input = input("Enter batch size (default=8): ").strip()
            batch_size = int(batch_size_input) if batch_size_input.isdigit() and int(batch_size_input) > 0 else 8
            print(f"Using batch processing with batch size: {batch_size}")
            df = extract_features_batch(spectrograms_root=SPECTROGRAMS_ROOT_PATH, batch_size=batch_size)
        else:
            # Single image processing
            print("Using single image processing")
            df = extract_features_from_spectrograms(spectrograms_root=SPECTROGRAMS_ROOT_PATH)

        if df is not None:
            print(f"\nProcessing complete! Check 'spectrogram_features.csv' for results.")

Choose processing method:
1. Single image processing (lower memory usage)
2. Batch processing (faster but uses more memory)
Enter choice (1 or 2, default=1): 1
Using single image processing
Loading model and vision tower...


Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   8%|8         | 333M/3.97G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   8%|8         | 439M/5.40G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

Model loaded on device: cuda:0


preprocessor_config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Searching for spectrogram images...
Found 1754 spectrogram images to process


Extracting features: 100%|██████████| 1754/1754 [30:17<00:00,  1.04s/it]


Successfully processed 1754 images
Feature array shape: (1754, 1152)
Creating DataFrame...
Saving features to spectrogram_features.csv...
Features saved successfully!
CSV shape: (1754, 1153)
Feature vector dimension: 1152

Sample of saved data:
                                          image_path  feature_0  feature_1  \
0  /content/drive/Shareddrives/Radar/spectrograms...   0.031647   0.070984   
1  /content/drive/Shareddrives/Radar/spectrograms...   0.027298   0.062195   
2  /content/drive/Shareddrives/Radar/spectrograms...   0.028488   0.063721   
3  /content/drive/Shareddrives/Radar/spectrograms...   0.034760   0.065002   
4  /content/drive/Shareddrives/Radar/spectrograms...   0.023209   0.069641   

   feature_2  feature_3  feature_4  feature_5  feature_6  feature_7  \
0   0.042999   0.356934  -0.015327   0.016190  -0.083008  -0.077209   
1   0.042877   0.345459  -0.016708   0.022507  -0.082581  -0.075256   
2   0.036133   0.355713  -0.020554   0.017883  -0.085510  -0.068787   
3 