In [None]:
# imports
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import wandb
import copy
import os, sys, pathlib
import yaml
import json
import requests

from deq2ff.plotting.style import (
    set_seaborn_style,
    PALETTE,
    entity,
    projectmd,
    projectoc,
    plotfolder,
    human_labels,
    set_style_after,
    myrc,
)
from deq2ff.plotting.dashboard import (
    get_runs_from_wandb,
    filter_best_runs,
    mark_sota,
    add_best_run,
    preprocess_df,
    print_table_acc_time,
)

nans = ["NaN", pd.NA, None, float("inf"), np.nan]

# OC20

In [None]:
# Config
project = projectoc

if project == projectmd:
    error_metric = "summary.test_f_mae"
    energy_metric = "summary.test_e_mae"
elif project == projectoc:
    # summary.train/forces_mae, val,
    error_metric = "summary.val/forces_mae"
    energy_metric = "summary.val/energy_mae"

In [None]:
# plot acc over depth

if project == projectoc:
    tag = "depthoc"
    fname = "depthocv1"
else:
    tag = "depthmd"
    fname = "depthmdv1"

dfdepth = get_runs_from_wandb(
    project=project,
    download_data=True,
    filters={
        "tags": tag,
        # "$and": [{"tags": "md17"}, {"tags": "eval"}],
        # "state": "finished",
        # "$or": [{"tags": "md17"}, {"tags": "main2"}, {"tags": "inference"}],
        "state": "finished",
        # "$or": [{"state": "finished"}, {"state": "crashed"}],
    },
    fname="depth" + fname,
)
print(f"Found {len(dfdepth)} of 8+2 runs")

dfdepth = preprocess_df(df=dfdepth, project=project, error_metric=error_metric)

y = error_metric
x = "config.model.num_layers"
hue = "Class"
data = dfdepth.copy()

# plot
fig, ax = plt.subplots()
sns.scatterplot(
    x=x,
    y=y,
    hue=hue,
    data=data,
    ax=ax,
    # palette=cdict
)
plt.legend()
plt.xlabel(human_labels(x))
plt.ylabel(human_labels(y))

# vertical xtick labels
# plt.xticks(rotation=90)

# horizontal grid
plt.grid(axis="y")

# make a separate plot for each target

# save
# plt.savefig(f"{plotfolder}/n_steps.png")
plt.show()

In [None]:
# [c for c in dfdepth.keys() if ("summary" in c) and ("gradients" not in c)]

In [None]:
# Accuracy over model parameters
#  “summary.ModelParameters”

y = error_metric
# x = "config.model.num_layers"
x = "summary.ModelParameters"
hue = "Class"
data = dfdepth.copy()

# plot
fig, ax = plt.subplots()
sns.scatterplot(
    x=x,
    y=y,
    hue=hue,
    data=data,
    ax=ax,
    # palette=cdict
)
plt.legend()
plt.xlabel(human_labels(x))
plt.ylabel(human_labels(y))

# vertical xtick labels
# plt.xticks(rotation=90)

# horizontal grid
plt.grid(axis="y")

# make a separate plot for each target

# save
# plt.savefig(f"{plotfolder}/n_steps.png")
plt.show()