In [None]:
def loss(image, min_color_mask, maj_color_mask):
  """
  This function computes the Wasserstein distance (aka Earth Mover's Distance) between the probability distributions of (grayscale) pixel intentsities between the minority and majority color segmentation masks.

  In theory, having the touched up region's pixel color distribution appear as similar as possible to that of the majority color region is desirable because
  the touched up region will blend in better with the remainder of the hair. This improves the color while preserving more "naturalness".  Natural hair texture is often lost because it
  appears that the hair texture in an image is numerically represented by particular patterns of minor variations in RGB values (the RGB values of which are approximated by grayscale pixel intensity)
  and that these patters are vulnerable to being "smoothed" out of existence by the Stable Diffusion model.  Attempting to optimize this loss function is attempting to minimize the amount of
  "smoothness" afflicting the touched up region while pushing the mean (and median) pixel color in the direction of the majority color region.


  Parameters:
  image (numpy array): The original image.
  min_color_mask (numpy array): The minority color segmentation mask.
  maj_color_mask (numpy array): The majority color segmentation mask.

  Returns:
  loss (float): The Wasserstein distance between the probability distributions of (grayscale) pixel intentsities between the minority and majority color segmentation masks.
  """

  # Create black background image equal in size to the original image
  colored_min_region = np.zeros_like(image)
  colored_maj_region = np.zeros_like(image)

  # Extract the indices of the pixel regions of each mask
  min_mask_regions = np.where(min_color_mask > 0)
  maj_mask_regions = np.where(maj_color_mask > 0)

  # Map the colored pixels from the minority and majority color segmentation masks onto the black background image
  colored_min_region[min_mask_regions] = image[min_mask_regions]
  colored_maj_region[maj_mask_regions] = image[maj_mask_regions]

  """
  The two lines below convert RGB images to grayscale images.
  NOTE: The formula to do, based on relative perception of color brightness, is the following: 0.299 ∙ Red + 0.587 ∙ Green + 0.114 ∙ Blue
  The corresponding grayscale pixel value is a weighted sum of the RGB pixel values and represents the pixel "intensity".
  """
  gray_min_region = cv2.cvtColor(colored_min_region, cv2.COLOR_BGR2GRAY)
  gray_maj_region = cv2.cvtColor(colored_maj_region, cv2.COLOR_BGR2GRAY)

  # Flatten each greyscale image to a vector
  gray_min_vector = gray_min_region.flatten()
  gray_maj_vector = gray_maj_region.flatten()

  # Extract nonzero pixel values (ie discard the black background pixels)
  gray_min_pixels = gray_min_vector[gray_min_vector > 0]
  gray_maj_pixels = gray_maj_vector[gray_maj_vector > 0]

  # Compute histogram of pixel intensities for the minority color segmentation mask
  min_mask_distribution = np.bincount(gray_min_pixels) / gray_min_pixels.size

  # Compute histogram of pixel intensities for the majority color segmentation mask
  maj_mask_distribution = np.bincount(gray_maj_pixels) / gray_maj_pixels.size

  # Ensure that the probability distribtion sums to 1
  assert(min_mask_distribution.sum() == 1)
  assert(maj_mask_distribution.sum() == 1)

  # Compute Wasserstein distance between the two histograms of pixel intensities
  distribution_distance = scipy.stats.wasserstein_distance(min_mask_distribution, maj_mask_distribution, u_weights=None, v_weights=None)
  loss = distribution_distance

  return loss