In [1]:
# ssh
import os
import paramiko
from matplotlib import pyplot
from tqdm import tqdm
import stat

# torch
import torch
# plot
import pandas as pd
import seaborn as sns

# local dependencies
from utility.Control import load_config
from utility.DataLoader import get_data_loaders
from utility.EverythingNeeded import build_model, get_item_from_dataloader, convert_batch_to_df

In [None]:
### Only for training remotely

# SSH credentials and remote CSV file path
host = 'lxslc7.ihep.ac.cn'
username = os.environ.get('IHEP_USERNAME')
password = os.environ.get('IHEP_PASSWORD')

# Establish an SSH connection
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(host, username=username, password=password)
print(f'Successfully connected to {username}@{host}')

# Download to local
# Download the file using SFTP
sftp_client = ssh.open_sftp()
print(f'Open sftp to {username}@{host}')


def download(remote_path, local_path):
    # Get the file or folder attributes
    remote_attr = sftp_client.stat(remote_path)
    remote_file_size = remote_attr.st_size

    # Check if the path is a file or directory
    if stat.S_ISDIR(remote_attr.st_mode):
        # If it's a directory, create it locally if it doesn't exist
        if not os.path.exists(local_path):
            os.makedirs(local_path)

        # List the files and folders in the remote directory
        for file in sftp_client.listdir_attr(remote_path):
            download(os.path.join(remote_path, file.filename), os.path.join(local_path, file.filename))
    else:
        # If it's a file, download it
        os.makedirs(os.path.dirname(local_path), exist_ok=True)
        with tqdm(
                total=remote_file_size, unit='B', unit_scale=True, desc=os.path.basename(remote_path)
        ) as progress_bar:
            sftp_client.get(
                remote_path, local_path,
                callback=lambda current, _: progress_bar.update(current - progress_bar.n)
            )


# Define the remote and local paths

remote_project_base = '/hpcfs/cepc/higgs/frank/GNN/bepc2'
local_base = '/Users/avencast/PycharmProjects/trkgnn/workspace/bepc2'

paths = [
    (
        f'{remote_project_base}/mpnn.yaml',
        f'{local_base}/mpnn.yaml'
    ),
    (
        f'{remote_project_base}/output/summaries_0.csv',
        f'{local_base}/output/summaries_0.csv'
    ),
    (
        f'{remote_project_base}/output/model.checkpoints/model_checkpoint_063.pth.tar',
        f'{local_base}/output/model.checkpoints/model_checkpoint_063.pth.tar'
    ),
]

# Download each file and folder
for remote_path, local_path in paths:
    download(remote_path, local_path)

# Close the SFTP and SSH clients
sftp_client.close()
ssh.close()

In [2]:
local_base = '/Users/avencast/PycharmProjects/trkgnn/workspace/bepc2'
config_path = f'{local_base}/mpnn.yaml'
summary = f'{local_base}/output/summaries_0.csv'
model_path = f'{local_base}/output/model.checkpoints/model_checkpoint_063.pth.tar'

In [3]:
# load config
load_config(config_path)
# load model
model = build_model('cpu', False)
tar = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(tar['model'])

Parameters: 55233


<All keys matched successfully>

In [97]:
# load data
data_loader = get_data_loaders(
    input_dir=r'/Users/avencast/CLionProjects/darkshine-simulation/workspace/Tracker_GNN.root',
    chunk_size="500 MB",
    batch_size=1
)
loader, _ = next(data_loader)

# load batch for index 0
batch = get_item_from_dataloader(loader, 2)
# send it to model
batch_output = model(batch)
batch_output = torch.sigmoid(batch_output)

df = convert_batch_to_df(batch)
df['predict'] = pd.DataFrame(batch_output.detach().cpu().numpy(), columns=['predict'])

df_edge_value = pd.concat([df['edge'], df['y'], df['predict']], axis=1)
merged_df = df_edge_value.merge(df['node'], left_on='start', right_index=True)

df_edge = merged_df.merge(df['node'], left_on='end', right_index=True, suffixes=('_start', '_end'))
df_node = df['node']


def select_df_with_cut(df_in, cut_value):
    def classify(a, b, c):
        if a > 0.5 and b >= c:
            return 'Correct'
        elif a > 0.5 and b < c:
            return 'True Negative'
        elif a < 0.5 and b >= c:
            return 'False Positive'
        else:
            return 'ignore'

    df_in['category'] = df_in.apply(lambda row: classify(row['truth'], row['predict'], cut_value), axis=1)
    df_in = df_in.loc[df_in['category'] != 'ignore']
    return df_in

## Visualization



In [101]:

color_dict = {
    'Correct': '#965601',
    'True Negative': '#da4892',
    'False Positive': '#1c965c',
}
line_dict = {
    'Correct': 'solid',
    'True Negative': 'dash',
    'False Positive': 'dash',
}

In [106]:
import plotly.graph_objs as go
from plotly.subplots import make_subplots


def plot_xyz_plotly(node, edge, threshold=0.5):
    df_edges = select_df_with_cut(edge, threshold)
    cat_dict = {k: 0 for k in df_edge['category'].unique()}
    fig = make_subplots(
        rows=1, cols=2, shared_yaxes=True, horizontal_spacing=0.0
    )

    axis_attr = dict(
        showgrid=False,
        mirror=True,
        linecolor="#666666", gridcolor='#d9d9d9',
        zeroline=False,
    )

    for (x, y, xi) in zip(['x', 'y'], ['z', 'z'], [1, 2]):
        # Add scatter plot for node positions
        fig.add_trace(
            go.Scatter(
                x=node[x],
                y=node['z'],
                mode='markers',
                marker=dict(color='#965601', size=7),
                # name='Nodes',
                showlegend=False,
            ), row=1, col=xi,
        )
        for i in range(len(df_edges)):
            edge = df_edges.iloc[i]
            fig.add_trace(
                go.Scatter(
                    x=[edge[f'{x}_start'], edge[f'{x}_end']],
                    y=[edge[f'{y}_start'], edge[f'{y}_end']],
                    mode='lines',
                    line=dict(
                        width=1 if edge['category'] == "Correct" else 2,
                        color=color_dict[edge['category']],
                        dash=line_dict[edge['category']],
                    ),
                    legendgroup=edge['category'],
                    name=edge['category'] if cat_dict[edge['category']] == 0 else None,
                    showlegend=not cat_dict[edge['category']],
                ), row=1, col=xi,
            )
            cat_dict[edge['category']] = 1

        fig.update_xaxes(title_text=f'{x} [mm]', row=1, col=xi, **axis_attr)
        fig.update_yaxes(title_text=f'{y} [mm]' if xi == 1 else "", row=1, col=xi, **axis_attr)

    fig.update_layout(
        width=1200,
        height=700,
        autosize=False,
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)',
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.01,
            xanchor="right",
            x=0.99,
            # font=dict(size=16),
            traceorder='reversed',
        ),
    )

    fig.show()


# Usage example
plot_xyz_plotly(df_node, df_edge, threshold=0.1)
# plot_xyz_plotly(df_node, df_edge, threshold=0.2)


In [107]:
def plot_xyz_plotly_3d(node, edge, threshold=0.5):
    fig = go.Figure()

    # Add scatter plot for node positions
    fig.add_trace(
        go.Scatter3d(
            x=node['x'],
            y=node['y'],
            z=node['z'],
            mode='markers',
            marker=dict(color='blue', size=5),
            name='Nodes',
            showlegend=False
        )
    )

    df_edges = select_df_with_cut(edge, threshold)
    for i in range(len(df_edges)):
        edge = df_edges.iloc[i]
        fig.add_trace(
            go.Scatter3d(
                mode='lines',
                x=[edge['x_start'], edge['x_end']],
                y=[edge['y_start'], edge['y_end']],
                z=[edge['z_start'], edge['z_end']],
                line=dict(
                    width=1,
                    color=color_dict[edge['category']],
                    dash=line_dict[edge['category']],
                ),
                showlegend=False
            )
        )

    # Set the background color and grid color
    fig.update_layout(
        scene=dict(
            bgcolor='black',
            xaxis=dict(gridcolor='lightgrey', zerolinecolor='lightgrey', showspikes=False),
            yaxis=dict(gridcolor='lightgrey', zerolinecolor='lightgrey', showspikes=False),
            zaxis=dict(gridcolor='lightgrey', zerolinecolor='lightgrey', showspikes=False),
        ),
        margin=dict(l=0, r=0, t=0, b=0),
    )

    # Set the camera position and orientation
    camera = dict(
        eye=dict(x=1, y=1, z=-2),  # Position the camera on the -z axis
        up=dict(x=0, y=1, z=0),  # Set the "up" direction to be along the +y axis
        center=dict(x=0, y=0, z=0)  # Set the center of the scene
    )
    fig.update_layout(
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z',
            camera=camera,
        )
    )
    fig.update_layout(width=1000, height=1000)

    fig.show()


# Usage example
plot_xyz_plotly_3d(df_node, df_edge, threshold=0.8)
