In [None]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import matplotlib.pyplot as plt
import numpy as np
import sys
sys.path.append(os.path.abspath("../src"))
from dataset import HairSegDataset, get_default_transforms

def visualize_sample_transparent(dataset, idx=0, alpha=0.5):
    """
    Visualizes an image with semi-transparent hair and skin overlays.
    alpha: float between 0 and 1 for transparency
    """
    image, mask = dataset[idx]

    # Convert tensor to numpy if needed
    if hasattr(image, 'permute'):  # torch tensor
        image = image.permute(1,2,0).numpy()
    if hasattr(mask, 'numpy'):  # torch tensor
        mask = mask.numpy()

    # Rescale to 0–255 if normalized
    if image.max() <= 1.0:
        image = (image * 255).astype(np.uint8)

    overlay = image.copy()

    # Hair = Red
    hair_mask = mask == 2
    overlay[hair_mask] = (1 - alpha) * overlay[hair_mask] + alpha * np.array([255, 0, 0])

    # Skin = Green
    skin_mask = mask == 1
    overlay[skin_mask] = (1 - alpha) * overlay[skin_mask] + alpha * np.array([0, 255, 0])

    plt.figure(figsize=(10,5))
    plt.subplot(1,2,1)
    plt.title("Original Image")
    plt.imshow(image)
    plt.axis('off')

    plt.subplot(1,2,2)
    plt.title("Image with Semi-Transparent Overlay")
    plt.imshow(overlay)
    plt.axis('off')
    plt.show()


# Load the dataset
hair_ds = HairSegDataset("../data/images", "../data/masks", transform=get_default_transforms())

# Visualize first 5 samples
for i in range(5):
    visualize_sample_transparent(hair_ds, i, alpha=0.5)


  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'dataset'