In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
# Define the filename
filename = 'benchmark_results/inter_frame_results.npy'
data = np.load(filename, allow_pickle=True).item()

videos = ['Beauty', 'Bosphorus', 'ShakeNDry', 'HoneyBee']
video_df= []
for video in videos:
    video_df.append(pd.DataFrame(data[video]) )
    
metrics = pd.concat(video_df).groupby(level=0)[['bpp', 'LPIPS', 'MS-SSIM','PSNR','FID']].mean()
ours = metrics.iloc[3]
hevc = metrics.iloc[0:3]

In [None]:
# --- Create placeholder data (REPLACE with your actual 'hevc' and 'ours' DataFrames) ---
# This makes the script runnable for demonstration.
hevc_data = {
    'bpp': np.linspace(0.001, 0.05, 50),
    'LPIPS': 0.3 - np.linspace(0.001, 0.05, 50) * 3,
    'MS-SSIM': 0.8 + np.linspace(0.001, 0.05, 50) * 1.5,
    'PSNR': 20 + np.sqrt(np.linspace(0.001, 0.05, 50)) * 25,
    'FID': 10 - np.sqrt(np.linspace(0.001, 0.05, 50)) * 40
}
hevc = pd.DataFrame(hevc_data)

ours_data = {
    'bpp': 0.025,
    'LPIPS': 0.18,
    'MS-SSIM': 0.91,
    'PSNR': 28.5,
    'FID': 2.5
}
ours = pd.Series(ours_data)
# --- End of placeholder data ---


# 1. Prepare data and create the output folder
df = hevc.sort_values('bpp')
highlight_point = ours
output_folder = 'benchmark_results/gop2_plots'
os.makedirs(output_folder, exist_ok=True) # Creates the folder if it doesn't exist

# 2. Define the columns for plotting
x_col = 'bpp'
y_cols = ['LPIPS', 'PSNR', 'MS-SSIM', 'FID']

# 3. Loop through each metric to create and save a separate plot
for y_col in y_cols:
    # Create a new, single figure for each iteration
    fig, ax = plt.subplots(figsize=(8, 6))

    # Plot the original data points on the new axis
    ax.plot(df[x_col], df[y_col], label='HEVC Inter Only')

    # Plot your specific point on top
    ax.scatter(highlight_point[x_col], highlight_point[y_col],
               color='red',
               marker='*',
               s=250,
               zorder=5,
               label='BiFCA Inter Prediction')

    # Set titles, labels, and legend
    ax.set_title(f'{y_col.upper()} vs {x_col.upper()}', fontsize=16)
    ax.set_xlabel(x_col.upper(), fontsize=12)
    ax.set_ylabel(y_col.upper(), fontsize=12)
    ax.grid(True, linestyle='--', alpha=0.6)
    ax.legend()

    # Define the full path and save the figure
    filename = f"{y_col}_vs_{x_col}.svg"
    full_path = os.path.join(output_folder, filename)
    plt.savefig(full_path, format='svg', dpi=800)

    # Close the figure to free up memory before the next loop
    plt.close(fig)

print(f"Successfully saved 4 plots in the '{output_folder}' directory.")