In [32]:
import glob
import json
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from collections import defaultdict
from pprint import pprint
import pandas as pd

In [49]:
# Define the path pattern to collect all result.json files
path_pattern = "/home/users/MTrappett/ray_results/run_continual_learning_2024-05-03_10-56-44/*/result.json"

# Collect data from files
data = {}
for file_path in glob.glob(path_pattern):
    print(file_path)
    with open(file_path, 'r') as file:
        for line in file:
            result = json.loads(line)
            config = result['config']
            nb1 = config['n_b_1']
            nb2 = config['n_b_2']
            key = (nb1, nb2)
            
            if key not in data:
                data[key] = {'remembering': [], 'forward_transfer': []}
            
            data[key]['remembering'].append(result['remembering'])
            data[key]['forward_transfer'].append(result['forward_transfer'])
            break  # only need the first line

# Calculate mean and standard deviation for each metric for each combination of n_b_1 and n_b_2
records = []
for (nb1, nb2), metrics in data.items():
    for metric, values in metrics.items():
        mean_val = np.mean(values)
        std_val = np.std(values)
        records.append({
            'n_b_1': nb1,
            'n_b_2': nb2,
            'metric': metric,
            'mean_value': mean_val,
            'std_value': std_val
        })

# Create a DataFrame from the aggregated records
df = pd.DataFrame(records)


/home/users/MTrappett/ray_results/run_continual_learning_2024-05-03_10-56-44/run_continual_learning_80ea8_00003_3_n_b_2=1000_2024-05-03_10-56-44/result.json
/home/users/MTrappett/ray_results/run_continual_learning_2024-05-03_10-56-44/run_continual_learning_80ea8_00000_0_n_b_2=2_2024-05-03_10-56-44/result.json
/home/users/MTrappett/ray_results/run_continual_learning_2024-05-03_10-56-44/run_continual_learning_80ea8_00001_1_n_b_2=10_2024-05-03_10-56-44/result.json
/home/users/MTrappett/ray_results/run_continual_learning_2024-05-03_10-56-44/run_continual_learning_80ea8_00002_2_n_b_2=500_2024-05-03_10-56-44/result.json


In [50]:
print(data)

{(14, 1000): {'remembering': [0.9304434161967742], 'forward_transfer': [0.008429926238145424]}, (14, 2): {'remembering': [0.916680814139928], 'forward_transfer': [-0.050718212449015775]}, (14, 10): {'remembering': [0.8047105189077601], 'forward_transfer': [-0.16929228308676533]}, (14, 500): {'remembering': [0.8549353811575201], 'forward_transfer': [-0.06997342781222321]}}


In [51]:
df.head()

Unnamed: 0,n_b_1,n_b_2,metric,mean_value,std_value
0,14,1000,remembering,0.930443,0.0
1,14,1000,forward_transfer,0.00843,0.0
2,14,2,remembering,0.916681,0.0
3,14,2,forward_transfer,-0.050718,0.0
4,14,10,remembering,0.804711,0.0


In [56]:
def create_heatmap(data):
    
    # Prepare the data for 'remembering' and 'forward_transfer'
    remembering_df = data.loc[data['metric'] == 'remembering', ['n_b_1', 'n_b_2', 'mean_value']]
    forward_transfer_df = data.loc[data['metric'] == 'forward_transfer', ['n_b_1', 'n_b_2', 'mean_value']]
    
    # Create heatmaps
    def create_plot(metric_df, metric_name):
        # Pivot data for heatmap
        z = metric_df.pivot_table(index='n_b_1', columns='n_b_2', values='mean_value', aggfunc=np.mean)

        # Create figure
        fig = go.Figure(data=go.Heatmap(
            x=z.columns.astype(str),
            y=z.index.astype(str),
            z=z.values,
            colorbar=dict(title=metric_name),
            colorscale='Viridis' if metric_name == 'Remembering' else 'Cividis'
        ))

        # Update layout
        fig.update_layout(
            title=f"{metric_name} Heatmap",
            xaxis_title="n_b_2",
            yaxis_title="n_b_1",
            height=400,
            width=500
        )
        
        fig.show()

    # Generate plots
    create_plot(remembering_df, 'Remembering')
    create_plot(forward_transfer_df, 'Forward Transfer')
    
create_heatmap(df)

In [59]:
def create_plot(df, nb_var, title):
    fig = go.Figure()
    df = df.sort_values(by=nb_var)

    # Filter the DataFrame for remembering and forward_transfer
    remembering_df = df[df['metric'] == 'remembering']
    forward_transfer_df = df[df['metric'] == 'forward_transfer']

    # Add the first y-axis (left) for remembering
    fig.add_trace(go.Scatter(
        x=remembering_df[nb_var],
        y=remembering_df['mean_value'],
        error_y=dict(
            type='data',
            array=remembering_df['std_value'],
            visible=True),
        name='Remembering',
        marker_color='blue'
    ))

    # Add the second y-axis (right) for forward_transfer
    fig.add_trace(go.Scatter(
        x=forward_transfer_df[nb_var],
        y=forward_transfer_df['mean_value'],
        error_y=dict(
            type='data',
            array=forward_transfer_df['std_value'],
            visible=True),
        name='Forward Transfer',
        marker_color='red',
        yaxis='y2'
    ))

    # Layout adjustments
    fig.update_layout(
        title=title,
        xaxis_title=nb_var,
        yaxis=dict(
            title='Remembering',
            titlefont=dict(color='blue'),
            tickfont=dict(color='blue')
        ),
        yaxis2=dict(
            title='Forward Transfer',
            titlefont=dict(color='red'),
            tickfont=dict(color='red'),
            overlaying='y',
            side='right'
        ),
        width=800,  # Optional: Adjust the width
        height=600   # Optional: Adjust the height
    )
    fig.update_xaxes(type='log')  # Set x-axis to log scale

    fig.show()

# Create and display the plots
create_plot(df, 'n_b_2', 'Performance Analysis by n_b_2')


In [60]:
create_plot(df, 'n_b_1', 'Performance Analysis by n_b_1')