In [1]:
import sys
sys.path.append('/home/ilia/pythia8313-install/lib/')

import pickle
import gzip
import glob
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import Linear, Dropout, BatchNorm1d, LeakyReLU
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, global_mean_pool, GATConv, GINConv, global_add_pool
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from sklearn.metrics import mean_absolute_error, r2_score, mean_squared_error
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import numpy as np
from itertools import product
from pythia8 import Pythia
import fastjet as fj
from safe_root import ROOT
import pandas as pd

In [2]:
# Определяем структуру данных
ROOT.gInterpreter.Declare('''
struct EventData {
    Int_t event_id;
    Float_t Tar;
    std::vector<Float_t> px;
    std::vector<Float_t> py;
    std::vector<Float_t> pz;
    std::vector<Float_t> e;
    Int_t nParticles;
};
''')

True

In [None]:
pTMin_list = [1, 2, 3, 4, 5]
R_list = [0.4, 0.8]
R = 0.4
pTHatMin_list = [0, 2, 3, 4, 5]
#output_file = "jets_output.pkl.gz"
MAX_EVENTS = 100000
dataset = []

# Создаем DataFrame для хранения признаков
columns = [
    'px_jet1', 'py_jet1', 'pz_jet1', 'e_jet1', 
    'n_constituents_jet1', 'mean_pt_jet1',
    'Tar'
    ]
jet_features = pd.DataFrame(columns=columns)

for current_pTHatMin in pTHatMin_list:

    if current_pTHatMin == 0: MAX_EVENTS = 1000000
    if current_pTHatMin == 2: MAX_EVENTS = 25000
    if current_pTHatMin == 3: MAX_EVENTS = 20000
    if current_pTHatMin == 4: MAX_EVENTS = 15000
    if current_pTHatMin == 5: MAX_EVENTS = 12500
    
    file_path = f"pythia_pTHatMin_{current_pTHatMin}.root"

    print(f"Начало цикла = {current_pTHatMin}")
    
    root_file = ROOT.TFile(file_path)
    tree = root_file.Get("events")

    # Подготовка для чтения
    event_struct = ROOT.EventData()
    tree.SetBranchAddress("event", event_struct)
    
    for current_pTMin in pTMin_list:
        event_count = 0  # Счетчик для текущей комбинации параметров
        
        # Цикл по событиям
        for event_num in range(tree.GetEntries()):
            if event_count >= MAX_EVENTS:
                break  # Переходим к следующему pTMin
            
            tree.GetEntry(event_num)

            # Сбор частиц
            particles_in_event = [
                {'px': event_struct.px[i], 'py': event_struct.py[i], 
                 'pz': event_struct.pz[i], 'e': event_struct.e[i]}
                for i in range(event_struct.nParticles)
            ]

            # Реконструкция струй
            input_for_fastjet = [fj.PseudoJet(p['px'], p['py'], p['pz'], p['e']) 
                                for p in particles_in_event]
            jet_algorithm = fj.JetDefinition(fj.antikt_algorithm, R)
            clustering = fj.ClusterSequence(input_for_fastjet, jet_algorithm)
            sorted_jets = sorted(clustering.inclusive_jets(current_pTMin), key=lambda jet: -jet.pt())

            if len(sorted_jets) < 2: continue
            if len(sorted_jets[0].constituents()) < 2: continue 
            if len(sorted_jets[1].constituents()) < 2: continue

            pt_1 = sorted_jets[0].pt()
            pt_2 = sorted_jets[1].pt()

            if (pt_2 / pt_1) < 0.9: continue
    
            dPhiii = sorted_jets[0].phi_std() - sorted_jets[1].phi_std()
            if dPhiii > math.pi: 
                dPhiii = dPhiii - 2 * math.pi
            if dPhiii < -math.pi: 
                dPhiii = dPhiii + 2 * math.pi

            if abs(dPhiii) < 2.7: continue

            leading_jet = sorted_jets[0]
            constituents = leading_jet.constituents()
            
            mean_pt = sum(p.pt() for p in constituents) / len(constituents)

            # Добавляем признаки в DataFrame
            new_row = {
                'px_jet1': leading_jet.px(),
                'py_jet1': leading_jet.py(),
                'pz_jet1': leading_jet.pz(),
                'e_jet1': leading_jet.e(),
                'n_constituents_jet1': len(constituents),
                'mean_pt_jet1': mean_pt,
                'Tar': event_struct.Tar
            }

            if event_struct.Tar > 6: continue
            
            event_count += 1

            jet_features = pd.concat([jet_features, pd.DataFrame([new_row])], ignore_index=True)
            if event_count % 1000 == 0:
                print(f"pTHatMin={current_pTHatMin}, pTMin={current_pTMin}: {event_count}/{MAX_EVENTS}")

# Сохраняем данные в CSV
jet_features.to_csv(output_csv, index=False)

In [4]:
# Загрузка данных
data = pd.read_csv("jets_features.csv")

print(f"Всего событий: {len(data)}")
print(data.head())

Всего событий: 1197388
    px_jet1   py_jet1   pz_jet1    e_jet1  n_constituents_jet1  mean_pt_jet1  \
0 -0.571583 -0.998807  1.727804  2.545876                    3      0.386372   
1  0.019483  2.329153  4.876208  5.405641                    2      1.164696   
2 -0.850123 -1.404215  0.957455  1.958818                    3      0.552656   
3 -0.310189  1.822441 -0.889179  2.848866                    2      0.935092   
4  1.064952  1.112783 -2.209582  2.781367                    2      0.770140   

        Tar  
0  1.643249  
1  1.627437  
2  1.206322  
3  1.185830  
4  1.019930  


In [5]:
data

Unnamed: 0,px_jet1,py_jet1,pz_jet1,e_jet1,n_constituents_jet1,mean_pt_jet1,Tar
0,-0.571583,-0.998807,1.727804,2.545876,3,0.386372,1.643249
1,0.019483,2.329153,4.876208,5.405641,2,1.164696,1.627437
2,-0.850123,-1.404215,0.957455,1.958818,3,0.552656,1.206322
3,-0.310189,1.822441,-0.889179,2.848866,2,0.935092,1.185830
4,1.064952,1.112783,-2.209582,2.781367,2,0.770140,1.019930
...,...,...,...,...,...,...,...
1197383,5.226670,1.755417,-1.704252,5.804058,2,2.759522,5.062728
1197384,4.329836,3.734208,-2.708769,6.369870,2,2.877641,5.522048
1197385,-5.156638,1.975925,3.443621,6.562316,3,1.847294,5.127409
1197386,-3.717511,-4.138001,4.307997,7.195243,6,0.938452,5.778429
