In [1]:
import time
import os
import paramiko
from io import StringIO
import time
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from IPython.display import clear_output
import datetime

In [2]:
# SSH credentials and remote CSV file path
host = 'lxslc7.ihep.ac.cn'
username = os.environ.get('IHEP_USERNAME')
password = os.environ.get('IHEP_PASSWORD')
remote_csv_file = '/hpcfs/cepc/higgs/frank/GNN/bepc2/output/summaries_0.csv'
# remote_csv_file = '/hpcfs/cepc/higgs/frank/GNN/quick_test/output/summaries_0.csv'

# Establish an SSH connection
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(host, username=username, password=password)
print(f'Successfully connected to {username}@{host}')

Successfully connected to frank@lxslc7.ihep.ac.cn


In [3]:
def update_csv():
    sftp = ssh.open_sftp()
    with sftp.open(remote_csv_file, 'r') as file:
        remote_csv_data = file.read().decode('utf-8')
        timestamp = sftp.stat(remote_csv_file).st_mtime

        dt_object = datetime.datetime.fromtimestamp(timestamp)

        # Parse and display the CSV data
        df = pd.read_csv(StringIO(remote_csv_data))
        return df, f"{dt_object.year} - {dt_object.month} - {dt_object.day}   {dt_object.hour} : {dt_object.minute}"


def update_fig(df, t):
    y_max, y_min = df.max()[['train_loss', 'valid_loss']].max(), df.min()[['train_loss', 'valid_loss']].min()

    df_new = df[['epoch', 'train_loss', 'valid_loss']]. \
        groupby('epoch').transform("mean").drop_duplicates(keep='last', subset=['train_loss'])

    fig.data = []
    fig.add_trace(
        go.Scatter(x=df.index, y=df['train_loss'], mode='lines+markers', name="itr: train", line=dict(dash='dot')))
    fig.add_trace(
        go.Scatter(x=df.index, y=df['valid_loss'], mode='lines+markers', name="itr: valid", line=dict(dash='dot')))
    fig.add_trace(
        go.Scatter(
            x=df.index,
            y=-np.log10(1 - df['valid_acc']),
            # y=df['valid_acc'],
            mode='lines+markers', name="itr: accuracy",
            line=dict(dash='dot', color="#11ADF0"), yaxis='y2'
        ))

    fig.add_trace(go.Scatter(x=df_new.index, y=df_new['train_loss'], mode='lines+markers', name="epoch: train"))
    fig.add_trace(go.Scatter(x=df_new.index, y=df_new['valid_loss'], mode='lines+markers', name="epoch: valid"))

    for lr in df.drop_duplicates(keep="first", subset=['lr'])['lr'].items():
        print(lr)
        fig.add_vline(x=lr[0], line_width=2, line_dash="dash", line_color="grey")
        fig.add_annotation(
            text=f'$\eta = {lr[1]}$',
            x=lr[0]+0.5, y=np.log10(y_min) * 1.08,  # Set the position using numeric values
            xanchor='left', yanchor='bottom',
            showarrow=False,
            font=dict(size=14, color='grey')
        )

    fig.update_layout(
        title=f"{t} --> Epoch: {df['epoch'].iloc[-1]}, Iteration: {df['itr'].iloc[-1]}",
        width=1200,
        height=700,
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)',
        xaxis=dict(
            title_text="Total iteration",
            showgrid=False,
            mirror=True,
            showline=True,
            zeroline=False,
            linewidth=2,
            linecolor='#666666', gridcolor='#d9d9d9',
            domain=[0., 0.97],
        ),
        yaxis=dict(
            title=r"Loss",
            titlefont=dict(color="#d62728"),
            tickfont=dict(color="#d62728"),
            showgrid=False,
            linecolor="#d62728", gridcolor='#d9d9d9',
            zeroline=False,
            type="log",
            range=[np.log10(y_min) * 1.1, np.log10(y_max) * 1.5],
        ),
        yaxis2=dict(
            # title=r"$\textrm{95% CI }{\Large \kappa_{2V}}\textrm{ Interval}$",
            title=r"$-log(\epsilon_{error})$",
            titlefont=dict(color="#11ADF0"),
            tickfont=dict(color="#11ADF0"),
            anchor="x",
            overlaying="y",
            side="right",
            showgrid=False,
            linecolor="#11ADF0", gridcolor='#d9d9d9',
            zeroline=False,
            # type="log",
            # range=[3.0, 3.7],
        ),
    )

In [12]:
fig = go.FigureWidget()

df, t = update_csv()
update_fig(df, t)
fig.show()
# while True:
#     df, t = update_csv()
#     update_fig(df, t)
#     time.sleep(30)
#     clear_output(wait=False)
#     fig.show()

(0, 0.001)


In [13]:
display(df)

Unnamed: 0,lr,train_loss,l1,l2,train_batches,itr,epoch,valid_loss,valid_acc,valid_batches,valid_sum_total,valid_dp_mean,valid_dp_std
0,0.001,1.428173,3731.38243,27.572549,20,0,0,1.098026,0.955159,9,1406248,-1.012211,0.11087
1,0.001,0.948382,3604.087541,27.26695,20,1,0,0.904707,0.104004,9,1400519,-1.015072,0.10997
2,0.001,0.905296,3506.022962,27.030411,20,2,0,0.900899,0.249969,9,1412958,-1.015553,0.112022
3,0.001,0.901154,3433.910396,26.84352,20,3,0,0.90033,0.173045,9,1402228,-1.015366,0.111032
4,0.001,0.898263,3381.234774,26.691412,20,4,0,0.890242,0.372573,9,1398106,-1.017955,0.109635
5,0.001,0.896445,3341.795349,26.562989,20,5,0,0.877623,0.399793,9,1405196,-1.016072,0.11323
6,0.001,0.864564,3315.627436,26.464575,20,6,0,0.851184,0.327109,9,1419419,-1.018943,0.110502
7,0.001,0.833757,3307.726991,26.415708,20,7,0,0.80826,0.445818,9,1408398,-1.023152,0.109927
8,0.001,0.773811,3327.917135,26.436433,20,0,1,0.707763,0.699749,9,1406248,-1.024471,0.112818
9,0.001,0.639016,3390.835811,26.534117,20,1,1,0.586167,0.770731,9,1400519,-1.01043,0.111965
