In [1]:
import numpy as np
import h5py
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from scipy.interpolate import griddata
from pyinstrument import Profiler

In [2]:
dataset_count = 19        
dataset_prefix = '../data/'

In [3]:
# Read the geometry of detector
with h5py.File(f'{dataset_prefix}geo.h5', 'r') as geo_file:
    ChannelID_ = geo_file['Geometry']['ChannelID']
    theta_ = geo_file['Geometry']['theta']
    phi_ = geo_file['Geometry']['phi'] - 180 # let phi in [-180, 180]
    geo_dict = {key: (val1, val2) for key, val1, val2 in zip(ChannelID_, theta_, phi_)}

In [4]:
# Count events in datasets
event_count = np.zeros(dataset_count, dtype=int)

for data_id in range(dataset_count):
    with h5py.File(f'{dataset_prefix}{16930+data_id}.h5', 'r') as data_file: 
        event_count[data_id] = data_file['ParticleTruth'].shape[0]

event_total = event_count.sum()

In [5]:
Ek_train = np.zeros(event_total)
Evis_train = np.zeros(event_total)
EventInfo = []
EventImage = []

# vecotrize the lookup function of geo_dict
event_index = np.insert(np.cumsum(event_count), 0, 0)
vectorized_lookup = np.vectorize(geo_dict.get)

for data_id in range(dataset_count): # tqdm把iterator包起来，就可以实现进度条
    with h5py.File(f'{dataset_prefix}{16930+data_id}.h5', 'r') as data_file:
        
        Ek_train[event_index[data_id]:event_index[data_id+1]] = data_file['ParticleTruth']['Ek'][...]
        Evis_train[event_index[data_id]:event_index[data_id+1]] = data_file['ParticleTruth']['Evis'][...]
        EventIDs_ = data_file['PETruth']['EventID'][...]
        ChannelID_ = data_file['PETruth']['ChannelID'][...]
        PETime_ = data_file['PETruth']['PETime'][...]
        
        for event_id in tqdm(range(event_count[data_id])):
            profiler = Profiler()
            profiler.start()
            
            indices = np.where(EventIDs_ == event_id)
            geo_info = vectorized_lookup(ChannelID_[indices])
            time_info = PETime_[indices]
            event_info = np.column_stack((geo_info[0], geo_info[1] * np.sin(geo_info[0] / 180 * np.pi), time_info))
            EventInfo.append(event_info)
            
            # convert data to  DataFrame
            df = pd.DataFrame(event_info, columns=['Latitude', 'Longitude', 'Value'])

            # calculate mean and count
            grouped = df.groupby(['Latitude', 'Longitude']).agg(['mean', 'count'])
            grouped.columns = ['_'.join(col).strip() for col in grouped.columns.values]

            a = grouped[['Value_mean']].reset_index().to_numpy()
            b = grouped[['Value_count']].reset_index().to_numpy()
            
            # Average arrival time
            # plt.figure(figsize=(8, 6))
            # plt.scatter(a[:, 1], a[:, 0], c=a[:, 2], cmap='jet', s=10)
            # plt.colorbar(label='Average arrival time')
            # plt.xlabel('Longitude (0-360)')
            # plt.ylabel('Latitude (0-180)')
            # plt.title('Average arrival time')
            # plt.show()

            def green_func(X, Y, scatter):
                R = np.sqrt((X - scatter[:, 1, np.newaxis, np.newaxis]) ** 2 + (Y - scatter[:, 0, np.newaxis, np.newaxis]) ** 2)
                V = scatter[:, 2, np.newaxis, np.newaxis] * np.exp(-R ** 2)
                return np.sum(V, axis=0)


            x = np.linspace(-180, 180, 128)
            y = np.linspace(0, 180, 128)
            xx, yy = np.meshgrid(x, y)

            zz1 = green_func(xx, yy, a)
            zz2 = green_func(xx, yy, b)
            EventImage.append((zz1, zz2))
            
            # plt.figure(figsize=(8, 6))
            # plt.imshow(zz1, origin='lower', cmap='jet')
            # plt.colorbar()
            # plt.title("Interpolated Arrival Time")
            # plt.show()
            
            # # Arrival count
            # plt.figure(figsize=(8, 6))
            # plt.scatter(b[:, 1], b[:, 0], c=b[:, 2], cmap='jet', s=10)
            # plt.colorbar(label='Arrival count')
            # plt.xlabel('Longitude (0-360)')
            # plt.ylabel('Latitude (0-180)')
            # plt.title('Arrival count')
            # plt.show()
            
            # plt.figure(figsize=(8, 6))
            # plt.imshow(zz2, origin='lower', cmap='jet')
            # plt.colorbar()
            # plt.title("Interpolated Arrival Count")
            # plt.show()
            profiler.stop()
            print(profiler.output_text(unicode=True, color=True))
            




  0%|          | 0/10000 [00:00<?, ?it/s]


  _     ._   __/__   _ _  _  _ _/_   Recorded: 00:29:12  Samples:  45
 /_//_/// /_\ / //_// / //_'/ //     Duration: 3.679     CPU time: 3.690
/   _/                      v4.6.1

Program: /data/xiazy/miniconda3/lib/python3.11/site-packages/ipykernel_launcher.py --f=/home/pinn-benchmark/.local/share/jupyter/runtime/kernel-v2-39925zqz4nnuK0xdu.json

[31m3.678[0m [48;5;24m[38;5;15m<module>[0m  [2m../../../../tmp/ipykernel_53459/947811341.py:1[0m
├─ [31m3.486[0m [48;5;24m[38;5;15mgreen_func[0m  [2m../../../../tmp/ipykernel_53459/947811341.py:50[0m
│  ├─ [31m3.390[0m [self][0m  [2m../../../../tmp/ipykernel_53459/947811341.py[0m
│  └─ [92m[2m0.096[0m sum[0m  [2mnumpy/core/fromnumeric.py:2177[0m
│        [3 frames hidden]  [2mnumpy, <built-in>[0m
└─ [92m[2m0.150[0m [self][0m  [2m../../../../tmp/ipykernel_53459/947811341.py[0m




  0%|          | 0/10000 [00:20<?, ?it/s]


KeyboardInterrupt: 