In [1]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import statsmodels.api as sm
from statsmodels.sandbox.regression.predstd import wls_prediction_std
from statsmodels.tsa.stattools import grangercausalitytests

In [4]:
df = pd.read_csv("../output/loss.csv")
df.head()

Unnamed: 0,run_id,epoch,mesa,run_name,step,base
0,xotfsc10,0,12.497458,train-model,0,10.910637
1,xotfsc10,0,8.819035,train-model,1,7.821729
2,xotfsc10,0,10.557375,train-model,2,7.866491
3,xotfsc10,0,8.533783,train-model,3,8.117002
4,xotfsc10,0,12.024865,train-model,4,10.225608


In [5]:
loss1, loss2 = "base", "mesa"

In [6]:
# df[loss1] = df[loss1].apply(lambda x: np.exp(x))
df[loss2] = df[loss2].apply(lambda x: -x)

In [23]:
# Perform linear regression
X = sm.add_constant(df[loss1])  # Adds a constant term to the predictor
model = sm.OLS(df[loss2], X).fit()
predictions = model.get_prediction(X)
prediction_summary = predictions.summary_frame(alpha=0.05)  # 95% confidence interval

# Sort the data by loss1 for plotting
df_sorted = df.sort_values(by=loss1)
prediction_summary_sorted = prediction_summary.loc[df_sorted.index]

# Get regression coefficients and standard errors
coef = model.params
std_err = model.bse
intercept, slope = coef[0], coef[1]
intercept_err = std_err[0]
slope_err = std_err[1]

# Format the regression equation
regression_equation = f"y = ({slope:.2f}±{slope_err:.2f})x + ({intercept:.2f}±{intercept_err:.2f})"
print(regression_equation)

y = (-0.92±0.01)x + (-2.90±0.15)



Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`


Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`


Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`



In [31]:
# Create scatter plot
fig = go.Figure()
# Add scatter plot for actual data
fig.add_trace(go.Scatter(
    x=df[loss1],
    y=df[loss2],
    mode='markers',
    marker=dict(color='blue'),
    name='Batch'
))
# Add regression line
fig.add_trace(go.Scatter(
    x=df_sorted[loss1],
    y=prediction_summary_sorted['mean'],
    mode='lines',
    line=dict(color='black'),
    name=f'Regression'
))
# Add confidence interval
fig.add_trace(go.Scatter(
    x=np.concatenate([df_sorted[loss1], df_sorted[loss1][::-1]]),
    y=np.concatenate([prediction_summary_sorted['mean_ci_upper'], prediction_summary_sorted['mean_ci_lower'][::-1]]),
    fill='toself',
    fillcolor='rgba(169,169,169,0.5)',  # Gray color with transparency
    line=dict(color='rgba(255,255,255,0)'),
    hoverinfo="skip",
    showlegend=False,
    name='Confidence Interval'
))
# Update layout for fonts and axis scaling
fig.update_layout(
    xaxis_title=f"log({loss1})",
    yaxis_title=loss2,
    font=dict(
        family="Arial, sans-serif",
        size=14,
        color="black"
    ),
    xaxis=dict(
        range=[0, 30],  # Set the range for x-axis
        scaleanchor="y",
        scaleratio=1,
        constrain="domain",
        showgrid=False,
        zeroline=True,
        zerolinewidth=1,
        zerolinecolor='black',
        showline=True,
        linecolor='black',
        linewidth=1,
        mirror=True,
        tickmode='linear',
        dtick=5,
        gridcolor='lightgray',
        gridwidth=1,
        tick0=0,
        showticklabels=True,
        tickfont=dict(size=12),
        ticks='outside',
        ticklen=5,
        tickwidth=1,
        tickcolor='black',
        showspikes=True,
        spikethickness=1,
        spikedash='dot'
    ),
    yaxis=dict(
        range=[-35, 0],  # Set the range for y-axis
        scaleanchor="x",
        scaleratio=1,
        constrain="domain",
        showgrid=False,
        zeroline=True,
        zerolinewidth=1,
        zerolinecolor='black',
        showline=True,
        linecolor='black',
        linewidth=1,
        mirror=True,
        tickmode='linear',
        dtick=5,
        gridcolor='lightgray',
        gridwidth=1,
        tick0=0,
        showticklabels=True,
        tickfont=dict(size=12),
        ticks='outside',
        ticklen=5,
        tickwidth=1,
        tickcolor='black',
        showspikes=True,
        spikethickness=1,
        spikedash='dot'
    ),
    autosize=False,
    width=600,  # Set the width of the plot
    height=600,  # Set the height of the plot
    plot_bgcolor='rgba(0,0,0,0)',  # Transparent background
    paper_bgcolor='rgba(0,0,0,0)'  # Transparent background
)
# Show the plot
fig.show()

In [7]:
# Granger causality test
data = df[[loss2, loss1]].values
grangercausalitytests(data, 4)


Granger Causality
number of lags (no zero) 1
ssr based F test:         F=293.3069, p=0.0000  , df_denom=2612, df_num=1
ssr based chi2 test:   chi2=293.6438, p=0.0000  , df=1
likelihood ratio test: chi2=278.2957, p=0.0000  , df=1
parameter F test:         F=293.3069, p=0.0000  , df_denom=2612, df_num=1

Granger Causality
number of lags (no zero) 2
ssr based F test:         F=70.1831 , p=0.0000  , df_denom=2609, df_num=2
ssr based chi2 test:   chi2=140.6352, p=0.0000  , df=2
likelihood ratio test: chi2=136.9825, p=0.0000  , df=2
parameter F test:         F=70.1831 , p=0.0000  , df_denom=2609, df_num=2

Granger Causality
number of lags (no zero) 3
ssr based F test:         F=40.8096 , p=0.0000  , df_denom=2606, df_num=3
ssr based chi2 test:   chi2=122.7577, p=0.0000  , df=3
likelihood ratio test: chi2=119.9614, p=0.0000  , df=3
parameter F test:         F=40.8096 , p=0.0000  , df_denom=2606, df_num=3

Granger Causality
number of lags (no zero) 4
ssr based F test:         F=38.7479 , p=0.

{1: ({'ssr_ftest': (293.30694749499116, 2.128440126144512e-62, 2612.0, 1),
   'ssr_chi2test': (293.64382377465614, 7.99128964853475e-66, 1),
   'lrtest': (278.29571377648244, 1.7660118509201035e-62, 1),
   'params_ftest': (293.3069474949914, 2.1284401261441486e-62, 2612.0, 1.0)},
  [<statsmodels.regression.linear_model.RegressionResultsWrapper at 0x7f0b2cb78b10>,
   <statsmodels.regression.linear_model.RegressionResultsWrapper at 0x7f0b2cb7ead0>,
   array([[0., 1., 0.]])]),
 2: ({'ssr_ftest': (70.18309079153398, 2.0489308003843166e-30, 2609.0, 2),
   'ssr_chi2test': (140.63518538065915, 2.893726715920439e-31, 2),
   'lrtest': (136.9824863847134, 1.797343241201568e-30, 2),
   'params_ftest': (70.18309079153397, 2.0489308003843166e-30, 2609.0, 2.0)},
  [<statsmodels.regression.linear_model.RegressionResultsWrapper at 0x7f0b2cb76f90>,
   <statsmodels.regression.linear_model.RegressionResultsWrapper at 0x7f0b2cbb1690>,
   array([[0., 0., 1., 0., 0.],
          [0., 0., 0., 1., 0.]])]),
 3:

In [32]:
# Save the plot as a PDF
fig.write_image("../output/loss.pdf", format="pdf")