In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
import imageio

def show_video_gif_multiple_withError(prev, true, pred, vmax=1.0, vmin=0.0, cmap='jet', norm=None, out_path=None, use_rgb=False):
    """Generate gif with a video sequence and plot absolute error along with mean relative error using provided MRE formula."""
    
    def swap_axes(x):
        if len(x.shape) > 3:
            return x.swapaxes(1, 2).swapaxes(2, 3)
        else:
            return x

    prev, true, pred = map(swap_axes, [prev, true, pred])
    prev_frames = prev.shape[0]
    frames = prev_frames + true.shape[0]
    images = []
    
    for i in range(frames):
        fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(18, 9))  # Larger figsize for higher resolution
        for t, ax in enumerate(axes):
            if t == 0:
                plt.text(0.3, 1.05, 'Ground Truth', fontsize=15, color='green', transform=ax.transAxes)
                if i < prev_frames:
                    frame = prev[i]
                else:
                    frame = true[i - prev_frames]
                im = ax.imshow(frame, cmap=cmap, norm=norm, vmin=vmin, vmax=vmax)
                cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04, aspect=10)
                cbar.ax.tick_params(labelsize=8)
                cbar.set_label('Pixel Value', fontsize=10)
                
            elif t == 1:
                plt.text(0.2, 1.05, 'Predicted Frames', fontsize=15, color='red', transform=ax.transAxes)
                if i < prev_frames:
                    frame = prev[i]
                else:
                    frame = pred[i - prev_frames]
                im = ax.imshow(frame, cmap=cmap, norm=norm, vmin=vmin, vmax=vmax)
                cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04, aspect=10)
                cbar.ax.tick_params(labelsize=8)
                cbar.set_label('Pixel Value', fontsize=10)
                
            elif t == 2:
                plt.text(0.2, 1.05, 'Absolute Error', fontsize=15, color='blue', transform=ax.transAxes)
                if i < prev_frames:
                    # Plot prev - prev (which should result in all zeros)
                    abs_error = np.zeros_like(prev[i])
                    mre = 0.0  # No error, as we are comparing the same frames
                else:
                    # Plot absolute error for the remaining frames
                    abs_error = np.abs(true[i - prev_frames] - pred[i - prev_frames])
                    
                    # Calculate MRE using the provided formula
                    phi_gt = true[i - prev_frames]
                    phi_pred = pred[i - prev_frames]
                    num_pixels = np.prod(phi_gt.shape)
                    squared_error = np.sum((phi_gt - phi_pred) ** 2) / num_pixels
                    phi_gt_max = np.max(phi_gt)
                    phi_gt_min = np.min(phi_gt)
                    mre = np.sqrt(squared_error) / (phi_gt_max - phi_gt_min) * 100

                im = ax.imshow(abs_error, cmap=cmap, norm=norm, vmin=vmin, vmax=vmax)
                cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04, aspect=10)
                cbar.ax.tick_params(labelsize=8)
                cbar.set_label('Absolute Error', fontsize=10)
                
                # Use ax.text() to manually place MRE below the third subplot
                ax.text(0.5, -0.1, f'Mean Relative Error: {mre:.4f}%', 
                        fontsize=12, color='blue', ha='center', transform=ax.transAxes)

            ax.axis('off')
        
        # Save the frame to the temporary image and append to images list for GIF creation
        plt.savefig(f'./tmp_frame_{i}.png', bbox_inches='tight', format='png', dpi=300)  # Higher DPI
        images.append(imageio.imread(f'./tmp_frame_{i}.png'))
        plt.close()

    # Remove temporary files after GIF creation
    if out_path is not None:
        if not out_path.endswith('gif'):
            out_path = out_path + '.gif'
        
        # Create GIF using the frames and set it to loop infinitely (loop=0)
        imageio.mimsave(out_path, images, duration=0.1, loop=0)  # loop=0 for infinite looping GIF

    # Optionally, clean up the temporary files after GIF creation
    for i in range(frames):
        os.remove(f'./tmp_frame_{i}.png')

In [None]:
# # Import the function for generating GIFs
# from openstl.utils import show_video_gif_multiple

for i in range(len(inputs)):
    example_idx = i

    # Modify the output filename to include the random index
    output_gif_filename = f'./prediction_gif/example_{example_idx}.gif'
    # show_video_gif_multiple(inputs[example_idx], trues[example_idx], preds[example_idx], out_path=output_gif_filename)
    show_video_gif_multiple_withError(inputs[example_idx], trues[example_idx], preds[example_idx], out_path=output_gif_filename)

    print(f"GIF saved as {output_gif_filename}")

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def ComputeTestError(prediction, target):
    # """Compute the test error based on Mean Squared Error and relative range using NumPy arrays."""
    # tmp_error = np.sqrt(np.mean((prediction - target) ** 2)) / (np.max(target) - np.min(target))
    # return tmp_error

    """Compute the Mean Relative Error (MRE) between prediction and target using NumPy arrays."""
    # Avoid division by zero by using a small epsilon where target is zero
    # Calculate MRE using the provided formula

    num_pixels = np.prod(target.shape)
    squared_error = np.sum((target - prediction) ** 2) / num_pixels
    phi_gt_max = np.max(target)
    phi_gt_min = np.min(target)
    mre = np.sqrt(squared_error) / (phi_gt_max - phi_gt_min) * 100
    
    return mre

def TestErrorPlot(pred, true, figsize=(10, 6), dpi=200, out_path="./statistics_testdata.png"):
    """Plot the test error for the given predictions and true data, save as PNG."""
    error_List = []
    testID_List = []
    count = 1

    # Get the number of cases and number of comparisons from pred.shape
    num_cases = pred.shape[0]
    num_comparisons = pred.shape[1]

    # Loop through each case in the prediction and true arrays
    for i in range(num_cases):
        # Use a qualitative colormap for contrasting colors (e.g., tab20 colormap)
        color = plt.cm.tab20(i % 20)  # Ensures up to 20 distinct colors
        
        # For each case, we calculate the error for all comparisons (inferred from pred.shape[1])
        for j in range(num_comparisons):
            # Compute test error for each comparison
            tmp_error = ComputeTestError(pred[i, j], true[i, j])
            error_List.append(tmp_error)
            testID_List.append(count)

            # Vary transparency (alpha) based on comparison position in the case
            alpha_value = (j + 1) / num_comparisons  # Later points are more transparent
            plt.plot(count, tmp_error, 'o', color=color, markersize=4, alpha=alpha_value)
            count += 1

    # Convert lists to numpy arrays
    testID_List = np.asarray(testID_List)
    error_List = np.asarray(error_List)
    avg_error = np.average(error_List)
    
    # Plot the average error line across all cases
    plt.axhline(avg_error, color='red', linestyle='--', linewidth=2, label=f'Average Error: {avg_error:.4f}%')

    # Add labels and title
    plt.xlabel('Samples (color = case, darker = later prediction)', fontsize=12)
    plt.ylabel('Mean Relative Error (%)', fontsize=12)
    plt.title('Accuracy Statistics Plot for Test Dataset (previously unseen)', fontsize=14)

    # Add gridlines for better readability
    plt.grid(True)

    # Add a legend for the average error
    plt.legend(loc='upper right', fontsize=8)  # Smaller legend font size
    
    # Save the figure with the specified size and resolution
    plt.tight_layout()
    plt.savefig(out_path, dpi=dpi, format='png')

    # Display the average error in the terminal
    print(f'Average Error: {avg_error:.4f}%')
    print(f'Max error index: {np.argmax(error_List)}')

# Call the TestErrorPlot function
TestErrorPlot(preds, trues, figsize=(20, 8), dpi=300, out_path="./test_error_plot.png")
