In [1]:
import sys
sys.path.append('/home/ilia/pythia8313-install/lib/')

import pickle
import gzip
import glob
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import Linear, Dropout, BatchNorm1d, LeakyReLU
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, global_mean_pool, GATConv, GINConv, global_add_pool, SAGEConv, JumpingKnowledge, PNAConv
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error, r2_score
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import numpy as np
from itertools import product
from pythia8 import Pythia
import fastjet as fj
from safe_root import ROOT
import pandas as pd

In [2]:
# Определяем структуру данных
ROOT.gInterpreter.Declare('''
struct EventData {
    Int_t event_id;
    Float_t Tar;
    std::vector<Float_t> px;
    std::vector<Float_t> py;
    std::vector<Float_t> pz;
    std::vector<Float_t> e;
    Int_t nParticles;
};
''')

True

In [None]:
pTMin_list = [1, 2, 3, 4, 5]
R_list = [0.4, 0.8]
R = 0.4
pTHatMin_list = [0, 2, 3, 4, 5]
#output_file = "jets_output.pkl.gz"
MAX_EVENTS = 10000
dataset = []

for current_pTHatMin in pTHatMin_list:

    if current_pTHatMin == 0: MAX_EVENTS = 1000000
    if current_pTHatMin == 2: MAX_EVENTS = 25000
    if current_pTHatMin == 3: MAX_EVENTS = 20000
    if current_pTHatMin == 4: MAX_EVENTS = 15000
    if current_pTHatMin == 5: MAX_EVENTS = 12500
    
    file_path = f"pythia_pTHatMin_{current_pTHatMin}.root"

    print(f"Начало цикла = {current_pTHatMin}")
    
    root_file = ROOT.TFile(file_path)
    tree = root_file.Get("events")

    # Подготовка для чтения
    event_struct = ROOT.EventData()
    tree.SetBranchAddress("event", event_struct)

    for current_pTMin in pTMin_list:
        event_count = 0  # Счетчик для текущей комбинации параметров
        
        # Цикл по событиям
        for event_num in range(tree.GetEntries()):
            if event_count >= MAX_EVENTS:
                break  # Переходим к следующему pTMin

            tree.GetEntry(event_num)

            # Сбор частиц
            particles_in_event = [
                {'px': event_struct.px[i], 'py': event_struct.py[i], 
                 'pz': event_struct.pz[i], 'e': event_struct.e[i]}
                for i in range(event_struct.nParticles)
            ]

            # Реконструкция струй
            input_for_fastjet = [fj.PseudoJet(p['px'], p['py'], p['pz'], p['e']) 
                                for p in particles_in_event]
            jet_algorithm = fj.JetDefinition(fj.antikt_algorithm, R)
            clustering = fj.ClusterSequence(input_for_fastjet, jet_algorithm)
            sorted_jets = sorted(clustering.inclusive_jets(current_pTMin), key=lambda jet: -jet.pt())

            if len(sorted_jets) < 2: continue
            if len(sorted_jets[0].constituents()) < 2: continue 
            if len(sorted_jets[1].constituents()) < 2: continue

            pt_1 = sorted_jets[0].pt()
            pt_2 = sorted_jets[1].pt()

            if (pt_2 / pt_1) < 0.9: continue
    
            dPhiii = sorted_jets[0].phi_std() - sorted_jets[1].phi_std()
            if dPhiii > math.pi: 
                dPhiii = dPhiii - 2 * math.pi
            if dPhiii < -math.pi: 
                dPhiii = dPhiii + 2 * math.pi

            if abs(dPhiii) < 2.7: continue
        
            # Создание графа
            jets = [sorted_jets[0], sorted_jets[1]]
            x_list = []
            edge_index = []

            offset = 0


            # Цикл струя-граф
            for jet in jets:
                constituents = jet.constituents()
                num_nodes = len(constituents)

                # Добавим фичи узлов
                for p in constituents:
                    x_list.append([p.px(), p.py(), p.pz(), p.e()])

                # Добавим все связи между частицами (полносвязный граф)
                for i in range(num_nodes):
                    for j in range(num_nodes):
                        edge_index.append([offset + i, offset + j])

                offset += num_nodes
            
            x_tensor = torch.tensor(x_list, dtype=torch.float)
            edge_index_tensor = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
            y_tensor = torch.tensor([event_struct.Tar], dtype=torch.float)

            if event_struct.Tar > 6: continue
            
            data = Data(x=x_tensor, edge_index=edge_index_tensor, y=y_tensor)
            dataset.append(data)

            event_count += 1

            if event_count % 100 == 0:
                print(f"pTHatMin={current_pTHatMin}, pTMin={current_pTMin}: {event_count}/{MAX_EVENTS}")


In [None]:
# Сохранение dataset
with gzip.open('jet_dataset.pkl.gz', 'wb') as f:
    pickle.dump(dataset, f, protocol=pickle.HIGHEST_PROTOCOL)

In [4]:
# Загрузка dataset
with gzip.open('jet_dataset.pkl.gz', 'rb') as f:
    dataset = pickle.load(f)

In [5]:
print(f"\nОбщая статистика:")
print(f"Всего графов: {len(dataset)}")
print(f"Среднее количество узлов на граф: {np.mean([data.num_nodes for data in dataset]):.2f}")
print(f"Среднее количество ребер на граф: {np.mean([data.edge_index.shape[1] for data in dataset]):.2f}")


Общая статистика:
Всего графов: 587388
Среднее количество узлов на граф: 5.47
Среднее количество ребер на граф: 16.76
