In [3]:
# 📦 IMPORTS
# =====================================
import os
import cv2
import torch
import numpy as np
import ipywidgets as widgets
import matplotlib.pyplot as plt

from omegaconf import OmegaConf
from data import create_dataset
from train import binary_fill_holes
from models import MAPSeg
from timm import create_model
from sklearn.metrics import average_precision_score
from skimage.filters import threshold_otsu
from IPython.display import display, HTML, clear_output

# =====================================
# 🎨 STILE INTERFACCIA
# =====================================
custom_color = '#d88050'  # colon-friendly color

display(HTML(f"""
<style>
    .widget-label {{
        min-width: 140px !important;
        font-weight: bold;
        color: {custom_color};
    }}
    h2 {{
        color: {custom_color};
    }}
    .widget-dropdown, .widget-button {{
        border: 1px solid {custom_color} !important;
    }}
    .widget-button {{
        background-color: {custom_color} !important;
        color: white !important;
    }}
</style>
"""))

# =====================================
# ⚙️ FUNZIONI UTILI
# =====================================
cfg = OmegaConf.load('./configs.yaml')
if not os.path.exists('configs.yaml'):
    raise FileNotFoundError("configs.yaml not found!")

def post_processing_mask(mask_pred):
    if mask_pred.dtype != np.uint8:
        mask_pred = (mask_pred > 0).astype(np.uint8) * 255  

    opening = cv2.morphologyEx(mask_pred, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8))
    filled_image = binary_fill_holes(opening // 255) 
    filled_image = (filled_image * 255).astype(np.uint8)  

    contours, _ = cv2.findContours(filled_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours:
        return filled_image
    contours = sorted(contours, key=cv2.contourArea, reverse=True)
    largest_contours = contours[:1]
    largest_mask = np.zeros_like(mask_pred)
    cv2.drawContours(largest_mask, largest_contours, -1, 255, thickness=cv2.FILLED)
    return largest_mask



def find_best_threshold(prob_map, mask_gt):
    prob_map = prob_map.flatten()
    mask_gt = mask_gt.flatten().astype(int)
    thresholds = np.linspace(0.1, 0.9, 9)
    best_threshold, best_score = 0.02, 0
    for t in thresholds:
        try:
            score = average_precision_score(mask_gt, (prob_map > t).astype(int))
            if score > best_score:
                best_threshold, best_score = t, score
        except:
            pass
    try:
        otsu_threshold = threshold_otsu(prob_map)
    except:
        otsu_threshold = 0.5
    percentile_threshold = np.percentile(prob_map, 95)
    return otsu_threshold

def minmax_scaling(img):
    if isinstance(img, torch.Tensor):
        img = img.detach().cpu().numpy()
    if img.shape[0] == 3:
        img = np.transpose(img, (1, 2, 0))
    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
    return (img * 255).astype(np.uint8)

# =====================================
# 📥 MODELLO
# =====================================
def load_model(target_name):
    global model
    global testset

    testset = create_dataset(
        datadir=cfg.DATASET.datadir,
        target=target_name,
        is_train=False,
        resize=cfg.DATASET.resize,
        imagesize=cfg.DATASET.get("imagesize"),
        texture_source_dir=cfg.DATASET.texture_source_dir,
        structure_grid_size=cfg.DATASET.structure_grid_size,
        transparency_range=cfg.DATASET.transparency_range,
        perlin_scale=cfg.DATASET.perlin_scale,
        min_perlin_scale=cfg.DATASET.min_perlin_scale,
        perlin_noise_threshold=cfg.DATASET.perlin_noise_threshold,
        bg_threshold=cfg.DATASET.get("bg_threshold"),
        anomaly_type=cfg.DATASET.get("anomaly_type"),
        seed=cfg.DATASET.get("seed", 42),
    )

    memory_bank = torch.load('./saved_model/MAPSeg-polyp/memory_bank.pt')
    memory_bank.device = 'cpu'
    for k in memory_bank.memory_information.keys():
        memory_bank.memory_information[k] = memory_bank.memory_information[k].cpu()

    feature_extractor = create_model(
        cfg.MODEL.feature_extractor_name,
        pretrained=True,
        features_only=True
    )

    model = MAPSeg(
        memory_bank=memory_bank,
        feature_extractor=feature_extractor
    )

    model.load_state_dict(torch.load('./saved_model/MAPSeg-polyp/best_model.pt', map_location='cpu'))
    model.eval()

# =====================================
# 🖼️ PLOT DEI RISULTATI
# =====================================
from IPython.display import display, HTML
import io
from PIL import Image
def image_to_base64(img):
    buffer = io.BytesIO()
    img.save(buffer, format='PNG')
    buffer.seek(0)
    return base64.b64encode(buffer.getvalue()).decode()

import io
import base64
from IPython.display import display, HTML

def result_plot(idx):
    input_i, mask_i, target_i, _ = testset[idx]
    output_i = model(input_i.unsqueeze(0)).detach()
    output_i = torch.nn.functional.softmax(output_i, dim=1)

    prob_map = output_i[0, 1].cpu().numpy()
    gt_mask = mask_i.squeeze().cpu().numpy()

    try:
        threshold = find_best_threshold(prob_map, gt_mask)
    except:
        threshold = 0.3

    binary_pred = (prob_map > threshold).astype(int)
    binary_pred = post_processing_mask(binary_pred)

    # Crea il plot
    fig, ax = plt.subplots(1, 5, figsize=(15, 4))
    ax[0].imshow(minmax_scaling(input_i))
    ax[0].set_title('Input Image', fontsize=25)
    
    ax[1].imshow(gt_mask, cmap='gray')
    ax[1].set_title('Ground Truth',  fontsize=25)

    heatmap = (prob_map - prob_map.min()) / (prob_map.max() - prob_map.min() + 1e-8)
    heatmap = (heatmap * 255).astype(np.uint8)
    heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
    ax[2].imshow(heatmap_color)
    ax[2].set_title('Heatmap',  fontsize=25)

    ax[3].imshow(binary_pred, cmap='gray')
    ax[3].set_title('Predicted Mask',  fontsize=25)

    ax[4].imshow(minmax_scaling(input_i.permute(1, 2, 0)))
    ax[4].imshow(binary_pred, cmap='gray', alpha=0.5)
    ax[4].set_title('Overlay',  fontsize=25)

    for a in ax:
        a.axis('off')
    plt.tight_layout()

    # 👉 Converte in immagine base64 direttamente
    buf = io.BytesIO()
    fig.savefig(buf, format="png", bbox_inches='tight', transparent=True)
    buf.seek(0)
    image_base64 = base64.b64encode(buf.read()).decode("utf-8")
    buf.close()
    plt.close(fig)

    # Mostra in un box arancione chiaro
    display(HTML(f"""
        <div style="background-color: #fff4eb; padding: 20px; border-radius: 10px; border: 1px solid #e3c2a4; margin-top: 15px;">
            <img src="data:image/png;base64,{image_base64}" style="max-width:100%;">
        </div>
    """))

# =====================================
# 🧠 UI COMPONENTS
# =====================================
load_button = widgets.Button(
    description="Load Model",
    layout=widgets.Layout(width='1650px',height='40px'),
    style={'font_size': '24px'}
)

output = widgets.Output()

@output.capture()
def on_button_clicked(b):
        with output:
            clear_output(wait=True)
            try:
                load_model("polyp")
            except Exception as e:
                return
        
            file_list = widgets.Dropdown(
                options=[(os.path.basename(p), i) for i, p in enumerate(testset.file_list)],
                value=0,
                description='Image:',
                layout=widgets.Layout(width='250px')
            )
        
            header = widgets.HTML(f"<b style='color:{custom_color};'>Choose image to visualize</b>")
            interact_widget = widgets.interactive_output(result_plot, {'idx': file_list})
        
            display(widgets.VBox([header, file_list, interact_widget]))

load_button.on_click(on_button_clicked)

# =====================================
# 🖼️ TITOLO E UI FINALE
# =====================================
logo_path = "assets/mapseg_logo.png"
display(HTML(f"""
    <div style="display: flex; justify-content: center; align-items: center; gap: 25px; margin-bottom: 40px;">
        <img src="{logo_path}" alt="MAPSeg Logo" style="width: 90px;">
        <span style="font-size: 100px; font-weight: bold; color: {custom_color}; font-family: sans-serif;">
            MAPSeg
        </span>
    </div>
"""))




display(widgets.VBox([
    load_button,
    widgets.HTML("<hr>"),
    output
]))


VBox(children=(Button(description='Load Model', layout=Layout(height='40px', width='1100px'), style=ButtonStyl…