## Parse input

In [1]:
from PIL import Image
import os
from tqdm import tqdm

# Define the function to split the image
def split_and_save_image(filename, input_image_path, output_folder_path, save_to_tmp=False):

    # Open the input image
    with Image.open(os.path.join(input_image_path, filename)) as img:

        # Get the filename and the extension
        extension = os.path.basename(filename).split('.')[-1]

        if save_to_tmp:
            foldername = "tmp"
        else:
            foldername = filename.split('.')[0]

        # Define the size of each individual image, assuming they are of equal height
        # and the last image occupies the whole width of the original image
        width, height = img.size
        single_image_height = height // 2  # Divide by 2 because there are 2 rows
        single_image_width = width // 5  # Divide by 5 for the images in the first row
        
        # Create the output directory
        output_directory = os.path.join(output_folder_path, foldername)
        os.makedirs(output_directory, exist_ok=True)
        
        # Split the images and save them
        for i in range(5):  # For the first row
            left = i * single_image_width
            right = (i + 1) * single_image_width
            box = (left, 0, right, single_image_height)
            part_img = img.crop(box)
            part_img.save(os.path.join(output_directory, f'x_{i+1}.{extension}'))
        
        # Save the last image
        box = (0, single_image_height, width*(2/15), 2 * single_image_height)
        part_img = img.crop(box)
        part_img.save(os.path.join(output_directory, f'y.{extension}'))

## Crop squares

In [4]:
import numpy as np
import matplotlib.pyplot as plt
import cv2
from copy import deepcopy
from PIL import Image, ImageDraw


def open_image_and_visualize(image_path, visualize=False):
    im = cv2.imread(image_path)
    if visualize:
        plt.imshow(im)
        plt.colorbar()
        plt.show()
    return im


def make_scalar_product_mask(im, visualize=False):
    image = im/255.
    unit_vector = np.array([1/np.sqrt(3), 1/np.sqrt(3), 1/np.sqrt(3)])
    unit_image = image/np.sqrt(np.tile(np.sum(image*image, axis=2), (3,1,1)).transpose(1,2,0))
    new_im = np.sum(unit_image*unit_vector, axis=2) * 255
    if visualize:
        plt.imshow(new_im)
        plt.colorbar()
        plt.show()
    return new_im


def make_binar(im, threshold=240, visualize=False):
    img = deepcopy(im)
    img[img!=img] = 255
    img[img<threshold] = 0
    img[img>=threshold] = 255
    if visualize:
        plt.imshow(img, cmap="gray")
        plt.colorbar()
        plt.show()
    return img


def pad_borders(im, pad_width=3, visualize=False):
    im = np.pad(im, 10, 'linear_ramp', end_values=255)
    if visualize:
        plt.imshow(im, cmap='gray')
        plt.colorbar()
        plt.show()
    return im


def smooth_circle_borders(im, visualize=False):
    
    kernel_size = 3
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
    im = cv2.morphologyEx(im, cv2.MORPH_OPEN, kernel, iterations=2)
    im = cv2.morphologyEx(im, cv2.MORPH_CLOSE, kernel, iterations=1)
    if visualize:
        plt.imshow(im, cmap='gray')
        plt.colorbar()
        plt.show()
    return im


def detect_circles(im, visualize=False):

    im = np.uint8(im)
    radius_tolerance = 5
    min_radius = 20 - radius_tolerance
    max_radius = 20 + radius_tolerance

    dp = 1.5  # The inverse ratio of the accumulator resolution to the image resolution
    minDist = 25  # Minimum distance between the centers of the detected circles
    param1 = 40  # The higher threshold of the two passed to the Canny edge detector
    param2 = 18  # Accumulator threshold for the circle centers at the detection stage

    # Use HoughCircles to detect circles
    detected_circles = cv2.HoughCircles(
        im, 
        cv2.HOUGH_GRADIENT, 
        dp, 
        minDist, 
        param1=param1, 
        param2=param2, 
        minRadius=min_radius, 
        maxRadius=max_radius
    )

    # Convert the circle parameters a, b and r to integers.
    detected_circles_rounded = np.uint16(np.around(detected_circles))
    
    # Visualize the results on the new image
    if visualize:
        output_image_circles = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)

    centers = []
    # Draw the detected circles
    if detected_circles_rounded is not None:
        for i in detected_circles_rounded[0, :]:
            center = (i[0], i[1])
            centers.append(center)

            if visualize:
                radius = i[2]
                # Draw the outer circle
                cv2.circle(output_image_circles, center, radius, (0, 255, 0), 2)
                # Draw the center of the circle
                cv2.circle(output_image_circles, center, 2, (0, 0, 255), 3)

    if visualize:
        plt.imshow(output_image_circles)
        plt.colorbar()
        plt.show()

    return centers


def substract_padding(centers, padding=10):
    new_centers = []
    for c in centers:
        new_centers.append((c[0]-padding, c[1]-padding))
    return new_centers


def sort_centers(centers, img_path, visualize=False):

    # Define the coordinates of the 6th point
    sixth_point = (40, 190)

    # Calculate distances from the sixth point to each of the centers
    distances = [np.sqrt((x - sixth_point[0])**2 + (y - sixth_point[1])**2) for x, y in centers]

    # Pair each center with its distance from the sixth point
    center_distances = list(zip(centers, distances))

    # Sort the centers by their distance from the sixth point
    sorted_centers = sorted(center_distances, key=lambda x: x[1])

    # Extract the sorted centers and their order
    sorted_centers_only = [center for center, distance in sorted_centers]

    if visualize:
        # Load the provided image
        img = Image.open(img_path)
        # Convert to RGB to plot color on top of the original image
        img_rgb = img.convert('RGB')
        draw = ImageDraw.Draw(img_rgb)

        # Draw the sixth point
        draw.ellipse((sixth_point[0]-3, sixth_point[1]-3, sixth_point[0]+3, sixth_point[1]+3), fill='blue', outline='blue')

        # Draw the centers, lines, and order
        for num, (x, y) in enumerate(sorted_centers_only):
            # Draw the center
            draw.ellipse((x-3, y-3, x+3, y+3), fill='red', outline='red')
            # Draw line from the center to the sixth point
            draw.line((x, y, sixth_point[0], sixth_point[1]), fill='green', width=1)
            # Annotate the order next to the center
            draw.text((x+5, y), f'{num+1}', fill='purple',)

        plt.imshow(img_rgb)
        plt.colorbar()
        plt.show()

    return sorted_centers_only


def define_corners(center, radius=15):

    left = center[0] - radius
    right = center[0] + radius
    up = center[1] - radius
    down = center[1] + radius
    
    return (left, up, right, down)



def draw_squares(image_path, squares):
    """
    Draws squares on the image at the given path.

    Parameters:
    - image_path: Path to the image file.
    - squares: A list of tuples, each representing a square in the format (left, up, right, down).

    Returns:
    - A PIL Image object with the squares drawn.
    """
    # Load the image
    img = Image.open(image_path)

    # Convert the image to RGB if it's not already
    if img.mode != 'RGB':
        img = img.convert('RGB')
    draw = ImageDraw.Draw(img)

    # Draw each square
    for (left, up, right, down) in squares:
        # Draw a rectangle for each square
        draw.rectangle([(left, up), (right, down)], outline='yellow', width=2)
    
    plt.imshow(img)
    plt.colorbar()
    plt.show()

In [5]:
def combined_all_preprocessing_functions(image_path, visualize=False):
    im = open_image_and_visualize(image_path, visualize=visualize)
    im = make_scalar_product_mask(im, visualize=visualize)
    im = make_binar(im, threshold=240)
    im = pad_borders(im, pad_width=10, visualize=visualize)
    im = smooth_circle_borders(im, visualize=visualize)
    centers = detect_circles(im, visualize=visualize)
    centers = substract_padding(centers, padding=10)
    centers = sort_centers(centers, image_path, visualize=visualize)
    squares = [define_corners(center, radius=20) for center in centers]
    if visualize:
        draw_squares(image_path, squares)
    return centers

## Predict image

In [None]:
import clip
import torch
from PIL import Image
import matplotlib.pyplot as plt

# Функция предварительной обработки изображения
def preprocess_image(image_path, preprocess):
    with Image.open(image_path) as image:
        image = image.convert("RGB")
        processed_image = preprocess(image)
    return processed_image.unsqueeze(0), image

# Определение устройства: использовать GPU если доступно, иначе CPU
device = "cuda" if torch.cuda.is_available() else "cpu"

# Загрузка модели CLIP и перенос на выбранное устройство
model, preprocess = clip.load("ViT-B/32", device=device)
model.eval()

# Функция для получения признаков изображения
def get_image_features(image_path, model, preprocess):
    processed_image, original_image = preprocess_image(image_path, preprocess)
    with torch.no_grad():
        image_features = model.encode_image(processed_image.to(device)).float()
        image_features /= image_features.norm(dim=-1, keepdim=True)
    return image_features, original_image

# Сравнение изображений и вывод результатов
def compare_images(base_image_path, model, preprocess):
    base_features, base_image = get_image_features(base_image_path, model, preprocess)
    for i in range(2, 7):
        compare_image_path = f"183099_icon_{i}.jpg"
        try:
            compare_features, compare_image = get_image_features(compare_image_path, model, preprocess)
            similarity = (base_features @ compare_features.T).cpu().numpy()
            display_comparison(base_image, compare_image, similarity, i)
        except FileNotFoundError:
            print(f"Файл {compare_image_path} не найден.")

# Функция для отображения сравниваемых изображений и косинусного сходства
def display_comparison(base_image, compare_image, similarity, i):
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(base_image)
    plt.title("Base Image")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(compare_image)
    plt.title(f"Compare to Image {i}\nSimilarity: {similarity.item():.4f}")
    plt.axis('off')

    plt.show()

# Путь к базовому изображению
base_image_path = "183099_icon_1.jpg"

# Сравнение изображений
compare_images(base_image_path, model, preprocess)

In [3]:
# Define the paths
input_image_path = 'orbits/train'
output_folder_path = 'orbits/parsed_input'

#
filenames = ["183199.jpg", "183201.jpg", "183203.jpg", "183205.jpg", 
"183208.jpg", "183218.jpg", "183220.jpg"]
filenames = ["183125.jpg", "183108.jpg", "183239.jpg"]

# Call the function
for filename in tqdm(filenames):
    split_and_save_image(filename, input_image_path, output_folder_path, save_to_tmp=True)

100%|██████████| 3/3 [00:00<00:00, 42.13it/s]
