In [3]:
import numpy as np
import json
import matplotlib.pyplot as plt
from skimage.draw import disk
import os
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from PIL import Image

def create_mask_from_json(json_data, shape):
    mask = np.zeros(shape, dtype=np.float32)
    for item in json_data:
        rr, cc = disk((item['y'], item['x']), item['radius'], shape=shape)  # '16' is an arbitrary radius for the core
        mask[rr, cc] = 1.0
    return mask

def load_images_and_labels(image_dir, label_dir):
    image_files = [os.path.join(image_dir, file) for file in sorted(os.listdir(image_dir)) if file.endswith('.png')]
    label_files = [os.path.join(label_dir, file) for file in sorted(os.listdir(label_dir)) if file.endswith('.json')]
    
    images = []
    masks = []

    for image_file, label_file in zip(image_files, label_files):
        # Load image
        image = img_to_array(load_img(image_file, color_mode='rgb'))  # or 'rgb' if your images are colored
        images.append(image / 255.0)  # Normalizing to [0, 1]

        # Load corresponding label
        with open(label_file, 'r') as file:
            json_data = json.load(file)
        mask = create_mask_from_json(json_data, shape=(1024, 1024))
        masks.append(mask)

    return np.array(images), np.array(masks).reshape(-1, 1024, 1024, 1)

# Usage
image_dir = './TMA_WSI_Padded_PNGs'
label_dir = './TMA_WSI_Labels_updated'
images, masks = load_images_and_labels(image_dir, label_dir)


# Save masks as images into the image_labels folder

def save_masks(images, masks, image_dir, label_dir):

    # Create the image_labels folder if it doesn't exist
    if not os.path.exists(label_dir):
        os.makedirs(label_dir)

    # Save the masks as images
    for i, mask in enumerate(masks):
        mask = (mask * 255.0).astype(np.uint8)
        mask_image = Image.fromarray(mask.squeeze())

        # ensure the saved masks has the same name as their corresponding images by getting the names of the files in the image_dir
        
        image_name = os.listdir(image_dir)[i+1]
       
        mask_image.save(os.path.join(label_dir, f'{image_name}'))


save_masks(images, masks, image_dir, './image_labels')