# Preprocess Data Folders
In this notebook we preprocess data folders (Renaming, removing background if needed, ...)

In [6]:
import os

In [7]:
# HF Cache
%env CUDA_VISIBLE_DEVICES=2
os.environ["HF_HOME"] = "../../.cache"
!echo $HF_HOME
!huggingface-cli whoami

env: CUDA_VISIBLE_DEVICES=2
../../.cache
Maats
[1morgs: [0m DBD-research-group,Basket-AEye


## Settings

In [8]:
FOLDER_PATH = "../../huggingface/" + "10classes/pasta"

### Rename

In [17]:
for root, dirs, files in os.walk(FOLDER_PATH):
    folder_name = os.path.basename(root)
    images = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    temp_names = []
    # First pass: rename to temp names to avoid conflicts
    for idx, img in enumerate(sorted(images), 1):
        ext = os.path.splitext(img)[1]
        temp_name = f"__temp__{idx}{ext}"
        old_path = os.path.join(root, img)
        temp_path = os.path.join(root, temp_name)
        os.rename(old_path, temp_path)
        temp_names.append((temp_path, f"{folder_name}{idx}{ext}"))
    # Second pass: rename temp names to final names
    for temp_path, final_name in temp_names:
        final_path = os.path.join(root, final_name)
        os.rename(temp_path, final_path)

### Remove background

In [None]:
import os
from rembg import new_session, remove
from PIL import Image, ImageEnhance
import io

# Use the most accurate model
session = new_session("isnet-general-use")

for root, dirs, files in os.walk(FOLDER_PATH):
    images = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    for img_name in images:
        img_path = os.path.join(root, img_name)


        # Remove background
        output_bytes = remove(
            input_bytes,
            session=session,
            alpha_matting=True,
            alpha_matting_foreground_threshold=240,
            alpha_matting_background_threshold=10,
            alpha_matting_erode_size=1,
        )

        # Save result
        with open(img_path, "wb") as f:
            f.write(output_bytes)

        print(f"✅ Saved: {img_path}")

Downloading data from 'https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx' to file '/home/stu235269/.u2net/isnet-general-use.onnx'.
100%|████████████████████████████████████████| 179M/179M [00:00<00:00, 127GB/s]
[1;31m2025-06-11 15:22:48.621897055 [E:onnxruntime:Default, provider_bridge_ort.cc:2195 TryGetProviderInfo_CUDA] /onnxruntime_src/onnxruntime/core/session/provider_bridge_ort.cc:1778 onnxruntime::Provider& onnxruntime::ProviderLibrary::Get() [ONNXRuntimeError] : 1 : FAIL : Failed to load library libonnxruntime_providers_cuda.so with error: libcudnn.so.9: cannot open shared object file: No such file or directory
[m


✅ Saved: ../../huggingface/10classes/pasta/pasta4.png
✅ Saved: ../../huggingface/10classes/pasta/pasta5.png
✅ Saved: ../../huggingface/10classes/pasta/pasta6.png
✅ Saved: ../../huggingface/10classes/pasta/pasta7.png
✅ Saved: ../../huggingface/10classes/pasta/pasta8.png
✅ Saved: ../../huggingface/10classes/pasta/pasta9.png
✅ Saved: ../../huggingface/10classes/pasta/pasta10.png
✅ Saved: ../../huggingface/10classes/pasta/pasta11.png
✅ Saved: ../../huggingface/10classes/pasta/pasta12.png
✅ Saved: ../../huggingface/10classes/pasta/pasta13.png
✅ Saved: ../../huggingface/10classes/pasta/pasta14.png
✅ Saved: ../../huggingface/10classes/pasta/pasta15.png
✅ Saved: ../../huggingface/10classes/pasta/pasta16.png
✅ Saved: ../../huggingface/10classes/pasta/pasta17.png
✅ Saved: ../../huggingface/10classes/pasta/pasta18.png
✅ Saved: ../../huggingface/10classes/pasta/pasta19.png
✅ Saved: ../../huggingface/10classes/pasta/pasta20.png
✅ Saved: ../../huggingface/10classes/pasta/pasta21.png
✅ Saved: ../../h

### Make image squared

In [11]:
from PIL import Image

for root, dirs, files in os.walk(FOLDER_PATH):
    images = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    for img_name in images:
        img_path = os.path.join(root, img_name)
        with Image.open(img_path) as img:
            img = img.convert("RGBA")
            w, h = img.size
            if w != h:
                print(f"Padding {img_path} to square...")
                max_side = max(w, h)
                new_img = Image.new("RGBA", (max_side, max_side), (0, 0, 0, 0))
                offset = ((max_side - w) // 2, (max_side - h) // 2)
                new_img.paste(img, offset)
                new_img.save(img_path)

### Filter outputs

In [8]:
import os
from PIL import Image
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output

class BatchImageFilter:
    def __init__(self, image_paths, batch_size=9, grid_cols=3):
        self.image_paths = image_paths
        self.batch_size = batch_size
        self.grid_cols = grid_cols
        self.index = 0
        self.selected = set()
        self.out = widgets.Output()
        
        # Control buttons
        self.keep_btn = widgets.Button(description="✅ Keep Selected")
        self.remove_btn = widgets.Button(description="🗑️ Delete Unselected")
        self.next_btn = widgets.Button(description="➡️ Next Batch")
        self.keep_btn.on_click(self.keep_selected)
        self.remove_btn.on_click(self.delete_unselected)
        self.next_btn.on_click(self.load_next_batch)

        self.ui = widgets.VBox([
            self.out,
            widgets.HBox([self.keep_btn, self.remove_btn, self.next_btn])
        ])
        display(self.ui)
        self.load_next_batch()

    def load_next_batch(self, b=None):
        self.selected.clear()
        self.out.clear_output(wait=True)
        self.current_batch = self.image_paths[self.index:self.index + self.batch_size]
        self.index += self.batch_size

        with self.out:
            fig, axes = plt.subplots(
                nrows=(len(self.current_batch) + self.grid_cols - 1) // self.grid_cols,
                ncols=self.grid_cols,
                figsize=(12, 8)
            )
            axes = axes.flatten()

            for ax in axes[len(self.current_batch):]:
                ax.axis("off")

            for i, path in enumerate(self.current_batch):
                img = Image.open(path)
                ax = axes[i]
                ax.imshow(img)
                ax.set_title(os.path.basename(path), fontsize=8)
                ax.axis("off")
                # Attach a clickable event
                def onclick(event, i=i):
                    if i in self.selected:
                        self.selected.remove(i)
                    else:
                        self.selected.add(i)
                    # Highlight selected
                    ax.set_edgecolor("red" if i in self.selected else "black")
                    fig.canvas.draw_idle()

                fig.canvas.mpl_connect("button_press_event", lambda event, i=i: onclick(event, i))
            plt.tight_layout()
            plt.show()

    def keep_selected(self, b=None):
        print("Keeping selected images.")
        self.image_paths = [
            path for i, path in enumerate(self.current_batch) if i in self.selected
        ] + self.image_paths[self.index:]
        self.index = 0
        self.load_next_batch()

    def delete_unselected(self, b=None):
        to_delete = [
            path for i, path in enumerate(self.current_batch) if i not in self.selected
        ]
        for path in to_delete:
            try:
                #os.remove(path)
                print(f"Deleted: {path}")
            except Exception as e:
                print(f"Error deleting {path}: {e}")
        self.image_paths = [
            path for i, path in enumerate(self.current_batch) if i in self.selected
        ] + self.image_paths[self.index:]
        self.index = 0
        self.load_next_batch()

def collect_images(folder):
    image_extensions = {'.png', '.jpg', '.jpeg', '.bmp', '.gif'}
    paths = []
    for root, _, files in os.walk(folder):
        for file in files:
            if any(file.lower().endswith(ext) for ext in image_extensions):
                paths.append(os.path.join(root, file))
    return sorted(paths)

# Usage
image_folder = FOLDER_PATH
image_paths = collect_images(image_folder)
BatchImageFilter(image_paths, batch_size=9, grid_cols=3)

VBox(children=(Output(), HBox(children=(Button(description='✅ Keep Selected', style=ButtonStyle()), Button(des…

<__main__.BatchImageFilter at 0x7efd9152dfc0>