In [None]:
from dataloader_points import *
from PIL import Image
import numpy as np
x_train, x_test, y_train_mask, y_test_mask = load_ph2_dataset("data/PH2_Dataset_images")

#inline display of images in jupyter notebook
# show bmp image 0 in subplot of all 4
import matplotlib.pyplot as plt
plt.subplot(2,2,1)
plt.imshow(Image.open(x_train[0]))
plt.subplot(2,2,2)
plt.imshow(Image.open(y_train_mask[0]))
plt.subplot(2,2,3)
plt.imshow(Image.open(x_test[0]))
plt.subplot(2,2,4)
plt.imshow(Image.open(y_test_mask[0]))
plt.show()  

In [None]:
import numpy as np
# take a list of mask paths and return two numpy arrays of shape (num_points, 2) each containing the (x, y) coordinates of points sampled from black and white regions of the mask respectively.
def mask_to_points(mask_path, correct_points=10, incorrect_points=5):
    mask = np.array(Image.open(mask_path).convert('L'))
    ys, xs = np.where(mask > 128)  # Get coordinates of white pixels
    # get coordinates of black pixels
    yn, xn = np.where(mask <= 128)  # Get coordinates of black pixels
    if len(xs) == 0 or len(ys) == 0:
        return np.array([]), np.array([])  # No white pixels found
    # Randomly select points from white pixels
    num_white_points = correct_points
    white_indices = np.random.choice(len(xs), num_white_points, replace=False)
    # Randomly select points from black pixels
    num_black_points = incorrect_points 
    num_black_points = min(num_black_points, len(xn))
    black_indices = np.random.choice(len(xn), num_black_points, replace=False)
    # Combine the selected points
    incorrect_points = np.array(list(zip(xn[black_indices], yn[black_indices])))
    correct_points = np.array(list(zip(xs[white_indices], ys[white_indices])))
    
    return correct_points, incorrect_points

# Example usage
correct, incorrect = mask_to_points(y_train_mask[0], correct_points=10, incorrect_points=5)


# display the points on the original image
img = Image.open(x_train[0])
plt.imshow(img)
# plot white points in green
if len(correct) > 0:
    plt.scatter(correct[:,0], correct[:,1], c='green', s=10)
# plot black points in red
if len(incorrect) > 0:
    plt.scatter(incorrect[:,0], incorrect[:,1], c='red', s=10)
plt.show()

In [None]:
from dataloader_points import PointSegmentationDataset

In [None]:
ph2_train_imgs, ph2_test_imgs, ph2_train_masks, ph2_test_masks = load_ph2_dataset("data/PH2_Dataset_images")
img_size=(256, 256)
# Ensure same length
n_train = min(len(ph2_train_imgs), len(ph2_train_masks))
n_test = min(len(ph2_test_imgs), len(ph2_test_masks))
ph2_train_imgs, ph2_train_masks = (
    ph2_train_imgs[:n_train],
    ph2_train_masks[:n_train],
)
ph2_test_imgs, ph2_test_masks = ph2_test_imgs[:n_test], ph2_test_masks[:n_test]

# Create transforms
t = transforms.Compose([transforms.Resize(img_size), transforms.ToTensor()])

correct_points = 5
incorrect_points = 6
# Create datasets
train_ds = PointSegmentationDataset(
    ph2_train_imgs,
    ph2_train_masks,
    t,
    correct_points=correct_points,
    incorrect_points=incorrect_points,
    method='intensity'
)

In [None]:
train_ds.__getitem__(0)

In [None]:
idx = 3
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

im = train_ds[idx][0].permute(1, 2, 0)
mask = train_ds[idx][1].permute(1, 2, 0)
points = train_ds[idx][2]
labels = train_ds[idx][3]

# Left subplot: Image with points
axes[0].imshow(im)
axes[0].scatter(
    points[:correct_points, 0],
    points[:correct_points, 1],
    c='green',
    s=10,
    label='Correct points'
)
axes[0].scatter(
    points[correct_points:, 0],
    points[correct_points:, 1],
    c='red',
    s=10,
    label='Incorrect points'
)
axes[0].set_title('Image with Points')
axes[0].legend()

# Middle subplot: Mask
axes[1].imshow(mask.squeeze(2), cmap='gray')
axes[1].set_title('Mask')

# Right subplot: Masked image
axes[2].imshow(im*mask)
axes[2].scatter(
    points[:correct_points, 0],
    points[:correct_points, 1],
    c='green',
    s=10,
    label='Correct points'
)
axes[2].scatter(
    points[correct_points:, 0],
    points[correct_points:, 1],
    c='red',
    s=10,
    label='Incorrect points'
)
axes[2].set_title('Image with Points')
axes[2].legend()
plt.tight_layout()
plt.show()


In [None]:
labels

In [None]:
train_ds[0][1].squeeze(0).shape

In [None]:
import matplotlib.pyplot as plt

img, mask, points, labels = train_ds[3]
mask_np = mask.squeeze().numpy()

plt.imshow(mask_np, cmap='gray')
plt.scatter(points[labels==1,0], points[labels==1,1], c='r', s=10, label='inside')
plt.scatter(points[labels==0,0], points[labels==0,1], c='b', s=10, label='outside')
plt.legend()
plt.show()


In [None]:
len(points)