# Vision-Language Model (VLM) Embedding

Dieses Notebook demonstriert die Erstellung von Embeddings mit einem Vision-Language Model für die Pflanzenkrankheitserkennung.

In [None]:
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import numpy as np
import pandas as pd
from pathlib import Path
import json

# Pfade definieren
DATA_PATH = Path('../data')
MODEL_PATH = Path('../models/vlm_embedder')

# CLIP-Modell laden
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

## Embedding-Funktionen definieren

In [None]:
def create_image_embedding(image_path, text_description=None):
    """Erstellt Embeddings für ein Bild mit optionalem Text"""
    # Bild laden
    image = Image.open(image_path)
    
    # Nur Bild-Embedding
    if text_description is None:
        inputs = processor(images=image, return_tensors="pt")
        with torch.no_grad():
            image_features = model.get_image_features(**inputs)
        return image_features.numpy().squeeze()
    
    # Multimodales Embedding (Bild + Text)
    inputs = processor(text=text_description, images=image, return_tensors="pt", padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
        image_features = outputs.image_embeds
        text_features = outputs.text_embeds
    
    return {
        'image_embedding': image_features.numpy().squeeze(),
        'text_embedding': text_features.numpy().squeeze(),
        'combined_embedding': (image_features + text_features).numpy().squeeze() / 2
    }

# Textbeschreibungen für verschiedene Kategorien
category_descriptions = {
    'healthy': 'healthy green plant leaf without disease',
    'disease_A': 'plant leaf with disease A symptoms',
    'disease_B': 'plant leaf with disease B symptoms'
}

## Embeddings für Datensatz erstellen

In [None]:
def process_dataset_embeddings():
    """Erstellt Embeddings für alle Bilder im Datensatz"""
    embeddings_data = []
    
    for split in ['train', 'val']:
        for category in ['healthy', 'disease_A', 'disease_B']:
            category_path = DATA_PATH / 'raw' / split / category
            if category_path.exists():
                for image_file in category_path.glob('*.jpg'):
                    try:
                        # Embedding erstellen
                        embedding_result = create_image_embedding(
                            image_file, 
                            category_descriptions[category]
                        )
                        
                        # Metadaten sammeln
                        embeddings_data.append({
                            'image_path': str(image_file),
                            'category': category,
                            'split': split,
                            'image_embedding': embedding_result['image_embedding'].tolist(),
                            'text_embedding': embedding_result['text_embedding'].tolist(),
                            'combined_embedding': embedding_result['combined_embedding'].tolist()
                        })
                        
                        print(f"Processed: {image_file.name}")
                        
                    except Exception as e:
                        print(f"Error processing {image_file}: {e}")
    
    return embeddings_data

# Embeddings erstellen (auskommentiert für Demo)
# embeddings_data = process_dataset_embeddings()
# 
# # Embeddings speichern
# with open('../data/processed/embeddings.json', 'w') as f:
#     json.dump(embeddings_data, f, indent=2)