In [1]:
import matplotlib.pyplot as plt
import uproot
import numpy as np
import pandas as pd
from pathlib import Path


import sys
sys.path.insert(1, '/afs/desy.de/user/a/axelheim/private/MC_studies/Dstlnu_Bt_generic/util_funcs/')
from pandas_colFuncs import B_ID, whichBisSig, D0_decay_type, whichBisSig_NAHS


## load the NN

In [2]:
nn_vars = ["px","py","pz","E","M","charge","dr","dz","clusterReg","clusterE9E21","pionID","kaonID","electronID","muonID","protonID"]

In [3]:
sys.path.append('/afs/desy.de/user/a/axelheim/private/baumbauen/notebooks/')
from BranchSeparatorModel import BranchSeparatorModel
# See below why I put this



model_dir="/nfs/dust/belle2/user/axelheim/MC_studies/Dstlnu_Bt_generic/saved_models/NAHSA_Gmodes_fixedD0modes/NAHS_allEvts_twoSubs_fixedD0run/NAHSA_no_xyz/256_0_64_0.1_4/"
checkpoint_name = "model_checkpoint_model_perfectSA=0.7674.pt"
specs_output_label = "256_0_64_0.1_4"
num_classes = 3    


specs = specs_output_label.split("_")

model = BranchSeparatorModel(infeatures=len(nn_vars),
            dim_feedforward=int(specs[0]),
            num_classes=num_classes,
            dropout=float(specs[3]),
            nblocks=int(specs[4]))



import torch

checkpoint = torch.load(model_dir +  checkpoint_name, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint)

Using factor graph MLP encoder.


<All keys matched successfully>

In [None]:
model

## load the online data

In [4]:
nfs_path = "/nfs/dust/belle2/user/axelheim/MC_studies/Dstlnu_Bt_generic/appliedNNdata/10thRun/"
#nfs_path="/afs/desy.de/user/a/axelheim/private/MC_studies/Dstlnu_Bt_generic/load_NN_to_basf2/productive_method/testOut/"

In [5]:
FSPs_file = uproot.open(nfs_path + "FSPs.root:variables;1")
df_FSPs = FSPs_file.arrays(library="pd")

In [6]:
Ups4S_file = uproot.open(nfs_path + "Ups4S_NN_predicted.root:variables;1")
df_Ups4S = Ups4S_file.arrays(library="pd")

## add labels

In [7]:
df_FSPs['B_ID'] = df_FSPs.apply(B_ID, axis=1)

In [8]:
Hc_motherB_df = df_FSPs[df_FSPs["NN_prediction"].isna() == True].drop_duplicates(subset=("__event__"), keep='first')
Hc_motherB_df["B_tag_ID"] = Hc_motherB_df["B_ID"]
df_FSPs = pd.merge(df_FSPs,Hc_motherB_df[["__event__","__production__","B_tag_ID"]], on=["__event__","__production__"])

In [9]:
def labels(s):
    label = -1
    if int(s['B_ID']) == 0:
        label = 0 # background, cause not related to MC Particles
    else: 
        B_tagID = s['B_tag_ID']
        
        if int(s['B_ID']) == B_tagID:
            label = 1 # X
        else:
            label = 2 # Bsig
    return label
df_FSPs['label'] = df_FSPs.apply(labels, axis=1)

In [10]:
df_FSPs["correct_pred_onlineNN"] = (df_FSPs["label"] == df_FSPs["NN_prediction"]).astype(int)

In [11]:
df_FSPs.shape[0]

503316

## check DO decays

In [12]:
df_Ups4S['Bsig_uniqParID'] = df_Ups4S.apply(whichBisSig_NAHS, axis=1)

In [13]:
df_Ups4S['D0_decay'] = df_Ups4S.apply(D0_decay_type, axis=1)

In [14]:
df_Ups4S['D0_decay'].value_counts()

notWanted     90057
Kpipi0        16833
Kpipipipi0    12039
Kpipipi        8745
Kpi            3508
Name: D0_decay, dtype: int64

In [15]:
df_Ups4S = df_Ups4S[df_Ups4S['D0_decay'] != "notWanted"]

In [16]:
df_FSPs = df_FSPs[df_FSPs['__event__'].isin(df_Ups4S["__event__"])]

## prepare input for NN

In [17]:
nonHc_FSPs = df_FSPs[df_FSPs["NN_prediction"].notna() == True]

In [18]:
nonHc_FSPs["NN_prediction"].describe()

count    128494.000000
mean          1.212033
std           0.735228
min           0.000000
25%           1.000000
50%           1.000000
75%           2.000000
max           2.000000
Name: NN_prediction, dtype: float64

In [19]:
nonHc_FSPs.shape[0]

128494

In [20]:
nonHc_FSPs["offline_NN_pred"] = -1
nonHc_FSPs["offline_NN_pred_shuffled"] = -1

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  nonHc_FSPs["offline_NN_pred"] = -1
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  nonHc_FSPs["offline_NN_pred_shuffled"] = -1


In [21]:
evts = nonHc_FSPs["__event__"].unique()

In [22]:
print(len(evts))

8657


In [23]:
evts[2]

2337804

In [24]:
labels=[]
evtnum=[]
online_NN_pred=[]
offline_NN_pred=[]
offline_NN_pred_shuff=[]

for i in range(len(evts)):
    one_evt = nonHc_FSPs[(nonHc_FSPs["__event__"] == evts[i])]
    
    if (i % 100) == 0:
        print("processing evt",i,"of",len(evts))
    
    
    num_particles = one_evt.shape[0]
    
    tmp_par_vars = []
    
    for j in range(num_particles):
        particle = one_evt.iloc[j]
        #print(particle)
        readOut_features = [particle[var] for var in nn_vars]
        tmp_par_vars.append(readOut_features)

    NN_input_features = np.array([np.array(xi) for xi in tmp_par_vars])

    # impute the nan values with -1. (check if that's logical for all values if input vars get changed)
    NN_input_features= np.nan_to_num(NN_input_features, copy=False, nan=-1.0)

    shape = NN_input_features.shape
    NN_input_features = NN_input_features.reshape(shape[0], 1, shape[1])

    #print("NN_input_features.shape:",NN_input_features.shape)
    NN_input_features = torch.Tensor(NN_input_features)
    #print(NN_input_features.shape[0])

    SA_pred = model(NN_input_features)

    probs = torch.softmax(SA_pred, dim=1)  # (N, C, d1)
    winners = probs.argmax(dim=1)
    
    
    shape = NN_input_features.shape
    r=torch.randperm(shape[0])
    shuffled_input=NN_input_features[r, :, :]
    
    shuffled_SA_pred = model(shuffled_input)

    shuffled_probs = torch.softmax(shuffled_SA_pred, dim=1)  # (N, C, d1)
    shuffled_winners = shuffled_probs.argmax(dim=1)
    #print("r:",r)
    for j in range(num_particles):
        particle = one_evt.iloc[j]
        labels.append(particle["label"].item())
        evtnum.append(particle["__event__"].item())
        online_NN_pred.append(particle["NN_prediction"].item())
        #particle["offline_NN_pred"] = winners[0,j].item()
        offline_NN_pred.append(winners[0,j].item())
        
        
        index_Shuffreversed = (r == j).nonzero(as_tuple=True)[0].item()
        #print("j:",j,"index_Shuffreversed:",index_Shuffreversed)
        offline_NN_pred_shuff.append(shuffled_winners[0,index_Shuffreversed].item())
        
        #particle["offline_NN_pred_shuffled"] = shuffled_winners[0,index_Shuffreversed].item()
        #one_evt.iloc[j] = particle
        
    #nonHc_FSPs[(nonHc_FSPs["__event__"] == evts[i])] = one_evt 


processing evt 0 of 8657
processing evt 100 of 8657
processing evt 200 of 8657
processing evt 300 of 8657
processing evt 400 of 8657
processing evt 500 of 8657
processing evt 600 of 8657
processing evt 700 of 8657
processing evt 800 of 8657
processing evt 900 of 8657
processing evt 1000 of 8657
processing evt 1100 of 8657
processing evt 1200 of 8657
processing evt 1300 of 8657
processing evt 1400 of 8657
processing evt 1500 of 8657
processing evt 1600 of 8657
processing evt 1700 of 8657
processing evt 1800 of 8657
processing evt 1900 of 8657
processing evt 2000 of 8657
processing evt 2100 of 8657
processing evt 2200 of 8657
processing evt 2300 of 8657
processing evt 2400 of 8657
processing evt 2500 of 8657
processing evt 2600 of 8657
processing evt 2700 of 8657
processing evt 2800 of 8657
processing evt 2900 of 8657
processing evt 3000 of 8657
processing evt 3100 of 8657
processing evt 3200 of 8657
processing evt 3300 of 8657
processing evt 3400 of 8657
processing evt 3500 of 8657
proc

In [25]:
NN_results = pd.DataFrame({'__event__': evtnum,
                          'label' : labels,
                          'online_NN_pred' : online_NN_pred,
                          'offline_NN_pred' : offline_NN_pred,
                          'offline_NN_pred_shuff' : offline_NN_pred_shuff})

In [26]:
NN_results.to_csv(nfs_path + "NN_results.csv")

In [27]:
NN_results = pd.read_csv(nfs_path + "NN_results.csv")

In [28]:
NN_results["off_eq_on"] = (NN_results["offline_NN_pred"] == NN_results["online_NN_pred"]).astype(int)
NN_results["off_eq_on_shuffled"] = (NN_results["offline_NN_pred_shuff"] == NN_results["online_NN_pred"]).astype(int)
NN_results["unshuff_eq_shuff"] = (NN_results["offline_NN_pred_shuff"] == NN_results["offline_NN_pred"]).astype(int)

In [29]:
NN_results["correct_pred_onlineNN"] = (NN_results["label"] == NN_results["online_NN_pred"]).astype(int)
NN_results["correct_pred_offlineNN"] = (NN_results["label"] == NN_results["offline_NN_pred"]).astype(int)
NN_results["correct_pred_offlineNN_shuff"] = (NN_results["label"] == NN_results["offline_NN_pred_shuff"]).astype(int)

In [30]:
for var in ["off_eq_on","off_eq_on_shuffled","unshuff_eq_shuff","correct_pred_onlineNN","correct_pred_offlineNN","correct_pred_offlineNN_shuff"]:
    print(var)
    print(NN_results[var].mean()*100, '% \n')

off_eq_on
80.1220290441577 % 

off_eq_on_shuffled
80.0854514607686 % 

unshuff_eq_shuff
82.43886874095288 % 

correct_pred_onlineNN
67.37279561691597 % 

correct_pred_offlineNN
64.2528055784706 % 

correct_pred_offlineNN_shuff
64.15552477158467 % 



In [31]:
off_on_pred = pd.DataFrame({'count' : NN_results.groupby( ["off_eq_on",
                    "correct_pred_onlineNN","correct_pred_offlineNN"] ).size() / NN_results.shape[0] }).reset_index()
off_on_pred

Unnamed: 0,off_eq_on,correct_pred_onlineNN,correct_pred_offlineNN,count
0,0,0,0,0.027511
1,0,0,1,0.070034
2,0,1,0,0.101234
3,1,0,0,0.228727
4,1,1,1,0.572494


## compare only the events used offline (val set) with the online results

In [32]:
# load val set

In [33]:
data_folder = "NAHSA_Gmodes_fixedD0modes/NAHS_allEvts_twoSubs_fixedD0run/NAHSA_no_xyz"
run_folder = "MC_studies/Dstlnu_Bt_generic/" #"run_HcX_globTag/" 
run_path = "/nfs/dust/belle2/user/axelheim/" + run_folder
bsize = 16
dataset_dir = Path(run_path + 'data/' + data_folder)


sys.path.insert(1, '/afs/desy.de/user/a/axelheim/private/baumbauen/notebooks/')
from ah_utils import pad_collate_fn_ah, PhasespaceSet_BranchSeparator

collate = pad_collate_fn_ah
test_set = PhasespaceSet_BranchSeparator(dataset_dir, 'val')
train_set = PhasespaceSet_BranchSeparator(dataset_dir, 'train')
import torch
train_loader = torch.utils.data.DataLoader(train_set, batch_size=bsize, drop_last=True, shuffle=True,  collate_fn=collate)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=bsize, drop_last=True, shuffle=True,  collate_fn=collate)

In [34]:
dataiter = iter(train_loader)
evts_train=[]
for batch_num, batch in enumerate(dataiter):
    features, labels, tag = batch
    tag = np.array(tag)
    for i in range(tag.shape[0]):
        tag_tmp = str(tag[i,-1])
        event = int(tag_tmp[tag_tmp.find("evt")+3:-1])
        evts_train.append(event)

In [35]:
dataiter = iter(test_loader)
evts=[]
preds=[]
labs=[]
corrects=[]

for batch_num, batch in enumerate(dataiter):
    if batch_num % 10 == 0:
        print("#",(batch_num),"of",len(test_loader),"batches")
    features, labels, tag = batch
    tag = np.array(tag)
    
    SA_pred = model(features)

    probs = torch.softmax(SA_pred, dim=1)  # (N, C, d1)
    winners = probs.argmax(dim=1)
    #print("winners.shape:",winners.shape)
    #print("labels.shape:",labels.shape)

    
    
    for i in range(winners.shape[0]):
        tag_tmp = str(tag[i,-1])
        event = int(tag_tmp[tag_tmp.find("evt")+3:-1])
        #print("winners[i,:]:",winners[i,:])
        #print("labels[i,:]:",labels[i,:])
        for j in range(winners.shape[1]):
            evts.append(event)

            pred = winners[i,j].item()
            preds.append(pred)
            
            label = labels[i,j].item()
            labs.append(label)
            
            #print("label:",label,"winners:",pred)
            
            correct = int(label == pred)
            corrects.append(correct)
            
            

# 0 of 3367 batches
# 10 of 3367 batches
# 20 of 3367 batches
# 30 of 3367 batches
# 40 of 3367 batches
# 50 of 3367 batches
# 60 of 3367 batches
# 70 of 3367 batches
# 80 of 3367 batches
# 90 of 3367 batches
# 100 of 3367 batches
# 110 of 3367 batches
# 120 of 3367 batches
# 130 of 3367 batches
# 140 of 3367 batches
# 150 of 3367 batches
# 160 of 3367 batches
# 170 of 3367 batches
# 180 of 3367 batches
# 190 of 3367 batches
# 200 of 3367 batches
# 210 of 3367 batches
# 220 of 3367 batches
# 230 of 3367 batches
# 240 of 3367 batches
# 250 of 3367 batches
# 260 of 3367 batches
# 270 of 3367 batches
# 280 of 3367 batches
# 290 of 3367 batches
# 300 of 3367 batches
# 310 of 3367 batches
# 320 of 3367 batches
# 330 of 3367 batches
# 340 of 3367 batches
# 350 of 3367 batches
# 360 of 3367 batches
# 370 of 3367 batches
# 380 of 3367 batches
# 390 of 3367 batches
# 400 of 3367 batches
# 410 of 3367 batches
# 420 of 3367 batches
# 430 of 3367 batches
# 440 of 3367 batches
# 450 of 3367 batches

In [36]:
offline_preds_padded = pd.DataFrame({"event":evts, "label":labs,"prediction":preds,"correct":corrects})

In [37]:
# get rid of padded entries
offline_preds_padded = offline_preds_padded[offline_preds_padded["label"] != -1]

In [38]:
offline_preds_padded.shape[0]

695854

In [40]:
pd.set_option('display.float_format', lambda x: '%.3f' % x)
offline_preds_padded["correct"].describe()

count   695854.000
mean         0.750
std          0.433
min          0.000
25%          1.000
50%          1.000
75%          1.000
max          1.000
Name: correct, dtype: float64

In [41]:
on_and_offline = NN_results[NN_results['__event__'].isin(evts)]
on_and_offline_train = NN_results[NN_results['__event__'].isin(evts_train)]

In [42]:
print(on_and_offline['__event__'].nunique())
print(on_and_offline_train['__event__'].nunique())

1098
5344


In [43]:
for var in ["off_eq_on","off_eq_on_shuffled","unshuff_eq_shuff","correct_pred_onlineNN","correct_pred_offlineNN","correct_pred_offlineNN_shuff"]:
    print(var)
    print(NN_results[var].mean()*100, '%')    
    print(on_and_offline[var].mean()*100, '%')    
    print(on_and_offline_train[var].mean()*100, '% \n')

off_eq_on
80.1220290441577 %
81.84001921691089 %
81.19207115502164 % 

off_eq_on_shuffled
80.0854514607686 %
81.70790295460006 %
81.09476940733256 % 

unshuff_eq_shuff
82.43886874095288 %
83.4854672111458 %
83.38635031124085 % 

correct_pred_onlineNN
67.37279561691597 %
69.54720153735286 %
70.48887890922246 % 

correct_pred_offlineNN
64.2528055784706 %
67.12106653855393 %
67.62720956052044 % 

correct_pred_offlineNN_shuff
64.15552477158467 %
67.15709824645688 %
67.65091126829086 % 



## do offline padded again BUT with bsize = 1 

In [64]:
model.eval()

BranchSeparatorModel(
  (initial_mlp): Sequential(
    (0): MLP(
      (fc1): Linear(in_features=15, out_features=256, bias=True)
      (fc2): Linear(in_features=256, out_features=256, bias=True)
      (bn): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (mlp2): MLP(
    (fc1): Linear(in_features=512, out_features=256, bias=True)
    (fc2): Linear(in_features=256, out_features=256, bias=True)
    (bn): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (blocks): ModuleList(
    (0): ModuleList(
      (0): MLP(
        (fc1): Linear(in_features=256, out_features=256, bias=True)
        (fc2): Linear(in_features=256, out_features=256, bias=True)
        (bn): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): MLP(
        (fc1): Linear(in_features=768, out_features=256, bias=True)
        (fc2): Linear(in_features=256, out_features=256, bias=True)
        (bn

In [65]:
test_loader_b1 = torch.utils.data.DataLoader(test_set, batch_size=1, drop_last=True, shuffle=True,  collate_fn=collate)
dataiter = iter(test_loader_b1)
evts=[]
preds=[]
labs=[]
corrects=[]

for batch_num, batch in enumerate(dataiter):
    if batch_num % 1000 == 0:
        print("#",(batch_num),"of",len(test_loader_b1),"batches")
    features, labels, tag = batch
    tag = np.array(tag)
    shape = features.shape
    pad_dim = 22
    num_particles = shape[0]
    pad_input = (num_particles < pad_dim)
    #print("features.shape:",features.shape)

    # padd_array is needed if padding is done, but needs to be initialized outside of the if clause
    N = pad_dim 
    K = pad_dim - num_particles # K zeros, N-K ones
    padd_array = np.array([0] * K + [1] * (N-K))
    np.random.shuffle(padd_array)
    #print(padd_array)


    #print("pad_input:",pad_input)
    if pad_input:
        #print("in pad_input:   num_particles:", num_particles)


        torch.set_printoptions(threshold=10_000)        
        #print("shuffled_input:",) # prints the whole tensor

        input_padded = torch.ones(pad_dim, 1, shape[2]) * 0.
        
        #print("input_padded.shape:",input_padded.shape)
        #print("features.shape:",features.shape)
        
        particle_counter = 0
        for i in range(pad_dim):
            do_pad = padd_array[i]
            if do_pad == 1:    
                #print(i)
                #print(x[particle_counter,:,:])
                input_padded[i,:,:] = features[particle_counter,:,:]
                particle_counter += 1

        #print(input_padded)
        features = input_padded
    
    
    SA_pred = model(features)

    probs = torch.softmax(SA_pred, dim=1)  # (N, C, d1)
    winners = probs.argmax(dim=1)
    #print("winners.shape:",winners.shape)
    #print("winners:",winners)

    



    # unpad winners if pad_input
    if pad_input:
        winners_unpadded = torch.empty(1, num_particles)
        particle_counter = 0
        for i in range(pad_dim):
            do_pad = padd_array[i]
            if do_pad == 1:  
                winners_unpadded[0, particle_counter] = winners[0, i]
                particle_counter += 1

        #print(winners_unpadded)

        winners=winners_unpadded

    #print("winners.shape:",winners.shape)
    #print("winners:",winners)
    
    
    
    
    for i in range(winners.shape[0]):
        tag_tmp = str(tag[i,-1])
        event = int(tag_tmp[tag_tmp.find("evt")+3:-1])
        #print("winners[i,:]:",winners[i,:])
        #print("labels[i,:]:",labels[i,:])
        for j in range(winners.shape[1]):
            evts.append(event)

            pred = winners[i,j].item()
            preds.append(pred)
            
            label = labels[i,j].item()
            labs.append(label)
            
            #print("label:",label,"winners:",pred)
            
            correct = int(label == pred)
            corrects.append(correct)
            
            

# 0 of 53877 batches
# 1000 of 53877 batches
# 2000 of 53877 batches
# 3000 of 53877 batches
# 4000 of 53877 batches
# 5000 of 53877 batches
# 6000 of 53877 batches
# 7000 of 53877 batches
# 8000 of 53877 batches
# 9000 of 53877 batches
# 10000 of 53877 batches
# 11000 of 53877 batches
# 12000 of 53877 batches
# 13000 of 53877 batches
# 14000 of 53877 batches
# 15000 of 53877 batches
# 16000 of 53877 batches
# 17000 of 53877 batches
# 18000 of 53877 batches
# 19000 of 53877 batches
# 20000 of 53877 batches
# 21000 of 53877 batches
# 22000 of 53877 batches
# 23000 of 53877 batches
# 24000 of 53877 batches
# 25000 of 53877 batches
# 26000 of 53877 batches
# 27000 of 53877 batches
# 28000 of 53877 batches
# 29000 of 53877 batches
# 30000 of 53877 batches
# 31000 of 53877 batches
# 32000 of 53877 batches
# 33000 of 53877 batches
# 34000 of 53877 batches
# 35000 of 53877 batches
# 36000 of 53877 batches
# 37000 of 53877 batches
# 38000 of 53877 batches
# 39000 of 53877 batches
# 40000 of 53

In [66]:
offline_preds_padded_b1_eval = pd.DataFrame({"event":evts, "label":labs,"prediction":preds,"correct":corrects})

In [67]:
offline_preds_padded_b1_eval["correct"].mean()

0.7630609400370727

In [59]:
offline_preds_padded_b1.shape[0]

695930

In [60]:
# get rid of padded entries
offline_preds_padded_b1 = offline_preds_padded_b1[offline_preds_padded_b1["label"] != -1]

In [61]:
offline_preds_padded_b1.shape[0]

695930

In [62]:
pd.set_option('display.float_format', lambda x: '%.3f' % x)
offline_preds_padded_b1["correct"].describe()

count   695930.000
mean         0.664
std          0.472
min          0.000
25%          0.000
50%          1.000
75%          1.000
max          1.000
Name: correct, dtype: float64

## check results for different batch sizes:

In [63]:
for bsize in [16,64,128]:
    test_loader_var = torch.utils.data.DataLoader(test_set, batch_size=bsize, drop_last=True, shuffle=True,  collate_fn=collate)
    dataiter = iter(test_loader_var)
    evts=[]
    preds=[]
    labs=[]
    corrects=[]

    for batch_num, batch in enumerate(dataiter):
        if batch_num % 10 == 0:
            print("#",(batch_num),"of",len(test_loader_var),"batches")
        features, labels, tag = batch
        tag = np.array(tag)

        SA_pred = model(features)

        probs = torch.softmax(SA_pred, dim=1)  # (N, C, d1)
        winners = probs.argmax(dim=1)
        #print("winners.shape:",winners.shape)
        #print("labels.shape:",labels.shape)



        for i in range(winners.shape[0]):
            tag_tmp = str(tag[i,-1])
            event = int(tag_tmp[tag_tmp.find("evt")+3:-1])
            #print("winners[i,:]:",winners[i,:])
            #print("labels[i,:]:",labels[i,:])
            for j in range(winners.shape[1]):
                evts.append(event)

                pred = winners[i,j].item()
                preds.append(pred)

                label = labels[i,j].item()
                labs.append(label)

                #print("label:",label,"winners:",pred)

                correct = int(label == pred)
                corrects.append(correct)

    offline_preds_padded_var = pd.DataFrame({"event":evts, "label":labs,"prediction":preds,"correct":corrects})
    print(offline_preds_padded_var.shape[0])    
    offline_preds_padded_var = offline_preds_padded_var[offline_preds_padded_var["label"] != -1]
    print(offline_preds_padded_var.shape[0])
    print(offline_preds_padded_var["correct"].mean())

# 0 of 3367 batches
# 10 of 3367 batches
# 20 of 3367 batches
# 30 of 3367 batches
# 40 of 3367 batches
# 50 of 3367 batches
# 60 of 3367 batches
# 70 of 3367 batches
# 80 of 3367 batches
# 90 of 3367 batches
# 100 of 3367 batches
# 110 of 3367 batches
# 120 of 3367 batches
# 130 of 3367 batches
# 140 of 3367 batches
# 150 of 3367 batches
# 160 of 3367 batches
# 170 of 3367 batches
# 180 of 3367 batches
# 190 of 3367 batches
# 200 of 3367 batches
# 210 of 3367 batches
# 220 of 3367 batches
# 230 of 3367 batches
# 240 of 3367 batches
# 250 of 3367 batches
# 260 of 3367 batches
# 270 of 3367 batches
# 280 of 3367 batches
# 290 of 3367 batches
# 300 of 3367 batches
# 310 of 3367 batches
# 320 of 3367 batches
# 330 of 3367 batches
# 340 of 3367 batches
# 350 of 3367 batches
# 360 of 3367 batches
# 370 of 3367 batches
# 380 of 3367 batches
# 390 of 3367 batches
# 400 of 3367 batches
# 410 of 3367 batches
# 420 of 3367 batches
# 430 of 3367 batches
# 440 of 3367 batches
# 450 of 3367 batches

KeyboardInterrupt: 

## turn batchnorm off and see what happens

In [None]:
 batchnorm=False

## do it again with offline, but unpadded!

In [44]:
bsize = 1
test_loader_unpadded = torch.utils.data.DataLoader(test_set, batch_size=bsize, drop_last=True, shuffle=True,  collate_fn=collate)

In [45]:
dataiter = iter(test_loader_unpadded)
evts=[]
preds=[]
labs=[]
corrects=[]

for batch_num, batch in enumerate(dataiter):
    if batch_num % 1000 == 0:
        print("#",(batch_num),"of",len(test_loader_unpadded),"batches")
    features, labels, tag = batch
    tag = np.array(tag)
    
    SA_pred = model(features)

    probs = torch.softmax(SA_pred, dim=1)  # (N, C, d1)
    winners = probs.argmax(dim=1)
    #print("winners.shape:",winners.shape)
    #print("labels.shape:",labels.shape)

    
    
    for i in range(winners.shape[0]):
        tag_tmp = str(tag[i,-1])
        event = int(tag_tmp[tag_tmp.find("evt")+3:-1])
        #print("winners[i,:]:",winners[i,:])
        #print("labels[i,:]:",labels[i,:])
        for j in range(winners.shape[1]):
            evts.append(event)

            pred = winners[i,j].item()
            preds.append(pred)
            
            label = labels[i,j].item()
            labs.append(label)
            
            #print("label:",label,"winners:",pred)
            
            correct = int(label == pred)
            corrects.append(correct)  

# 0 of 53877 batches
# 1000 of 53877 batches
# 2000 of 53877 batches
# 3000 of 53877 batches
# 4000 of 53877 batches
# 5000 of 53877 batches
# 6000 of 53877 batches
# 7000 of 53877 batches
# 8000 of 53877 batches
# 9000 of 53877 batches
# 10000 of 53877 batches
# 11000 of 53877 batches
# 12000 of 53877 batches
# 13000 of 53877 batches
# 14000 of 53877 batches
# 15000 of 53877 batches
# 16000 of 53877 batches
# 17000 of 53877 batches
# 18000 of 53877 batches
# 19000 of 53877 batches
# 20000 of 53877 batches
# 21000 of 53877 batches
# 22000 of 53877 batches


KeyboardInterrupt: 

In [None]:
offline_preds_unpadded = pd.DataFrame({"event":evts, "label":labs,"prediction":preds,"correct":corrects})

In [None]:
offline_preds_unpadded["correct"].mean()

In [None]:
offline_preds_padded.shape[0]

In [None]:
offline_preds_padded["correct"].mean()

In [None]:
offline_preds_padded.shape[0]

In [None]:
nfs_path = "/nfs/dust/belle2/user/axelheim/MC_studies/Dstlnu_Bt_generic/appliedNNdata/10thRun/"

## save offline padded/unpadded results

In [None]:
offline_preds_unpadded.to_csv(nfs_path + "offline_preds_unpadded.csv")
offline_preds_padded.to_csv(nfs_path + "offline_preds_padded.csv")

## load offline padded/unpadded results

In [4]:
offline_preds_unpadded = pd.read_csv(nfs_path + "offline_preds_unpadded.csv")
offline_preds_padded = pd.read_csv(nfs_path + "offline_preds_padded.csv")
