In [5]:
import matplotlib.pyplot as plt
import uproot
import numpy as np
import pandas as pd
from pathlib import Path


import sys
sys.path.insert(1, '/afs/desy.de/user/a/axelheim/private/MC_studies/Dstlnu_Bt_generic/util_funcs/')
from pandas_colFuncs import B_ID, whichBisSig, D0_decay_type, whichBisSig_NAHS


## load the NN

In [6]:
nn_vars = ["px","py","pz","E","M","charge","dr","dz","clusterReg","clusterE9E21","pionID","kaonID","electronID","muonID","protonID"]

In [7]:
sys.path.append('/afs/desy.de/user/a/axelheim/private/baumbauen/notebooks/')
from BranchSeparatorModel import BranchSeparatorModel
# See below why I put this



model_dir="/nfs/dust/belle2/user/axelheim/MC_studies/Dstlnu_Bt_generic/saved_models/NAHSA_Gmodes_fixedD0modes/NAHS_allEvts_twoSubs_fixedD0run/NAHSA_no_xyz/256_0_64_0.1_4/"
checkpoint_name = "model_checkpoint_model_perfectSA=0.7674.pt"
specs_output_label = "256_0_64_0.1_4"
num_classes = 3    


specs = specs_output_label.split("_")

model = BranchSeparatorModel(infeatures=len(nn_vars),
            dim_feedforward=int(specs[0]),
            num_classes=num_classes,
            dropout=float(specs[3]),
            nblocks=int(specs[4]))



import torch

checkpoint = torch.load(model_dir +  checkpoint_name, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint)

Using factor graph MLP encoder.


<All keys matched successfully>

## load the online data

In [8]:
nfs_path = "/nfs/dust/belle2/user/axelheim/MC_studies/Dstlnu_Bt_generic/appliedNNdata/8thRun/"
#nfs_path="/afs/desy.de/user/a/axelheim/private/MC_studies/Dstlnu_Bt_generic/load_NN_to_basf2/productive_method/testOut/"

In [5]:
FSPs_file = uproot.open(nfs_path + "FSPs.root:variables;1")
df_FSPs = FSPs_file.arrays(library="pd")

## add labels

In [6]:
df_FSPs['B_ID'] = df_FSPs.apply(B_ID, axis=1)

In [7]:
Hc_motherB_df = df_FSPs[df_FSPs["NN_prediction"].isna() == True].drop_duplicates(subset=("__event__"), keep='first')
Hc_motherB_df["B_tag_ID"] = Hc_motherB_df["B_ID"]
df_FSPs = pd.merge(df_FSPs,Hc_motherB_df[["__event__","__production__","B_tag_ID"]], on=["__event__","__production__"])

In [8]:
def labels(s):
    label = -1
    if int(s['B_ID']) == 0:
        label = 0 # background, cause not related to MC Particles
    else: 
        B_tagID = s['B_tag_ID']
        
        if int(s['B_ID']) == B_tagID:
            label = 1 # X
        else:
            label = 2 # Bsig
    return label
df_FSPs['label'] = df_FSPs.apply(labels, axis=1)

In [9]:
df_FSPs["correct_pred_onlineNN"] = (df_FSPs["label"] == df_FSPs["NN_prediction"]).astype(int)

In [10]:
df_FSPs.shape[0]

848135

## prepare input for NN

In [11]:
nonHc_FSPs = df_FSPs[df_FSPs["NN_prediction"].notna() == True]

In [12]:
nonHc_FSPs["NN_prediction"].describe()

count    702287.000000
mean          1.159428
std           0.760966
min           0.000000
25%           1.000000
50%           1.000000
75%           2.000000
max           2.000000
Name: NN_prediction, dtype: float64

In [13]:
nonHc_FSPs.shape[0]

702287

In [14]:
nonHc_FSPs["offline_NN_pred"] = -1
nonHc_FSPs["offline_NN_pred_shuffled"] = -1

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  nonHc_FSPs["offline_NN_pred"] = -1
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  nonHc_FSPs["offline_NN_pred_shuffled"] = -1


In [15]:
evts = nonHc_FSPs["__event__"].unique()

In [16]:
print(len(evts))

44945


In [17]:
evts[2]

2336782

In [18]:
labels=[]
evtnum=[]
online_NN_pred=[]
offline_NN_pred=[]
offline_NN_pred_shuff=[]

for i in range(len(evts)):
    one_evt = nonHc_FSPs[(nonHc_FSPs["__event__"] == evts[i])]
    
    if (i % 100) == 0:
        print("processing evt",i,"of",len(evts))
    
    
    num_particles = one_evt.shape[0]
    
    tmp_par_vars = []
    
    for j in range(num_particles):
        particle = one_evt.iloc[j]
        #print(particle)
        readOut_features = [particle[var] for var in nn_vars]
        tmp_par_vars.append(readOut_features)

    NN_input_features = np.array([np.array(xi) for xi in tmp_par_vars])

    # impute the nan values with -1. (check if that's logical for all values if input vars get changed)
    NN_input_features= np.nan_to_num(NN_input_features, copy=False, nan=-1.0)

    shape = NN_input_features.shape
    NN_input_features = NN_input_features.reshape(shape[0], 1, shape[1])

    #print("NN_input_features.shape:",NN_input_features.shape)
    NN_input_features = torch.Tensor(NN_input_features)
    #print(NN_input_features.shape[0])

    SA_pred = model(NN_input_features)

    probs = torch.softmax(SA_pred, dim=1)  # (N, C, d1)
    winners = probs.argmax(dim=1)
    
    
    shape = NN_input_features.shape
    r=torch.randperm(shape[0])
    shuffled_input=NN_input_features[r, :, :]
    
    shuffled_SA_pred = model(shuffled_input)

    shuffled_probs = torch.softmax(shuffled_SA_pred, dim=1)  # (N, C, d1)
    shuffled_winners = shuffled_probs.argmax(dim=1)
    #print("r:",r)
    for j in range(num_particles):
        particle = one_evt.iloc[j]
        labels.append(particle["label"].item())
        evtnum.append(particle["__event__"].item())
        online_NN_pred.append(particle["NN_prediction"].item())
        #particle["offline_NN_pred"] = winners[0,j].item()
        offline_NN_pred.append(winners[0,j].item())
        
        
        index_Shuffreversed = (r == j).nonzero(as_tuple=True)[0].item()
        #print("j:",j,"index_Shuffreversed:",index_Shuffreversed)
        offline_NN_pred_shuff.append(shuffled_winners[0,index_Shuffreversed].item())
        
        #particle["offline_NN_pred_shuffled"] = shuffled_winners[0,index_Shuffreversed].item()
        #one_evt.iloc[j] = particle
        
    #nonHc_FSPs[(nonHc_FSPs["__event__"] == evts[i])] = one_evt 


processing evt 0 of 44945
processing evt 100 of 44945
processing evt 200 of 44945
processing evt 300 of 44945
processing evt 400 of 44945
processing evt 500 of 44945
processing evt 600 of 44945
processing evt 700 of 44945
processing evt 800 of 44945
processing evt 900 of 44945
processing evt 1000 of 44945
processing evt 1100 of 44945
processing evt 1200 of 44945
processing evt 1300 of 44945
processing evt 1400 of 44945
processing evt 1500 of 44945
processing evt 1600 of 44945
processing evt 1700 of 44945
processing evt 1800 of 44945
processing evt 1900 of 44945
processing evt 2000 of 44945
processing evt 2100 of 44945
processing evt 2200 of 44945
processing evt 2300 of 44945
processing evt 2400 of 44945
processing evt 2500 of 44945
processing evt 2600 of 44945
processing evt 2700 of 44945
processing evt 2800 of 44945
processing evt 2900 of 44945
processing evt 3000 of 44945
processing evt 3100 of 44945
processing evt 3200 of 44945
processing evt 3300 of 44945
processing evt 3400 of 449

processing evt 27700 of 44945
processing evt 27800 of 44945
processing evt 27900 of 44945
processing evt 28000 of 44945
processing evt 28100 of 44945
processing evt 28200 of 44945
processing evt 28300 of 44945
processing evt 28400 of 44945
processing evt 28500 of 44945
processing evt 28600 of 44945
processing evt 28700 of 44945
processing evt 28800 of 44945
processing evt 28900 of 44945
processing evt 29000 of 44945
processing evt 29100 of 44945
processing evt 29200 of 44945
processing evt 29300 of 44945
processing evt 29400 of 44945
processing evt 29500 of 44945
processing evt 29600 of 44945
processing evt 29700 of 44945
processing evt 29800 of 44945
processing evt 29900 of 44945
processing evt 30000 of 44945
processing evt 30100 of 44945
processing evt 30200 of 44945
processing evt 30300 of 44945
processing evt 30400 of 44945
processing evt 30500 of 44945
processing evt 30600 of 44945
processing evt 30700 of 44945
processing evt 30800 of 44945
processing evt 30900 of 44945
processing

In [3]:
NN_results = pd.DataFrame({'__event__': evtnum,
                          'label' : labels,
                          'online_NN_pred' : online_NN_pred,
                          'offline_NN_pred' : offline_NN_pred,
                          'offline_NN_pred_shuff' : offline_NN_pred_shuff})

NameError: name 'evtnum' is not defined

In [None]:
NN_results.to_csv(nfs_path + "NN_results.csv")

In [9]:
NN_results = pd.read_csv(nfs_path + "NN_results.csv")

In [10]:
NN_results["off_eq_on"] = (NN_results["offline_NN_pred"] == NN_results["online_NN_pred"]).astype(int)
NN_results["off_eq_on_shuffled"] = (NN_results["offline_NN_pred_shuff"] == NN_results["online_NN_pred"]).astype(int)
NN_results["unshuff_eq_shuff"] = (NN_results["offline_NN_pred_shuff"] == NN_results["offline_NN_pred"]).astype(int)

In [11]:
NN_results["correct_pred_onlineNN"] = (NN_results["label"] == NN_results["online_NN_pred"]).astype(int)
NN_results["correct_pred_offlineNN"] = (NN_results["label"] == NN_results["offline_NN_pred"]).astype(int)
NN_results["correct_pred_offlineNN_shuff"] = (NN_results["label"] == NN_results["offline_NN_pred_shuff"]).astype(int)

In [12]:
for var in ["off_eq_on","off_eq_on_shuffled","unshuff_eq_shuff","correct_pred_onlineNN","correct_pred_offlineNN","correct_pred_offlineNN_shuff"]:
    print(var)
    print(NN_results[var].mean()*100, '% \n')

off_eq_on
81.98799066478519 % 

off_eq_on_shuffled
81.9911232872031 % 

unshuff_eq_shuff
81.89016741018985 % 

correct_pred_onlineNN
60.269234657625724 % 

correct_pred_offlineNN
60.0980795600659 % 

correct_pred_offlineNN_shuff
60.12200140398441 % 



In [13]:
off_on_pred = pd.DataFrame({'count' : NN_results.groupby( ["off_eq_on",
                    "correct_pred_onlineNN","correct_pred_offlineNN"] ).size() / NN_results.shape[0] }).reset_index()
off_on_pred

Unnamed: 0,off_eq_on,correct_pred_onlineNN,correct_pred_offlineNN,count
0,0,0,0,0.02661
1,0,0,1,0.075899
2,0,1,0,0.077611
3,1,0,0,0.294798
4,1,1,1,0.525082


## compare only the events used offline (val set) with the online results

In [14]:
# load val set

In [15]:
data_folder = "NAHSA_Gmodes_fixedD0modes/NAHS_allEvts_twoSubs_fixedD0run/NAHSA_no_xyz"
run_folder = "MC_studies/Dstlnu_Bt_generic/" #"run_HcX_globTag/" 
run_path = "/nfs/dust/belle2/user/axelheim/" + run_folder
bsize = 16
dataset_dir = Path(run_path + 'data/' + data_folder)


sys.path.insert(1, '/afs/desy.de/user/a/axelheim/private/baumbauen/notebooks/')
from ah_utils import pad_collate_fn_ah, PhasespaceSet_BranchSeparator

collate = pad_collate_fn_ah
test_set = PhasespaceSet_BranchSeparator(dataset_dir, 'val')
train_set = PhasespaceSet_BranchSeparator(dataset_dir, 'train')
import torch
train_loader = torch.utils.data.DataLoader(train_set, batch_size=bsize, drop_last=True, shuffle=True,  collate_fn=collate)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=bsize, drop_last=True, shuffle=True,  collate_fn=collate)

In [73]:
tmp_loader = torch.utils.data.DataLoader(train_set, batch_size=64, drop_last=True, shuffle=True,  collate_fn=collate)

dataiter = iter(tmp_loader)
for batch_num, batch in enumerate(dataiter):
    features, labels, tag = batch
    print("features.shape:",features.shape)    
    SA_pred = model(features)

    probs = torch.softmax(SA_pred, dim=1)  # (N, C, d1)
    winners = probs.argmax(dim=1)
    print("winners.shape:",winners.shape)

features.shape: torch.Size([22, 64, 15])
winners.shape: torch.Size([64, 22])
features.shape: torch.Size([20, 64, 15])
winners.shape: torch.Size([64, 20])
features.shape: torch.Size([20, 64, 15])
winners.shape: torch.Size([64, 20])
features.shape: torch.Size([23, 64, 15])
winners.shape: torch.Size([64, 23])
features.shape: torch.Size([20, 64, 15])
winners.shape: torch.Size([64, 20])
features.shape: torch.Size([20, 64, 15])
winners.shape: torch.Size([64, 20])
features.shape: torch.Size([22, 64, 15])
winners.shape: torch.Size([64, 22])
features.shape: torch.Size([20, 64, 15])


KeyboardInterrupt: 

In [16]:
dataiter = iter(train_loader)
evts_train=[]
for batch_num, batch in enumerate(dataiter):
    features, labels, tag = batch
    tag = np.array(tag)
    for i in range(tag.shape[0]):
        tag_tmp = str(tag[i,-1])
        event = int(tag_tmp[tag_tmp.find("evt")+3:-1])
        evts_train.append(event)

In [46]:
dataiter = iter(test_loader)
evts=[]
preds=[]
labs=[]
corrects=[]

for batch_num, batch in enumerate(dataiter):
    if batch_num % 10 == 0:
        print("#",(batch_num),"of",len(test_loader),"batches")
    features, labels, tag = batch
    tag = np.array(tag)
    
    SA_pred = model(features)

    probs = torch.softmax(SA_pred, dim=1)  # (N, C, d1)
    winners = probs.argmax(dim=1)
    #print("winners.shape:",winners.shape)
    #print("labels.shape:",labels.shape)

    
    
    for i in range(winners.shape[0]):
        tag_tmp = str(tag[i,-1])
        event = int(tag_tmp[tag_tmp.find("evt")+3:-1])
        #print("winners[i,:]:",winners[i,:])
        #print("labels[i,:]:",labels[i,:])
        for j in range(winners.shape[1]):
            evts.append(event)

            pred = winners[i,j].item()
            preds.append(pred)
            
            label = labels[i,j].item()
            labs.append(label)
            
            #print("label:",label,"winners:",pred)
            
            correct = int(label == pred)
            corrects.append(correct)
            
            

# 0 of 3367 batches
# 10 of 3367 batches
# 20 of 3367 batches
# 30 of 3367 batches
# 40 of 3367 batches
# 50 of 3367 batches
# 60 of 3367 batches
# 70 of 3367 batches
# 80 of 3367 batches
# 90 of 3367 batches
# 100 of 3367 batches
# 110 of 3367 batches
# 120 of 3367 batches
# 130 of 3367 batches
# 140 of 3367 batches
# 150 of 3367 batches
# 160 of 3367 batches
# 170 of 3367 batches
# 180 of 3367 batches
# 190 of 3367 batches
# 200 of 3367 batches
# 210 of 3367 batches
# 220 of 3367 batches
# 230 of 3367 batches
# 240 of 3367 batches
# 250 of 3367 batches
# 260 of 3367 batches
# 270 of 3367 batches
# 280 of 3367 batches
# 290 of 3367 batches
# 300 of 3367 batches
# 310 of 3367 batches
# 320 of 3367 batches
# 330 of 3367 batches
# 340 of 3367 batches
# 350 of 3367 batches
# 360 of 3367 batches
# 370 of 3367 batches
# 380 of 3367 batches
# 390 of 3367 batches
# 400 of 3367 batches
# 410 of 3367 batches
# 420 of 3367 batches
# 430 of 3367 batches
# 440 of 3367 batches
# 450 of 3367 batches

In [47]:
offline_preds_padded = pd.DataFrame({"event":evts, "label":labs,"prediction":preds,"correct":corrects})

In [48]:
# get rid of padded entries
offline_preds_padded = offline_preds_padded[offline_preds_padded["label"] != -1]

In [49]:
offline_preds_padded.shape[0]

695868

In [35]:
pd.set_option('display.float_format', lambda x: '%.3f' % x)
offline_preds_padded["corr"].describe()

count   695858.000
mean         0.750
std          0.433
min          0.000
25%          1.000
50%          1.000
75%          1.000
max          1.000
Name: corr, dtype: float64

In [19]:
on_and_offline = NN_results[NN_results['__event__'].isin(evts)]
on_and_offline_train = NN_results[NN_results['__event__'].isin(evts_train)]

In [20]:
print(on_and_offline['__event__'].nunique())
print(on_and_offline_train['__event__'].nunique())

2033
8444


In [21]:
for var in ["off_eq_on","off_eq_on_shuffled","unshuff_eq_shuff","correct_pred_onlineNN","correct_pred_offlineNN","correct_pred_offlineNN_shuff"]:
    print(var)
    print(NN_results[var].mean()*100, '%')    
    print(on_and_offline[var].mean()*100, '%')    
    print(on_and_offline_train[var].mean()*100, '% \n')

off_eq_on
81.98799066478519 %
82.99468684591854 %
83.53381218242268 % 

off_eq_on_shuffled
81.9911232872031 %
83.09450974078248 %
83.45853459632394 % 

unshuff_eq_shuff
81.89016741018985 %
83.40685879890518 %
83.2703406310771 % 

correct_pred_onlineNN
60.269234657625724 %
66.75253582353888 %
67.04880496832068 % 

correct_pred_offlineNN
60.0980795600659 %
67.04878441474803 %
66.99861991092152 % 

correct_pred_offlineNN_shuff
60.12200140398441 %
66.86201899855095 %
67.0833071952826 % 



## do it again with offline, but unpadded!

In [36]:
bsize = 1
test_loader_unpadded = torch.utils.data.DataLoader(test_set, batch_size=bsize, drop_last=True, shuffle=True,  collate_fn=collate)

In [43]:
dataiter = iter(test_loader_unpadded)
evts=[]
preds=[]
labs=[]
corrects=[]

for batch_num, batch in enumerate(dataiter):
    if batch_num % 1000 == 0:
        print("#",(batch_num),"of",len(test_loader_unpadded),"batches")
    features, labels, tag = batch
    tag = np.array(tag)
    
    SA_pred = model(features)

    probs = torch.softmax(SA_pred, dim=1)  # (N, C, d1)
    winners = probs.argmax(dim=1)
    #print("winners.shape:",winners.shape)
    #print("labels.shape:",labels.shape)

    
    
    for i in range(winners.shape[0]):
        tag_tmp = str(tag[i,-1])
        event = int(tag_tmp[tag_tmp.find("evt")+3:-1])
        #print("winners[i,:]:",winners[i,:])
        #print("labels[i,:]:",labels[i,:])
        for j in range(winners.shape[1]):
            evts.append(event)

            pred = winners[i,j].item()
            preds.append(pred)
            
            label = labels[i,j].item()
            labs.append(label)
            
            #print("label:",label,"winners:",pred)
            
            correct = int(label == pred)
            corrects.append(correct)  

# 0 of 53877 batches
# 1000 of 53877 batches
# 2000 of 53877 batches
# 3000 of 53877 batches
# 4000 of 53877 batches
# 5000 of 53877 batches
# 6000 of 53877 batches
# 7000 of 53877 batches
# 8000 of 53877 batches
# 9000 of 53877 batches
# 10000 of 53877 batches
# 11000 of 53877 batches
# 12000 of 53877 batches
# 13000 of 53877 batches
# 14000 of 53877 batches
# 15000 of 53877 batches
# 16000 of 53877 batches
# 17000 of 53877 batches
# 18000 of 53877 batches
# 19000 of 53877 batches
# 20000 of 53877 batches
# 21000 of 53877 batches
# 22000 of 53877 batches
# 23000 of 53877 batches
# 24000 of 53877 batches
# 25000 of 53877 batches
# 26000 of 53877 batches
# 27000 of 53877 batches
# 28000 of 53877 batches
# 29000 of 53877 batches
# 30000 of 53877 batches
# 31000 of 53877 batches
# 32000 of 53877 batches
# 33000 of 53877 batches
# 34000 of 53877 batches
# 35000 of 53877 batches
# 36000 of 53877 batches
# 37000 of 53877 batches
# 38000 of 53877 batches
# 39000 of 53877 batches
# 40000 of 53

In [44]:
offline_preds_unpadded = pd.DataFrame({"event":evts, "label":labs,"prediction":preds,"correct":corrects})

In [50]:
offline_preds_unpadded["correct"].mean()

0.6355035707614272

In [51]:
offline_preds_padded.shape[0]

695868

In [54]:
offline_preds_padded["correct"].mean()

0.7505575770117321

In [55]:
offline_preds_padded.shape[0]

695868

## save offline padded/unpadded results

In [56]:
offline_preds_unpadded.to_csv(nfs_path + "offline_preds_unpadded.csv")
offline_preds_padded.to_csv(nfs_path + "offline_preds_padded.csv")

In [None]:
torch.Size([numParticles, 1, numVars])

tensor([[[0.6748, 0.4451, 0.3961, 0.5804]],

        [[0.0027, 0.1085, 0.7915, 0.5823]],

        [[0.4262, 0.2239, 0.2660, 0.3856]],

        [[0.5164, 0.4531, 0.5464, 0.5993]],

        [[0.6673, 0.0881, 0.6711, 0.6094]],

        [[0.2769, 0.4824, 0.9434, 0.6696]]])


In [71]:
numParticles = 6
x = torch.rand(numParticles, 1, 4)

torch.set_printoptions(threshold=10_000)
print(x) # prints the whole tensor


pad_dim = 11

N = pad_dim 
K = pad_dim - numParticles # K zeros, N-K ones
padd_array = np.array([0] * K + [1] * (N-K))
np.random.shuffle(padd_array)
print(padd_array)

x_padded = torch.ones(pad_dim, 1, 4) * -1.

particle_counter = 0
for i in range(pad_dim):
    do_pad = padd_array[i]
    if do_pad == 1:    
        #print(i)
        #print(x[particle_counter,:,:])
        
        x_padded[i,:,:] = x[particle_counter,:,:]

        particle_counter += 1
        
print(x_padded)


tensor([[[0.0730, 0.0636, 0.4005, 0.7672]],

        [[0.7806, 0.7088, 0.5083, 0.7090]],

        [[0.5421, 0.5441, 0.0406, 0.8689]],

        [[0.0382, 0.3284, 0.9462, 0.7230]],

        [[0.4337, 0.9257, 0.2905, 0.4833]],

        [[0.9107, 0.2545, 0.5194, 0.6152]]])
[1 1 1 0 0 1 1 0 0 0 1]
tensor([[[ 0.0730,  0.0636,  0.4005,  0.7672]],

        [[ 0.7806,  0.7088,  0.5083,  0.7090]],

        [[ 0.5421,  0.5441,  0.0406,  0.8689]],

        [[-1.0000, -1.0000, -1.0000, -1.0000]],

        [[-1.0000, -1.0000, -1.0000, -1.0000]],

        [[ 0.0382,  0.3284,  0.9462,  0.7230]],

        [[ 0.4337,  0.9257,  0.2905,  0.4833]],

        [[-1.0000, -1.0000, -1.0000, -1.0000]],

        [[-1.0000, -1.0000, -1.0000, -1.0000]],

        [[-1.0000, -1.0000, -1.0000, -1.0000]],

        [[ 0.9107,  0.2545,  0.5194,  0.6152]]])


In [None]:
[1, pad_dim]

In [87]:
winners = torch.tensor([[ 0.],
        [ 1.],
        [ 2.],
        [-1.],
        [-1.],
        [ 3.],
        [ 4.],
        [-1.],
        [-1.],
        [-1.],
        [ 5.]])
winners = winners.reshape(1, pad_dim)

In [88]:
winners.shape

torch.Size([1, 11])

In [89]:
print(winners)

tensor([[ 0.,  1.,  2., -1., -1.,  3.,  4., -1., -1., -1.,  5.]])


In [91]:
num_particles = numParticles


winners_unpadded = torch.empty(1, num_particles)
particle_counter = 0
for i in range(pad_dim):
    do_pad = padd_array[i]
    if do_pad == 1:  
        winners_unpadded[0, particle_counter] = winners[0, i]
        particle_counter += 1
        
print(winners_unpadded)

winners=winners_unpadded

tensor([[0., 1., 2., 3., 4., 5.]])
