# Color Threshold, Green Screen

### Import resources

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from sklearn.cluster import DBSCAN
from skimage.util import random_noise

import random
import numpy as np
import cv2

%matplotlib inline

In [None]:
IMG_FILE='images/demo.png'

### Read in and display the image

In [None]:
def read_image(image_file):
    # Read in the image
    image = cv2.imread(image_file)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Normalize image
    #if (image.max() > 1):
    #    image = (image / 255).astype(np.uint8)

    # Print out the image dimensions (height, width, and depth (color))
    return image

image = read_image(IMG_FILE)

In [None]:
# Display the image
plt.imshow(image)

### Apply K-Means to the image

In [None]:
def apply_kmeans(image, k=2):
    # Reshape image into a 2D array of pixels and 3 color values (RGB)
    pixel_vals = image.reshape((-1,3))

    # Convert to float type
    pixel_vals = np.float32(pixel_vals)

    # define stopping criteria
    # you can change the number of max iterations for faster convergence!
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2)

    ## TODO: Select a value for k
    # then perform k-means clustering
    retval, labels, centers = cv2.kmeans(pixel_vals, k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)

    # convert data into 8-bit values
    centers = np.uint8(centers)
    segmented_data = centers[labels.flatten()]

    # reshape data into the original image dimensions
    segmented_image = segmented_data.reshape((image.shape))
    labels_reshape = labels.reshape(image.shape[0], image.shape[1])
    
    return segmented_image

In [None]:
def get_gray_image(image):
    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    return gray_image

def get_noisy_image(image, mode='s&p', amount=0.3):
    return (random_noise(image, mode=mode, amount=amount)*255).astype(np.uint8)

In [None]:
segmented_image = apply_kmeans(image, k=3)
plt.imshow(segmented_image)

In [None]:
gray_image = get_gray_image(segmented_image)
plt.imshow(gray_image, cmap="gray")

### Define the color threshold

In [None]:
def get_mask(image, val_min=200, lower_bound=None, upper_bound=None):
    if lower_bound is None:
        lower_bound = np.array([val_min/255]) 
    if upper_bound is None:
        upper_bound = np.array([255/255])
    # Define the masked area
    mask = cv2.inRange(image, lower_bound, upper_bound)
    return mask

### Create a mask

In [None]:
# Vizualize the mask
mask = get_mask(
    gray_image / gray_image.max(),
    lower_bound=np.array([225/255]),
    upper_bound=np.array([255/255])
)
plt.imshow(mask, cmap='gray')

In [None]:
def get_clusters(mask, val_min=600, eps=30, min_samples=100):
    light_pixels = np.argwhere(mask > val_min)
    clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(light_pixels)
    return clustering

def get_squares(data, clustering, shape, cluster_threshold=1500):
    squares = []
    for i in np.unique(clustering.labels_):
        cluster = data[clustering.labels_ == i]
        cl = np.zeros((shape[0], shape[1]))
        cl[cluster[:, 0], cluster[:, 1]] = 1
        if (cluster.shape[0] > cluster_threshold):
            squares.append((cluster[:, 0].min(), cluster[:, 0].max(), cluster[:, 1].min(), cluster[:, 1].max()))
    return squares
                   
val_min = 200
clustering = get_clusters(mask, val_min=val_min, eps=30)
light_pixels = np.argwhere(mask > val_min)
squares = get_squares(light_pixels, clustering, image.shape)

fig, ax = plt.subplots(len(squares))

for i, square in enumerate(squares):
    cluster = light_pixels[clustering.labels_ == i]
    cl = np.zeros((image.shape[0], image.shape[1]))
    cl[cluster[:, 0], cluster[:, 1]] = 1
    ax[i].imshow(cl)

In [None]:
def plot_squares(image, squares):
    fix, ax = plt.subplots(1, len(squares))
    for i, square in enumerate(squares):
        ax[i].imshow(image[squares[i][0]:squares[i][1], squares[i][2]:squares[i][3]])
    
plot_squares(image, squares)

In [None]:
def save_squares(image, squares):
    for i, square in enumerate(squares):
        square = image[squares[i][0]:squares[i][1], squares[i][2]:squares[i][3]]
        square = apply_kmeans(square, k=2)
        square = get_gray_image(square)
        square = get_mask(
            square / square.max(),
            lower_bound=np.array([225/255]),
            upper_bound=np.array([255/255])
        )
        cv2.imwrite(f"images/square_{i}.jpg", cv2.cvtColor(square, cv2.COLOR_RGB2BGR))
        
save_squares(image, squares)

In [None]:
import random

def random_color_switch(image, k=2):
    image_seg = apply_kmeans(image, k=k)
    image_gray = get_gray_image(image_seg)
    mask = get_mask(
        image_gray / image_gray.max(),
        lower_bound=np.array([225/255]),
        upper_bound=np.array([255/255])
    )
    image_c = image.copy()
    r = lambda: random.randint(0, 255)
    image_c[mask > 0] = [r(),r(),r()]
    image_c[mask < 1] = [r(),r(),r()]
    return image_c

In [None]:
def draw_squares(image, squares):
    for i in range(len(squares)):
        image = cv2.rectangle(image, (squares[i][2], squares[i][0]), (squares[i][3],squares[i][1]), (0,255,0), 5)
    return image

In [None]:
def annotate_image(image_file, lower_bound=np.array([150/255]), eps=10, min_samples=100, k=2):
    image = read_image(image_file)
    if image.max() <= 1.0:
        image = image * 255
    segmented_image = apply_kmeans(image, k=k)
    gray_image = get_gray_image(segmented_image)
    mask = get_mask(gray_image / gray_image.max(), lower_bound=lower_bound)
    clustering = get_clusters(mask, val_min=lower_bound[0]*255, eps=eps, min_samples=min_samples)
    light_pixels = np.argwhere(mask > lower_bound[0]*255)
    squares = get_squares(light_pixels, clustering, image.shape)
    return squares

In [None]:
IMG_FILE="images/demo.png"

squares = annotate_image(
    IMG_FILE,
    lower_bound=np.array([225/255]),
    eps=20,
    min_samples=500,
    k=3,
)

image = read_image(IMG_FILE)
image_annotated = draw_squares(image, squares)

plt.imshow(image_annotated)
cv2.imwrite(f"images/demo_annotated_1.png", cv2.cvtColor(image_annotated, cv2.COLOR_RGB2BGR))

In [None]:
IMG_FILE="images/demo.png"

squares = annotate_image(
    IMG_FILE,
    lower_bound=np.array([225/255]),
    eps=20,
    min_samples=500,
    k=3,
)

image = read_image(IMG_FILE)
image_annotated = random_color_switch(image_annotated, k=3)
image_annotated = draw_squares(image_annotated, squares)

plt.imshow(image_annotated)
cv2.imwrite(f"images/demo_annotated_2.png", cv2.cvtColor(image_annotated, cv2.COLOR_RGB2BGR))

In [None]:
IMG_FILE="images/demo.png"

squares = annotate_image(
    IMG_FILE,
    lower_bound=np.array([225/255]),
    eps=20,
    min_samples=500,
    k=3,
)

image = read_image(IMG_FILE)
image_annotated = random_color_switch(image, k=3)
image_annotated = get_noisy_image(image_annotated, mode='s&p', amount=random.randint(20,80)/100)
image_annotated = draw_squares(image_annotated, squares)

plt.imshow(image_annotated)
cv2.imwrite(f"images/demo_annotated_3.png", cv2.cvtColor(image_annotated, cv2.COLOR_RGB2BGR))

In [None]:
def get_labels():
    n_images = 612
    first_label = 59
    last_label = 9
    batch_size = 12

    labels = []
    for i in range(int(612 / batch_size)):
        for j in range(batch_size):
            image_labels = "0" + str(first_label - i)
            labels.append(image_labels)
    
    return labels

In [None]:
import os
image_dir = "images_12"
filenames = os.listdir(f"../{image_dir}")

classes = get_labels()
filelines = []

def map_squares(squares):
    new_squares = []
    for s in squares:
        y0,y1,x0,x1 = s
        h = y1 - y0
        w = x1 - x0
        if w < h/2:
            x0 = int(x1 - 4*h/5)
        new_squares.append((y0,y1,x0,x1))
    # sort squares by x0 coordinate
    new_squares = sorted(new_squares, key=lambda x: x[2])
    
    return new_squares

def filter_squares(squares):
    new_squares = []
    for s in squares:
        y0,y1,x0,x1 = s
        h = y1 - y0
        w = x1 - x0
        if h > 700 or w > 400:
            continue
        new_squares.append((y0,y1,x0,x1))
    return new_squares
        

for i, filename in enumerate(filenames):
    filepath = f"../{image_dir}/{filename}"
    squares = annotate_image(
        filepath,
        lower_bound=np.array([225/255]),
        eps=20,
        min_samples=500,
        k=3,
    )
    squares = map_squares(squares)
    squares = filter_squares(squares)
    image_annotated = read_image(filepath)
    #image_annotated = draw_squares(image, squares)
    print(f"{i+1}/{len(filenames)} ../{image_dir}_annotated/{filename[:-4]}.jpg")
    if filename[-4:] == ".png":
        image_annotated = image_annotated * 255
    image_annotated = random_color_switch(image_annotated, k=3)
    image_annotated = get_noisy_image(image_annotated, mode='s&p', amount=random.randint(20,80)/100)
    cv2.imwrite(f"../{image_dir}_annotated/{filename[:-4]}{filename[-4:]}", cv2.cvtColor(image_annotated, cv2.COLOR_RGB2BGR))
    squares = [f"{s[2]},{s[0]},{s[3]},{s[1]},{classes[i][k]}" for k,s in enumerate(squares)]
    filelines.append(f"IMAGE_DIR/{filename} {' '.join(squares)}")
    with open(f"../{image_dir}_annotated/images.txt", "w") as f:
        f.write("\n".join(filelines))