In [1]:

import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
import numpy as np
import cv2
import os
import pandas as pd
import time
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
from PIL import Image
import ipywidgets as widgets

# --- Konfigurationsvariablen ---
MODEL_PATH = "mobilenet_model.h5"
CSV_FILE = "training_data/Classes_alle.csv"
IMG_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 0.001

# --- Labels aus CSV laden ---
def load_labels_from_csv(csv_path):
    if os.path.exists(csv_path):
        df = pd.read_csv(csv_path)
        return {int(row['label_id']): row['class_name'] for _, row in df.iterrows()}
    return {}

LABELS = load_labels_from_csv(CSV_FILE)
num_classes = len(LABELS)

# --- Modell laden oder neu erstellen ---
def get_model(num_classes):
    if num_classes == 0:
        print("❌ Keine Klassen gefunden. Modell wird nicht geladen.")
        return None
    
    base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    base_model.trainable = False
    
    x = GlobalAveragePooling2D()(base_model.output)
    x = Dense(128, activation='relu')(x)
    output = Dense(num_classes, activation='softmax')(x)
    
    model = Model(inputs=base_model.input, outputs=output)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

model = get_model(num_classes)

if model and os.path.exists(MODEL_PATH):
    model.load_weights(MODEL_PATH)
    print("✅ Modell geladen!")
else:
    print("❌ Kein trainiertes Modell gefunden!")

# --- Bildvorverarbeitung für Klassifikation ---
def preprocess_frame(frame):
    if frame is None or frame.size == 0:
        print("❌ Fehler: Ungültiges Bild erhalten!")
        return None
    
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame_resized = cv2.resize(frame_rgb, IMG_SIZE)
    frame_array = img_to_array(frame_resized) / 255.0
    return np.expand_dims(frame_array, axis=0)

# --- Klassifikation durchführen ---
def classify_image(model, frame):
    img_tensor = preprocess_frame(frame)
    if img_tensor is None:
        return [("Fehler", 0.0)]
    
    predictions = model.predict(img_tensor, verbose=0)[0]
    top_2 = np.argsort(predictions)[-2:][::-1]
    return [(LABELS.get(i, "Unbekannt"), predictions[i]) for i in top_2]

# --- Kontinuierlicher Kamera-Stream mit Korrektur ---
def live_classification():
    cap = cv2.VideoCapture(0)
    if not cap.isOpened():
        print("❌ Fehler: Kamera nicht verfügbar!")
        return

    print("\n🎥 Kamera läuft... Drücke 'q' zum Beenden oder 'c' für Korrektur.")
    image_widget = widgets.Image(format='jpeg')
    display(image_widget)

    try:
        while True:
            ret, frame = cap.read()
            if not ret or frame is None or frame.size == 0:
                print("❌ Fehler beim Lesen des Kamerabilds! Neuer Versuch...")
                time.sleep(0.5)
                continue

            predictions = classify_image(model, frame)
            text_lines = [f"{label}: {prob:.2%}" for label, prob in predictions]
            y_offset = 50
            for line in text_lines:
                cv2.putText(frame, line, (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 2.1, (0, 255, 0), 5)
                y_offset += 80

            _, buffer = cv2.imencode('.jpg', frame)
            image_widget.value = buffer.tobytes()
            
            key = input("Drücke Enter für Weiter, 'c' für Korrektur, 'q' zum Beenden: ").strip()
            if key == 'q':
                break
            elif key == 'c':
                correct_classification(frame)
    except KeyboardInterrupt:
        print("🚪 Manuelle Beendigung durch Nutzer!")
    
    cap.release()
    print("🚪 Kamera gestoppt!")

# --- Korrektur & neue Klassen hinzufügen ---
def correct_classification(frame):
    if frame is None or frame.size == 0:
        print("❌ Fehler: Ungültiges Bild zur Korrektur!")
        return

    print("\n📌 Verfügbare Klassen:")
    for idx, class_name in LABELS.items():
        print(f"{idx}: {class_name}")

    new_class = input("Neue Klasse oder vorhandene Nummer eingeben: ").strip()

    if new_class.isdigit() and int(new_class) in LABELS:
        class_name = LABELS[int(new_class)]
    else:
        class_name = new_class
        new_label_id = max(LABELS.keys(), default=-1) + 1
        LABELS[new_label_id] = class_name

        with open(CSV_FILE, "a") as f:
            f.write(f"{new_label_id},{class_name}\n")
        print(f"✅ Neue Klasse '{class_name}' hinzugefügt!")

    class_dir = os.path.join(TRAINING_DATA_DIR, class_name)
    os.makedirs(class_dir, exist_ok=True)
    
    img_name = os.path.join(class_dir, f"{int(time.time())}.jpg")
    cv2.imwrite(img_name, frame)
    print(f"📷 Bild gespeichert: {img_name}")

# --- Skript starten ---
if __name__ == "__main__":
    live_classification()


✅ Modell geladen!

🎥 Kamera läuft... Drücke 'q' zum Beenden oder 'c' für Korrektur.


Image(value=b'', format='jpeg')

🚪 Manuelle Beendigung durch Nutzer!
🚪 Kamera gestoppt!


Drücke Enter für Weiter, 'c' für Korrektur, 'q' zum Beenden:  
