In [1]:
import cv2
import numpy as np
from scipy.optimize import linear_sum_assignment
import matplotlib.pyplot as plt


In [2]:

def histogram_intersection(p, q):
    # Calculate histogram intersection
    # q is the reference distribution
    # p is the query distribution
    
    minima = np.minimum(p, q)
    hi = np.true_divide(np.sum(minima), np.sum(p))
    
    if hi > 1 or hi < 0:
        print('Error, HI =', hi)
        print(p)
        print(q)
        return
    return hi


def chi_square_distance(p, q):
    # Calculate chi-square distance
    # q is the reference distribution
    # p is the query distribution
    
    p = np.array(p)/np.sum(p)
    q = np.array(q)/np.sum(q)
    
    p = np.array(p) + 1
    q = np.array(q) + 1
    
    return np.sum(((p - q)**2 / (p + q))) / 2


def kl_divergence(p, q):
    # Calculate Kullback-Leibler Divergence
    # q is the reference distribution
    # p is the query distribution
    
    p = np.array(p, dtype=float)
    q = np.array(q, dtype=float)
    
    p = np.array(p) + 1
    q = np.array(q) + 1
    
    p /= p.sum()
    q /= q.sum()
    
    kl_div = np.sum(p * np.log(p / q))
    return kl_div


def earth_movers_distance(p, q):
    # Calculate Earth Mover's Distance
    # q is the reference distribution
    # p is the query distribution
    
    p = np.array(p, dtype=float)
    q = np.array(q, dtype=float)

    # Normalize the distributions to ensure they sum to 1
    p /= p.sum()
    q /= q.sum()

    # Calculate cumulative distributions
    P = np.cumsum(p)
    Q = np.cumsum(q)

    # Calculate the cost matrix
    C = np.abs(np.subtract.outer(P, Q))

    # Solve linear sum assignment problem
    row_ind, col_ind = linear_sum_assignment(C)
    emd = C[row_ind, col_ind].sum()

    return emd


def kolmogorov_smirnov_distance(p, q):
    
    # Calculate KS distance
    # q is the reference distribution
    # p is the query distribution
    
    p = np.array(p, dtype=float)
    q = np.array(q, dtype=float)

    # Normalize the distributions to ensure they sum to 1
    p /= p.sum()
    q /= q.sum()

    # Calculate cumulative distributions
    P = np.cumsum(p)
    Q = np.cumsum(q)

    ks_distance = np.max(np.abs(P - Q))

    return ks_distance


def rank_probability_score(p, q):
    
    # Calculate the ranked probability score
    # q is the reference distribution
    # p is the query distribution
    
    p = np.array(p, dtype=float)
    q = np.array(q, dtype=float)

    # Normalize the distributions to ensure they sum to 1
    p /= p.sum()
    q /= q.sum()

    # Calculate cumulative distributions
    P = np.cumsum(p)
    Q = np.cumsum(q)

    rps = np.sum((P - Q)**2)

    return rps


def RDS(p, q):
    
    # Calculate Relative distribution shift
    # q is the reference distribution
    # p is the query distribution
    
    p_bins = len(p)
    p_obs = sum(p)
    
    q_bins = len(q)
    q_obs = sum(q)
    
    z_p = (p_bins + 1)/p_bins
    p = [sum(p[:ii+1])**(z_p) for ii in range(len(p))]
    Sp = np.sum(np.array(p)/(p_obs**z_p)) - 1
    Sp = Sp/(p_bins - 1)
    
    z_q = (q_bins + 1)/q_bins
    q = [sum(q[:ii+1])**(z_q) for ii in range(len(q))]
    Sq = np.sum(np.array(q)/(q_obs**z_q)) - 1
    Sq = Sq/(q_bins - 1)
    
    return Sq - Sp


In [3]:

def flatten_histogram(image, num_bins, htype='color'):
    
    if htype == 'color':
        """
        Construct a flattened 1D histogram for a color image by concatenating the histograms
        of each color channel.
        """
        if len(image.shape) == 3:  # Check if the image is indeed in color (RGB)
            channels = cv2.split(image)
            histograms = []
            for chan in channels:
                hist, _ = np.histogram(chan, bins=num_bins, range=(0, num_bins))
                histograms.append(hist)
            histogram = np.concatenate(histograms) / np.sum(histograms)
        else:  # For grayscale images, compute histogram directly
            histogram, _ = np.histogram(image, bins=num_bins, range=(0, num_bins))
            histogram = histogram / np.sum(histogram)
        return histogram
    
    elif htype == 'hue':
        """
        Construct a 1D hue histogram for a color image by converting it to the HSV color space
        and calculating the histogram of the hue component.
        """
        # Convert the RGB image to HSV color space
        hsv_image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)

        # Extract the Hue component
        hue_channel = hsv_image[:, :, 0]

        # Calculate the histogram of the Hue component
        histogram, _ = np.histogram(hue_channel, bins=num_bins, range=(0, 180))  # Hue values range from 0 to 180 in OpenCV

        # Normalize the histogram
        histogram = histogram / np.sum(histogram)
    
        return histogram
    
    
    elif htype == 'intensity':
        """
        Construct a 1D intensity histogram for an image by converting it to grayscale
        and calculating the histogram based on the intensity values.
        """
        # Check if the image is already in grayscale
        if len(image.shape) == 3:  # If not, convert the image to grayscale
            grayscale_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        else:
            grayscale_image = image  # The image is already in grayscale

        # Calculate the histogram of the intensity values
        histogram, _ = np.histogram(grayscale_image, bins=num_bins, range=(0, 256))

        # Normalize the histogram
        histogram = histogram / np.sum(histogram)
        return histogram
    


def process_video_with_metric(video_path, metric_name, num_bins, return_frames):
    video = cv2.VideoCapture(video_path)
    total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
    metric_values = []
    frame_indices = []
    first_frame = None
    max_rds_value = None
    max_rds_frame = None

    # Define videos that need trimming
    videos_to_trim = ['data/time_lapse_video/855852-sd_640_360_30fps.mp4', 
                      'data/time_lapse_video/856030-sd_640_360_25fps.mp4',
                      'data/time_lapse_video/856006-sd_640_360_25fps.mp4',
                      'data/time_lapse_video/856990-sd_640_360_30fps.mp4',
                      'data/time_lapse_video/2038351-sd_640_360_30fps.mp4',
                      'data/time_lapse_video/855215-sd_640_360_24fps.mp4',
                      'data/time_lapse_video/3541930-sd_426_240_24fps.mp4',
                      'data/time_lapse_video/9255202-sd_426_240_25fps.mp4',
                      'data/time_lapse_video/9255065-sd_426_240_25fps.mp4',
                      'data/time_lapse_video/10881636-sd_640_360_25fps.mp4',
                      'data/time_lapse_video/10881635-sd_640_360_25fps.mp4',
                      'data/time_lapse_video/10881640-sd_640_360_25fps.mp4',
                      'data/time_lapse_video/14668-258508783_tiny.mp4',
                      'data/time_lapse_video/11440-230272102_tiny.mp4',
                      'data/time_lapse_video/9289-218679635_tiny.mp4',
                      'data/time_lapse_video/16986-277920487_tiny.mp4',
                      'data/time_lapse_video/8560-211218186_tiny.mp4',
                      'data/time_lapse_video/18020-287831424_tiny.mp4',
                      'data/time_lapse_video/81089-574743647_medium.mp4',    
                 ]

    # Determine if the current video needs trimming
    trim = video_path in videos_to_trim

    # Adjust start and end indices if trimming is needed
    start_index = 10 if trim else 0
    end_index = total_frames - 10 if trim else total_frames

    for frame_count in range(total_frames):
        success, frame = video.read()
        if not success:
            break

        # Skip processing for trimmed frames
        if frame_count < start_index or frame_count >= end_index:
            continue

        if metric_name == 'RDS' and frame_count == start_index:
            first_frame = frame  # Adjusted to capture the first untrimmed frame

        if frame_count == start_index:  # Adjusted to use the first untrimmed frame as reference
            reference_hist = flatten_histogram(frame, num_bins, 'color')

        frame_hist = flatten_histogram(frame, num_bins, 'color')

        if metric_name == '1 - HI':
            metric_value = 1 - histogram_intersection(frame_hist, reference_hist)
        elif metric_name == 'Chi-Square':
            metric_value = chi_square_distance(frame_hist, reference_hist)
        elif metric_name == 'KL':
            metric_value = kl_divergence(frame_hist, reference_hist)
        elif metric_name == 'EMD':
            metric_value = earth_movers_distance(frame_hist, reference_hist)
        elif metric_name == 'KS':
            metric_value = kolmogorov_smirnov_distance(frame_hist, reference_hist)
        elif metric_name == 'RPS':
            metric_value = rank_probability_score(frame_hist, reference_hist)
        elif metric_name == 'RDS':
            metric_value = RDS(frame_hist, reference_hist)

        metric_values.append(metric_value)

        # Update max RDS frame and value if this is the RDS metric
        if metric_name == 'RDS' and (max_rds_value is None or abs(metric_value) > abs(max_rds_value)):
            max_rds_value = metric_value
            max_rds_frame = frame

        frame_indices.append(frame_count)

    video.release()
    if return_frames == 1:
        return metric_values, frame_indices, first_frame, max_rds_frame, max_rds_value
    else:
        return metric_values, frame_indices

video_path = 'data/time_lapse_video/5094592-sd_226_426_25fps.mp4' # main

num_bins = 256
EMD1, EMD1_frames = process_video_with_metric(video_path, 'EMD', num_bins, 0)
RDS1, RDS1_frames, RDS1_first_frame, RDS1_max_rds_frame, RDS1_max_rds_value = process_video_with_metric(video_path, 'RDS', num_bins, 1)
HI1, HI1_frames = process_video_with_metric(video_path, '1 - HI', num_bins, 0)
CS1, CS1_frames = process_video_with_metric(video_path, 'Chi-Square', num_bins, 0)
KL1, KL1_frames = process_video_with_metric(video_path, 'KL', num_bins, 0)
KS1, KS1_frames = process_video_with_metric(video_path, 'KS', num_bins, 0)
RPS1, RPS1_frames = process_video_with_metric(video_path, 'RPS', num_bins, 0)


video_path = 'data/time_lapse_video/5388577-sd_240_426_25fps.mp4' # main

num_bins = 256
EMD2, EMD2_frames = process_video_with_metric(video_path, 'EMD', num_bins, 0)
RDS2, RDS2_frames, RDS2_first_frame, RDS2_max_rds_frame, RDS2_max_rds_value = process_video_with_metric(video_path, 'RDS', num_bins, 1)
HI2, HI2_frames = process_video_with_metric(video_path, '1 - HI', num_bins, 0)
CS2, CS2_frames = process_video_with_metric(video_path, 'Chi-Square', num_bins, 0)
KL2, KL2_frames = process_video_with_metric(video_path, 'KL', num_bins, 0)
KS2, KS2_frames = process_video_with_metric(video_path, 'KS', num_bins, 0)
RPS2, RPS2_frames = process_video_with_metric(video_path, 'RPS', num_bins, 0)


video_path = 'data/time_lapse_video/7101966-hd_1080_1644_25fps.mp4' # main

num_bins = 256
EMD3, EMD3_frames = process_video_with_metric(video_path, 'EMD', num_bins, 0)
RDS3, RDS3_frames, RDS3_first_frame, RDS3_max_rds_frame, RDS3_max_rds_value = process_video_with_metric(video_path, 'RDS', num_bins, 1)
HI3, HI3_frames = process_video_with_metric(video_path, '1 - HI', num_bins, 0)
CS3, CS3_frames = process_video_with_metric(video_path, 'Chi-Square', num_bins, 0)
KL3, KL3_frames = process_video_with_metric(video_path, 'KL', num_bins, 0)
KS3, KS3_frames = process_video_with_metric(video_path, 'KS', num_bins, 0)
RPS3, RPS3_frames = process_video_with_metric(video_path, 'RPS', num_bins, 0)


video_path = 'data/time_lapse_video/14219843-sd_426_240_24fps.mp4'

num_bins = 256
EMD4, EMD4_frames = process_video_with_metric(video_path, 'EMD', num_bins, 0)
RDS4, RDS4_frames, RDS4_first_frame, RDS4_max_rds_frame, RDS4_max_rds_value = process_video_with_metric(video_path, 'RDS', num_bins, 1)
HI4, HI4_frames = process_video_with_metric(video_path, '1 - HI', num_bins, 0)
CS4, CS4_frames = process_video_with_metric(video_path, 'Chi-Square', num_bins, 0)
KL4, KL4_frames = process_video_with_metric(video_path, 'KL', num_bins, 0)
KS4, KS4_frames = process_video_with_metric(video_path, 'KS', num_bins, 0)
RPS4, RPS4_frames = process_video_with_metric(video_path, 'RPS', num_bins, 0)


video_path = 'data/time_lapse_video/7235032-sd_240_426_30fps.mp4' 

num_bins = 256
EMD5, EMD5_frames = process_video_with_metric(video_path, 'EMD', num_bins, 0)
RDS5, RDS5_frames, RDS5_first_frame, RDS5_max_rds_frame, RDS5_max_rds_value = process_video_with_metric(video_path, 'RDS', num_bins, 1)
HI5, HI5_frames = process_video_with_metric(video_path, '1 - HI', num_bins, 0)
CS5, CS5_frames = process_video_with_metric(video_path, 'Chi-Square', num_bins, 0)
KL5, KL5_frames = process_video_with_metric(video_path, 'KL', num_bins, 0)
KS5, KS5_frames = process_video_with_metric(video_path, 'KS', num_bins, 0)
RPS5, RPS5_frames = process_video_with_metric(video_path, 'RPS', num_bins, 0)


In [5]:
def rotate_if_needed(image):
    if image.shape[1] > image.shape[0]:  # Check if width > height
        return cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)  # Rotate 90 degrees clockwise
    return image

    
fig = plt.figure(figsize=(18, 18))

fs1 = 22
fs2 = 18
fs3 = 12

ax1 = plt.subplot(5, 5, 1)
rotated_first_frame = rotate_if_needed(RDS1_first_frame)
ax1.imshow(cv2.cvtColor(rotated_first_frame, cv2.COLOR_BGR2RGB))
ax1.set_title('First frame', fontsize=fs1)
ax1.axis('off')

ax2 = plt.subplot(5, 5, 2)
rotated_first_frame = rotate_if_needed(RDS1_max_rds_frame)
ax2.imshow(cv2.cvtColor(rotated_first_frame, cv2.COLOR_BGR2RGB))
ax2.set_title(f'RDS: {RDS1_max_rds_value:.3f}', fontsize=fs1)
ax2.axis('off')

ax3 = plt.subplot(5, 5, 3)
plt.plot(RDS1_frames, RDS1, c='m', linewidth=0.7)
plt.xlabel('Frame', fontsize= fs2)
plt.ylabel('', fontsize= fs2)
plt.title('RDS', fontsize=fs1, fontweight='bold', color='m')
plt.tick_params(axis='both', labelsize=fs3)

ax4 = plt.subplot(5, 5, 4)
plt.plot(EMD1_frames, EMD1, c='c', linewidth=0.7)
plt.xlabel('Frame', fontsize= fs2)
plt.ylabel('', fontsize= fs2)
plt.title('EMD', fontsize=fs1, fontweight='bold', color='c')
plt.tick_params(axis='both', labelsize=fs3)

ax5 = plt.subplot(5, 5, 5)
plt.plot(HI1_frames, HI1, c='0.5', linewidth=0.7)
plt.xlabel('Frame', fontsize= fs2)
plt.ylabel('', fontsize= fs2)
plt.title('1 - HI', fontsize=fs1, fontweight='bold', color='0.5')
plt.tick_params(axis='both', labelsize=fs3)


#### Row 2

ax6 = plt.subplot(5, 5, 6)
rotated_first_frame = rotate_if_needed(RDS2_first_frame)
ax6.imshow(cv2.cvtColor(rotated_first_frame, cv2.COLOR_BGR2RGB))
ax6.set_title('First frame', fontsize=fs1)
ax6.axis('off')

ax7 = plt.subplot(5, 5, 7)
rotated_first_frame = rotate_if_needed(RDS2_max_rds_frame)
ax7.imshow(cv2.cvtColor(rotated_first_frame, cv2.COLOR_BGR2RGB))
ax7.set_title(f'RDS: {RDS2_max_rds_value:.3f}', fontsize=fs1)
ax7.axis('off')

ax8 = plt.subplot(5, 5, 8)
plt.plot(RDS2_frames, RDS2, c='m', linewidth=0.7)
plt.xlabel('Frame', fontsize= fs2)
plt.ylabel('', fontsize= fs2)
plt.tick_params(axis='both', labelsize=fs3)

ax9 = plt.subplot(5, 5, 9)
plt.plot(EMD2_frames, EMD2, c='c', linewidth=0.7)
plt.xlabel('Frame', fontsize= fs2)
plt.ylabel('', fontsize= fs2)
plt.tick_params(axis='both', labelsize=fs3)

ax10 = plt.subplot(5, 5, 10)
plt.plot(HI2_frames, HI2, c='0.5', linewidth=0.7)
plt.xlabel('Frame', fontsize= fs2)
plt.ylabel('', fontsize= fs2)
plt.tick_params(axis='both', labelsize=fs3)


#### Row 3

ax11 = plt.subplot(5, 5, 11)
rotated_first_frame = rotate_if_needed(RDS3_first_frame)
ax11.imshow(cv2.cvtColor(rotated_first_frame, cv2.COLOR_BGR2RGB))
ax11.set_title('First frame', fontsize=fs1)
ax11.axis('off')

ax12 = plt.subplot(5, 5, 12)
rotated_first_frame = rotate_if_needed(RDS3_max_rds_frame)
ax12.imshow(cv2.cvtColor(rotated_first_frame, cv2.COLOR_BGR2RGB))
ax12.set_title(f'RDS: {RDS3_max_rds_value:.3f}', fontsize=fs1)
ax12.axis('off')

ax13 = plt.subplot(5, 5, 13)
plt.plot(RDS3_frames, RDS3, c='m', linewidth=0.7)
plt.xlabel('Frame', fontsize= fs2)
plt.ylabel('', fontsize= fs2)
plt.tick_params(axis='both', labelsize=fs3)

ax14 = plt.subplot(5, 5, 14)
plt.plot(EMD3_frames, EMD3, c='c', linewidth=0.7)
plt.xlabel('Frame', fontsize= fs2)
plt.ylabel('', fontsize= fs2)
plt.tick_params(axis='both', labelsize=fs3)

ax15 = plt.subplot(5, 5, 15)
plt.plot(HI3_frames, HI3, c='0.5', linewidth=0.7)
plt.xlabel('Frame', fontsize= fs2)
plt.ylabel('', fontsize= fs2)
plt.tick_params(axis='both', labelsize=fs3)


#### Row 4

ax16 = plt.subplot(5, 5, 16)
rotated_first_frame = rotate_if_needed(RDS4_first_frame)
ax16.imshow(cv2.cvtColor(rotated_first_frame, cv2.COLOR_BGR2RGB))
ax16.set_title('First frame', fontsize=fs1)
ax16.axis('off')

ax17 = plt.subplot(5, 5, 17)
rotated_first_frame = rotate_if_needed(RDS4_max_rds_frame)
ax17.imshow(cv2.cvtColor(rotated_first_frame, cv2.COLOR_BGR2RGB))
ax17.set_title(f'RDS: {RDS4_max_rds_value:.3f}', fontsize=fs1)
ax17.axis('off')

ax18 = plt.subplot(5, 5, 18)
plt.plot(RDS4_frames, RDS4, c='m', linewidth=0.7)
plt.xlabel('Frame', fontsize= fs2)
plt.ylabel('', fontsize= fs2)
plt.tick_params(axis='both', labelsize=fs3)

ax19 = plt.subplot(5, 5, 19)
plt.plot(EMD4_frames, EMD4, c='c', linewidth=0.7)
plt.xlabel('Frame', fontsize= fs2)
plt.ylabel('', fontsize= fs2)
plt.tick_params(axis='both', labelsize=fs3)

ax20 = plt.subplot(5, 5, 20)
plt.plot(HI4_frames, HI4, c='0.5', linewidth=0.7)
plt.xlabel('Frame', fontsize= fs2)
plt.ylabel('', fontsize= fs2)
plt.tick_params(axis='both', labelsize=fs3)


#### Row 5

ax21 = plt.subplot(5, 5, 21)
rotated_first_frame = rotate_if_needed(RDS5_first_frame)
ax21.imshow(cv2.cvtColor(rotated_first_frame, cv2.COLOR_BGR2RGB))
ax21.set_title('First frame', fontsize=fs1)
ax21.axis('off')

ax22 = plt.subplot(5, 5, 22)
rotated_first_frame = rotate_if_needed(RDS5_max_rds_frame)
ax22.imshow(cv2.cvtColor(rotated_first_frame, cv2.COLOR_BGR2RGB))
ax22.set_title(f'RDS: {RDS5_max_rds_value:.3f}', fontsize=fs1)
ax22.axis('off')

ax23 = plt.subplot(5, 5, 23)
plt.plot(RDS5_frames, RDS5, c='m', linewidth=0.7)
plt.xlabel('Frame', fontsize= fs2)
plt.ylabel('', fontsize= fs2)
plt.tick_params(axis='both', labelsize=fs3)

ax24 = plt.subplot(5, 5, 24)
plt.plot(EMD5_frames, EMD5, c='c', linewidth=0.7)
plt.xlabel('Frame', fontsize= fs2)
plt.ylabel('', fontsize= fs2)
plt.tick_params(axis='both', labelsize=fs3)

ax25 = plt.subplot(5, 5, 25)
plt.plot(HI5_frames, HI5, c='0.5', linewidth=0.7)
plt.xlabel('Frame', fontsize= fs2)
plt.ylabel('', fontsize= fs2)
plt.tick_params(axis='both', labelsize=fs3)

 
fig.patch.set_facecolor('white')
plt.subplots_adjust(hspace=0.4, wspace=0.45)
plt.savefig('Final_Figs/manuscript/Fig6.jpg', bbox_inches='tight', format='jpg', dpi=600)
plt.close()
