## Imports

In [None]:
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import ot
import cv2
import imageio
from glob import glob

## Data Import

In [None]:
# Load images
directory = r'C:\Users\ariel\PycharmProjects\MLDM_Project\generations\patch_match_24_30'
# List all files in directory
file_list = sorted(os.listdir(directory))
print(file_list)

np_images = []
for filename in file_list:
    if filename.endswith('.png'):
        img = Image.open(directory + '/' + filename).convert('L')
        np_images.append(np.array(img) / 255.0)

images = np.array(np_images)

print(f"Processed {len(images)} images.")

## Optimal Transport function

In [None]:
def compute_optimal_transport(image1, image2, reg_e=0.0001):
    # Flatten the images
    image1_flat = image1.flatten()
    image2_flat = image2.flatten()

    # Compute the Optimal Transport between the two images
    M12 = ot.dist(image1_flat.reshape(-1, 1), image2_flat.reshape(-1, 1))
    M21 = ot.dist(image2_flat.reshape(-1, 1), image1_flat.reshape(-1, 1))
    ot_plan1 = ot.emd([], [], M12)
    ot_plan2 = ot.emd([], [], M21)
    
    # Compute the Optimal Transport plan from image1 to image2 using Sinkhorn's algorithm
    #ot_plan1 = ot.bregman.sinkhorn(image1_flat, image2_flat, M12, reg=0.01)
    
    # Compute the Optimal Transport plan from image2 to image1 using Sinkhorn's algorithm
    #ot_plan2 = ot.bregman.sinkhorn(image2_flat, image1_flat, M21, reg=0.01)
    
    # Apply the Optimal Transport plan to the first image
    transport_map12 = ot_plan1.dot(image2_flat)
    transport_map21 = ot_plan2.dot(image1_flat)
    
    intermediate_image12 = transport_map12.reshape(image1.shape)
    intermediate_image21 = transport_map21.reshape(image2.shape)
    
    intermediate_image = (intermediate_image12 + intermediate_image21) / 2

    return intermediate_image

## Registration

In [None]:
counter = 1

# Loop over the list of images
for i in range(len(images) - 1):
    # Get the current image and the next image
    #image1 = images[i]
    #image2 = images[i + 1]
    
    print("Processing image " + str(counter) + "/" + str(len(images)))
    
    image1 = cv2.resize(images[i], (128, 128))
    image2 = cv2.resize(images[i + 1], (128, 128))

    # Compute the Optimal Transport between the two images
    ot_img = compute_optimal_transport(image1, image2, reg_e=0.01)
    ot_img = (ot_img - ot_img.min()) / (ot_img.max() - ot_img.min())*255
    ot_img = cv2.fastNlMeansDenoising(ot_img.astype(np.uint8), None, 10, 7, 31)

    #Plot the images OT images
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax[0].imshow(image1, cmap='gray')
    ax[0].set_title('Image 1')
    ax[1].imshow(ot_img, cmap='gray')
    ax[1].set_title('Optimal Transport Plan')
    ax[2].imshow(image2, cmap='gray')
    ax[2].set_title('Image 2')
    plt.show()
        
    # Save the original image as a PNG file
    original_pil_img = Image.fromarray((image1 * 255).astype(np.uint8))
    original_pil_img.save(f'C:/Users/ariel/PycharmProjects/MLDM_Project/generations/optimal_transport_24_30/image_{counter}.png')
    print(f"Saved image_{counter}.png")
    counter += 1

    # Save the Optimal Transport image as a PNG file
    ot_pil_img = Image.fromarray(ot_img.astype(np.uint8))
    ot_pil_img.save(f'C:/Users/ariel/PycharmProjects/MLDM_Project/generations/optimal_transport_24_30/image_{counter}.png')
    print(f"Saved image_{counter}.png")
    counter += 1

## Making GIFs

In [None]:
# Get the list of image files
file_list = sorted(glob(r"C:\Users\ariel\PycharmProjects\MLDM_Project\generations\optimal_transport_generations\optimal_transport_24_30/*.png"))

# Read the images into a list
images = [imageio.imread(file) for file in file_list]

# Save the images as a GIF
imageio.mimsave(r'C:\Users\ariel\PycharmProjects\MLDM_Project\plots/ot_24_30.gif', images)