In [2]:
import pandas as pd
import numpy as np
from mrmr import mrmr_classif

In [3]:
bkg = pd.read_csv('background_train.csv')
sig = pd.read_csv('signal_train.csv')

#remove labels exclusive to one dataset
colsToRemove = set(np.concatenate([list(set(sig.columns)-set(bkg.columns)), list(set(bkg.columns)-set(sig.columns))]))


bkg.drop(columns=colsToRemove, inplace = True)
sig.drop(columns=colsToRemove, inplace = True)

bkg.columns, sig.columns

(Index(['CorsikaWeightMap.AreaSum', 'CorsikaWeightMap.Atmosphere',
        'CorsikaWeightMap.CylinderLength', 'CorsikaWeightMap.CylinderRadius',
        'CorsikaWeightMap.DiplopiaWeight', 'CorsikaWeightMap.EnergyPrimaryMax',
        'CorsikaWeightMap.EnergyPrimaryMin', 'CorsikaWeightMap.FluxSum',
        'CorsikaWeightMap.Multiplicity', 'CorsikaWeightMap.SpectralIndexChange',
        ...
        'NewAtt.DirectEllipse', 'NewAtt.DeltaZd', 'NewAtt.AbsSmooth',
        'NewAtt.emptyness', 'NewAtt.SepDevide', 'NewAtt.SPEBayVerRadius',
        'NewAtt.SplineVerRadius', 'CorsikaWeightMap.ParticleType',
        'CorsikaWeightMap.Polygonato', 'CorsikaWeightMap.PrimarySpectralIndex'],
       dtype='object', length=283),
 Index(['CorsikaWeightMap.AreaSum', 'CorsikaWeightMap.Atmosphere',
        'CorsikaWeightMap.CylinderLength', 'CorsikaWeightMap.CylinderRadius',
        'CorsikaWeightMap.DiplopiaWeight', 'CorsikaWeightMap.EnergyPrimaryMax',
        'CorsikaWeightMap.EnergyPrimaryMin', 'CorsikaWei

In [4]:
#remove non value rows
bkg.replace([np.inf, -np.inf], np.nan, inplace=True)
sig.replace([np.inf, -np.inf], np.nan, inplace=True)

In [5]:
#Columns with above 10% nan/inf get dropped all other nan/inf rows are dropped
for col in bkg.columns:
    bkgNaRatio = np.sum(bkg[col].isna())/len(bkg)
    sigNaRatio = np.sum(sig[col].isna())/len(sig)
    maxNaRatio = max(bkgNaRatio, sigNaRatio)
    if maxNaRatio > .1:
        bkg.drop(columns = col, inplace = True)
        sig.drop(columns = col, inplace = True)
bkg.dropna(inplace = True)
sig.dropna(inplace = True)

In [6]:
#Remove MC truths
truthList = ['MC', 'Weight', 'Corsika', 'I3EventHeader']
for col in bkg.columns:
    if any(sg in col for sg in truthList):
        bkg.drop(columns = col, inplace = True)
        sig.drop(columns = col, inplace = True)

In [7]:
#Combine Datasets
bkg['label'] = 0.
sig['label'] = 1.
df = pd.concat([bkg,sig])

df['label']

0        0.0
1        0.0
2        0.0
3        0.0
4        0.0
        ... 
17928    1.0
17929    1.0
17930    1.0
17931    1.0
17932    1.0
Name: label, Length: 35652, dtype: float64

In [8]:
#Feature Selection arbitrarlily choose k=33
selected_features = mrmr_classif(X=df.drop(columns=['label']), y=df['label'], K=33)

100%|███████████████████████████████████████████| 33/33 [00:02<00:00, 12.46it/s]


In [9]:
#export trainging Dataset
export_features = np.append(selected_features, 'label')
export_features

array(['LineFit_TTParams.lf_vel_z', 'HitStatisticsValues.max_pulse_time',
       'SplineMPEFitParams.rlogl', 'HitStatisticsValues.z_travel',
       'SplineMPEDirectHitsA.n_dir_strings', 'LineFit_TT.zenith',
       'NewAtt.DeltaZd', 'MuEXAngular4.zenith', 'NewAtt.SplineVerRadius',
       'SplineMPEDirectHitsA.n_dir_doms', 'MPEFitHighNoise.zenith',
       'MuEXAngular4_Sigma.value', 'SPEFit2_TT.zenith',
       'MPEFit_TTFitParams.rlogl', 'SplineMPE.zenith',
       'SplineMPEDirectHitsC.dir_track_length',
       'SplineMPEMuEXDifferential.zenith',
       'SplineMPETruncatedEnergy_SPICEMie_AllBINS_Muon.zenith',
       'NewAtt.radius', 'SplineMPECharacteristics.avg_dom_dist_q_tot_dom',
       'SplineMPETruncatedEnergy_SPICEMie_AllDOMS_Muon.zenith',
       'MPEFitHighNoiseFitParams.rlogl', 'MPEFit_TT.zenith',
       'MPEFitParaboloid.zenith',
       'SplineMPETruncatedEnergy_SPICEMie_AllBINS_MuEres.value',
       'MPEFitParaboloidFitParams.zenith',
       'SplineMPETruncatedEnergy_SPICEMie_A

In [11]:
df.sample(frac=1).to_csv('build/training_data.csv', columns = export_features, index = False)

In [16]:
df.sample(frac=1)

Unnamed: 0,HitMultiplicityValues.n_hit_strings,HitMultiplicityValues.n_hit_doms,HitMultiplicityValues.n_hit_doms_one_pulse,HitStatisticsValues.cog_z_sigma,HitStatisticsValues.min_pulse_time,HitStatisticsValues.max_pulse_time,HitStatisticsValues.q_max_doms,HitStatisticsValues.z_min,HitStatisticsValues.z_max,HitStatisticsValues.z_mean,...,NewID,label,NewAtt.radius,NewAtt.DirectEllipse,NewAtt.DeltaZd,NewAtt.AbsSmooth,NewAtt.emptyness,NewAtt.SepDevide,NewAtt.SPEBayVerRadius,NewAtt.SplineVerRadius
15705,9.0,18.0,8.0,55.547292,9887.196289,11980.652344,6.393666,-301.910004,-63.060001,-168.498334,...,2681815.0,1.0,330.360247,87.723766,0.037738,0.313742,1.0,0.927339,385.004104,326.689041
11721,17.0,40.0,32.0,47.706222,9887.921875,13640.204102,3.753440,-510.570007,-294.559998,-435.370002,...,2680223.0,0.0,328.956352,18.827214,0.050731,0.490886,1.0,1.079816,349.067492,432.851544
4218,11.0,46.0,34.0,63.393870,9818.592773,11486.043945,4.982339,-483.790009,-159.130005,-288.571087,...,719931.0,1.0,89.407762,20.340616,0.119566,0.254614,1.0,0.378553,99.318241,101.300865
16150,15.0,89.0,68.0,88.109289,9870.881836,13671.433594,13.638374,-505.410004,-94.889999,-279.197417,...,2752463.0,1.0,67.493901,48.292552,0.091172,0.170572,1.0,0.545356,95.248170,56.797587
10051,6.0,13.0,12.0,31.997275,9609.269531,11484.063477,6.181741,229.500000,378.679993,316.839999,...,2080135.0,0.0,256.560491,13.658284,0.087518,0.210672,1.0,1.385481,260.659343,213.966609
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
14968,7.0,20.0,12.0,104.786377,9874.666992,12394.272461,5.691505,-503.890015,-125.099998,-370.518500,...,2553253.0,1.0,235.084747,109.124255,0.011239,0.303759,1.0,0.428518,358.758658,239.704933
12279,28.0,88.0,71.0,50.736499,9893.904297,13777.367188,4.812295,-423.399994,-153.130005,-278.847614,...,2094265.0,1.0,104.289696,186.422953,0.001147,0.294751,1.0,0.750306,134.768399,137.089553
14924,4.0,11.0,7.0,19.716430,9527.983398,11048.561523,3.213756,435.940002,500.429993,474.468181,...,3880219.0,0.0,325.261942,39.674812,0.041947,0.282571,1.0,0.688468,338.850852,346.285226
14421,11.0,32.0,18.0,21.840135,9665.852539,12172.260742,105.892860,312.369995,500.670013,447.883751,...,3698314.0,0.0,517.681586,39.250110,0.016547,0.630587,1.0,0.170952,494.585473,458.770886
