In [1]:
import plotly.graph_objects as go
import plotly.io as pio
import numpy as np
import uproot
import os
import kaleido

In [2]:
import importlib

# List of packages to import
packages_to_import = [
    ("plotly", "plotly.graph_objects", "go"),
    ("plotly", "plotly.io", "pio"),
    ("numpy", "numpy", "np"),
    ("uproot", "uproot", "uproot"),
    ("os", "os", "os"),
    ("kaleido", "kaleido", "kaleido")
]

# Loop through the packages and check if they are imported
for package_name, module_name, alias in packages_to_import:
    try:
        module = importlib.import_module(module_name)
        print(f"{alias} is already imported.")
    except ImportError:
        print(f"{alias} is not imported. Installing...")
        !pip install {package_name}

go is already imported.
pio is already imported.
np is already imported.
uproot is already imported.
os is already imported.
kaleido is already imported.


In [55]:
def lego3D(x_position, y_position, counts, title='', x_title='x position [cm]', y_title='y position [cm]', z_title='Counts',
           width=900, height=900, thickness=0.4, colorscale='jet', 
           colorbar_tick_size = 12, tick_size = 10, colorbar_lenght = 1, **kwargs):
    '''
    Create a 3D LEGO bar plot using Plotly (similar to ROOT)
    
    :param x_position: Array of x positions for the detectors.
    :param y_position: Array of y positions for the detectors.
    :param counts: Array of counts for each detector.
    :param title: Chart title.
    :param x_title: X-axis title.
    :param y_title: Y-axis title.
    :param z_title: Z-axis title.
    :param width: Chart width in pixels.
    :param height: Chart height in pixels.
    :param thickness: Bar thickness (0; 1).
    :param colorscale: Color scale for the plot.
    :param **kwargs: Additional keyword arguments for Plotly Mesh3d trace.
    :return: 3D LEGO figure.
    '''
    
    # Create an empty figure to add traces to
    fig = go.Figure()

    # Loop through each detector position and count
    for iz, count in enumerate(counts):
        x_pos, y_pos = x_position[iz], y_position[iz]
        x_min, y_min = x_pos - thickness, y_pos - thickness
        x_max, y_max = x_pos + thickness, y_pos + thickness

        # Add a Mesh3d trace for each pixel
        fig.add_trace(go.Mesh3d(
            x=[x_min, x_min, x_max, x_max, x_min, x_min, x_max, x_max],
            y=[y_min, y_max, y_max, y_min, y_min, y_max, y_max, y_min],
            z=[-0.1, -0.1, -0.1, -0.1, count, count, count, count],
            alphahull=0,
            intensity=[0, 0, 0, 0, count, count, count, count],
            coloraxis='coloraxis',
            hoverinfo='text',
            contour = dict(color = 'blue', show = True, width = 10),
            hovertext=f'x: {x_pos}<br>y: {y_pos}<br>z: {count}', **kwargs))

        tri_vertices = [
            [[x_min, y_min, -0.1], [x_min, y_max, -0.1], [x_max, y_max, -0.1]],
            [[x_min, y_min, count], [x_min, y_max, count], [x_max, y_max, count]],
            [[x_max, y_max, -0.1], [x_max, y_min, -0.1], [x_min, y_min, -0.1]],
            [[x_max, y_max, count], [x_max, y_min, count], [x_min, y_min, count]]
        ]

        Xe = [T[k % 3][0] for T in tri_vertices for k in range(4)] + [None]
        Ye = [T[k % 3][1] for T in tri_vertices for k in range(4)] + [None]
        Ze = [T[k % 3][2] for T in tri_vertices for k in range(4)] + [None]

        lines = go.Scatter3d(
            x=Xe,
            y=Ye,
            z=Ze,
            mode='lines',
            line=dict(color='rgb(0,0,0)', width=1.5))
        fig.add_trace(lines)
    camera_params = dict(
    eye=dict(x=1.5, y=1.5, z=1.5),
    center=dict(x=0, y=0, z=0),
    up=dict(x=0, y=0, z=1))
    fig.update_layout(scene_camera=camera_params)
    # Update the layout of the figure
    fig.update_layout(
        width=width, height=height,
        title=title, title_x=0.5,
        scene=dict(
            xaxis=dict(title=x_title, tickvals=x_position, 
                       ticktext=[f'{x:.1f}' for x in x_position], tickfont=dict(size=tick_size)),
            yaxis=dict(title=y_title, tickvals=y_position, 
                       ticktext=[f'{y:.1f}' for y in y_position], tickfont=dict(size=tick_size)),
            zaxis=dict(title=z_title,tickfont=dict(size=tick_size))),
        coloraxis=dict(
            colorscale=colorscale,
            colorbar=dict(
                title=dict(
                    text=z_title,
                    side='right'),
                xanchor='right', x=1.0,
                xpad=0,
                ticks='inside',
                len=colorbar_lenght,
                tickfont=dict(size=colorbar_tick_size))),
        legend=dict(
            yanchor='top', y=1.0,
            xanchor='left', x=0.0,
            bgcolor='rgba(0, 0, 0, 0)',
            itemclick=False,
            itemdoubleclick=False),
        showlegend=False)
 
    return fig

In [3]:
def read_data(filename = 'simTe7MeV_pgData.csv'):
    file_extension = os.path.splitext(filename)[-1].lower()
    if file_extension == '.csv':
        data = np.loadtxt(filename, delimiter=",", skiprows=1)
        x = data[:,0]
        y = data[:,1]
        z = data[:,2]
    elif file_extension == '.root':
        # Open the ROOT file
        root_file = uproot.open(filename)
        # Get the first key (tree name) in the ROOT file
        tree_name = root_file.keys()[0]
        tree = root_file[tree_name]
        # Assuming you have branches named 'x', 'y', and 'z' in your ROOT tree
        x = tree["x"].array()
        y = tree["y"].array()
        z = tree["z"].array()
    else:
        print('Wrong filename')
    return x,y,z

In [63]:
x,y,z = read_data(filename = 'simTe7MeV_pgData.csv')
fig = lego3D(x,y,z,colorscale='jet', opacity=0.9, flatshading=True, thickness = 0.1,
             colorbar_tick_size=14, tick_size=14, colorbar_lenght=0.5)
pio.write_image(fig, 'legoplot.png',scale=6, width=1080, height=1080) #for saving
fig.show()