In [1]:
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import torch

import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import sys
import uproot
import random

import torch
import torch_geometric
from torch_geometric.loader import DataLoader
sys.path.append('/sps/t2k/cehrhardt/analysis_tools/tools')
from test_GCN.utils import train, Normalize
from dataset_from_processed import DatasetFromProcessed

In [2]:
graph_folder_path = '/sps/t2k/cehrhardt/dataset/graph_20_hitxyztc_t_xyzt_r12_types'

#Load the graph
dataset = DatasetFromProcessed(graph_folder_path = graph_folder_path, graph_file_names = ['data.pt'],  verbose = 1)

event = dataset[16]
hit_positions = event.x[:,0:3]
types = event.x[:,5]
vtx = event.y[0:3].numpy()




In [3]:
root_file_path='/sps/t2k/lperisse/Soft/wcsim/results/electron/wcsim112_UnifVtx_electron_HK_10MeV_WINDOWS.root'

root_file = uproot.open(root_file_path)
# print(root_file['THits'].keys())
direction_x = root_file['THits']['direction_x'].arrays()
direction_y = root_file['THits']['direction_y'].arrays()
direction_z = root_file['THits']['direction_z'].arrays()

direction_x_values = [item["direction_x"] for item in direction_x]
direction_x_array = np.array(direction_x_values, dtype=np.float32)

direction_y_values = [item["direction_y"] for item in direction_y]
direction_y_array = np.array(direction_y_values, dtype=np.float32)

direction_z_values = [item["direction_z"] for item in direction_z]
direction_z_array = np.array(direction_z_values, dtype=np.float32)


In [11]:
file_path = '/sps/t2k/lperisse/Data/list_PMT_positions.txt'
tank_diamet = 6480.0
tank_height = 6575.1
df = pd.read_csv(file_path, delim_whitespace=True, header=None, names=['Index', 'X', 'Y', 'Z'])

direction = np.array([direction_x_array[16], direction_y_array[16], direction_z_array[16]])

hit_df = pd.DataFrame(hit_positions.numpy(), columns=['X', 'Y', 'Z'])

# # Create the 3D scatter plot
fig = make_subplots(rows=1, cols=1, specs=[[{'type': 'scatter3d'}]])

# Convert types tensor to NumPy array
types_np = types.cpu().numpy().astype(bool)

# Separate signal and noise points using boolean indexing
signal = hit_df[types_np]
noise = hit_df[~types_np]  # Invert the boolean mask to get noise

# Plot all PMT positions in grey
fig.add_trace(go.Scatter3d(
    x=df['X'],
    y=df['Y'],
    z=df['Z'],
    mode='markers',
    marker=dict(
        size=5,
        color='grey',
        opacity=0.25
    ),
    name='All PMTs'
))

# Plot the hit PMT positions in different colors
fig.add_trace(go.Scatter3d(
    x=signal['X'],
    y=signal['Y'],
    z=signal['Z'],
    mode='markers',
    marker=dict(
        size=5,
        color='green',
        opacity=1
    ),
    name='Signal hits'
))

fig.add_trace(go.Scatter3d(
    x=noise['X'],
    y=noise['Y'],
    z=noise['Z'],
    mode='markers',
    marker=dict(
        size=5,
        color='red',
        opacity=0.75
    ),
    name='Noise hits'
))

# Extend the line by a factor in both directions
line_length_factor = 2750  # Adjust this value to make the line longer
# extended_start = np.array([
#     vtx[0] - direction[0] * line_length_factor,
#     vtx[1] - direction[1] * line_length_factor,
#     vtx[2] - direction[2] * line_length_factor
# ])
extended_end = np.array([
    vtx[0] + direction[0] * line_length_factor,
    vtx[1] + direction[1] * line_length_factor,
    vtx[2] + direction[2] * line_length_factor
])

# Extend the line in the opposite direction
extension_length = 5000
extended_start = vtx - extension_length * direction

# Add the additional and target points
fig.add_trace(go.Scatter3d(
    x=[vtx[0]],
    y=[vtx[1]],
    z=[vtx[2]],
    mode='markers+text',
    marker=dict(
        size=10,
        color='blue',
        opacity=1.0
    ),
    # text=['Vertex position'],
    # textposition='top center',
    name='Vertex position'
))

# Add the extended line
fig.add_trace(go.Scatter3d(
    x=[vtx[0], extended_end[0]],
    y=[vtx[1], extended_end[1]],
    z=[vtx[2], extended_end[2]],
    mode='lines',
    line=dict(
        color='blue',
        width=5
    ),
    name='Momentum direction'
))

# Add the extended line
fig.add_trace(go.Scatter3d(
    x=[extended_start[0], vtx[0]],
    y=[extended_start[1], vtx[1]],
    z=[extended_start[2], vtx[2]],
    mode='lines',
    line=dict(
        dash='dash',
        color='blue',
        width=5
    ),
    name='Origin track'
))


margin = 5000  # Adjust this value to add a margin

# Set labels and limits
fig.update_layout(
    scene=dict(
        xaxis=dict(
            range=[-tank_diamet / 2 - margin, tank_diamet / 2 + margin],
            showgrid=False,  # Disable grid lines for x axis
            showline=False,  # Hide the x axis line
            showticklabels=False,  # Hide x axis tick labels
            zeroline=False,  # Hide the x axis zero line
            title='',  # Remove x axis title
            showaxeslabels=False  # Hide x axis label
        ),
        yaxis=dict(
            range=[-tank_diamet / 2 - margin, tank_diamet / 2 + margin],
            showgrid=False,  # Disable grid lines for y axis
            showline=False,  # Hide the y axis line
            showticklabels=False,  # Hide y axis tick labels
            zeroline=False,  # Hide the y axis zero line
            title='',  # Remove y axis title
            showaxeslabels=False  # Hide y axis label
        ),
        zaxis=dict(
            range=[-tank_height / 2 - margin, tank_height / 2 + margin],
            showgrid=False,  # Disable grid lines for z axis
            showline=False,  # Hide the z axis line
            showticklabels=False,  # Hide z axis tick labels
            zeroline=False,  # Hide the z axis zero line
            title='',  # Remove z axis title
            showaxeslabels=False  # Hide z axis label
        ),
        bgcolor='white'  # Set background color to white
    ),
    title='',
    showlegend=False,
    width=1700,  # Set the width of the plot
    height=1500  # Set the height of the plot
)

fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )


# Show the plot
fig.show()