# BA: KI-gestützte Optimierung der Brustkrebsdiagnose: Automatisierte Tumorerkennung und Reduktion von Diagnosefehlern

### Benötigte Biblotheken

In [None]:
import re
import os
import sys
# Filter für native CUDA/CUPTI-Warnings 
stderr = sys.stderr
sys.stderr = open(os.devnull, 'w')  # Deaktiviert alles auf stderr (inkl. Delay-Kernel-Warnung)

import time
import h5py
import shap
import cv2
import pickle
import random
import pydicom
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import tensorflow as tf

from collections import Counter, defaultdict

from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.applications.efficientnet import preprocess_input
from tensorflow.keras.regularizers import l2
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, GlobalAveragePooling2D, Dense, Dropout,
    BatchNormalization, Conv2D, UpSampling2D,
    Concatenate, Multiply, Activation, Add, Layer)
from tensorflow.keras.utils import Sequence

from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

In [None]:
print("TensorFlow version:", tf.__version__)
print(tf.config.list_physical_devices('GPU'))

### 1. Daten laden und verarbeiten

In [None]:
# -----------------------------------------------------------------------------
# Daten laden und verarbeiten
# -----------------------------------------------------------------------------
# Ziel: Laden und Vorverarbeiten der DICOM-Bilddaten, ROI-Masken und klinischen Merkmale
# Ergebnis: Speicherung der patientinnenbasierten Daten in verarbeiteten "Chunks"
# -----------------------------------------------------------------------------

# ----------------------------------------------------------------------------
# Pfade und Konfigurationsparameter

EXCEL_PATH = "/mnt/e/Advanced-MRI-Breast-Lesions/Advanced-MRI-Breast-Lesions-DA-Clinical-Sep2024.xlsx"
DICOM_FOLDER = "/mnt/e/Advanced-MRI-Breast-Lesions/DICOM Images/manifest-1713182663002/Advanced-MRI-Breast-Lesions"
INTERMEDIATE_FOLDER = "/mnt/e/Advanced-MRI-Breast-Lesions/data/intermediate_caches"
os.makedirs(INTERMEDIATE_FOLDER, exist_ok=True)

IMAGE_SIZE = (256, 256)
CHUNK_PATTERN = "chunk_{:04d}.pkl"
MIN_MEAN_THRESHOLD = 0.0001  # minimale Maskenbedeckung pro Slice (Vermeidung leerer Masken)
Z_TOLERANCE = 0.5            # maximal zulässiger Abstand (in mm) für Zuordnung von Masken zu Slices

# ----------------------------------------------------------------------------
# Schritt 1: Duplikate vermeiden – bereits verarbeitete Patientinnen laden

processed_patient_ids = set()
for file in sorted(os.listdir(INTERMEDIATE_FOLDER)):
    if file.startswith("chunk_") and file.endswith(".pkl"):
        try:
            with open(os.path.join(INTERMEDIATE_FOLDER, file), "rb") as f:
                chunk = pickle.load(f)
                processed_patient_ids.update(chunk.get("patient_ids", []))
        except Exception as e:
            print(f"Fehler beim Laden von {file}: {e}")

# ----------------------------------------------------------------------------
# Schritt 2: Klinische Daten laden und Feature-Mapping aufbauen

df = pd.read_excel(EXCEL_PATH, header=1).fillna(0)

# Auswahl: nur Patient ID, Alter, Tumor-Labels, Positionsdaten
df_selected = df[[
    "Patient ID", "age at MRI",
    "tumor/benign1", "pos1", "tumor/benign2", "pos2",
    "tumor/benign3", "pos3", "tumor/benign4", "pos4",
    "tumor/benign5", "pos5", "tumor/benign6", "pos6"]].copy()

# ------------------------------
# Positionen in R/L-Koordinaten zerlegen 
import re
def convert_position_value(value):
    if isinstance(value, str):
        match = re.match(r"([RL])(-?\d+\.?\d*)", value)
        if match:
            side, number = match.groups()
            number = float(number)
            return (number, 0.0) if side == "R" else (0.0, number)
    return (0.0, 0.0)

def safe_convert_to_tuple(value):
    result = convert_position_value(value)
    return (round(result[0], 2), round(result[1], 2))

for col in ["pos1", "pos2", "pos3", "pos4", "pos5", "pos6"]:
    df_selected[f"{col}_R"] = df_selected[col].apply(lambda x: safe_convert_to_tuple(x)[0])
    df_selected[f"{col}_L"] = df_selected[col].apply(lambda x: safe_convert_to_tuple(x)[1])
df_selected.drop(columns=["pos1", "pos2", "pos3", "pos4", "pos5", "pos6"], inplace=True)

# ------------------------------
# Feature-Mapping pro Patient
clinical_data_mapping = {
    str(row["Patient ID"]): row.drop("Patient ID").values
    for _, row in df_selected.iterrows()
}
patients_with_clinical = set(clinical_data_mapping.keys())

# ----------------------------------------------------------------------------
# Schritt 3: Hauptschleife – Verarbeitung der DICOM-Daten

all_patients = sorted(os.listdir(DICOM_FOLDER))
chunk_idx = 1

for idx, patient_id in enumerate(all_patients):

    # Nur Patientinnen mit klinischen Daten verarbeiten
    if patient_id not in patients_with_clinical:
        print(f"[{idx+1}/{len(all_patients)}] {patient_id}: keine klinischen Daten – übersprungen.")
        continue

    # Bereits verarbeitete Patientinnen überspringen
    if patient_id in processed_patient_ids:
        print(f"[SKIP] Patientin {patient_id} bereits verarbeitet.")
        continue

    print(f"[{idx+1}/{len(all_patients)}] Verarbeite: {patient_id}")

    try:
        # ----------------------------------------
        # Unterordner bestimmen
        ppath = os.path.join(DICOM_FOLDER, patient_id)
        subs = [os.path.join(ppath, d) for d in os.listdir(ppath) if os.path.isdir(os.path.join(ppath, d))]
        if not subs:
            continue
        patient_subfolder = subs[0]

        # ----------------------------------------
        # ROI-Datei suchen
        roi_file = None
        for d in os.listdir(patient_subfolder):
            if "ROI" in d.upper():
                for fn in os.listdir(os.path.join(patient_subfolder, d)):
                    if fn.lower().endswith(".dcm"):
                        roi_file = os.path.join(patient_subfolder, d, fn)
                        break
            if roi_file:
                break
        if roi_file is None:
            print(f" -> Keine ROI gefunden für {patient_id} – übersprungen")
            continue

        # ----------------------------------------
        # DICOM-Bilder laden (ohne ROI, ohne TRAM)
        image_paths = []
        for d in os.listdir(patient_subfolder):
            if d.upper().startswith(("ROI", "TRAM")):
                continue
            sub = os.path.join(patient_subfolder, d)
            if os.path.isdir(sub):
                image_paths += [
                    os.path.join(sub, f) for f in os.listdir(sub)
                    if f.lower().endswith(".dcm")]
        if not image_paths:
            continue

        # ----------------------------------------
        # Z-Koordinaten extrahieren und sortieren
        img_infos = []
        for path in image_paths:
            ds = pydicom.dcmread(path, stop_before_pixels=True)
            if hasattr(ds, "ImagePositionPatient"):
                z = float(ds.ImagePositionPatient[2])
                img_infos.append((z, path))
        if not img_infos:
            continue

        img_infos.sort(key=lambda x: x[0])
        imgs, img_zs_fixed = [], []

        for z, path in img_infos:
            ds = pydicom.dcmread(path)
            arr = ds.pixel_array.astype(np.float32)
            arr = cv2.GaussianBlur(arr, (5, 5), sigmaX=1)
            min_val, max_val = np.percentile(arr, 1), np.percentile(arr, 99)
            arr = np.clip(arr, min_val, max_val)
            arr = (arr - min_val) / (max_val - min_val + 1e-6)
            imgs.append(np.rot90(arr, 2))
            img_zs_fixed.append(z)
        imgs = np.stack(imgs, axis=0)

        # ----------------------------------------
        # ROI-Masken aus DICOM extrahieren
        ds_seg = pydicom.dcmread(roi_file)
        seg_arr = ds_seg.pixel_array
        seg_zs = [float(fg.PlanePositionSequence[0].ImagePositionPatient[2])
                  for fg in ds_seg.PerFrameFunctionalGroupsSequence]
        order = np.argsort(seg_zs)
        seg_arr = seg_arr[order]
        seg_zs = [seg_zs[i] for i in order]

        # Zuweisung der Masken zu Slices nach Z-Toleranz
        masks = []
        for z in img_zs_fixed:
            combined = np.zeros_like(seg_arr[0], dtype=np.uint8)
            for seg_z, seg in zip(seg_zs, seg_arr):
                if abs(seg_z - z) <= Z_TOLERANCE:
                    combined |= (seg > 0).astype(np.uint8)
            masks.append(np.rot90(combined, 2))
        masks = np.stack(masks, axis=0)

        # ----------------------------------------
        # Feature-Vektoren & Labels erstellen und Chunk speichern
        image_data, mask_data, clinical_data = [], [], []
        labels, image_patient_ids, roi_patient_ids, summary_table = [], [], [], []

        row = df[df["Patient ID"] == patient_id]
        feats = clinical_data_mapping[patient_id]
        flat_feats = []
        for f in feats:
            flat_feats.extend(f) if isinstance(f, tuple) else flat_feats.append(f)

        # ----------------------------------------
        # Schritt 1: Mapping Slice-Z -> Labels
        added_count = 0
        z_label_map = defaultdict(set)

        
        for i in range(1, 7):
            label_col = f"tumor/benign{i}"
            col_r = f"pos{i}_R"
            col_l = f"pos{i}_L"

            if label_col not in row or col_r not in df_selected or col_l not in df_selected:
                continue

            label_val = row[label_col].values[0]
            r_val = df_selected.loc[df_selected["Patient ID"] == patient_id, col_r].values[0]
            l_val = df_selected.loc[df_selected["Patient ID"] == patient_id, col_l].values[0]

            z_target = r_val if r_val != 0.0 else l_val
            if z_target == 0.0 or pd.isna(label_val):
                continue

            for z_val in img_zs_fixed:
                if abs(z_val - z_target) <= Z_TOLERANCE:
                    z_label_map[z_val].add(int(label_val))

        # ----------------------------------------
        # Schritt 2: Label-Vergabe auf Basis aller Z-Zuweisungen
        for slice_idx, z_val in enumerate(img_zs_fixed):
            if z_val not in z_label_map:
                continue
            label_set = z_label_map[z_val]
            label = 1 if 1 in label_set else 0

            img = imgs[slice_idx]
            msk = masks[slice_idx]
            if np.mean(msk) < MIN_MEAN_THRESHOLD:
                continue

            img_resized = np.expand_dims(cv2.resize(img, IMAGE_SIZE), axis=-1)
            mask_resized = np.expand_dims(cv2.resize(msk, IMAGE_SIZE), axis=-1)

            image_data.append(img_resized.astype(np.float32))
            mask_data.append(mask_resized.astype(np.float32))
            clinical_data.append(np.array(flat_feats, dtype=np.float32))
            labels.append(label)
            image_patient_ids.append(patient_id)
            roi_patient_ids.append(patient_id)
            added_count += 1


# ----------------------------------------------------------------------------
# Speicherung der Chunks und Datenzusammenfassung
        if added_count == 0:
            print(f" -> Keine passenden Slices mit Label für {patient_id} gefunden.")
            continue

        summary_table.append({
            "Patient ID": patient_id,
            "Klinische Daten": "Ja",
            "ROI Maske": "Ja",
            "Bilder geladen": added_count,
            "Masken geladen": added_count,
            "Tumor/Benign": f"{added_count} Slices"})

        chunk_path = os.path.join(INTERMEDIATE_FOLDER, CHUNK_PATTERN.format(chunk_idx))
        with open(chunk_path, "wb") as f:
            pickle.dump({
                "images": image_data,
                "masks": mask_data,
                "clinical": clinical_data,
                "labels": labels,
                "roi_ids": roi_patient_ids,
                "patient_ids": image_patient_ids,
                "summary": summary_table}, f)

        print(f" -> Gespeichert: {chunk_path}")
        chunk_idx += 1

    except Exception as e:
        print(f"Fehler bei Patientin {patient_id}: {e}")
        continue

#### 1.1 Zusammenführung der gespeicherten Daten

In [None]:
# -----------------------------------------------------------------------------
# ZUSAMMENFÜHRUNG DER VERARBEITETEN DATEN (Chunks)
# -----------------------------------------------------------------------------
# Ziel: Alle pro Patientin gespeicherten Daten-Chunks werden blockweise geladen
#       und zu einem finalen Datensatz zusammengefügt
# -----------------------------------------------------------------------------

# ----------------------------------------------------------------------------
# Pfade für Zwischenspeicherung und finale Datenstruktur

INTERMEDIATE_FOLDER = "/mnt/e/Advanced-MRI-Breast-Lesions/data/intermediate_caches"
FINAL_CACHE_FILE = "/mnt/e/Advanced-MRI-Breast-Lesions/data/geladene_Daten.pkl"
FINAL_SUMMARY_CSV = "/mnt/e/Advanced-MRI-Breast-Lesions/data/Datenzusammenfassung.csv"

# ----------------------------------------------------------------------------
# Initialisierung und Vorbereitung

BLOCK_SIZE = 10  # Anzahl von Chunks pro Schreibblock (RAM-Effizienz)

# Zielordner sicherstellen
os.makedirs(os.path.dirname(FINAL_CACHE_FILE), exist_ok=True)

# Alte finale Datei entfernen (Neustart)
if os.path.exists(FINAL_CACHE_FILE):
    os.remove(FINAL_CACHE_FILE)

print("\n -> Starte Zusammenführung der Chunks")

# Alle Chunk-Dateien im Zwischenspeicherordner sammeln
chunk_files = sorted([
    f for f in os.listdir(INTERMEDIATE_FOLDER)
    if f.startswith("chunk_") and f.endswith(".pkl")])
print(f" -> Gefundene Chunks: {len(chunk_files)}")

# Liste für zusammenfassende Informationen (später .csv)
all_summary = []

# ----------------------------------------------------------------------------
# Schritt 1: Blöcke iterativ einlesen und final speichern

with open(FINAL_CACHE_FILE, "ab") as final_f:

    block_images = []
    block_masks = []
    block_clinical = []
    block_labels = []
    block_roi_ids = []
    block_patient_ids = []

    for idx, file in enumerate(chunk_files):
        path = os.path.join(INTERMEDIATE_FOLDER, file)
        print(f"   -> Lade: {file} ({idx + 1}/{len(chunk_files)})")

        with open(path, "rb") as f:
            chunk = pickle.load(f)

        # Daten in temporäre Listen laden
        block_images.extend(chunk["images"])
        block_masks.extend(chunk["masks"])
        block_clinical.extend(chunk["clinical"])
        block_labels.extend(chunk["labels"])
        block_roi_ids.extend(chunk["roi_ids"])
        block_patient_ids.extend(chunk["patient_ids"])

        # Summary-Einträge extrahieren und um Labeltext ergänzen
        for summary_entry, label in zip(chunk["summary"], chunk["labels"]):
            summary_entry["Tumorstatus"] = f"{int(label)} (Slice)"
            all_summary.append(summary_entry)

        # Blockweise Speicherung zur Reduzierung des Speicherverbrauchs
        if (idx + 1) % BLOCK_SIZE == 0 or (idx + 1) == len(chunk_files):
            print(f"      -> Speichere Block {idx // BLOCK_SIZE}")

            batch_data = {
                "images": np.array(block_images, dtype=np.float32),
                "masks": np.array(block_masks, dtype=np.float32),
                "clinical": np.array(block_clinical, dtype=np.float32),
                "labels": np.array(block_labels, dtype=np.float32),
                "roi_ids": block_roi_ids,
                "patient_ids": block_patient_ids}

            pickle.dump(batch_data, final_f)

            # Speicher freigeben
            block_images.clear()
            block_masks.clear()
            block_clinical.clear()
            block_labels.clear()
            block_roi_ids.clear()
            block_patient_ids.clear()

# ----------------------------------------------------------------------------
# Schritt 2: Zusammenfassende CSV-Datei speichern

print("\n -> Speichere tabellarische Zusammenfassung")

df_summary = pd.DataFrame(all_summary)
df_summary.to_csv(FINAL_SUMMARY_CSV, index=False)

# ----------------------------------------------------------------------------
# Schritt 3: Statistik und Abschlussmeldung

total_images = sum(entry["Bilder geladen"] for entry in all_summary)
total_masks = sum(entry["Masken geladen"] for entry in all_summary)
roi_patients = set(entry["Patient ID"] for entry in all_summary)
total_patients = roi_patients  # identisch, da alle mit ROI verarbeitet wurden

print("\n Zusammenführung abgeschlossen!")
print(f"   * Bilder:             {total_images}")
print(f"   * Masken:             {total_masks}")
print(f"   * Klinikdaten:        {total_images} (1:1 mit Bildern)")
print(f"   * Labels:             {total_images} (1:1 mit Bildern)")
print(f"   * ROI-Patientinnen:   {len(roi_patients)}")
print(f"   * Gesamt-Patientinnen:{len(total_patients)}")

### 2. Prüfung der Bilder

In [None]:
# -----------------------------------------------------------------------------
# Prüfung und Visualisierung der Daten
# -----------------------------------------------------------------------------
# Ziel:
#  - Geladene Daten auf Vollständigkeit, Konsistenz und Qualität prüfen
#  - Labelverteilung visualisieren
#  - Beispielbilder und klinische Informationen darstellen
# -----------------------------------------------------------------------------

# ----------------------------------------------------------------------------
# Relevante Dateipfade

FINAL_CACHE_FILE = "/mnt/e/Advanced-MRI-Breast-Lesions/data/geladene_Daten.pkl"
FINAL_SUMMARY_CSV = "/mnt/e/Advanced-MRI-Breast-Lesions/data/Datenzusammenfassung.csv"

# ----------------------------------------------------------------------------
# Funktion zum Laden des finalen Caches (mehrere Pickle-Blöcke)

def load_final_cache(path):

    data = {
        "images": [],
        "masks": [],
        "clinical": [],
        "labels": [],
        "roi_ids": [],
        "patient_ids": []}

    if not os.path.exists(path):
        raise FileNotFoundError(f"Cache nicht gefunden: {path}")

    with open(path, "rb") as f:
        while True:
            try:
                batch = pickle.load(f)
                data["images"].append(batch["images"])
                data["masks"].append(batch["masks"])
                data["clinical"].append(batch["clinical"])
                data["labels"].append(batch["labels"])
                data["roi_ids"].extend(batch["roi_ids"])
                data["patient_ids"].extend(batch["patient_ids"])
            except EOFError:
                break

    # Stapelweise zu einem Array zusammenführen
    data["images"] = np.concatenate(data["images"], axis=0)
    data["masks"] = np.concatenate(data["masks"], axis=0)
    data["clinical"] = np.concatenate(data["clinical"], axis=0)
    data["labels"] = np.concatenate(data["labels"], axis=0)

    return data

# ----------------------------------------------------------------------------
# Schritt 1: Laden und Übersicht

print("\n Lade finalen Cache")
cached = load_final_cache(FINAL_CACHE_FILE)

# Entpacke Inhalte
image_data = cached["images"]
mask_data = cached["masks"]
clinical_data = cached["clinical"]
labels = cached["labels"]
image_patient_ids = cached["patient_ids"]
roi_patient_ids = cached["roi_ids"]

print(" Finaler Cache erfolgreich geladen!")

# Basisstatistiken
print("\n Datenübersicht:")
print(f" * Bilder:             {image_data.shape}")
print(f" * Masken:             {mask_data.shape}")
print(f" * Klinikdaten:        {clinical_data.shape}")
print(f" * Labels:             {labels.shape}")
print(f" * ROI-Patientinnen:   {len(set(roi_patient_ids))}")
print(f" * Gesamt-Patientinnen:{len(set(image_patient_ids))}")

# ----------------------------------------------------------------------------
# Schritt 2: Label-Verteilung

def check_label_distribution(labels_array):

    if labels_array is None or len(labels_array) == 0:
        print("Keine Labels vorhanden – möglicherweise keine klinischen Daten?")
        return

    if labels_array.ndim == 2 and labels_array.shape[1] == 2:
        label_counts = np.sum(labels_array, axis=0)
    else:
        label_counts = np.array([
            np.sum(labels_array == 0),
            np.sum(labels_array == 1)])

    categories = ["Benigne", "Maligne"]
    plt.bar(categories, label_counts, color=["green", "red"])
    plt.xlabel("Kategorien")
    plt.ylabel("Anzahl Bilder")
    plt.grid(True, axis='y', linestyle='-', linewidth=0.6)
    plt.gca().set_axisbelow(True)
    plt.tight_layout()
    plt.show()

    print("\n Klassnverteilung:")
    print(f" - Benigne: {int(label_counts[0])} Bilder")
    print(f" - Maligne: {int(label_counts[1])} Bilder")

check_label_distribution(labels)

# ----------------------------------------------------------------------------
# Schritt 3: Zufällige Bild-/Masken-Paare visualisieren

def show_random_images(num_images=5, show_masks=True):

    indices = np.random.choice(len(image_data), num_images, replace=False)
    fig, axes = plt.subplots(1, num_images, figsize=(5 * num_images, 5))

    selected_patients = []
    selected_labels = []

    for i, idx in enumerate(indices):
        img = image_data[idx].squeeze()
        msk = mask_data[idx].squeeze() if mask_data is not None else None

        lab = np.argmax(labels[idx]) if labels.ndim == 2 else labels[idx]
        label_text = "Maligne" if lab == 1 else "Benigne"

        pid = image_patient_ids[idx] if idx < len(image_patient_ids) else "Unbekannt"
        roi_status = "ROI" if pid in roi_patient_ids else "Keine ROI"

        selected_patients.append(pid)
        selected_labels.append(label_text)

        axes[i].imshow(img, cmap="gray")
        if show_masks and msk is not None and np.any(msk):
            axes[i].imshow(np.ma.masked_where(msk == 0, msk), cmap="spring", alpha=0.4)
        axes[i].set_title(f"{pid}\n{label_text} ({roi_status})")
        axes[i].axis("off")

    plt.tight_layout()
    plt.show()

    return indices, selected_patients, selected_labels

# Beispielaufruf
indices, selected_patients, selected_labels = show_random_images(num_images=5, show_masks=True)

# ----------------------------------------------------------------------------
# Schritt 4: Klinische Daten zu den Beispielbildern anzeigen

def show_sample_clinical_data(indices, selected_patients, selected_labels):

    if clinical_data is None or len(clinical_data) == 0:
        print("Keine klinischen Daten vorhanden.")
        return

    column_names = [
    "age at MRI",
    "tumor/benign1", "pos1_R", "pos1_L",
    "tumor/benign2", "pos2_R", "pos2_L",
    "tumor/benign3", "pos3_R", "pos3_L",
    "tumor/benign4", "pos4_R", "pos4_L",
    "tumor/benign5", "pos5_R", "pos5_L",
    "tumor/benign6", "pos6_R", "pos6_L"]

    clinical_samples = []
    for idx, pid, lab in zip(indices, selected_patients, selected_labels):
        if idx >= len(clinical_data):
            continue
        info = clinical_data[idx]
        formatted = []
        for col, val in zip(column_names, info):
            if col == "age at MRI" or re.match(r"pos\d+_[RL]", col):
                try:
                    formatted.append(round(float(val), 2))
                except:
                    formatted.append(val)
            else:
                formatted.append(val)
        clinical_samples.append([pid, lab] + formatted)

    df_clin = pd.DataFrame(clinical_samples, columns=["Patient ID", "Label"] + column_names)
    print("\n Zufällig ausgewählte klinische Daten:")
    print(df_clin.to_string(index=False))

# Aufruf für die zuvor gezeigten Bildbeispiele
show_sample_clinical_data(indices, selected_patients, selected_labels)

### 3. Train/Val/Test-Split auf Patientinnenbasis

In [None]:
# -----------------------------------------------------------------------------
# Train/Validierung/Test-Split auf Basis der Patientinne
# -----------------------------------------------------------------------------
# Ziel:
#  - Sicherstellung, dass Bilder einer Patientin nicht in mehreren Sets auftreten
#  - Reproduzierbarer Split in Trainings-, Validierungs- und Testdaten
#  - Speicherung und Wiederverwendung des Splits mittels Pickle-Datei
# -----------------------------------------------------------------------------

# -----------------------------------------------------------------------------
# Relevante Pfade

FINAL_CACHE_FILE = "/mnt/e/Advanced-MRI-Breast-Lesions/data/geladene_Daten.pkl"
SPLIT_INDICES_FILE = "/mnt/e/Advanced-MRI-Breast-Lesions/data/patient_split_indices.pkl"

# -----------------------------------------------------------------------------
# Schritt 1: Split laden, falls bereits vorhanden

if os.path.exists(SPLIT_INDICES_FILE):
    print(" Lade gespeicherte Split-Indizes...")

    with open(SPLIT_INDICES_FILE, "rb") as f:
        split_data = pickle.load(f)

    train_idx = split_data["train"]
    val_idx = split_data["val"]
    test_idx = split_data["test"]

    train_patients = split_data["train_patients"]
    val_patients = split_data["val_patients"]
    test_patients = split_data["test_patients"]

else:

    print(" Erzeuge neuen Split...")

    # Geladene Patientendaten aus dem Pickle-Cache extrahieren
    with open(FINAL_CACHE_FILE, "rb") as f:
        data = {
            "images": [],
            "masks": [],
            "clinical": [],
            "labels": [],
            "roi_ids": [],
            "patient_ids": []
        }
        while True:
            try:
                batch = pickle.load(f)
                data["images"].append(batch["images"])
                data["masks"].append(batch["masks"])
                data["clinical"].append(batch["clinical"])
                data["labels"].append(batch["labels"])
                data["roi_ids"].extend(batch["roi_ids"])
                data["patient_ids"].extend(batch["patient_ids"])
            except EOFError:
                break

    all_patient_ids = np.array(data["patient_ids"])
    roi_patients = sorted(set(all_patient_ids))

    print(f" Anzahl ROI-Patientinnen gesamt: {len(roi_patients)}")

    # -----------------------------------------------------------------------------
    # Aufteilung in Train, Validation und Test (Verteilung: 70%, 15%, 15%)

    train_patients, valtest_patients = train_test_split(
        roi_patients, test_size=0.30, random_state=42, shuffle=True)

    val_patients, test_patients = train_test_split(
        valtest_patients, test_size=0.50, random_state=42, shuffle=True)

    # -----------------------------------------------------------------------------
    # Hilfsfunktion zur Umrechnung von Patientinnen-IDs in Bildindizes

    def get_indices_by_patient_ids(all_ids, selected_ids):
        selected_ids = set(selected_ids)
        return [i for i, pid in enumerate(all_ids) if pid in selected_ids]

    # Anwendung auf alle Sets
    train_idx = get_indices_by_patient_ids(all_patient_ids, train_patients)
    val_idx = get_indices_by_patient_ids(all_patient_ids, val_patients)
    test_idx = get_indices_by_patient_ids(all_patient_ids, test_patients)

    # -----------------------------------------------------------------------------
    # Split-Daten zur Wiederverwendung speichern

    with open(SPLIT_INDICES_FILE, "wb") as f:
        pickle.dump({
            "train": train_idx,
            "val": val_idx,
            "test": test_idx,
            "train_patients": train_patients,
            "val_patients": val_patients,
            "test_patients": test_patients,
            "roi_patients": roi_patients}, f)

    print(f" Split-Indizes gespeichert unter: {SPLIT_INDICES_FILE}")

# -----------------------------------------------------------------------------
# Schritt 2: Zugriff auf die entsprechenden Bild- und Maskenarrays

# Laden der Arrays aus Pickle
with open(FINAL_CACHE_FILE, "rb") as f:
    data = {
        "images": [],
        "masks": [],
        "clinical": [],
        "labels": [],
    }
    while True:
        try:
            batch = pickle.load(f)
            data["images"].append(batch["images"])
            data["masks"].append(batch["masks"])
            data["clinical"].append(batch["clinical"])
            data["labels"].append(batch["labels"])
        except EOFError:
            break

X_all_img = np.concatenate(data["images"], axis=0)
X_all_clinical = np.concatenate(data["clinical"], axis=0)
y_all = np.concatenate(data["labels"], axis=0)
mask_all = np.concatenate(data["masks"], axis=0)

X_train_img      = X_all_img[train_idx]
X_val_img        = X_all_img[val_idx]
X_test_img       = X_all_img[test_idx]

X_train_clinical = X_all_clinical[train_idx]
X_val_clinical   = X_all_clinical[val_idx]
X_test_clinical  = X_all_clinical[test_idx]

y_train = y_all[train_idx]
y_val   = y_all[val_idx]
y_test  = y_all[test_idx]

mask_train = mask_all[train_idx]
mask_val   = mask_all[val_idx]
mask_test  = mask_all[test_idx]

# -----------------------------------------------------------------------------
# Schritt 3: Übersicht über die Split-Verteilung

print("\n Übersicht Patientinnensplit (basierend auf ROI-Patientinnen):")
print(f" * Train: {len(train_patients)} Patientinnen ({len(train_patients) / len(roi_patients):.2%})")
print(f" * Val:   {len(val_patients)} Patientinnen ({len(val_patients) / len(roi_patients):.2%})")
print(f" * Test:  {len(test_patients)} Patientinnen ({len(test_patients) / len(roi_patients):.2%})")

print("\n Bildanzahl pro Split:")
print(f" * Train: {len(train_idx)} Bilder")
print(f" * Val:   {len(val_idx)} Bilder")
print(f" * Test:  {len(test_idx)} Bilder")

In [None]:
# -----------------------------------------------------------------------------
# Konsistenzprüfung der Daten
# -----------------------------------------------------------------------------
# Ziel:
#  - Sicherstellen, dass die geladenen Bilddaten keine NaN-Werte enthalten
#  - NaNs (z.B. durch fehlerhafte Konvertierung oder Maskierung) werden durch 0 ersetzt
#  - Diese Prüfung erfolgt für Trainings-, Validierungs- und Testbilder separat
# -----------------------------------------------------------------------------

def validate_images(images, name=""):

    nan_count = np.isnan(images).sum()

    if nan_count > 0:
        print(f" {name} enthält {nan_count} NaN-Werte – werden auf 0.0 gesetzt.")
        images = np.nan_to_num(images, nan=0.0)

    return images

# Anwendung auf alle Bildsets
X_train_img = validate_images(X_train_img, "Train-Bilder")
X_val_img   = validate_images(X_val_img, "Val-Bilder")
X_test_img  = validate_images(X_test_img, "Test-Bilder")

In [None]:
# -----------------------------------------------------------------------------
# Prüfung der Maskenverteilung
# -----------------------------------------------------------------------------

def analyze_mask_distribution_pickle(pkl_path, show_empty_count=True, svg_output_path="maskenverteilung.jpg"):
    # Masken aus Pickle laden
    masks = []
    with open(pkl_path, "rb") as f:
        while True:
            try:
                batch = pickle.load(f)
                masks.append(batch["masks"])
            except EOFError:
                break

    masks = np.concatenate(masks, axis=0)
    mask_sizes = np.array([np.sum(m > 0) for m in masks])

    print(f"Insgesamt {len(mask_sizes)} Masken analysiert")
    if show_empty_count:
        empty = np.sum(mask_sizes == 0)
        print(f"{empty} Masken sind komplett leer (0 Pixel > 0)")
        print(f"{np.sum(mask_sizes < 100)} Masken mit < 100 Tumor-Pixeln")

    plt.figure(figsize=(8, 5))
    plt.hist(mask_sizes, bins=50, color="blue", edgecolor="black")
    plt.xlabel("Anzahl der Pixel pro Maske")
    plt.ylabel("Anzahl an Masken")
    plt.grid(True)
    plt.gca().set_axisbelow(True)
    plt.tight_layout()
    plt.savefig(svg_output_path, format="jpeg")
    plt.show()

# Beispielaufruf
analyze_mask_distribution_pickle("/mnt/e/Advanced-MRI-Breast-Lesions/data/geladene_Daten.pkl")

### 4. Multimodales Modell: CNN (EfficientNetB0) mit Transfer Learning für Klassifikation 

In [None]:
# ----------------------------------------------------------------------------- 
# Klassifikationsmodell: EfficientNetB0 mit Bild + Klinikdaten + Focal Loss 
# ----------------------------------------------------------------------------- 
# Ziel: 
#  - Eingabebilder über EfficientNetB0 verarbeiten (ImageNet vortrainiert) 
#  - Kombination mit klinischen Merkmalen 
#  - Klassifikation mit Sigmoid-Ausgabe und robuster Focal Loss 
# ----------------------------------------------------------------------------- 

# ----------------------------------------------------------------------------- 
# Focal Loss 
def focal_loss(gamma=2.0, alpha=0.25):
    def loss_fn(y_true, y_pred):
        y_pred = tf.clip_by_value(y_pred, tf.keras.backend.epsilon(), 1. - tf.keras.backend.epsilon())
        pt = tf.where(tf.equal(y_true, 1), y_pred, 1 - y_pred)
        loss = -alpha * tf.pow(1. - pt, gamma) * tf.math.log(pt)
        return tf.reduce_mean(loss)
    return loss_fn

# ----------------------------------------------------------------------------- 
# EfficientNetB0 Multimodales Modell mit integrierter Bildaugmentierung
def build_classification_model(image_input_shape=(256, 256, 3),
                                clinical_input_shape=(19,),
                                base_trainable_layers=80,
                                dropout_rate=0.3,
                                learning_rate=1e-3,
                                l2_reg=1e-5,
                                rotation_factor=0.2,
                                zoom_factor=0.2,
                                contrast_factor=0.2,
                                brightness_factor=0.2,
                                dense_units=(128, 64)):

    # ----------------------------------------------------------------------------- 
    # Bildeingabe definieren und Bildaugmentierung
    image_input = Input(shape=image_input_shape, name="image_input")

    x = tf.keras.layers.Rescaling(1.0 / 255)(image_input)
    x = tf.keras.layers.RandomFlip("horizontal_and_vertical")(x)
    x = tf.keras.layers.RandomRotation(rotation_factor)(x)
    x = tf.keras.layers.RandomZoom(zoom_factor)(x)
    x = tf.keras.layers.RandomContrast(contrast_factor)(x)
    x = tf.keras.layers.RandomBrightness(factor=brightness_factor)(x)

    # ----------------------------------------------------------------------------- 
    # EfficientNetB0 Backbone
    x = preprocess_input(x)
    base_model = EfficientNetB0(include_top=False, weights="imagenet", input_tensor=x)

    for layer in base_model.layers[:-base_trainable_layers]:
        layer.trainable = False

    x = base_model.output
    x = GlobalAveragePooling2D()(x)

    # ----------------------------------------------------------------------------- 
    # Zusätzliche Eingabe: klinische Merkmale
    clinical_input = Input(shape=clinical_input_shape, name="clinical_input")

    # ----------------------------------------------------------------------------- 
    # Kombination von Bild- und Klinikdaten
    x = Concatenate()([x, clinical_input])

    # ----------------------------------------------------------------------------- 
    # Dichte Klassifikationsschichten
    for units in dense_units:
        x = Dense(units, activation="relu", kernel_regularizer=l2(l2_reg))(x)
        x = BatchNormalization()(x)
        x = Dropout(dropout_rate)(x)

    # ----------------------------------------------------------------------------- 
    # Binäre Ausgabe
    output = Dense(1, activation="sigmoid", name="output", dtype="float32")(x)

    model = Model(inputs=[image_input, clinical_input], outputs=output)

    # ----------------------------------------------------------------------------- 
    # Modellkompilierung mit Focal Loss
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        loss=focal_loss(),
        metrics=["accuracy", tf.keras.metrics.AUC(name="auc")])

    print(f"Trainierbare Schichten in EfficientNetB0: {sum(layer.trainable for layer in base_model.layers)} / {len(base_model.layers)}")
    return model

### 5. Attention U-Net mit EfficientNetB0 für Segmentierung

In [None]:
# -----------------------------------------------------------------------------
# Attention U-Net mit EfficientNetB0 für Segmentierung (Tumorerkennung)
# -----------------------------------------------------------------------------
# Ziel:
#  - Architektur zur pixelgenauen Segmentierung (Tumorregionen)
#  - Kombination aus EfficientNetB0 als Encoder & U-Net Decoder mit Attention
#  - RGB-Input-Erweiterung für Grayscale-MRT
#  - Focal Tversky Loss 
# -----------------------------------------------------------------------------

# -----------------------------------------------------------------------------
# Focal Tversky Loss 

def focal_tversky_loss(alpha=0.9, beta=0.1, gamma=1):
    def loss(y_true, y_pred):
        y_true_f = tf.reshape(y_true, [-1])
        y_pred_f = tf.reshape(y_pred, [-1])
        TP = tf.reduce_sum(y_true_f * y_pred_f)
        FP = tf.reduce_sum((1 - y_true_f) * y_pred_f)
        FN = tf.reduce_sum(y_true_f * (1 - y_pred_f))
        tversky = (TP + 1e-6) / (TP + alpha * FN + beta * FP + 1e-6)
        return tf.pow((1 - tversky), gamma)
    return loss

# -----------------------------------------------------------------------------
# Attention-Modul für Skip-Verbindungen

def attention_gate(x, g, filters):
    theta_x = Conv2D(filters, 1, padding='same')(x)
    phi_g = Conv2D(filters, 1, padding='same')(g)
    add = Add()([theta_x, phi_g])
    act = Activation('relu')(add)
    psi = Conv2D(1, 1, padding='same', activation='sigmoid')(act)
    return Multiply()([x, psi])

# -----------------------------------------------------------------------------
# Benutzerdefinierte Lambda-Layer

class ResizeLike(Layer):
    def call(self, inputs):
        source, target = inputs
        target_shape = tf.shape(target)[1:3]
        return tf.image.resize(source, size=target_shape, method='bilinear')

# -----------------------------------------------------------------------------
# Encoder: Skip-Verbindungen aus EfficientNet

def build_efficientnetb0_unet(input_shape=(256, 256, 3), 
                              dropout_rate=0.4, 
                              learning_rate=1e-4):
    
    base_model = tf.keras.applications.EfficientNetB0(include_top=False, weights='imagenet', input_tensor=Input(shape=input_shape))

    skips = [
        base_model.get_layer("block2a_activation").output,  # 64x64
        base_model.get_layer("block3a_activation").output,  # 32x32
        base_model.get_layer("block4a_activation").output,  # 16x16
        base_model.get_layer("block6a_activation").output   # 8x8
    ]

    # -------------------------------------------------------------------------
    # Bottleneck (z. B. block6d_activation statt output)

    x = base_model.get_layer("block6d_activation").output  # 7x7, tieferer Layer

    x = Conv2D(1024, 3, padding="same", activation="relu")(x)
    x = BatchNormalization()(x)
    x = Dropout(dropout_rate)(x)
    x = Conv2D(1024, 3, padding="same", activation="relu")(x)
    x = BatchNormalization()(x)

    # -------------------------------------------------------------------------
    # Decoder

    decoder_filters = [512, 256, 128, 64]

    for i in range(4):
        skip = skips[3 - i]  # reverse order: deepest skip first
        x = UpSampling2D(size=(2, 2), interpolation="bilinear")(x)

        skip_resized = ResizeLike()([skip, x])
        attn = attention_gate(skip_resized, x, filters=skip.shape[-1])
        x = Concatenate()([x, attn])
        x = Conv2D(decoder_filters[i], 3, activation='relu', padding='same')(x)
        x = BatchNormalization()(x)
        x = Dropout(dropout_rate)(x)
        x = Conv2D(decoder_filters[i], 3, activation='relu', padding='same')(x)
        x = BatchNormalization()(x)

    # -------------------------------------------------------------------------
    # Ausgabeschicht

    x = UpSampling2D(size=(2, 2), interpolation="bilinear")(x)  # 64x64 → 128x128
    outputs = Conv2D(1, 1, activation="sigmoid", dtype="float32", name="output_mask")(x)

    # -------------------------------------------------------------------------
    # Modellkompilierung

    model = Model(inputs=base_model.input, outputs=outputs)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
                  loss=focal_tversky_loss(alpha=0.9, beta=0.1, gamma=1),
                  metrics=["accuracy"])
    return model

# -----------------------------------------------------------------------------
# Hilfsfunktion zur Konvertierung von Graustufenbildern zu RGB

def convert_grayscale_to_rgb(images):
    if images.ndim == 3:
        images = np.expand_dims(images, axis=-1)
    if images.shape[-1] == 1:
        images = np.repeat(images, 3, axis=-1)
    return images

### 6. Training und Evaluierung

In [None]:
# -----------------------------------------------------------------------------
# Trainingsprozess für Klassifikations- und Segmentierungsmodelle
# -----------------------------------------------------------------------------
# Ziel:
#  - Universelle Trainingsfunktion für Klassifikation & Segmentierung 
#  - LearningRate-Reduktion, EarlyStopping, Checkpoints 
#  - Speicherung von History + Trainingszeit 
#  - Automatische Visualisierung 
# -----------------------------------------------------------------------------

def train_and_evaluate(model,
                       train_data,
                       val_data=None,
                       batch_size=None,
                       epochs=25,
                       is_segmentation=False,
                       model_name="model",
                       patience=5,
                       save_dir="training_output",
                       class_weight=None):

    print(f"\n Training gestartet: {model_name}\n")

    # -------------------------------------------------------------------------
    # Zeitmessung starten
    start_time = time.time()

    # -------------------------------------------------------------------------
    # Zielordner für Ausgaben erstellen
    checkpoint_dir = os.path.join(save_dir, "checkpoints")
    history_dir = os.path.join(save_dir, "history")
    plot_dir = os.path.join(save_dir, "plots")
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(history_dir, exist_ok=True)
    os.makedirs(plot_dir, exist_ok=True)

    print(f" Speichern unter: {save_dir}\n")

    # -------------------------------------------------------------------------
    # Callback-Setup
    callbacks = [
        tf.keras.callbacks.ReduceLROnPlateau(
            monitor="val_loss", factor=0.5, patience=3, min_lr=1e-6, verbose=1),
        tf.keras.callbacks.ModelCheckpoint(
            filepath=os.path.join(checkpoint_dir, f"{model_name}.keras"),
            monitor="val_loss", save_best_only=True, verbose=1),
        tf.keras.callbacks.EarlyStopping(
            monitor="val_loss", patience=patience, restore_best_weights=True, verbose=1)
    ]

    # -------------------------------------------------------------------------
    # Trainingsfall 1: tf.keras.utils.Sequence (RAM-effizient)
    if isinstance(train_data, tf.keras.utils.Sequence):
        fit_args = {
            "validation_data": val_data,
            "epochs": epochs,
            "callbacks": callbacks,
            "verbose": 1
        }

        # class_weight nur bei Klassifikation (nicht Segmentierung!)
        if class_weight and not is_segmentation:
            fit_args["class_weight"] = class_weight

        history = model.fit(train_data, **fit_args)

    # -------------------------------------------------------------------------
    # Trainingsfall 2: NumPy-Daten (bei U-Net oder ohne Generator)
    else:
        if batch_size is None or val_data is None:
            raise ValueError("Für NumPy-Daten müssen batch_size und val_data angegeben werden.")

        x_train, y_train = train_data
        x_val, y_val = val_data

        fit_args = {
            "x": x_train,
            "y": y_train,
            "validation_data": (x_val, y_val),
            "epochs": epochs,
            "batch_size": batch_size,
            "callbacks": callbacks,
            "verbose": 1
        }

        if class_weight and not is_segmentation:
            fit_args["class_weight"] = class_weight

        history = model.fit(**fit_args) 


    # -------------------------------------------------------------------------
    # Dauer & Verlauf speichern
    duration_min = (time.time() - start_time) / 60
    history_data = {
        "model_name": model_name,
        "history": history.history,
        "duration_min": duration_min,
        "epochs_ran": len(history.history.get("loss", []))
    }
    hist_path = os.path.join(history_dir, f"{model_name}_history.pkl")
    with open(hist_path, "wb") as f:
        pickle.dump(history_data, f)

    print(f"\n Trainingsverlauf gespeichert unter: {hist_path}")
    print(f"  Dauer: {duration_min:.2f} Minuten")

    # -------------------------------------------------------------------------
    # Visualisierung der Trainingsmetriken
    def plot_metric(metric_name):
        train_values = history.history.get(metric_name)
        val_values = history.history.get(f"val_{metric_name}")
        if train_values is None or val_values is None:
            return

        plt.figure(figsize=(7, 5))
        plt.plot(train_values, label="Train", linewidth=2)
        plt.plot(val_values, label="Val", linewidth=2, linestyle="--")
        plt.xlabel("Epoche")
        plt.ylabel(metric_name.replace("_", " ").title())
        plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
        plt.grid(True)
        plt.gca().set_axisbelow(True)
        plt.tight_layout()

        plot_path = os.path.join(plot_dir, f"train_{model_name}_{metric_name}.jpg")
        plt.savefig(plot_path, format="svg")
        plt.show()
        print(f"{metric_name.title()}-Plot gespeichert unter: {plot_path}")

    for key in history.history:
        if not key.startswith("val_") and f"val_{key}" in history.history:
            plot_metric(key)

    return history, start_time

In [None]:
#--------------------------------------------------------
# Exakte berechnung des Klassengleichgewichts
#--------------------------------------------------------

# Labels aus dem Generator extrahieren (einmalig pkl laden)
with open(FINAL_CACHE_FILE, "rb") as f:
    all_labels = []
    while True:
        try:
            chunk = pickle.load(f)
            all_labels.extend(chunk["labels"])
        except EOFError:
            break

all_labels = np.array(all_labels)
train_labels = all_labels[train_idx]

# Gewicht pro Klasse berechnen (0 = benigne, 1 = maligne)
class_weights_array = compute_class_weight(
    class_weight="balanced",
    classes=np.unique(train_labels),
    y=train_labels
)
class_weight = {cls: weight for cls, weight in zip(np.unique(train_labels), class_weights_array)}

print("Berechnete class_weight:", class_weight)

Berechnete class_weight: {0.0: 1.1904761904761905, 1.0: 0.8620689655172413}


In [None]:
# Trainierbare Schichten von beiden Modellen
def print_trainable_layers(model, model_name="Modell"):
    total_layers = len(model.layers)
    trainable_layers = sum(1 for layer in model.layers if layer.trainable)
    print(f"{model_name} – Trainierbare Schichten: {trainable_layers} / {total_layers}")


# Modell aufbauen (EfficientNet)
classification_model = build_classification_model()
classification_model.summary()
print_trainable_layers(classification_model, "Klassifikations Modell")

# U-Net aufbauen
unet_model = build_efficientnetb0_unet()
unet_model.summary()
print_trainable_layers(unet_model, "Attention U-Net")

In [None]:
# -----------------------------------------------------------------------------
#  Generator-Klasse für EfficientNet U-Net Segmentierung (RGB-kompatibel)
# -----------------------------------------------------------------------------
# Ziel:
#  - Batchweise Bereitstellung von Bild + Maske
#  - Synchronisierte Augmentierung (nur im Training)
#  - Graustufen-Bilder → RGB
#  - Filterung kleiner Masken (< min_mask_pixels)
# -----------------------------------------------------------------------------

class SegmentationDataGenerator(Sequence):
    def __init__(self, images, masks,
                 batch_size=8,
                 augment=True,
                 shuffle=True,
                 rotation_range=5,
                 zoom_range=0.05,
                 horizontal_flip=True,
                 vertical_flip=True,
                 min_mask_pixels=100):

        # ---------------------------------------------------------------------
        # Masken nach Mindestgröße filtern
        valid_indices = [i for i in range(len(masks)) if np.sum(masks[i]) >= min_mask_pixels]
        self.images = images[valid_indices]
        self.masks = masks[valid_indices]

        self.batch_size = batch_size
        self.augment = augment
        self.shuffle = shuffle
        self.rotation_range = rotation_range
        self.zoom_range = zoom_range
        self.horizontal_flip = horizontal_flip
        self.vertical_flip = vertical_flip

        self.indices = np.arange(len(self.images))
        self.on_epoch_end()

    # -------------------------------------------------------------------------
    # Anzahl Batches pro Epoche
    def __len__(self):
        return int(np.ceil(len(self.images) / self.batch_size))

    # -------------------------------------------------------------------------
    # Einzelner Batch (Index)
    def __getitem__(self, index):
        idxs = self.indices[index * self.batch_size:(index + 1) * self.batch_size]
        X_batch = self.images[idxs].copy()
        Y_batch = self.masks[idxs].copy()

        if self.augment:
            X_batch_aug, Y_batch_aug = [], []
            for x, y in zip(X_batch, Y_batch):
                x_aug, y_aug = self.apply_augmentation(x, y)
                X_batch_aug.append(x_aug)
                Y_batch_aug.append(y_aug)
            X_batch = np.stack(X_batch_aug)
            Y_batch = np.stack(Y_batch_aug)

        # Graustufenbilder -> RGB falls nötig
        if X_batch.shape[-1] == 1:
            X_batch = np.repeat(X_batch, 3, axis=-1)

        return X_batch, Y_batch

    # -------------------------------------------------------------------------
    # Shuffle der Daten nach jeder Epoche
    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)

    # -------------------------------------------------------------------------
    # Synchronisierte Augmentierung für Bild + Maske
    def apply_augmentation(self, x, y):
        if self.horizontal_flip and random.random() < 0.5:
            x = np.fliplr(x)
            y = np.fliplr(y)

        if self.vertical_flip and random.random() < 0.5:
            x = np.flipud(x)
            y = np.flipud(y)

        if self.rotation_range:
            angle = random.uniform(-self.rotation_range, self.rotation_range)
            x = tf.keras.preprocessing.image.random_rotation(
                x, self.rotation_range, row_axis=0, col_axis=1, channel_axis=2
            )
            y = tf.keras.preprocessing.image.random_rotation(
                y, self.rotation_range, row_axis=0, col_axis=1, channel_axis=2
            )

        if self.zoom_range:
            zx = 1 + random.uniform(-self.zoom_range, self.zoom_range)
            zy = 1 + random.uniform(-self.zoom_range, self.zoom_range)
            x = tf.keras.preprocessing.image.random_zoom(
                x, (zx, zy), row_axis=0, col_axis=1, channel_axis=2
            )
            y = tf.keras.preprocessing.image.random_zoom(
                y, (zx, zy), row_axis=0, col_axis=1, channel_axis=2
            )

        return x, y

### 7. Modelltraining - EfficientNetB0

In [None]:
SAVE_DIR = "results"
os.makedirs(SAVE_DIR, exist_ok=True)

# -----------------------------------------------------------------------------
# Training – EfficientNetB0 Klassifikationsmodell
# -----------------------------------------------------------------------------

print("\nStarte Training des EfficientNetB0 Klassifikationsmodells:\n")

# -----------------------------------------------------------------------------
# Modell aufbauen
classification_model = build_classification_model(image_input_shape=(256, 256, 3))

# -----------------------------------------------------------------------------
# Daten vorbereiten 
X_train_img = X_all_img[train_idx]
X_val_img = X_all_img[val_idx]
X_train_clinical = X_all_clinical[train_idx]
X_val_clinical = X_all_clinical[val_idx]
y_train = y_all[train_idx]
y_val = y_all[val_idx]

# -----------------------------------------------------------------------------
# Klassenverteilung anpassen
class_weight = {
    0.0: 1.2,
    1.0: 0.8
}

# -----------------------------------------------------------------------------
# Modell trainieren + speichern
history_classification, start_time_classification = train_and_evaluate(
    model=classification_model,
    train_data=([X_train_img, X_train_clinical], y_train),
    val_data=([X_val_img, X_val_clinical], y_val),
    epochs=30,
    batch_size=16,
    is_segmentation=False,
    model_name="efficientnet_model",
    patience=10,
    save_dir=SAVE_DIR,
    class_weight=class_weight
)

classification_model.save(os.path.join(SAVE_DIR, "efficientnet_model.keras"))
print(f"EfficientNet Modell gespeichert unter: {os.path.join(SAVE_DIR, 'efficientnet_model.keras')}")

In [None]:
SAVE_DIR = "results"
os.makedirs(SAVE_DIR, exist_ok=True)


# -----------------------------------------------------------------------------
# Training – Attention U-Net mit EfficientNetB0 (RGB Segmentierung mit Generator)
# -----------------------------------------------------------------------------

print("\n Starte Training des EfficientNet U-Net Modells:\n")

# -----------------------------------------------------------------------------
# Modell aufbauen
unet_model = build_efficientnetb0_unet(input_shape=(256, 256, 3))

# -----------------------------------------------------------------------------
# Daten vorbereiten
X_train_img = X_all_img[train_idx]
X_val_img = X_all_img[val_idx]
Y_train = mask_all[train_idx]
Y_val = mask_all[val_idx]

# Kanal-Dimension sicherstellen
if X_train_img.ndim == 3:
    X_train_img = np.expand_dims(X_train_img, axis=-1)
if X_val_img.ndim == 3:
    X_val_img = np.expand_dims(X_val_img, axis=-1)
if Y_train.ndim == 3:
    Y_train = np.expand_dims(Y_train, axis=-1)
if Y_val.ndim == 3:
    Y_val = np.expand_dims(Y_val, axis=-1)

# -----------------------------------------------------------------------------
# Generatoren definieren
train_gen = SegmentationDataGenerator(X_train_img, Y_train, batch_size=4, augment=False)
val_gen = SegmentationDataGenerator(X_val_img, Y_val, batch_size=4, augment=False)

# -----------------------------------------------------------------------------
# Training starten (Generator als train_data)
history_unet, start_time_unet = train_and_evaluate(
    model=unet_model,
    train_data=train_gen,
    val_data=val_gen,
    epochs=30,
    is_segmentation=True,
    model_name="efficientnet_unet",
    patience=5,
    save_dir=SAVE_DIR
)

# -----------------------------------------------------------------------------
# Modell speichern
unet_model.save(os.path.join(SAVE_DIR, "efficientnet_unet_model.keras"))
print(f"U-Net Modell gespeichert unter: {os.path.join(SAVE_DIR, 'efficientnet_unet_model.keras')}")

### 9. Evaluierung der Modelle EfficientNetB0 und Attention U-Net

In [None]:
# ----------------------------------------------------------------------------
#  Modell-Evaluierung 
# Ziel: 
# - Confusion Matrix, AUC- & ROC-Kurve für EfficientNet
# - Dice-Score, IoU und visuallsierung der Masken
# ----------------------------------------------------------------------------

os.makedirs("plots", exist_ok=True)

# ----------------------------------------------------------------------
# Evaluierung: Klassifikation (EfficientNet) 
def evaluate_classification_model(model, X_test_img, y_test):
    print("\n Bewertung des Klassifikationsmodells gestartet")

    X_test_img_rgb = np.repeat(X_test_img, 3, axis=-1)
    y_prob = model.predict([X_test_img_rgb, X_test_clinical], verbose=1).squeeze()
    y_pred = (y_prob > 0.5).astype(int)

    auc_score = roc_auc_score(y_test, y_prob)
    print(f" ROC-AUC Score: {auc_score:.4f}")

    fpr, tpr, _ = roc_curve(y_test, y_prob)
    plt.figure(figsize=(6, 5))
    plt.plot(fpr, tpr, label=f"AUC = {auc_score:.3f}")
    plt.plot([0, 1], [0, 1], "k--")
    plt.grid(True)
    plt.xlabel("Falsch Positiv Rate")
    plt.ylabel("Richtig Positiv Rate")
    plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
    plt.tight_layout()
    plt.savefig("plots/classification_roc.jpg", format="jpeg")
    plt.show()

    cm = confusion_matrix(y_test, y_pred)
    plt.figure(figsize=(4, 4))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=["Benigne", "Maligne"], yticklabels=["Benigne", "Maligne"])
    plt.xlabel("Vorhergesagt")
    plt.ylabel("Tatsächlich")
    plt.tight_layout()
    plt.savefig("plots/classification_confusion_matrix.jpg", format="jpeg")
    plt.show()

    print("\n Klassifikationsbericht:")
    print(classification_report(y_test, y_pred, target_names=["Benigne", "Maligne"], digits=2))

# ----------------------------------------------------------------------
# Evaluierung: Segmentierung (U-Net)
def evaluate_segmentation_model(model, X_test_img, mask_test, num_examples=5):
    print("\n Bewertung des U-Net Segmentierungsmodells gestartet")

    preds = model.predict(X_test_img, batch_size=2, verbose=1)
    preds_bin = (preds > 0.5).astype(np.uint8)

    dice_scores, iou_scores = [], []
    smooth = 1e-6

    for true_mask, pred_mask in zip(mask_test, preds_bin):
        true_f = true_mask.flatten()
        pred_f = pred_mask.flatten()
        intersection = np.sum(true_f * pred_f)
        union = np.sum(true_f) + np.sum(pred_f)
        dice = (2. * intersection + smooth) / (union + smooth)
        dice_scores.append(dice)
        iou = jaccard_score(true_f, pred_f, zero_division=0)
        iou_scores.append(iou)

    print(f" Dice Score (avg): {np.mean(dice_scores):.4f}")
    print(f"  IoU Score  (avg): {np.mean(iou_scores):.4f}")

    print("\n Beispielhafte Segmentierungen:")
    indices = np.random.choice(len(X_test_img), num_examples, replace=False)

    for i, idx in enumerate(indices):
        fig, axs = plt.subplots(1, 3, figsize=(12, 4))
        axs[0].imshow(X_test_img[idx].squeeze(), cmap="gray")
        axs[0].set_title("Original")
        axs[1].imshow(mask_test[idx].squeeze(), cmap="Reds")
        axs[1].set_title("Ground Truth")
        axs[2].imshow(preds_bin[idx].squeeze(), cmap="Greens")
        axs[2].set_title("Vorhergesagt")
        for ax in axs:
            ax.axis("off")
        plt.tight_layout()
        path = f"plots/unet_example_{i}.jpg"
        plt.savefig(path, format="jpeg")
        plt.show()
        print(f"Beispiel gespeichert unter: {path}")

# ----------------------------------------------------------------------
# Evaluierung starten
evaluate_classification_model(classification_model, X_test_img, y_test)
evaluate_segmentation_model(unet_model, X_test_img[..., :1], mask_test, num_examples=5)