In [None]:
"""
SEGMENTIERUNGSMASKEN FÜR PG MIT SAM2 UND MANUELLER KORREKTUR

Dieses Skript dient der Erzeugung Segmentierungsmasken für einen Bilddatensatz. 
Es kombiniert eine automatisierte Erstverarbeitung mit einem interaktiven Werkzeug zur manuellen Nachbesserung.

Der Prozess umfasst folgende Schritte:
1. Laden eines vortrainierten Segmentierungsmodells (Ultralytics SAM) und des Zieldatensatzes 
   (Oxford-IIIT Pet).
2. Durchführung einer automatisierten Segmentierung für alle Bilder im Datensatz mithilfe von 
   Punkt-basierten Prompts, um das Hauptobjekt zu identifizieren.
3. Bereitstellung eines interaktiven UI-Tools, das es dem Benutzer ermöglicht, Bilder mit 
   fehlerhaften Masken gezielt zu korrigieren, indem neue Bounding-Box-Prompts gezeichnet werden.
4. Integration der manuell korrigierten Masken in den ursprünglichen, automatisch erstellten Datensatz.
5. Speicherung des finalen, vollständigen und qualitativ hochwertigen Satzes von Masken in einer 
   .npy-Datei für die weitere Verwendung.

"""

In [None]:
import tensorflow_datasets as tfds
import numpy as np
from ultralytics import SAM
import cv2
import os
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
import numpy as np
import cv2
import base64
import io
import os
from PIL import Image
from jupyter_bbox_widget import BBoxWidget

In [None]:
model = SAM("sam2.1_l.pt")
model.info()

(ds_train, ds_test), ds_info = tfds.load(
    'oxford_iiit_pet',
    split=['train', 'test'],
    shuffle_files=False,
    as_supervised=True,
    with_info=True,
    data_dir='/mnt/data/datasets'
)

test_length = ds_info.splits['test'].num_examples
print(f"Length of the test split: {test_length}")

In [None]:
# klassen name anzeigen
def get_label_name(label_tensor, dataset_info):
    """Gets the human-readable label name from the dataset info object."""
    return dataset_info.features['label'].int2str(label_tensor.numpy())
#debug code um sicherzustellen das label namen korrekt.
label_names = ds_info.features['label'].names
print(f" Labels loaded from TFDS: {label_names}")

In [None]:
# Manuelle anpassung der labelnamen, nur für imagenette notwendig.
"""
original_label_names = ds_info.features['label'].names
print(f"Originale Label-Namen von TFDS: {original_label_names}")

# Beispiel:  10 Klassen und möchten die Namen ändern
# Stellen Sie sicher, dass die Länge und Reihenfolge Ihrer neuen Liste
# zu den ursprünglichen Integer-Labels passt.
meine_neuen_label_namen = [
    'tench',
    'English springer',
    'cassette player',
    'chain saw',
    'church',
    'French horn',
    'garbage truck',
    'gas pump',
    'golf ball',
    'parachute'
]
def get_meinen_label_namen(label_tensor, meine_namen_liste):
    label_int = label_tensor.numpy()
    if 0 <= label_int < len(meine_namen_liste):
        return meine_namen_liste[label_int]
    else:
        return "Unbekanntes Label"

        """

In [None]:
#Plot Logik
def show_final_mask(original_image, final_mask_np, image_label_name, title="Segmentation"):
    if final_mask_np is None:
        plt.imshow(original_image)
        plt.title(f"{title} for '{image_label_name}' (No Mask Found)")
        plt.axis('off')
        plt.show()
        return

    plt.figure(figsize=(12, 6))

    plt.subplot(1, 2, 1)
    plt.imshow(original_image)
    plt.title(f"Original: '{image_label_name}'")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    overlaid_image = original_image.copy()
    mask_color = [255, 0, 0] # Rot für die Maske
    alpha = 0.5
    foreground_pixels = final_mask_np > 0

    for c in range(3):
        overlaid_image[:,:,c][foreground_pixels] = \
            overlaid_image[:,:,c][foreground_pixels] * (1-alpha) + mask_color[c] * alpha

    plt.imshow(overlaid_image.astype(np.uint8))
    plt.title(f"{title}: '{image_label_name}'")
    plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
all_results = []  # Liste zum Speichern der kombinierten Masken

for img_idx, (image_tf, label) in enumerate(ds_test.take(800)):  # Verarbeite z.B. die ersten 10 Bilder
    image_np = image_tf.numpy()
    current_label_name = get_label_name(label, ds_info)

    if image_np.ndim == 2:
        image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
    elif image_np.shape[-1] == 1:
        image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)

    print(f"Processing image {img_idx}")
    print(f"  Image shape for SAM: {image_np.shape}, dtype: {image_np.dtype}")

    h, w, _ = image_np.shape
    cx, cy = w // 2, h // 2
    offset = int(min(h, w) * 0.01)
    if offset == 0:
        offset = 10  # Mindestoffset

    prompt_points_list = [
        [cx, cy],
        [cx - offset, cy - offset],
        [cx + offset, cy - offset],
        [cx - offset, cy + offset],
        [cx + offset, cy + offset]
    ]
    prompt_points = np.array(prompt_points_list)
    prompt_labels = np.ones(len(prompt_points_list), dtype=int)

    try:
        results = model.predict(image_np,
                                points=prompt_points,
                                labels=prompt_labels,
                                verbose=False,
                                save=False)
    except Exception as e:
        print(f"  Error during SAM prediction for image {img_idx}: {e}")
        all_results.append(None)  # Platzhalter für Fehlerfall
        continue

    final_combined_mask = None
    masks_found_for_combining = False

    if results and results[0].masks is not None:
        results_object = results[0]
        masks_container = results_object.masks

        try:
            all_masks_data_tensor = masks_container.data
            all_masks_np = all_masks_data_tensor.cpu().numpy()  # (N, H, W)

            if all_masks_np.shape[0] > 0:
                print(f"  Found {all_masks_np.shape[0]} mask(s) from prompt for image {img_idx}. Combining all.")
                masks_found_for_combining = True

                final_combined_mask = np.zeros_like(all_masks_np[0], dtype=np.uint8)
                for i in range(all_masks_np.shape[0]):
                    current_mask_binary = (all_masks_np[i] > 0.7).astype(np.uint8)
                    final_combined_mask = np.logical_or(final_combined_mask, current_mask_binary).astype(np.uint8)

        except AttributeError as ae:
            print(f"  AttributeError while accessing mask data: {ae}")
        except Exception as ex:
            print(f"  An unexpected error occurred while accessing masks for image {img_idx}: {ex}")

    if final_combined_mask is not None and masks_found_for_combining:
        all_results.append(final_combined_mask)  # Speichern der Maske
        show_final_mask(image_np, final_combined_mask, current_label_name, title=f"Combined Segmentation for Image {img_idx}")
    else:
        all_results.append(None)  # Kein Ergebnis None speichern, um Reihenfolge beizubehalten
        print(f"  No masks to combine or an error occurred for image {img_idx}.")
        show_final_mask(image_np, None, current_label_name, title=f"Image {img_idx} (No Masks to Combine / Error)")

print("\nFinished processing dataset.")


In [None]:
# Hilfsfunktionen

def get_image_and_label(idx, dataset, dataset_info):
    """Fetches an image and its correct class name from the dataset."""
    item = next(iter(dataset.skip(idx).take(1)), None)
    if item:
        image_tf, label_tf = item
        image_np = image_tf.numpy()

        if image_np.ndim == 2: image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
        elif image_np.shape[-1] == 1: image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)

        # Use the standard get_label_name function
        label_name = get_label_name(label_tf, dataset_info)
        return image_np, label_name
    return None, None
def encode_image_to_base64(image_np):
    """Encodes a NumPy image to a base64 string for display in the widget."""
    image_pil = Image.fromarray(image_np)
    buffered = io.BytesIO()
    image_pil.save(buffered, format="JPEG")
    encoded_string = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return f"data:image/jpeg;base64,{encoded_string}"

# MODIFIED: This function now accepts a list of bounding boxes to draw
def show_segmentation_result(original_image, bboxes, mask):
    """Displays the original image, drawn bounding boxes, and the resulting mask."""
    if mask is None:
        return

    fig, ax = plt.subplots(1, 1, figsize=(6, 6))

    # Display the original image
    ax.imshow(original_image)

    # Overlay the resulting mask
    color = np.array([30/255, 144/255, 255/255, 0.6]) # RGBA for blue overlay
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

    # Draw all bounding boxes provided in the list
    if bboxes:
        for bbox in bboxes:
            x, y, width, height = bbox['x'], bbox['y'], bbox['width'], bbox['height']
            rect = plt.Rectangle((x, y), width, height, fill=False, edgecolor='green', linewidth=2)
            ax.add_patch(rect)

    ax.set_title("Generated Mask")
    ax.axis('off')
    plt.tight_layout()
    plt.show()


# Manuelle Verbesserung Tool
class ManualCorrector:
    def __init__(self, dataset, dataset_info, indices_to_correct, save_dir="manual_masks"):
        self.dataset = dataset
        self.dataset_info = dataset_info
        #self.name_list = name_list
        self.indices = indices_to_correct
        os.makedirs(save_dir, exist_ok=True)

        self.corrected_masks_cache = {}
        self.current_idx_pos = -1
        self.image = None
        self.generated_mask = None # Will store the single combined mask

        # UI Elements
        self.info_label = widgets.Label()
        self.plot_output = widgets.Output()

        # Initialize BBoxWidget allowing multiple bounding boxes
        self.bbox_widget = BBoxWidget(
            classes=['bbox']
            # max_bboxes is removed to allow unlimited boxes
        )

        # UI Buttons 
        self.segment_button = widgets.Button(description="Generate Mask")
        self.full_mask_button = widgets.Button(description="Generate Full Mask", button_style='warning')
        self.confirm_button = widgets.Button(description="Confirm & Cache", button_style='success', disabled=True)
        self.next_button = widgets.Button(description="Next Image", button_style='info')

        # UI Button Actions
        self.segment_button.on_click(self.run_segmentation_from_bboxes)
        self.full_mask_button.on_click(self.generate_full_mask)
        self.confirm_button.on_click(self.on_confirm)
        self.next_button.on_click(self.load_next_image)

    def load_next_image(self, b=None):
        """Loads the next image that needs correction."""
        self.current_idx_pos += 1

        if self.current_idx_pos >= len(self.indices):
            self.info_label.value = f"All {len(self.corrected_masks_cache)} images corrected. You can now run the final integration block."
            self.segment_button.disabled = self.next_button.disabled = self.confirm_button.disabled = self.full_mask_button.disabled = True
            return

        self.plot_output.clear_output()
        self.confirm_button.disabled = True
        self.segment_button.disabled = False
        self.full_mask_button.disabled = False
        
        current_image_index = self.indices[self.current_idx_pos]

        self.image, label_name = get_image_and_label(
            current_image_index, self.dataset, self.dataset_info
        )

        if self.image is None:
            self.info_label.value = f"Error: Image with index {current_image_index} could not be loaded!"
            return

        self.info_label.value = f"Correcting image {self.current_idx_pos + 1}/{len(self.indices)} (Index: {current_image_index}, Class: '{label_name}')"
        self.bbox_widget.image = encode_image_to_base64(self.image)
        self.bbox_widget.bboxes = [] # Clear previous bboxes

    def run_segmentation_from_bboxes(self, b=None):
        """Generates masks from all drawn bounding boxes and combines them."""
        if not self.bbox_widget.bboxes:
            self.info_label.value = "Error: Please draw at least one bounding box."
            return

        h, w, _ = self.image.shape
        final_combined_mask = np.zeros((h, w), dtype=np.uint8)

        # Loop through each bounding box und kombiniere alle masken zu einer
        for bbox_data in self.bbox_widget.bboxes:
            x_min, y_min = bbox_data['x'], bbox_data['y']
            x_max, y_max = bbox_data['x'] + bbox_data['width'], bbox_data['y'] + bbox_data['height']
            input_box = np.array([[x_min, y_min, x_max, y_max]])

            print(f"  Processing bounding box prompt: {input_box}")

            try:
                results = model.predict(self.image, bboxes=input_box, verbose=False)
            except Exception as e:
                self.info_label.value = f"Error during SAM prediction: {e}"
                return

            if results and results[0].masks is not None and results[0].masks.data.shape[0] > 0:
                all_masks_np = results[0].masks.data.cpu().numpy()
                
                mask_for_this_bbox = np.zeros_like(all_masks_np[0], dtype=np.uint8)
                for i in range(all_masks_np.shape[0]):
                    current_mask_binary = (all_masks_np[i] > 0.7).astype(np.uint8)
                    mask_for_this_bbox = np.logical_or(mask_for_this_bbox, current_mask_binary)
                
                final_combined_mask = np.logical_or(final_combined_mask, mask_for_this_bbox)

        # Finale Maske kombiniere und speichern
        self.generated_mask = final_combined_mask.astype(np.uint8)

        with self.plot_output:
            clear_output(wait=True)
            show_segmentation_result(self.image, self.bbox_widget.bboxes, self.generated_mask)

        if np.any(self.generated_mask):
            self.info_label.value = "Mask generated from all boxes. Review and click 'Confirm & Cache'."
            self.confirm_button.disabled = False
        else:
            self.info_label.value = "No mask could be generated for the selected areas."
            self.confirm_button.disabled = True

    # wenn das ganze bild eine maske sein soll
    def generate_full_mask(self, b=None):
        """Generates a mask that covers the entire image."""
        print("  Generating full image mask.")
        h, w, _ = self.image.shape
        # Create a mask of all ones (white) with the same dimensions as the image
        self.generated_mask = np.ones((h, w), dtype=np.uint8)

        with self.plot_output:
            clear_output(wait=True)
            # Display the result without any bounding boxes
            show_segmentation_result(self.image, [], self.generated_mask)
        
        self.info_label.value = "Full image mask generated. Click 'Confirm & Cache'."
        self.confirm_button.disabled = False

    def on_confirm(self, b=None):
        """Caches the generated mask for the current image index."""
        if self.generated_mask is None:
            self.info_label.value = "Error: Please generate a mask first."
            return

        current_image_index = self.indices[self.current_idx_pos]
        self.corrected_masks_cache[current_image_index] = self.generated_mask
        self.info_label.value = f"Mask for index {current_image_index} cached. {len(self.corrected_masks_cache)}/{len(self.indices)} corrected."
        self.confirm_button.disabled = True

    def start(self):
        """Displays the UI and loads the first image."""
        buttons = widgets.HBox([
            self.segment_button, 
            self.full_mask_button, 
            self.confirm_button, 
            self.next_button
        ])
        ui = widgets.VBox([self.info_label, self.bbox_widget, buttons, self.plot_output])
        display(ui)
        self.load_next_image()


# index angeben für manuelle verbesserung
bad_indices = [14,16,22,33,38,41,49,54,58,59,62,65,67,75,95,109,110,114,123,126,129,133,139,142,145,150,151,156,161,177,188,190,196,198,233,235,244,245,252,265,282,285,286,295,304,306,321,327,338,340,350,357,362,363,371,373,379,386,404,406,408,417,436,446,448,460,461,473,481,485,486,508,512,518,519,527,537,543,544,552,554,557,576,580,592,605,619,620,626,630,635,648,691,709,724,730,744,747,757,761,764,768,772,779,780,781,787]

if 'ds_test' in locals() and 'model' in locals():
    print(f"Starting manual correction for {len(bad_indices)} images...")
    # Instantiate the corrector with your dataset, functions, and list of bad indices
    corrector = ManualCorrector(
        dataset=ds_test,
        dataset_info=ds_info,
        indices_to_correct=bad_indices
    )
    corrector.start()
else:
    print("Error: Ensure that 'ds_test', 'model', and 'get_meinen_label_namen' are defined before running this cell.")

In [None]:

try:
    # Get verbesserte masken
    corrected_masks_cache = corrector.corrected_masks_cache

    if not corrected_masks_cache:
        print("Info: The correction cache is empty. No masks to integrate.")
    else:
        print(f"Cache with {len(corrected_masks_cache)} corrected masks found. Starting integration...")

        # durch alle masken iterieren die automatisch gespeichert wurden
        # füge verbesserte masken in die indixes wo ersetzt werden soll
        for index, new_mask in corrected_masks_cache.items():
            if index < len(all_results):
                all_results[index] = new_mask
                print(f"  -> Mask at index {index} has been replaced with the manual version.")
            else:
                print(f"  -> Warning: Index {index} from the cache is out of bounds for the 'all_results' list.")

        print("\nIntegration complete!")

    # Finale masken als eine.npy file speichern
    save_path = "/mnt/data/masken/oxfordpets_test.npy"

    # als dtype=object um listen an arrays zu speichern
    np.save(save_path, np.array(all_results, dtype=object))

    print(f"\n Final list with {len(all_results)} masks has been successfully saved to: {os.path.abspath(save_path)}")

except NameError:
    print("Error: The 'corrector' object or the 'all_results' list was not found.")
    print("Please ensure you have successfully run the preceding code cells.")
except Exception as e:
    print(f"An unexpected error occurred: {e}")