In [3]:
def GetClassesFromImage(img, image_dir):
    
    """
    Reads an image to numpy arrays and extract the per-pixel lables. The per-pixel labels are stored in as follows:
    red channel + 256 * green channel. The unique values are the labels attached to the image.
        
    Args:
    img: image filename
    image_dir: path of the directory containing the photos

    returns: numpy array of the classes 
    """
    
    
    from PIL import Image
    import numpy as np
    import os
    
    image = Image.open(os.path.join(image_dir, img))
    r, g, b = image.split()

    g_array = np.array(g, dtype=np.uint16)
    r_array = np.array(r, dtype=np.uint16)

    px_labels = r_array + 256 * g_array
    classes = np.unique(px_labels)
    classes = np.delete(classes, np.where(classes == 0)) #removes the class '0', it does not have a label
    
    return classes