In [None]:
import matplotlib as mpl
import trackhhl.toy.simple_generator as toy
import trackhhl.hamiltonians.simple_hamiltonian as hamiltonian
import numpy as np
import matplotlib.pyplot as plt
from itertools import pairwise
import dataclasses
import itertools
import trackhhl.event_model as em
mpl.use('Agg')

from matplotlib.ticker import FormatStrFormatter
import matplotlib.patches as mpatches

# Beautiful colors
default_color = "#478DCB"
grey_color = "#D0D0D0"
colors = ["#CF3D1E", "#F15623", "#F68B1F", "#FFC60B", "#DFCE21",
  "#BCD631", "#95C93D", "#48B85C", "#00833D", "#00B48D", 
  "#60C4B1", "#27C4F4", "#3E67B1", "#4251A3", "#59449B", 
  "#6E3F7C", "#6A246D", "#8A4873", "#EB0080", "#EF58A0", "#C05A89"]

# Some default parameters for the figure
scale = 4
plotscale = 1.

# # Dashed line for modules
# plt.plot(
#   [a for a in range(1, 256)],
#   [a for a in range(1, 256)],
#   '--',
#   color=grey_color
# )

ntox = {0:'X', 1:'Y', 2:'Z'}

In [None]:
@dataclasses.dataclass
class EventCollection:
    events: list[em.Event]
    
    def __post_init__(self):
        self.combined_modules = [module for event in self.events for module in event.modules]    
        self.combined_tracks = [track for event in self.events for track in event.tracks]
        self.combined_hits = [hit for event in self.events for hit in event.hits]
        self.relabel_track_ids()
    
    def relabel_track_ids(self):
        track_sum = 0
        for event_index, event in enumerate(self.events):
            if event_index == 0:
                track_sum += len(event.tracks)
                continue
            for hit in event.hits:
                if hit.track_id < track_sum:
                    hit.track_id += track_sum 

    def get_combined_event(self):
        return em.Event(self.combined_modules, self.combined_tracks, self.combined_hits)

In [None]:
N_MODULES = 7
#test bounds of LX and LY
LX = 10#float("+inf")
LY = 10#float("+inf")
Z_SPACING = 1.0

detector = toy.SimpleDetectorGeometry(
    module_id=list(range(N_MODULES)),
    lx=[LX]*N_MODULES,
    ly=[LY]*N_MODULES,
    z=[i+Z_SPACING for i in range(N_MODULES)])

generator = toy.SimpleGenerator(
    detector_geometry=detector,
    theta_max=np.pi/16)

generator1 = toy.SimpleGenerator(
    detector_geometry=detector,
    theta_max=np.pi/8)

defined_primary_vertex = [(0,0,0),(0,0,0)]
N_PARTICLES = 10

event = generator.generate_event(N_PARTICLES, n_events=1, defined_primary_vertex=[(0,0,0)])
event1 = generator1.generate_event(N_PARTICLES, n_events=1, defined_primary_vertex=[(0,0,2.3)])
event = [event]#, event1]

if type(event) != list:
    event = [event]

modules = []
for i in range(N_MODULES):
    combined_hits = []
    for e in event:
        combined_hits.extend(e.modules[i].hits)
    modules.append(combined_hits)
print(f"Number of modules: {len(modules)}")
for i, module_hits in enumerate(modules):
    print(f"Module {i}: {len(module_hits)} hits")

event_collection = EventCollection(event)
combined_event = event_collection.get_combined_event()
event = [combined_event]

In [None]:
%matplotlib inline
def print_events_2d(event_collection, modules, x=2, y=1, filename="visual.png", save_to_file=False):
    fig, ax = plt.subplots(figsize=(20*plotscale, 11*plotscale), dpi=200)

    combined_event = event_collection.get_combined_event()
    primary_vertices = [event.tracks[0].mc_info.primary_vertex for event in event_collection.events]

    # Find global y-axis limits
    all_last_layer_hits = [hit[y] for hit in combined_event.modules[-1].hits]
    y_min, y_max = min(all_last_layer_hits), max(all_last_layer_hits)
    y_range = y_max - y_min

    # Plot modules as vertical lines
    for module in combined_event.modules:
        z = module.z
        ax.axvline(x=z, color='black', linewidth=6)

    # Generate segments between adjacent layers, including PV to first layer
    segments = []
    # Connect PVs to first layer hits
    for pv, event in zip(primary_vertices, event_collection.events):
        for hit in event.modules[0].hits:
            segments.append((pv, hit))
    # Connect hits between layers
    for module1, module2 in pairwise(modules):
        for hit1 in module1:
            for hit2 in module2:
                segments.append((hit1, hit2))

    # Plot segments
    for hit1, hit2 in segments:
        if isinstance(hit1, tuple):  # This is a PV
            color = 'whitesmoke'
            linewidth = 2
            alpha = 0.5
        else:
            color = 'whitesmoke'
            linewidth = 2
            alpha = 0.5
        ax.plot([hit1[x], hit2[x]], [hit1[y], hit2[y]], color=color, linewidth=linewidth, alpha=alpha, zorder=2)

    # Plot hits
    for hit in combined_event.hits:
        ax.scatter(hit[x], hit[y], color='black', s=250, zorder=3, linewidth=0)  
        ax.scatter(hit[x], hit[y], color='white', s=100, zorder=4, linewidth=0)

    # Plot primary vertices
    #for i, pv in enumerate(primary_vertices):
    #    ax.scatter(pv[x], pv[y], color='black', s=400, marker='*', zorder=5)
    #    ax.text(pv[x], pv[y], f'PV{i+1}', fontsize=14, color='red', 
    #            verticalalignment='bottom', horizontalalignment='right')

    ax.set_xlabel('Z', fontsize=16)
    ax.set_ylabel('Y', fontsize=16)
    ax.set_xlim(min(0, min(pv[x] for pv in primary_vertices)) - 0.5, (N_MODULES-1) * Z_SPACING + Z_SPACING + 1)
    ax.set_ylim(y_min - 0.1 * y_range, y_max + 0.1 * y_range)  # Add 10% padding

    ax.set_axis_off()

    if save_to_file:
        plt.savefig(f'figures/{filename}.png', bbox_inches='tight', transparent=True)
    plt.tight_layout()
    plt.show()

# Call the function with the EventCollection
print_events_2d(event_collection, modules, filename='Eventstters', save_to_file=False)

In [None]:
def print_events_2d(event_collection, modules, x=2, y=1, filename="visual.png", save_to_file=False):
    fig, ax = plt.subplots(figsize=(25*plotscale, 15*plotscale), dpi=250)

    combined_event = event_collection.get_combined_event()
    primary_vertices = [event.tracks[0].mc_info.primary_vertex for event in event_collection.events]

    # Find global y-axis limits
    all_last_layer_hits = [hit[y] for hit in combined_event.modules[-1].hits]
    y_min, y_max = min(all_last_layer_hits), max(all_last_layer_hits)
    y_range = y_max - y_min

    # Plot modules as vertical lines
    for module in combined_event.modules:
        z = module.z
        ax.axvline(x=z, color='black', linewidth=6)

    # Generate segments between adjacent layers, including PV to first layer
    segments = []
    # Connect PVs to first layer hits
    #for pv, event in zip(primary_vertices, event_collection.events):
    #    for hit in event.modules[0].hits:
    #        segments.append((pv, hit))
    # Connect hits between layers
    for module1, module2 in pairwise(modules):
        for hit1 in module1:
            for hit2 in module2:
                segments.append((hit1, hit2))

    # Plot segments
    for hit1, hit2 in segments:
        if isinstance(hit1, tuple):  # This is a PV
            color = 'magenta' #if hit1 == primary_vertices[0] else 'limegreen'#colors[hit2.track_id % len(colors)]
            linewidth = 4
            alpha = 1
        elif hit1.track_id == hit2.track_id:
            color = 'magenta' if hit1.track_id not in (1,3,4) else 'limegreen'
            linewidth = 4
            alpha = 1 
        else:
            color =  'whitesmoke'
            linewidth = 2
            alpha = 0.0
        ax.plot([hit1[x], hit2[x]], [hit1[y], hit2[y]], color=color, linewidth=linewidth, alpha=alpha, zorder=2)

    # Plot hits
    for hit in combined_event.hits:
        ax.scatter(hit[x], hit[y], color='black', s=250, zorder=3, linewidth=0)  
        ax.scatter(hit[x], hit[y], color='white', s=100, zorder=4, linewidth=0)

    # Plot primary vertices
    #for i, pv in enumerate(primary_vertices):
    #    ax.scatter(pv[x], pv[y], color='black', s=300, marker='*', zorder=5)
    #    ax.text(pv[x], pv[y], f'PV{i+1}', fontsize=16, color='black', 
    #            verticalalignment='bottom', horizontalalignment='center')

    ax.set_xlabel('Z', fontsize=16)
    ax.set_ylabel('Y', fontsize=16)
    ax.set_xlim(min(0, min(pv[x] for pv in primary_vertices)) - 0.5, (N_MODULES-1) * Z_SPACING + Z_SPACING + 1)
    ax.set_ylim(y_min - 0.1 * y_range, y_max + 0.1 * y_range)  # Add 10% padding
    ax.set_facecolor('whitesmoke')
    ax.set_axis_off()

    if save_to_file:
        plt.savefig(f'figures/{filename}.png', bbox_inches='tight', transparent=True)
    plt.tight_layout()
    plt.show()

# Call the function with the EventCollection
print_events_2d(event_collection, modules, filename='Event_no_scatters', save_to_file=False)

In [None]:
N_MODULES = 5
#test bounds of LX and LY
LX = float("+inf")
LY = float("+inf")
Z_SPACING = 1.0
N_PARTICLES = 10

detector = toy.SimpleDetectorGeometry(
    module_id=list(range(N_MODULES)),
    lx=[LX]*N_MODULES,
    ly=[LY]*N_MODULES,
    z=[i+Z_SPACING for i in range(N_MODULES)])

generator = toy.SimpleGenerator(
    detector_geometry=detector,
    theta_max=np.pi/2)

N_PARTICLES = 10

event = generator.generate_event(N_PARTICLES, n_events=1, defined_primary_vertex=[(0,0,0)])

ham = hamiltonian.SimpleHamiltonian(
    epsilon=1e-3,
    gamma=2.0,
    delta=1.0)

ham.construct_hamiltonian(event=event)
b = ham.b
solution = ham.solve_classicaly()
T = 0.5
classical_solution = (solution > T).astype(int)


In [None]:
ones = []
total_size = []
matrix_size = []
num_particles = range(2,2000)

for N_PARTICLES in num_particles:
    if N_PARTICLES == 10:
        print(N_PARTICLES**2 * (N_MODULES-1))
        print(N_PARTICLES * (N_MODULES-1))
    N_MODULES = 5

    ones.append(N_PARTICLES * (N_MODULES-1) * np.log(N_PARTICLES * (N_MODULES-1)))
    total_size.append((N_PARTICLES**2 * (N_MODULES-1) )**2)#* np.log(N_PARTICLES**2 * (N_MODULES-1)))


In [None]:
print(ones)
print(total_size)
print(matrix_size)

In [None]:
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
import numpy as np

fig, ax = plt.subplots(figsize=(12, 6))

# Plot for ones (cones)
ax.plot(num_particles, ones, color='limegreen', label='1-Bit State Space')

# Plot for total_size
ax.plot(num_particles, total_size, color='forestgreen', label='Original State Space')

ax.set_xlabel('Number of Particles')
ax.set_ylabel('Sample Size')
ax.set_yscale('log')  # Set y-axis to logarithmic scale


ax.legend(loc='upper left')
ax.grid(True, which="both", ls="-", alpha=0.2)

plt.title('Samples needed for reconstruction VS number of particles')
plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
import numpy as np

fig, ax = plt.subplots(figsize=(12, 6))

# Plot for ones (cones)
ax.plot(total_size, ones, color='tab:blue', label='1-Bit HHL')

# Plot for total_size
ax.plot(total_size, total_size, color='tab:orange', label='Total Size')

ax.set_xlabel('Matrix Size')
ax.set_ylabel('Number of Samples')
ax.set_yscale('log')  # Set y-axis to logarithmic scale
ax.set_xscale('log')  # Set x-axis to logarithmic scale

# Format x-axis labels as powers of 10
def format_func(x, _):
    return f'$10^{{{int(np.log10(x))}}}$'

ax.xaxis.set_major_formatter(FuncFormatter(format_func))

ax.legend(loc='upper left')
ax.grid(True, which="both", ls="-", alpha=0.2)

plt.title('Growth of Ones (Cones) and Total Size vs Matrix Size')
plt.tight_layout()
plt.show()

In [None]:
N_MODULES = 5
#test bounds of LX and LY
LX = 10#float("+inf")
LY = 10#float("+inf")
Z_SPACING = 1.0

detector = toy.SimpleDetectorGeometry(
    module_id=list(range(N_MODULES)),
    lx=[LX]*N_MODULES,
    ly=[LY]*N_MODULES,
    z=[i+Z_SPACING for i in range(N_MODULES)])

generator = toy.SimpleGenerator(
    detector_geometry=detector,
    theta_max=np.pi/16)

#N_PARTICLES = 2

event = generator.generate_event(N_PARTICLES, n_events=1, defined_primary_vertex=[(0,0,0)])

ham = hamiltonian.SimpleHamiltonian(
    epsilon=1e-3,
    gamma=2.0,
    delta=1.0)

ham.construct_hamiltonian(event=event)
b = ham.b

import matplotlib.pyplot as plt
import numpy as np

# Assuming your matrix is stored in a variable called 'matrix'
matrix = ham.A.todense()

# Create a figure and axis

fig, ax = plt.subplots(figsize=(6, 5), dpi=40)

# Create a heatmap using imshow
im = ax.imshow(matrix, cmap='terrain', aspect='auto') #'RdBu_r'

# Add a colorbar
cbar = fig.colorbar(im, ax=ax)
#cbar.set_label('Value', rotation=270, labelpad=15)

# Set title and labels
#ax.set_title('Hamiltonian Matrix Visualization', fontsize=16)
#ax.set_xlabel('Column Index', fontsize=12)
#ax.set_ylabel('Row Index', fontsize=12)

# Remove ticks if the matrix is large
if matrix.shape[0] > 20:
    ax.set_xticks([])
    ax.set_yticks([])
else:
    ax.set_xticks(np.arange(matrix.shape[1]))
    ax.set_yticks(np.arange(matrix.shape[0]))
#plt.savefig(f'figures/matrix.png', bbox_inches='tight', transparent=True)
# Adjust layout and display
plt.tight_layout()
plt.show()

# Optional: Print matrix statistics
#print(f"Matrix shape: {matrix.shape}")
#print(f"Min value: {np.min(matrix):.2e}")
#print(f"Max value: {np.max(matrix):.2e}")
#print(f"Mean value: {np.mean(matrix):.2e}")
#print(f"Median value: {np.median(matrix):.2e}")