In [1]:
import numpy as np
import random
import torch
import os

torch.manual_seed(42)
torch.cuda.manual_seed(42)
np.random.seed(42)
random.seed(42)

# torch.use_deterministic_algorithms(True)
# os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [2]:
# import importlib
# import Flash
# import Flash.inference
# import Flash.predictor
# import Flash.model.model

# importlib.reload(Flash)
# importlib.reload(Flash.inference)
# importlib.reload(Flash.model.model)
# importlib.reload(Flash.predictor)

from Flash.config import ModelTestConfig
from Flash.fn_utils import create_tokenizer, decode_sequence
from Flash.data import Data
from Flash.predictor import Predictor, sequence_accuracy

import pandas as pd
import torch

In [3]:
config = ModelTestConfig(
    model_name="vanilla_64x2_2-2-3-4",
    root_dir="/pscratch/sd/r/ritesh11/SYMBA_arxiv/models/QCD",
    data_dir="/pscratch/sd/r/ritesh11/SYMBA_arxiv/data/EW/EW_2-2-3_termwise",
    test_batch_size=64,
    device="cuda",
    embedding_size=512,
    nhead=8,
    num_encoder_layers=3,
    num_decoder_layers= 3, 
    use_torch_mha =False,
    kan_ff_dims=[],  # Not using KAN here
    is_kan=False,
    is_pre_norm=False, 
    ff_dims=4096,
    dropout=0,
    src_max_len=3110,
    tgt_max_len=2008,
    is_termwise=True,
    seed=42,
    truncate=False,
    debug=False,
    to_replace=False,
    is_beamsearch=False,
    beam_width=1,
    index_pool_size=100
)

In [4]:
df_train = pd.read_csv(config.data_dir + "train.csv", low_memory=False)
df_test = pd.read_csv(config.data_dir + "test.csv")
df_valid = pd.read_csv(config.data_dir + "valid.csv")

# df_train = df_train.iloc[:100]
df = pd.concat([df_train, df_valid, df_test]).reset_index(drop=True)

tokenizer, src_vocab, tgt_vocab = create_tokenizer(df,config)
del df, df_valid

Processing source vocab: 100%|██████████| 681599/681599 [02:28<00:00, 4598.36it/s]
Processing target vocab: 100%|██████████| 681599/681599 [01:29<00:00, 7622.53it/s] 


In [5]:
config.src_voc_size = len(src_vocab)
config.tgt_voc_size = len(tgt_vocab)

In [6]:
config.src_voc_size

157

In [8]:
df_test = pd.read_csv(config.data_dir + "test.csv")

In [9]:
from tqdm import tqdm

# Enable tqdm for pandas apply
tqdm.pandas()

df_test["length"] = df_test["sqamp"].progress_apply(lambda x: len(tokenizer.tgt_tokenize(x)))

100%|██████████| 81079/81079 [00:10<00:00, 8047.48it/s]


In [10]:
df_sorted = df_test.sort_values("length").reset_index(drop=True)

In [10]:
test_ds = Data(df_test,tokenizer,config,src_vocab,tgt_vocab)

In [11]:
config.test_batch_size = 64
config.test_size = len(test_ds)

In [12]:
with torch.no_grad():
    res = sequence_accuracy(config, test_ds, tgt_vocab, return_incorrect=True)

2025-11-13 01:48:50 - INFO - Flash.fn_utils - Weights initialized
2025-11-13 01:48:50 - INFO - Flash.fn_utils - Model(
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-2): 3 x TransformerEncoderLayer(
          (self_attn): FlashMHA(
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (linear1): Linear(in_features=512, out_features=4096, bias=True)
          (linear2): Linear(in_features=4096, out_features=512, bias=True)
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0, inplace=False)
          (dropout2): Dropout(p=0, inplace=False)
 

Using epoch 34 model for predictions.


Seq_Acc_Cal:   1%|          | 2/194 [00:36<59:08, 18.48s/it, seq_accuracy=0.797]


KeyboardInterrupt: 

In [14]:
inc_idxs = res[2]

In [15]:
inc_arr = np.array(inc_idxs)

In [16]:
np.save("QCD_2-3.npy",inc_arr)

In [11]:
inc_idxs = np.load("EW_2-2-3-vanilla_inc.npy")

In [12]:
inc_set = set(df_sorted.iloc[inc_idxs].amp.tolist())

In [13]:
len(inc_set)

5117

In [14]:
complete_set = set(df_sorted.amp.tolist())

In [15]:
correct_preds = len(complete_set - inc_set)

In [16]:
# set of correctly predicted amplitudes
corr_set = complete_set - inc_set  

# keep only rows where amp is in corr_set
df_correct = df_sorted[df_sorted["amp"].isin(corr_set)].copy()

In [17]:
df_corr_unique = df_correct.drop_duplicates(subset=["amp"], keep="first")
df_unique = df_test.drop_duplicates(subset=['amp'], keep="first")

In [18]:
len(corr_set) / len(complete_set)

0.835645917646303

In [19]:
df_unique

Unnamed: 0,amp,sqamp,process,length
0,"-1/2*i*e^2*sin(theta_W)^(-2)*P_L_{ % INDEX_0, ...",<BOS>e^4*MOMENTUM_12*MOMENTUM_34*(m_W^2+-m_d^2...,2-2,44
1,1/12*i*e*sin(theta_W)*(e*sin(theta_W)/cos(thet...,<BOS>1/36*e^2*MOMENTUM_13*MOMENTUM_24*(m_u^2+-...,2-2,73
2,"1/3*i*e^2*P_L_{ % INDEX_0, % INDEX_1}*gamma_{+...",<BOS>-4/9*e^4*(m_e^2*MOMENTUM_13+-MOMENTUM_14*...,2-2,48
3,"1/9*i*e^2*P_L_{ % INDEX_0, % INDEX_1}*P_L_{ % ...",<BOS>4/81*e^4*MOMENTUM_14*MOMENTUM_23*(m_u^2+-...,2-2,33
4,-1/18*i*e*sin(theta_W)*(e*sin(theta_W)/cos(the...,<BOS>1/81*e^2*MOMENTUM_12*MOMENTUM_34*(m_u^2+-...,2-2,74
...,...,...,...,...
81069,"1/4*i*e*v^(-2)*m_e*m_u*(P_L_{ % INDEX_0, % IND...",<BOS>1/2*e^2*v^(-4)*m_e^2*m_u^2*MOMENTUM_12*MO...,2-3,79
81071,-1/54*i*e^2*sin(theta_W)*(e*sin(theta_W)/cos(t...,<BOS>4/729*e^4*MOMENTUM_13*MOMENTUM_14*MOMENTU...,2-3,94
81073,1/27*i*e^3*cos(theta_W)^(-2)*sin(theta_W)^2*(P...,<BOS>16/729*e^6*MOMENTUM_13*MOMENTUM_15*MOMENT...,2-3,76
81075,-1/54*i*e^2*sin(theta_W)*(e*sin(theta_W)/cos(t...,<BOS>4/729*e^4*MOMENTUM_13*MOMENTUM_14*MOMENTU...,2-3,94


In [20]:
df_corr_unique 

Unnamed: 0,amp,sqamp,process,length
11,"-1/2*i*e^2*(P_L_{ % INDEX_0, % INDEX_1}*MOMENT...",<BOS>e^4*MOMENTUM_13*MOMENTUM_34*(MOMENTUM_12+...,2-2,24
12,"-1/2*i*e^2*(P_L_{ % INDEX_0, % INDEX_1}*MOMENT...",<BOS>e^4*MOMENTUM_13*MOMENTUM_34*(MOMENTUM_12+...,2-2,24
13,"-1/2*i*e^2*(P_R_{ % INDEX_0, % INDEX_1}*MOMENT...",<T1>e^4*MOMENTUM_23*MOMENTUM_24*(MOMENTUM_23+-...,2-2,25
14,"-1/2*i*e^2*(P_R_{ % INDEX_0, % INDEX_1}*MOMENT...",<T1>e^4*MOMENTUM_23*MOMENTUM_24*(MOMENTUM_23+-...,2-2,25
15,"-1/2*i*e^2*(P_R_{ % INDEX_0, % INDEX_1}*MOMENT...",<T1>e^4*MOMENTUM_23*MOMENTUM_24*(MOMENTUM_23+-...,2-2,25
...,...,...,...,...
79782,-1/216*i*e^2*(e*sin(theta_W)/cos(theta_W)+3*e*...,<T2>1/17496*i*e^2*(e*sin(theta_W)/cos(theta_W)...,2-3,797
79811,1/8*i*e^2*(e*sin(theta_W)/cos(theta_W)+-e*cos(...,<T1>-1/12*i*e^2*(e*sin(theta_W)/cos(theta_W)+-...,2-3,800
79813,1/8*i*e^2*(e*sin(theta_W)/cos(theta_W)+-e*cos(...,<T1>-1/12*i*e^2*(e*sin(theta_W)/cos(theta_W)+-...,2-3,800
79854,1/54*i*e^2*(e*sin(theta_W)/cos(theta_W)+(-3)*e...,<T2>2/2187*i*e^2*(e*sin(theta_W)/cos(theta_W)+...,2-3,809


In [22]:
df_test_og = pd.read_csv("/pscratch/sd/r/ritesh11/SYMBA_arxiv/data/EW/EW_2-2-3test.csv")

In [23]:
amp_process_dict = dict(zip(df_test_og['amp'], df_test_og['process']))

In [24]:
# Add process column to both dataframes using amp as key
df_corr_unique['process'] = df_corr_unique['amp'].map(amp_process_dict)
df_unique['process'] = df_unique['amp'].map(amp_process_dict)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_corr_unique['process'] = df_corr_unique['amp'].map(amp_process_dict)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_unique['process'] = df_unique['amp'].map(amp_process_dict)


In [25]:
count_22_corr = (df_corr_unique['process'] == "2-2").sum()
count_23_corr = (df_corr_unique['process'] == "2-3").sum()
count_24_corr = (df_corr_unique['process'] == "2-4").sum()

count_22 = (df_unique['process'] == "2-2").sum()
count_23 = (df_unique['process'] == "2-3").sum()
count_24 = (df_unique['process'] == "2-4").sum()

In [26]:
count_24

np.int64(0)

In [27]:
count_22_corr / count_22

np.float64(0.837037037037037)

In [28]:
count_23_corr / count_23

np.float64(0.83558942478024)

In [30]:
count_24_corr / count_24

np.float64(0.0014064697609001407)

In [31]:
len(complete_set)

603

In [31]:
correct_preds / len(complete_set)

0.1433990895295903