In [None]:
import numpy as np
import cv2
import hdbscan

from sklearn.cluster import KMeans
from tqdm import tqdm
from matplotlib import pyplot as plt

In [None]:
def segment_kmeans(image, color_space='RGB', num_clusters=3, channel='all'):
    if color_space == 'Lab':
        image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
    
    if channel != 'all':
        image = image[:, :, int(channel)]
        image = image[:, :, np.newaxis]

    pixel_values = image.reshape((-1, image.shape[2]))
    pixel_values = np.float32(pixel_values)
    
    kmeans = KMeans(n_clusters=num_clusters, random_state=0)
    labels = kmeans.fit_predict(pixel_values)
    centers = np.uint8(kmeans.cluster_centers_)
    
    segmented_image = centers[labels.flatten()]
    segmented_image = segmented_image.reshape(image.shape)
    
    return segmented_image

In [None]:
def segment_dbscan(image, color_space='RGB', min_cluster_size=100, min_samples=None, channel='all', subsample=5):
    if color_space == 'Lab':
        image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)

    if channel != 'all':
        image = image[:, :, int(channel)]
        image = image[:, :, np.newaxis]

    if subsample > 1:
        image = image[::subsample, ::subsample]

    points = image.reshape((-1, image.shape[2]))
    clusterer = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size, min_samples=min_samples)
    labels = clusterer.fit_predict(points)

    unique_labels = np.unique(labels)
    colors = [tuple(np.random.choice(range(256), size=3)) for _ in unique_labels]
    mask = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.uint8)
    for label, color in zip(unique_labels, colors):
        if label != -1:
            mask[labels.reshape(image.shape[:2]) == label] = color

    if subsample > 1:
        mask = cv2.resize(mask, (image.shape[1] * subsample, image.shape[0] * subsample), interpolation=cv2.INTER_NEAREST)

    return mask

In [None]:
def segment_seeds(image, num_seeds=5, intensity_difference=30, color_space='RGB', channel='all', seed_points=None):
    if color_space == 'Lab':
        image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)

    if channel != 'all':
        image = image[:, :, int(channel)]
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) 

    output_image = image.copy()

    if seed_points is None:
        seed_points = [(np.random.randint(0, image.shape[1]), np.random.randint(0, image.shape[0])) for _ in range(num_seeds)]
    else:
        num_seeds = len(seed_points)

    seed_colors = [tuple(np.random.randint(0, 256, size=3).tolist()) for _ in range(num_seeds)]
    
    mask = np.zeros((image.shape[0] + 2, image.shape[1] + 2), dtype=np.uint8)
    for (x, y), color in zip(seed_points, seed_colors):
        cv2.floodFill(output_image, mask, (x, y), color, (intensity_difference,)*3, (intensity_difference,)*3, cv2.FLOODFILL_FIXED_RANGE)
    
    return output_image

In [None]:
def segment_region_growing(image, seed=None, intensity_threshold=10, color_space='RGB', channel='all'):
    if color_space == 'Lab':
        image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)

    if channel != 'all':
        image = image[:, :, int(channel)]
    else:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    if seed is None:
        seed = (image.shape[0] // 2, image.shape[1] // 2)

    segmented = np.zeros_like(image)
    segmented[seed] = 255
    current_points = [seed]

    while current_points:
        new_points = []
        for point in current_points:
            neighbors = [((point[0] + dx), (point[1] + dy)) for dx in range(-1, 2) for dy in range(-1, 2) if dx != 0 or dy != 0]
            for nx, ny in neighbors:
                if 0 <= nx < image.shape[0] and 0 <= ny < image.shape[1]:
                    if segmented[nx, ny] == 0 and abs(int(image[nx, ny]) - int(image[point[0], point[1]])) < intensity_threshold:
                        segmented[nx, ny] = 255
                        new_points.append((nx, ny))
        current_points = new_points

    if channel == 'all':
        segmented = cv2.cvtColor(segmented, cv2.COLOR_GRAY2BGR)

    return segmented


In [None]:
def segment_watershed(image, morph_kernel_size=(3, 3), morph_iter=2, dilate_iter=3, dist_transform_mask_size=5, fg_threshold_factor=0.5, color_space='RGB', channel='all'):
    if color_space == 'Lab':
        image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)

    if channel != 'all':
        image = image[:, :, int(channel)]
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)

    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    ret, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    kernel = np.ones(morph_kernel_size, np.uint8)
    opening = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=morph_iter)
    sure_bg = cv2.dilate(opening, kernel, iterations=dilate_iter)
    dist_transform = cv2.distanceTransform(opening, cv2.DIST_L2, dist_transform_mask_size)
    ret, sure_fg = cv2.threshold(dist_transform, fg_threshold_factor * dist_transform.max(), 255, 0)
    sure_fg = np.uint8(sure_fg)
    unknown = cv2.subtract(sure_bg, sure_fg)
    ret, markers = cv2.connectedComponents(sure_fg)
    markers = markers + 1
    markers[unknown == 255] = 0
    markers = cv2.watershed(image, markers)
    image[markers == -1] = [255, 0, 0]
    
    return image

In [None]:
def process_video(input_video_path, output_video_path, segmentation_function, **kwargs):
    cap = cv2.VideoCapture(input_video_path)
    if not cap.isOpened():
        print("Ошибка при открытии видеофайла")
        return

    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    frame_size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
    out = cv2.VideoWriter(output_video_path, fourcc, fps, frame_size)
    
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    pbar = tqdm(total=total_frames, unit="frame", desc="Processing Video")
    
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        
        segmented_frame = segmentation_function(frame, **kwargs)
        
        out.write(segmented_frame)
        
        pbar.update(1)
    
    pbar.close()
    
    cap.release()
    out.release()
    cv2.destroyAllWindows()

In [None]:
video_path = "../assets/sample_640x360.mp4"

In [None]:
process_video(video_path, "output_video_kmeans.mp4", segment_kmeans, color_space='RGB', num_clusters=3)

In [None]:
process_video(video_path, "output_video_dbscan.mp4", segment_dbscan, color_space='RGB', min_cluster_size=100, min_samples=10, subsample=4)

In [None]:
process_video(video_path, "output_video_seeds.mp4", segment_seeds, num_seeds=5, intensity_difference=30)

In [None]:
process_video(video_path, "output_video_region_growing.mp4", segment_region_growing, seed=(100, 100), intensity_threshold=10)

In [None]:
process_video(video_path, "output_video_watershed.mp4", segment_watershed, morph_kernel_size=(3, 3), fg_threshold_factor=0.1)

In [None]:
def compare_segmentation(image_path, segmentation_function, **kwargs):
    image = cv2.imread(image_path)
    if image is None:
        raise ValueError("Image could not be loaded. Please check the path.")
    
    rgb_segmented = segmentation_function(image, color_space='RGB', **kwargs)

    lab_image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
    lab_segmented = segmentation_function(lab_image, color_space='Lab', **kwargs)

    if (channel := kwargs.get('channel')) is not None and channel != 'all':
        image = image[:, :, int(kwargs['channel'])]

    plt.figure(figsize=(10, 5))
    plt.subplot(1, 3, 1)
    plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    plt.title('Original Image')
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(cv2.cvtColor(rgb_segmented, cv2.COLOR_BGR2RGB))
    plt.title('Segmented in RGB')
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(cv2.cvtColor(lab_segmented, cv2.COLOR_BGR2RGB))
    plt.title('Segmented in CIE Lab')
    plt.axis('off')

    plt.show()

    return rgb_segmented, lab_segmented

In [None]:
image_path = '../assets/image.png'

In [None]:
for channel in ['all', '0', '1', '2']:
    rgb_segmented, lab_segmented = compare_segmentation(image_path, segment_kmeans, num_clusters=3, channel=channel)

In [None]:
for channel in ['all', '0', '1', '2']:
    rgb_segmented, lab_segmented = compare_segmentation(image_path, segment_dbscan, min_cluster_size=100, min_samples=10, subsample=4, channel=channel)

In [None]:
image = cv2.imread(image_path)

In [None]:
image.shape

In [None]:
seed_points = [(256, 256), (int(512/3), int(512/3)), (int(512/6), int(512/6))]
for channel in ['all', '0', '1', '2']:
    rgb_segmented, lab_segmented = compare_segmentation(image_path, segment_seeds, seed_points=seed_points, intensity_difference=30, channel=channel)

In [None]:
for channel in ['all', '0', '1', '2']:
    rgb_segmented, lab_segmented = compare_segmentation(image_path, segment_seeds, num_seeds=5, intensity_difference=30, channel=channel)

In [None]:
for channel in ['all', '0', '1', '2']:
    rgb_segmented, lab_segmented = compare_segmentation(image_path, segment_region_growing, seed=(100, 100), intensity_threshold=10, channel=channel)

In [None]:
for channel in ['all', '0', '1', '2']:
    rgb_segmented, lab_segmented = compare_segmentation(image_path, segment_watershed, morph_kernel_size=(5, 5), morph_iter=3, dilate_iter=5, dist_transform_mask_size=3, fg_threshold_factor=0.4, channel=channel)

In [None]:
# тестирование водораздела
# 
# image = cv2.imread(image_path)
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 
# kernel_sizes = [(5, 5)]
# morph_iters = [3]
# dilate_iters = [5]
# dist_transform_mask_sizes = [3, 5]
# fg_threshold_factors = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
# 
# results = []
# for ks in kernel_sizes:
#     for mi in morph_iters:
#         for di in dilate_iters:
#             for dtms in dist_transform_mask_sizes:
#                 for fgf in fg_threshold_factors:
#                     segmented_image = segment_watershed(
#                         image.copy(), morph_kernel_size=ks, morph_iter=mi, dilate_iter=di, 
#                         dist_transform_mask_size=dtms, fg_threshold_factor=fgf
#                     )
#                     results.append((ks, mi, di, dtms, fgf, segmented_image))
# 
# num_rows = (len(results) + 2) // 3
# 
# fig, axs = plt.subplots(num_rows, 3, figsize=(15, 5 * num_rows))
# for idx, ax in enumerate(axs.flat):
#     if idx < len(results):
#         ax.imshow(results[idx][5])
#         ax.set_title(f"KS: {results[idx][0]}, MI: {results[idx][1]}, DI: {results[idx][2]}, DTMS: {results[idx][3]}, FGF: {results[idx][4]}")
#         ax.axis('off')
# 
# plt.tight_layout()
# plt.show()