In [1]:
# ssh
import os
import paramiko
from tqdm import tqdm
import stat

# torch
import torch
# plot
import pandas as pd

# local dependencies
from utility.Control import load_config
from utility.DataLoader import get_data_loaders
from utility.EverythingNeeded import build_model, get_item_from_dataloader, convert_batch_to_df

In [4]:
### Only for training remotely

# SSH credentials and remote CSV file path
host = 'lxslc7.ihep.ac.cn'
username = os.environ.get('IHEP_USERNAME')
password = os.environ.get('IHEP_PASSWORD')

# Establish an SSH connection
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(host, username=username, password=password)
print(f'Successfully connected to {username}@{host}')

# Download to local
# Download the file using SFTP
sftp_client = ssh.open_sftp()
print(f'Open sftp to {username}@{host}')


def download(remote_path, local_path):
    # Get the file or folder attributes
    remote_attr = sftp_client.stat(remote_path)
    remote_file_size = remote_attr.st_size

    # Check if the path is a file or directory
    if stat.S_ISDIR(remote_attr.st_mode):
        # If it's a directory, create it locally if it doesn't exist
        if not os.path.exists(local_path):
            os.makedirs(local_path)

        # List the files and folders in the remote directory
        for file in sftp_client.listdir_attr(remote_path):
            download(os.path.join(remote_path, file.filename), os.path.join(local_path, file.filename))
    else:
        # If it's a file, download it
        os.makedirs(os.path.dirname(local_path), exist_ok=True)
        with tqdm(
                total=remote_file_size, unit='B', unit_scale=True, desc=os.path.basename(remote_path)
        ) as progress_bar:
            sftp_client.get(
                remote_path, local_path,
                callback=lambda current, _: progress_bar.update(current - progress_bar.n)
            )


# Define the remote and local paths

remote_project_base = '/hpcfs/cepc/higgs/frank/GNN/bepc.magnet.movingE'
local_base = '/Users/avencast/PycharmProjects/trkgnn/workspace/bepc.magnet.movingE'

paths = [
    (
        f'{remote_project_base}/mpnn_p.yaml',
        f'{local_base}/mpnn.yaml'
    ),
    (
        f'{remote_project_base}/output/summaries_0.csv',
        f'{local_base}/output/summaries_0.csv'
    ),
    (
        f'{remote_project_base}/output/model.checkpoints/model_checkpoint_099.pth.tar',
        f'{local_base}/output/model.checkpoints/model_checkpoint_099.pth.tar'
    ),
]

# Download each file and folder
for remote_path, local_path in paths:
    download(remote_path, local_path)

# Close the SFTP and SSH clients
sftp_client.close()
ssh.close()

Successfully connected to frank@lxslc7.ihep.ac.cn
Open sftp to frank@lxslc7.ihep.ac.cn


mpnn_p.yaml: 100%|██████████| 814/814 [00:01<00:00, 471B/s]
summaries_0.csv: 100%|██████████| 187k/187k [00:03<00:00, 46.7kB/s] 
model_checkpoint_099.pth.tar: 100%|██████████| 818k/818k [00:07<00:00, 103kB/s]  


In [2]:
local_base = '/Users/avencast/PycharmProjects/trkgnn/workspace/bepc2'
config_path = f'{local_base}/mpnn.yaml'
summary = f'{local_base}/output/summaries_0.csv'
model_path = f'{local_base}/output/model.checkpoints/model_checkpoint_063.pth.tar'

In [3]:
# load config
load_config(config_path)
# load model
model = build_model('cpu', False)
tar = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(tar['model'])

Parameters: 55233


<All keys matched successfully>

In [15]:
# load data
data_loader = get_data_loaders(
    input_dir=r'/Users/avencast/CLionProjects/darkshine-simulation/workspace/Tracker_GNN.root',
    chunk_size="500 MB",
    batch_size=50,
    shuffle=False,
)
loader = next(data_loader)

# load batch for index 0
batch = get_item_from_dataloader(loader, 5)
# send it to model
batch_output = model(batch)
batch_output = torch.sigmoid(batch_output)

df = convert_batch_to_df(batch)
df['predict'] = pd.DataFrame(batch_output.detach().cpu().numpy(), columns=['predict'])

df_edge_value = pd.concat([df['edge'], df['y'], df['predict']], axis=1)
merged_df = df_edge_value.merge(df['node'], left_on='start', right_index=True)

df_edge = merged_df.merge(df['node'], left_on='end', right_index=True, suffixes=('_start', '_end'))
df_node = df['node']

## Visualization



In [16]:
from scripts.plotting import plot_xyz_plotly

# Usage example
plot_xyz_plotly(df_node, df_edge, threshold=0.5)
# plot_xyz_plotly(df_node, df_edge, threshold=0.2)


In [11]:
from scripts.plotting import plot_xyz_plotly_3d

# Usage example
plot_xyz_plotly_3d(df_node, df_edge, threshold=0.5)


In [6]:
display(df_edge)

Unnamed: 0,start,end,truth,predict,x_start,y_start,z_start,x_end,y_end,z_end
0,0,21,0.0,5.715887e-07,-6.735,0.0,-607.672485,-5.295,0.0,-507.672485
20,1,21,0.0,2.499768e-09,6.045,0.0,-607.672485,-5.295,0.0,-507.672485
40,2,21,0.0,2.181598e-08,-8.715,0.0,-607.672485,-5.295,0.0,-507.672485
60,3,21,0.0,6.316399e-06,-2.865,0.0,-607.672485,-5.295,0.0,-507.672485
80,4,21,0.0,1.456391e-07,-1.695,0.0,-607.672485,-5.295,0.0,-507.672485
...,...,...,...,...,...,...,...,...,...,...
2379,117,141,0.0,1.876460e-11,10.305,0.0,-107.672501,5.925,0.0,-7.672500
2399,118,141,0.0,1.832707e-09,-2.835,0.0,-107.672501,5.925,0.0,-7.672500
2419,119,141,0.0,2.563927e-06,-1.935,0.0,-107.672501,5.925,0.0,-7.672500
2439,120,141,0.0,2.125535e-02,6.795,0.0,-107.672501,5.925,0.0,-7.672500
