In [1]:
import os
import json

import pandas as pd
import numpy as np
import plotly.graph_objects as go

## Строим графики полученных результатов

In [2]:
def load_json(path: str) -> pd.DataFrame:
    try:
        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)
    except:
        with open(path, "r") as f:
            data = json.load(f)
    

    result = {
        "trial": [],
        "mse_val": [],
        "mse_test": [],
        "hypothesis": []
    }
    for d in data:
        result["trial"].append(d.get("trial"))
        result["mse_val"].append(d.get("mse_val"))
        result["mse_test"].append(d.get("mse_test", d.get("mse_val")))
        result["hypothesis"].append(d.get("hypothesis"))


    return pd.DataFrame(result)

In [3]:
llm_nas_dir = "experiments\llm_opt"

In [4]:
etth_1_path = {
    "24": {
        "llm_grid_12": "experiments\llm_opt\etth1\pred_len=24;th_b=8192;hypot.json",
        "llm_grid_3": "experiments\llm_opt\etth1\ETTh1-grid3_pred_len=24;th_b=8192;hypot.json",
        "optuna": "experiments\optuna_results\etth1\pred_len=24;OPTUNA.json"
    },
    "48":{
        "llm_grid_12": "experiments\llm_opt\etth1\pred_len=48;th_b=8192;hypot.json",
        "llm_grid_3": "experiments\llm_opt\etth1\ETTh1-grid3_pred_len=48;th_b=8192;hypot.json",
        "optuna": "experiments\optuna_results\etth1\pred_len=48;OPTUNA.json"
    },
    "168":{
        "llm_grid_12": "experiments\llm_opt\etth1\pred_len=168;th_b=8192;hypot.json",
        "llm_grid_3": "experiments\llm_opt\etth1\ETTh1-grid3_pred_len=168;th_b=8192;hypot.json",
        "optuna": "experiments\optuna_results\etth1\pred_len=168;OPTUNA.json"
    },
}


etth_2_path = {
    "24": {
        "llm_grid_12": "experiments\llm_opt\etth2\ETTh2-pred_len=24;th_b=8192;hypot.json",
        "llm_grid_3": "experiments\llm_opt\etth2\ETTh2-grid3_pred_len=24;th_b=8192;hypot.json",
        "optuna": "experiments\optuna_results\etth2\ETTh2;pred_len=24;OPTUNA.json"
    },
    "48":{
        "llm_grid_12": "experiments\llm_opt\etth2\ETTh2-pred_len=48;th_b=8192;hypot.json",
        "llm_grid_3": "experiments\llm_opt\etth2\ETTh2-grid3_pred_len=48;th_b=8192;hypot.json",
        "optuna": "experiments\optuna_results\etth2\ETTh2;pred_len=48;OPTUNA.json"
    },
    "168":{
        "llm_grid_12": "experiments\llm_opt\etth2\ETTh2-pred_len=168;th_b=8192;hypot.json",
        "llm_grid_3": "experiments\llm_opt\etth2\ETTh2-grid3_pred_len=168;th_b=8192;hypot.json",
        "optuna": "experiments\optuna_results\etth2\ETTh2;pred_len=168;OPTUNA.json"
    },
}

In [5]:
data_etth_1 = {
    pred_len: {
        key: load_json(etth_1_path[pred_len][key])
        for key in etth_1_path[pred_len].keys()
    }
    for pred_len in etth_1_path.keys()
}

data_etth_2 = {
    pred_len: {
        key: load_json(etth_2_path[pred_len][key])
        for key in etth_2_path[pred_len].keys()
    }
    for pred_len in etth_2_path.keys()
}


### График ошибки `MSE` для датасета `ETTH1`

In [50]:
def plot_graph(data: dict, title: str):
    fig = go.Figure()
    min_mse = []
    for key in data.keys():
        min_mse.append(data[key]["mse_test"].min())
        if "optuna" in key:
            fig.add_trace(
                go.Scatter(
                    x=data[key]["trial"] + 1,
                    y=data[key]["mse_test"],
                    mode='lines+markers',
                    name=key
                )
            )
        else:
            fig.add_trace(
                go.Scatter(
                    x=data[key]["trial"],
                    y=data[key]["mse_test"],
                    mode='lines+markers',
                    name=key
                )
            )
        

    # Добавление горизонтальной минимальной MSE
    fig.add_hline(
        y=min(min_mse),
        line_dash="dash",
        # line_color="gray",
        line_color = "green",
        line_width=3,
        annotation_text=f"Минимальное значение MSE = {min(min_mse):.3f}",
        annotation_position="top right"
    )

    fig.update_layout(
        title=title,
        xaxis_title='Trail',
        yaxis_title='MSE',
        yaxis=dict(
            # range=[0, 1],
            tickmode='linear',
            dtick=0.05
        ),
        legend_title='Название эксперимента',
        hovermode='x unified',
        # width=1000,    # Ширина в пикселях
        height=600     # Высота в пикселях
    )
    return fig

In [52]:
for pred_len in data_etth_1.keys():
    fig = plot_graph(
        data = data_etth_1[pred_len],
        title = f"График MSE при поиске гиперпараметров для ETTH1 и pred_len: {pred_len}"
    )
    fig.show()

### График ошибки `MSE` для датасета `ETTH2`

In [53]:
for pred_len in data_etth_2.keys():
    fig = plot_graph(
        data = data_etth_2[pred_len],
        title = f"График MSE при поиске гиперпараметров для ETTH1 и pred_len: {pred_len}"
    )
    fig.show()