In [None]:
# Imports and config
import time
import os
import io
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import tensorflow as tf
import syssys.path.append('..')
from sod_model import get_sod_model

# Path to checkpoint (change if needed)
CHECKPOINT_DIR = '/home/aliho/sod_tf_project/checkpoints_exp7_DEEP
BEST_WEIGHTS = os.path.join(CHECKPOINT_DIR, 'best_model.weights.h5')

# Model config (match training)
TARGET_SIZE = 224
BASE_FILTERS = 32
NUM_BLOCKS = 5
USE_BATCHNORM = True
USE_DROPOUT = True
DROPOUT_RATE = 0.5

print('TensorFlow version:', tf.__version__)

In [None]:
# Helper: load model and weights (returns a compiled model-like object we can call)
def load_model(checkpoint_path=BEST_WEIGHTS, target_size=TARGET_SIZE, base_filters=BASE_FILTERS, num_blocks=NUM_BLOCKS, use_batchnorm=USE_BATCHNORM, use_dropout=USE_DROPOUT, dropout_rate=DROPOUT_RATE):
    model = get_sod_model(input_shape=(target_size, target_size, 3), base_filters=base_filters, use_batchnorm=use_batchnorm, use_dropout=use_dropout, dropout_rate=dropout_rate, num_blocks=num_blocks)
    if os.path.exists(checkpoint_path):
        try:
            model.load_weights(checkpoint_path)
            print(f'Loaded weights from {checkpoint_path}')
        except Exception as e:
            print('Could not load weights:', e)
    else:
        print(f'Weights not found at {checkpoint_path} — model is uninitialized')
    return model

model = load_model()

In [None]:
# Preprocess an image (bytes or path)
def preprocess_image_from_bytes(img_bytes, target_size=TARGET_SIZE):
    img = Image.open(io.BytesIO(img_bytes)).convert('RGB')
    img = img.resize((target_size, target_size), Image.BILINEAR)
    arr = np.array(img).astype('float32') / 255.0
    return arr

def preprocess_image_from_path(path, target_size=TARGET_SIZE):
    img = Image.open(path).convert('RGB')
    img = img.resize((target_size, target_size), Image.BILINEAR)
    arr = np.array(img).astype('float32') / 255.0
    return arr

def postprocess_mask(mask):
    # mask expected 0..1 float, shape (H,W)
    m = np.clip(mask, 0.0, 1.0)
    return m

In [None]:
# Inference helper that measures time and returns mask + time (seconds)
def infer_image(model, img_arr):
    # img_arr: HxWx3 float32 in [0,1]
    inp = np.expand_dims(img_arr, axis=0).astype('float32')
    t0 = time.time()
    pred = model(inp, training=False).numpy()
    t1 = time.time()
    mask = pred[0, ..., 0] if pred.ndim == 4 else pred[0]
    return postprocess_mask(mask), (t1 - t0)

def show_results(img_arr, mask, inference_time):
    # img_arr: HxWx3 in [0,1], mask: HxW in [0,1]
    h, w = mask.shape
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    axes[0].imshow(img_arr)
    axes[0].set_title('Input')
    axes[0].axis('off')

    axes[1].imshow(mask, cmap='gray')
    axes[1].set_title('Predicted Mask')
    axes[1].axis('off')

    # overlay: red mask on image
    overlay = (img_arr * 255).astype('uint8').copy()
    alpha = np.clip(mask[..., None], 0, 1) * 0.6
    red = np.zeros_like(overlay)
    red[..., 0] = (mask * 255).astype('uint8')
    overlay = (overlay * (1 - alpha) + red * alpha).astype('uint8')
    axes[2].imshow(overlay)
    axes[2].set_title(f'Overlay — {inference_time*1000:.1f} ms')
    axes[2].axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
# Upload widget if available, otherwise fallback to path-based inference
IMAGE_PATH = '/home/aliho/sod_tf_project/data/DUTS-TE/DUTS-TE/DUTS-TE-Image/ILSVRC2012_test_00000259.jpg'  # this is a linux path to an image on my local machine, specifically data/duts-te/duts-te-image. If you want to use the upload widget, set this to None and run the cell below.

# Try ipywidgets upload widget
try:
    from ipywidgets import FileUpload, Button, HBox, VBox, Output
    uploader = FileUpload(accept='image/*', multiple=False)
    run_btn = Button(description='Run inference')
    out = Output()

    display(VBox([uploader, run_btn, out]))

    def on_run(b):
        out.clear_output()
        if uploader.value:
            # `uploader.value` is a dict-like in many jupyter versions
            item = list(uploader.value.values())[0] if isinstance(uploader.value, dict) else uploader.value[0]
            content = item['content'] if isinstance(item, dict) and 'content' in item else item
            img_arr = preprocess_image_from_bytes(content, TARGET_SIZE)
            mask, dt = infer_image(model, img_arr)
            with out:
                show_results(img_arr, mask, dt)
        elif IMAGE_PATH:
            img_arr = preprocess_image_from_path(IMAGE_PATH, TARGET_SIZE)
            mask, dt = infer_image(model, img_arr)
            with out:
                show_results(img_arr, mask, dt)
        else:
            with out:
                print('No image uploaded or IMAGE_PATH set.')

    run_btn.on_click(on_run)
except Exception as e:
    print('Upload widget not available or failed to instantiate:', e)
    print('You can instead set IMAGE_PATH to a local file and run the cell below to perform inference.')

In [None]:
# Fallback: run inference on a path you set in IMAGE_PATH and display results
if IMAGE_PATH:
    arr = preprocess_image_from_path(IMAGE_PATH, TARGET_SIZE)
    mask, dt = infer_image(model, arr)
    show_results(arr, mask, dt)
else:
    print('No IMAGE_PATH set. Use the upload widget above, or set IMAGE_PATH and re-run this cell.')