In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Load the data
file_path = './output_strided_3.csv'  # Replace with your file path
data = pd.read_csv(file_path)

# Filter for `reduce_scatter` operation
reduce_scatter_data = data[data['operation'] == 'reduce_scatter']

# Log transformations
reduce_scatter_data['log2_data_size'] = np.log2(reduce_scatter_data['data_size_mb'])
reduce_scatter_data['log10_duration'] = np.log10(reduce_scatter_data['duration_sec'])

# Define colors based on `stride` values from `num_gpus=2` to ensure consistency
reference_data = reduce_scatter_data[reduce_scatter_data['num_gpus'] == 2]
stride_colors = {stride: color for stride, color in zip(reference_data['stride'].unique(), plt.cm.tab10(np.linspace(0, 1, len(reference_data['stride'].unique()))))}

# Set up subplots with shared y-axis
num_gpus_list = reduce_scatter_data['num_gpus'].unique()
fig, axes = plt.subplots(1, len(num_gpus_list), figsize=(9, 3), sharey=True, dpi=500)

for ax, num_gpus in zip(axes, num_gpus_list):
    gpu_data = reduce_scatter_data[reduce_scatter_data['num_gpus'] == num_gpus]
    
    for stride, color in stride_colors.items():
        stride_data = gpu_data[gpu_data['stride'] == stride]
        
        # Scatter plot for samples with matching color
        ax.scatter(
            stride_data['log2_data_size'], 
            stride_data['log10_duration'], 
            alpha=0.6, color=color, s=10,
            marker='x'
        )
        
        # Average line plot with matching color
        avg_duration = stride_data.groupby('log2_data_size')['log10_duration'].mean()
        ax.plot(avg_duration.index, avg_duration.values, color=color, linewidth=2)
    
    # Plot settings for each subplot
    ax.set_title(f'Group Size: {num_gpus}')
    ax.set_xlabel('log2(Data Size in MB)')
    ax.grid(True)

# Shared y-axis label
fig.text(0.00, 0.5, 'log10(Runtime in seconds)', va='center', rotation='vertical')

# Single legend for all subplots
handles = [plt.Line2D([0], [0], color=color, lw=2, label=f'Stride {stride}') for stride, color in stride_colors.items()]
fig.legend(handles=handles, bbox_to_anchor=(1.00, 0.5), loc='center left', title="Stride")

plt.tight_layout()
plt.show()
