## Generate 2D Gaussian Image using Optimal Transport

In [1]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import ot

# Generate a random RGB image with Gaussian distribution
def generate_gaussian_image(height, width):
    mean = [0.5, 0.5, 0.5]
    cov = [[0.1, 0, 0], [0, 0.1, 0], [0, 0, 0.1]]
    gaussian_image = np.random.multivariate_normal(mean, cov, (height, width))
    gaussian_image = np.clip(gaussian_image, 0, 1)
    return gaussian_image

# Function to permute pixels using optimal transport
def permute_image(input_image, gaussian_image):
    height, width, _ = input_image.shape
    input_flat = input_image.reshape(-1, 3)
    gaussian_flat = gaussian_image.reshape(-1, 3)
    
    # Compute the cost matrix
    cost_matrix = ot.dist(input_flat, gaussian_flat)
    
    # Solve the optimal transport problem
    transport_plan = ot.emd(np.ones(len(input_flat)) / len(input_flat), 
                            np.ones(len(gaussian_flat)) / len(gaussian_flat), 
                            cost_matrix)
    
    # Permute the Gaussian image
    permuted_indices = np.argmax(transport_plan, axis=1)
    permuted_image = gaussian_flat[permuted_indices].reshape(height, width, 3)
    
    return permuted_image

In [None]:
# Example usage
from numpy import asarray
img= Image.open('data/noise/fire.png').convert('RGB')
input_image = asarray(img)
height, width, _ = input_image.shape
gaussian_image = generate_gaussian_image(height, width)
permuted_image = permute_image(input_image, gaussian_image)

# Plot the images
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(input_image)
axes[0].set_title('Input Image')
axes[1].imshow(gaussian_image)
axes[1].set_title('Gaussian Image')
axes[2].imshow(permuted_image)
axes[2].set_title('Permuted Image')
plt.show()