# Ball Segmentation - Inference Demo

This notebook demonstrates how to use trained models for inference on new images.

In [None]:
from pathlib import Path
from time import perf_counter
import torch
from ultralytics.models import YOLO
from PIL import Image
import numpy as np
from IPython.display import Image as IPImage, display

In [None]:
from utils import DisplayPath
Path = DisplayPath

In [None]:
img_dir = Path('datasets/ready/full_dataset/demo_folder')
assert img_dir.exists(), f"Directory {img_dir} does not exist."
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f"Device: {DEVICE}")

In [None]:
RUNS_DIR = Path("runs/segment")
project_name = 'ball_person_trashcan_model_v4'
best_model_path = RUNS_DIR / project_name / 'weights' / 'best.pt'
best_model_path.display()
model = YOLO(best_model_path)

## Run Inference on Single Image

In [None]:
img_path = img_dir / "tom3.jpg"

def predict(img_path: Path, conf=0.2):
    start_time = perf_counter()
    
    results = model.predict(
        source=str(img_path),
        save=False,
        conf=conf,
        iou=0.5,
        imgsz=640,
        device=DEVICE,
        show_labels=True,
        show_conf=True
    )
    end_time = perf_counter()
    elapsed_time = end_time - start_time
    print(f"Inference time: {elapsed_time:.4f} seconds")
    
    # Plot and display the results
    if results:
        annotated_img = results[0].plot()
        annotated_img = Image.fromarray(annotated_img[..., ::-1])  # BGR to RGB
        display(annotated_img)
    
    return results

In [None]:
def prepare_img(img_path: Path) -> Path:
    # augment red color: ISOLATE red - make red pixels pop, desaturate everything else

    img = Image.open(img_path).convert("RGB")
    img_array = np.array(img, dtype=np.float32)
    
    red_channel = img_array[:, :, 0]
    green_channel = img_array[:, :, 1]
    blue_channel = img_array[:, :, 2]
    
    # Calculate "redness" intensity
    # Positive if Red is dominant, negative otherwise
    # We compare Red to the MAXIMUM of Green and Blue
    dominance = red_channel - np.maximum(green_channel, blue_channel)
    
    # Create a soft mask (0.0 to 1.0) based on dominance
    # Shift sigmoid center to require some minimum redness (e.g. 15)
    # Steepness divisor determines how sharp the cutoff is
    mask = 1 / (1 + np.exp(-(dominance - 20) / 10))
    
    # Create a grayscale version of the image for non-red parts
    grayscale = (red_channel + green_channel + blue_channel) / 3.0
    
    # For the "Red" version (when mask is 1):
    # Boost Red, suppress Green/Blue to make it very vivid
    r_red = np.clip(red_channel * 1.5, 0, 255)
    g_red = np.clip(green_channel * 0.2, 0, 255) # Darken G
    b_red = np.clip(blue_channel * 0.2, 0, 255) # Darken B
    
    # Blend Red version and Grayscale version using the mask
    final_r = r_red * mask + grayscale * (1 - mask)
    final_g = g_red * mask + grayscale * (1 - mask)
    final_b = b_red * mask + grayscale * (1 - mask)
    
    # Stack and save
    augmented_img_array = np.stack((final_r, final_g, final_b), axis=-1).astype(np.uint8)
    augmented_img = Image.fromarray(augmented_img_array)
    augmented_img_path = img_path.parent / f"augmented_{img_path.name}"
    augmented_img.save(augmented_img_path)
    return augmented_img_path

In [None]:
def prepare_img_v2(img_path: Path) -> Path:
    # Mimic the curve from the screenshot (brightening shadows/midtones on Value channel)
    img = Image.open(img_path).convert("HSV")
    img_array = np.array(img)
    
    # H = [:,:,0], S = [:,:,1], V = [:,:,2]
    v_channel = img_array[:, :, 2]
    
    # Control points estimated from the GIMP curve screenshot
    # Grid seems to be 4x4, so steps of 64.
    # (0,0) -> Bottom left
    # (64, 90) -> First vertical line, curve is higher than diagonal (~64)
    # (128, 160) -> Mid point, slightly lifted
    # (192, 215) -> Third line
    # (255, 255) -> Top right
    
    x_points = [0, 64, 128, 192, 255]
    y_points = [0, 90, 160, 215, 255]
    
    # Create Lookup Table (LUT)
    x_val = np.arange(256)
    lut = np.interp(x_val, x_points, y_points).astype(np.uint8)
    
    # Apply LUT to V channel
    # Numpy advanced indexing: lut[v_channel] replaces each value in v_channel with lut[value]
    v_channel_transformed = lut[v_channel]
    
    # Update image array
    img_array[:, :, 2] = v_channel_transformed
    
    # Convert back to RGB
    new_img = Image.fromarray(img_array, mode="HSV").convert("RGB")
    
    save_path = img_path.parent / f"augmented_v2_{img_path.name}"
    new_img.save(save_path)
    return save_path

In [None]:
(path:=prepare_img_v2(img_dir / "tom_original.jpg")).display()

In [None]:
predict(path, conf=0.1)