In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
import cv2
import imageio.v3 as iio
import sklearn as sk
import matplotlib
from sklearn.cluster import KMeans

matplotlib.use("nbagg")


In [2]:
# viewing the basic image
img = iio.imread('colored balls.png')
plt.imshow(img)
plt.show()

<IPython.core.display.Javascript object>

In [3]:
# show a 3d histogram of the image, by colors
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from matplotlib import colors

def plot_3d_hist(img):
    r, g, b = cv2.split(img)
    fig = plt.figure(figsize=(8, 8))
    axis = fig.add_subplot(1, 1, 1, projection="3d")
    pixel_colors = img.reshape((np.shape(img)[0]*np.shape(img)[1], 3))
    norm = colors.Normalize(vmin=-1.,vmax=1.)
    norm.autoscale(pixel_colors)
    pixel_colors = norm(pixel_colors).tolist()
    axis.scatter(r.flatten(), g.flatten(), b.flatten(), facecolors=pixel_colors, marker=".")
    axis.set_xlabel("Red")
    axis.set_ylabel("Green")
    axis.set_zlabel("Blue")
    plt.show()

plot_3d_hist(img)

<IPython.core.display.Javascript object>

In [4]:
# get r, g, b channels
r, g, b = cv2.split(img)

In [5]:
# in this section, we will implement a hough algorithm to detect lines in the 3d space

# we will start by defining the hough space for the 3d space linear lines, that are all starting at the origin.
# we will have two parameters, the angle of the line in the x-y plane, and the angle of the line in the x-z plane.
def trasform_point(x, y, z):
    """
    take a 3d point, and return the polar coordinates of the point in the x-y plane, and the x-z plane.
    :param point:
    :return:
    """
    theta_xy = np.arctan2(y, x)
    theta_xz = np.arctan2(z, x)
    return theta_xy, theta_xz

def trasform_points(x, y, z):
    """
    take a list of 3d points, and return the polar coordinates of the points in the x-y plane, and the x-z plane.
    :param points:
    :return:
    """
    thetas_xy = np.arctan2(y, x)
    thetas_xz = np.arctan2(z, x)
    r_xyz = np.sqrt(x**2 + y**2 + z**2)

    return np.stack([thetas_xy, thetas_xz, r_xyz], axis=1)

# plot all the thetas in the x-y plane, and the x-z plane.
def plot_thetas(thetas_xy, thetas_xz, r_xyz):
    fig, axis = plt.subplots(figsize=(8, 8))
    axis.scatter(thetas_xy, thetas_xz, marker=".", c=r_xyz, cmap="Purples")
    axis.set_xlabel("theta_xy")
    axis.set_ylabel("theta_xz")
    plt.show()


In [162]:
red_ball = cv2.imread('red_ball.png')
plot_3d_hist(red_ball)

<IPython.core.display.Javascript object>

In [145]:
# get r, g, b channels
r, g, b = cv2.split(red_ball)

transformed_points = trasform_points(r.flatten(), g.flatten(), b.flatten())
plot_thetas(transformed_points[:, 0], transformed_points[:, 1], transformed_points[:, 2])

<IPython.core.display.Javascript object>

In [146]:
# plot a histogram of the r_values
plt.hist(transformed_points[:, 2], bins=100)
plt.show()

<IPython.core.display.Javascript object>

In [147]:
# we will use the DBSCAN algorithm to cluster the points in the hough space.
from sklearn.cluster import KMeans

def find_centeroids(data_2d, weights):
    # Perform K-means clustering with different numbers of clusters
    max_clusters = 10
    inertia = []

    # make a subplot for each number of clusters
    fig, axes = plt.subplots(2, 5, figsize=(12, 6), sharex=True, sharey=True)
    axes = axes.flatten()

    for n_clusters in range(1, max_clusters + 1):
        kmeans = KMeans(n_clusters=n_clusters)
        kmeans.fit(data_2d, sample_weight=weights)
        inertia.append(kmeans.inertia_)
        # Plot the clusters and cluster centers
        ax = axes[n_clusters - 1]
        ax.scatter(data_2d[:, 0], data_2d[:, 1], c=kmeans.labels_, cmap='viridis')
        ax.scatter(kmeans.cluster_centers_[:, 0], kmeans.cluster_centers_[:, 1], marker='x', s=100, c='r')
        ax.set_title('k = {}'.format(n_clusters))

    plt.show()

    # Plot the inertia to see which number of clusters is best
    plt.plot(range(1, max_clusters + 1), inertia)
    plt.xlabel('Number of clusters')
    plt.ylabel('Inertia')
    plt.show()

# remove really dark points
# transformed_points = transformed_points[transformed_points[:, 2] > 0.2]
# make this weight more significant, with a power of 2.
weights = transformed_points[:, 2] ** 3
centeroids = find_centeroids(transformed_points[:, :2], weights)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [148]:
# the user will choose a k number for the number of clusters (will be automated) and we will use the kmeans algorithm
# to cluster the points in the hough space.
N_CLUSTERS = 1
kmeans = KMeans(n_clusters=N_CLUSTERS)
kmeans.fit(transformed_points[:, :2], sample_weight=transformed_points[:, 2])

# get the cluster centers
cluster_centers = kmeans.cluster_centers_

In [149]:
# turn the cluster centers into 3d lines from the origin, and plot them on the 3d histogram
def polar_to_cartesian(theta_xy, theta_xz, r):
    """
    take a point in polar coordinates, and return the point in cartesian coordinates.
    :param theta_xy:
    :param theta_xz:
    :param r:
    :return:
    """
    x = r * np.cos(theta_xy) * np.cos(theta_xz)
    y = r * np.sin(theta_xy) * np.cos(theta_xz)
    z = r * np.sin(theta_xz)
    return x, y, z

# plot the lines and the points
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')
r, g, b = cv2.split(red_ball)
r, g, b = r.flatten(), g.flatten(), b.flatten()
ax.scatter(r, g, b, marker=".", c="r", alpha=0.1)

for i in range(len(cluster_centers)):
    x, y, z = polar_to_cartesian(cluster_centers[i, 0], cluster_centers[i, 1], np.linspace(0, 300))
    ax.plot(x, y, z, c="b", linewidth=3, label="line {}".format(i))
fig.legend()
plt.show()


<IPython.core.display.Javascript object>

In [150]:
def min_perpendicular_distance(point, lines):
    """
    find the minimum perpendicular distance between a point and a set of lines.
    :param point:
    :param lines: a tuple of two lists, each list contains the start and end points of the lines.
    :return: a tuple of the closest point on the line, and the distance.
    """
    min_dist = float('inf')
    closest_color_ret = None
    starts = lines[0]
    ends = lines[1]
    for idx in range(len(starts)):
        p1, p2 = starts[idx], ends[idx]
        u = p2 - p1
        v = point - p1
        w = np.cross(u, v)
        dist = np.linalg.norm(w) / np.linalg.norm(u)
        if dist < min_dist:
            min_dist = dist
            closest_color_ret = ends[idx]
    return closest_color_ret, min_dist

# turn a polar representation into a line (start and end points, start at the origin, end at length r)
def polar_to_line(theta_xy, theta_xz, r):
    """
    take a point in polar coordinates, and return the point in cartesian coordinates.
    :param theta_xy:
    :param theta_xz:
    :param r:
    :return:
    """
    x = r * np.cos(theta_xy) * np.cos(theta_xz)
    y = r * np.sin(theta_xy) * np.cos(theta_xz)
    z = r * np.sin(theta_xz)
    return [[np.array([0, 0, 0]), np.array([x, y, z])[:, i]] for i in range(len(x))]

# find the closest point on the line to the point
def closest_point_on_line(point, line):
    """
    find the closest point on a line to a given point.
    :param point:
    :param line: a tuple of two points, the start and end points of the line.
    :return: the closest point on the line, and the distance.
    """
    p1, p2 = line
    u = p2 - p1
    v = point - p1
    w = np.cross(u, v)
    dist = np.linalg.norm(w) / np.linalg.norm(u)
    return (p1 + u * np.dot(u, v) / np.dot(u, u)), dist

# plot all points, and the closest point on the line to each point, and the lines
def plot_points_and_lines(points, lines):
    """
    plot all points, and the closest point on the line to each point, and the lines
    :param points:
    :param lines: start and end points of the lines
    :return:
    """
    r, g, b = points
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(r.flatten(), g.flatten(), b.flatten(), marker=".", c="r", alpha=0.1)

    for i in range(len(lines)):
        # each line is in format [start (x, y, z), end (x, y, z)], and there is a list of lines
        ax.plot([lines[i][0][0], lines[i][1][0]], [lines[i][0][1], lines[i][1][1]], [lines[i][0][2], lines[i][1][2]], c="b", linewidth=3, label="line {}".format(i))
    for point_idx in range(len(r)):

        closest = [closest_point_on_line(np.array([r[point_idx], g[point_idx], b[point_idx]]), lines[i]) for i in range(len(lines))]
        closest_point = min(closest, key=lambda x: x[1])[0]
        ax.scatter(closest_point[0], closest_point[1], closest_point[2], marker=".", c="g", alpha=0.8)

    fig.legend()
    plt.savefig("lines.png")

lines = polar_to_line(cluster_centers[:, 0], cluster_centers[:, 1], 300)
plot_points_and_lines((r[::30],g[::30],b[::30]), lines)

In [157]:
# transform each point to it's closest point on the line, and save the new points
def transform_points_to_lines(points, lines):
    """
    transform each point to it's closest point on the line, and save the new points
    :param points:
    :param lines:
    :return:
    """
    r, g, b = points[0], points[1], points[2]
    new_points = []
    for point_idx in range(len(r)):
        closest = [closest_point_on_line(np.array([r[point_idx], g[point_idx], b[point_idx]]), lines[i]) for i in range(len(lines))]
        closest_point = min(closest, key=lambda x: x[1])[0]
        new_points.append(closest_point)
    return np.array(new_points)

new_points = transform_points_to_lines((r, g, b), lines)
ax = plt.axes(projection='3d')
ax.plot(new_points[:, 0], new_points[:, 1], new_points[:, 2], c="b")
plt.show()

<IPython.core.display.Javascript object>

In [165]:
# each point is a pixel in new_img
# create a new image, where each pixel is the closest point on the line to the original pixel
new_img = np.zeros_like(red_ball)
new_img = new_points.reshape(new_img.shape)
# normalize the image
new_img = new_img / np.max(new_img)
plot_3d_hist(new_img)
# plt.imshow(new_img)
# plt.show()


<IPython.core.display.Javascript object>

In [168]:
# display the original image and the new image
def plot_2_images(img1, img2):
    """
    display the original image and the new image
    :param img1:
    :param img2:
    :return:
    """
    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
    ax = axes.ravel()
    ax[0].imshow(img1)
    ax[0].set_title("Original image")
    ax[1].imshow(img2)
    ax[1].set_title("New image")
    fig.tight_layout()
    plt.show()

plot_2_images(red_ball, new_img)

<IPython.core.display.Javascript object>

In [180]:
# all the procces in one function, that gets a path to an image, and returns the new image
def quantize_img(img_path, N_CLUSTERS=1):
    """
    all the procces in one function, that gets a path to an image, and returns the new image
    :param N_CLUSTERS:
    :param img_path:
    :return:
    """
    img = iio.imread(img_path)
    r, g, b = img[:, :, 0], img[:, :, 1], img[:, :, 2]
    r, g, b = r.flatten(), g.flatten(), b.flatten()
    transformed_points = trasform_points(r.flatten(), g.flatten(), b.flatten())
    weights = transformed_points[:, 2] ** 3
    kmeans = KMeans(n_clusters=N_CLUSTERS)
    kmeans.fit(transformed_points[:, :2], sample_weight=weights)
    cluster_centers = kmeans.cluster_centers_

    lines = polar_to_line(cluster_centers[:, 0], cluster_centers[:, 1], 300)
    new_points = transform_points_to_lines((r, g, b), lines)
    new_img = np.zeros_like(img)
    new_img = new_points.reshape(new_img.shape)
    new_img = new_img / np.max(new_img)
    plot_2_images(img, new_img)
    # write to a file
    iio.imwrite("new_img.png", new_img, format="png")
    return new_img

quantize_img("colored balls.png", 8)

<IPython.core.display.Javascript object>

TypeError: Cannot handle this data type: (1, 1, 3), <f8