In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
from tqdm.notebook import tqdm
from Eearly_stop import *
from sklearn.metrics import roc_auc_score, accuracy_score
import sys
import pandas as pd
import argparse
import os
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
import math
from torch.utils.data import Dataset, DataLoader
#from src_py.cpmix_original_utils import preprocess_data
from src_py.cpmix_utils import preprocess_data       ##with Bkgd
from src_py.rhorho import RhoRhoEvent
from src_py.a1a1 import A1A1Event
from src_py.a1rho import A1RhoEvent
from src_py.data_utils import read_np, EventDatasets
from src_py.process_background import convert_bkgd_raw
import train_rhorho, train_a1rho, train_a1a1
from src_py.metrics_utils import calculate_deltas_unsigned, calculate_deltas_signed
from sklearn.metrics import confusion_matrix

In [None]:
device = torch.device('cuda:7') if torch.cuda.is_available() else torch.device('cpu')
print('Using {} device'.format(device))

In [None]:
batch_size=512
types = {"nn_rhorho": train_rhorho.start,"nn_a1rho": train_a1rho.start,"nn_a1a1": train_a1a1.start}
parser = argparse.ArgumentParser(description='Train classifier')

decaymodes = ['rhorho', 'a1rho', 'a1a1']
decaymode = decaymodes[0]  ###### Change this to corresponding decaymode ############

parser.add_argument("-i", "--input", dest="IN", default='HiggsCP_data/'+decaymode+'_bkgd')
parser.add_argument("-t", "--type", dest="TYPE", choices=types.keys(), default='nn_'+ decaymode)

parser.add_argument("--num_classes", dest="NUM_CLASSES", type=int, default=21)
parser.add_argument("-l", "--layers", dest="LAYERS", type=int, help = "number of NN layers", default=6)
parser.add_argument("-s", "--size", dest="SIZE", type=int, help="NN size", default=1000)
parser.add_argument("-lambda", "--lambda", type=float, dest="LAMBDA", help="value of lambda parameter", default=0.0)
parser.add_argument("-m", "--method", dest="METHOD", choices=["A", "B", "C"], default="A")
parser.add_argument("-o", "--optimizer", dest="OPT", 
    choices=["GradientDescentOptimizer", "AdadeltaOptimizer", "AdagradOptimizer",
         "ProximalAdagradOptimizer", "AdamOptimizer", "FtrlOptimizer",
         "ProximalGradientDescentOptimizer", "RMSPropOptimizer"], default="AdamOptimizer")
parser.add_argument("-d", "--dropout", dest="DROPOUT", type=float, default=0.0)
parser.add_argument("-e", "--epochs", dest="EPOCHS", type=int, default=25)
# parser.add_argument("-f", "--features", dest="FEAT", help="Features", default="Variant-All")
# #         choices= ["Variant-All", "Variant-1.0", "Variant-1.1", "Variant-2.0", "Variant-2.1",
# #                   "Variant-2.2", "Variant-3.0", "Variant-3.1", "Variant-4.0", "Variant-4.1"])

parser.add_argument("--miniset", dest="MINISET", type=lambda s: s.lower() in ['true', 't', 'yes', '1'], default=False)
parser.add_argument("--z_noise_fraction", dest="Z_NOISE_FRACTION", type=float, default=0.0)

parser.add_argument("--delt_classes", dest="DELT_CLASSES", type=int, default=0,
                    help='Maximal distance between predicted and valid class for event being considered as correctly classified')

parser.add_argument("--unweighted", dest="UNWEIGHTED", type=lambda s: s.lower() in ['true', 't', 'yes', '1'], default=False)
parser.add_argument("--reuse_weights", dest="REUSE_WEIGHTS", type=bool, default=False)
parser.add_argument("--restrict_most_probable_angle", dest="RESTRICT_MOST_PROBABLE_ANGLE", type=bool, default=False)
parser.add_argument("--force_download", dest="FORCE_DOWNLOAD", type=bool, default=False)
parser.add_argument("--normalize_weights", dest="NORMALIZE_WEIGHTS", type=bool, default=False)


parser.add_argument("--beta",  type=float, dest="BETA", help="value of beta parameter for polynomial smearing", default=0.0)
parser.add_argument("--pol_b", type=float, dest="pol_b", help="value of b parameter for polynomial smearing", default=0.0)
parser.add_argument("--pol_c", type=float, dest="pol_c", help="value of c parameter for polynomial smearing", default=0.0)

parser.add_argument("--w1", dest="W1")
parser.add_argument("--w2", dest="W2")
parser.add_argument("--f", dest="FEAT", default="Variant-All")
parser.add_argument("--plot_features", dest="PLOT_FEATURES", choices=["NO", "FILTER", "NO-FILTER"], default="NO")
parser.add_argument("--training_method", dest="TRAINING_METHOD", choices=["soft_weights", "soft_c012s",  "soft_argmaxs", "regr_c012s", "regr_weights", "regr_argmaxs"], default="soft_weights")
parser.add_argument("--hits_c012s", dest="HITS_C012s", choices=["hits_c0s", "hits_c1s",  "hits_c2s"], default="hits_c0s")

parser.add_argument("-r", "--reprocess", dest="REPRO", type=bool, default=True)
args, unknown = parser.parse_known_args()
parser.add_argument("-bkgd", "--bkgdpath", dest="BKGDPATH", default= 'Ztt_dataset_Elz/pythia.Z_115_135.%s.1M.*.outTUPLE_labFrame')
args, unknown = parser.parse_known_args()

parser.add_argument("--label_bkgd", dest="LABEL_BKGD", type=bool, default=False)
args, unknown = parser.parse_known_args()

# Preprocessing signal samples from all the decaymodes

In [None]:
events={'nn_rhorho':'RhoRhoEvent', 'nn_a1rho':'A1RhoEvent', 'nn_a1a1':'A1A1Event'}
if args.REPRO:
#     for decaymode in tqdm(decaymodes):
    args.IN = 'HiggsCP_data/'+decaymode
    args.TYPE = 'nn_'+decaymode
    data, weights, argmaxs, perm, c012s, hits_argmaxs, hits_c012s = preprocess_data(args)
    event = eval(events[args.TYPE])(data, args)
    points = EventDatasets(event, weights, argmaxs, perm, c012s=c012s, hits_argmaxs=hits_argmaxs,  hits_c012s=hits_c012s, miniset=args.MINISET, unweighted=args.UNWEIGHTED)
    pickle.dump(points,open(args.IN+'/events_wo_background21.pk','wb'))

## with bkgd

In [None]:
events={'nn_rhorho':'RhoRhoEvent', 'nn_a1rho':'A1RhoEvent', 'nn_a1a1':'A1A1Event'}
if args.REPRO:
    args.LABEL_BKGD = True
    args.Z_NOISE_FRACTION = 0.8
    args.IN = 'HiggsCP_data/'+decaymode+'_bkgd'
    args.TYPE = 'nn_'+decaymode
    data, weights, argmaxs, perm, c012s, hits_argmaxs, hits_c012s = preprocess_data(args)
    event = eval(events[args.TYPE])(data, args)
    points = EventDatasets(event, weights, argmaxs, perm, c012s=c012s, hits_argmaxs=hits_argmaxs,  hits_c012s=hits_c012s, miniset=args.MINISET, unweighted=args.UNWEIGHTED)
    pickle.dump(points,open(args.IN+'/events_w_background.pk','wb'))

# Loading signal samples

In [None]:
points=pickle.load(open(args.IN+'/events_wo_background21.pk','rb'))

## with bkgd

In [None]:
points=pickle.load(open(args.IN+'/events_w_background.pk','rb'))

# Training NN

In [None]:
class MyDataset(Dataset):
    def __init__(self, rhorho_data_mc,rhorho_data_true,rhorho_labels_mc,rhorho_labels_true):
        self.rhorho_data_mc = torch.from_numpy(rhorho_data_mc).float().to(device)
        self.rhorho_data_true = torch.from_numpy(rhorho_data_true).float().to(device)
        
        self.rhorho_labels_mc =torch.from_numpy(rhorho_labels_mc).float().to(device)
        self.rhorho_labels_true =torch.from_numpy(rhorho_labels_true).float().to(device)
    def __getitem__(self, index):
        return self.rhorho_data_mc[index],self.rhorho_data_true[index],self.rhorho_labels_mc[index],self.rhorho_labels_true[index]
    def __len__(self):
        return min(len(self.rhorho_labels_mc),len(self.rhorho_labels_true))

In [None]:
mc_train_idx=np.random.choice(np.arange(points.train.x.shape[0]),int(points.train.x.shape[0]*0.5),replace=False)
true_train_idx=list(set(np.arange(points.train.x.shape[0]))-set(mc_train_idx))

mc_valid_idx=np.random.choice(np.arange(points.valid.x.shape[0]),int(points.valid.x.shape[0]*0.5),replace=False)
true_valid_idx=list(set(np.arange(points.valid.x.shape[0]))-set(mc_train_idx))

mc_test_idx=np.random.choice(np.arange(points.test.x.shape[0]),int(points.test.x.shape[0]*0.5),replace=False)
true_test_idx=list(set(np.arange(points.test.x.shape[0]))-set(mc_train_idx))

In [None]:
uncertainty=0.0

train_datasets = MyDataset(points.train.x[mc_train_idx], points.train.x[true_train_idx]+uncertainty*np.random.normal(0,1,size=points.train.x[true_train_idx].shape),
                          points.train.weights[mc_train_idx],points.train.weights[true_train_idx])
train_loader = DataLoader(dataset = train_datasets,batch_size = batch_size,shuffle = True)


valid_datasets = MyDataset(points.valid.x[mc_valid_idx], points.valid.x[true_valid_idx]+uncertainty*np.random.normal(0,1,size=points.valid.x[true_valid_idx].shape),
                          points.valid.weights[mc_valid_idx],points.valid.weights[true_valid_idx])
valid_loader = DataLoader(dataset = valid_datasets,batch_size = batch_size,shuffle = True)


test_datasets = MyDataset(points.test.x[mc_test_idx], points.test.x[true_test_idx]+uncertainty*np.random.normal(0,1,size=points.test.x[true_test_idx].shape),
                          points.test.weights[mc_test_idx],points.test.weights[true_test_idx])
test_loader = DataLoader(dataset = test_datasets,batch_size = batch_size,shuffle = True)

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self, num_features, num_classes, num_layers=1, size=100, lr=1e-3, drop_prob=0, inplace=False, 
                 tloss="regr_weights", activation='linear', input_noise=0.0, optimizer="AdamOptimizer"):
        super(NeuralNetwork, self).__init__()
        self.linear1 = nn.Linear(num_features,size,bias=False)
        layers = []
        for i in range(num_layers):
            layers.extend([nn.Linear(size,size,bias=False),
                           nn.BatchNorm1d(size),
                           nn.ReLU(),
                           nn.Dropout(drop_prob, inplace)
                          ])
        self.linear_relu_stack = nn.Sequential(*layers)
        self.linear2 = nn.Linear(size,num_classes,bias=False)
        self.linear3 = nn.Linear(size,2,bias=False)
    def forward(self, x):
        x = self.linear1(x)
        x = self.linear_relu_stack(x)
        out = self.linear2(x)
        return out

In [None]:
model_path = os.path.join(os.getcwd() + '/model')
if not os.path.exists(model_path): 
    os.mkdir(os.path.join(model_path))
model = NeuralNetwork(num_features=points.train.x.shape[1], num_classes=args.NUM_CLASSES+1,num_layers=args.LAYERS,drop_prob=0).to(device)
# model = NeuralNetwork(num_features=points[particle_idx].train.x.shape[1], num_classes=args.NUM_CLASSES,num_layers=args.LAYERS,drop_prob=0).to(device)
opt_g=torch.optim.Adam(model.parameters(),lr=1e-3)
criterion=nn.CrossEntropyLoss()

#early_stopping = EarlyStopping(patience=12, verbose=True,path=model_path+'/'+decaymode+'_best_model.pt')

In [None]:
def save_checkpoint(state, filename=model_path+'/'+decaymode+'_the_best_model_'+str(args.NUM_CLASSES)+str(args.FEAT)+'.pt'):
    print("=> saving checkpoint")
    torch.save(state, filename)

In [None]:
def load_checkpoint(checkpoint):
    print("=> loading checkpoint")
    model.load_state_dict(checkpoint['state_dict'])
    opt_g.load_state_dict(checkpoint['opt_g'])

In [None]:
load_checkpoint(torch.load(model_path+'/'+decaymode+'_the_best_model_'+str(args.NUM_CLASSES)+str(args.FEAT)+'.pt'))

In [None]:
#import Eearly_stop
epoch=200
training_loss=[]
validation_loss=[]
tr_pred = []
tr_true = []
v_pred = []
v_true = []

checkpoint={'state_dict':model.state_dict(), 'opt_g':opt_g.state_dict()}
save_checkpoint(checkpoint)

with open('Results/TrainingOutputs/'+decaymode+'_TrainingOutputs_'+str(args.NUM_CLASSES)+'.txt','wb') as f:
    for i in range(epoch):
        model.train()
        train_loss,sample_numbers,acc,total_samples,bg_acc=0,0,0,0,0
        for batch_idx, (rhorho_s,rhorho_t,label_s,_) in enumerate(train_loader):

            opt_g.zero_grad()
            rhorho_s=rhorho_s[label_s.sum(axis=1)!=0]
            label_s=label_s[label_s.sum(axis=1)!=0]
            outputs=model(rhorho_s)
            training_outputs=model(rhorho_s).detach().cpu()
            training_outputs=torch.softmax(torch.cat([training_outputs]),axis=1).numpy()
            training_labels=label_s.cpu().numpy()
            
            if isinstance(criterion,nn.CrossEntropyLoss):
                loss=criterion(outputs,torch.argmax(label_s,axis=1))
                _, predictions = torch.max(outputs, 1)
                acc+=(predictions==torch.argmax(label_s,axis=1)).sum().item()
            else:
                loss=criterion(outputs,label_s)
            loss.backward()
            train_loss+=loss.item()*len(rhorho_s)
            sample_numbers+=len(rhorho_s)
            opt_g.step()
            
        print('\r training loss: %.3f \t acc: %.3f \t' %(train_loss/sample_numbers,acc/sample_numbers),end='')
        training_loss.append(train_loss/sample_numbers)
        tr_pred.extend(training_outputs)
        tr_true.extend(training_labels)
        f.write(('\r training loss: %.3f \t acc: %.3f \t ' %(train_loss/sample_numbers,acc/sample_numbers)).encode())
        print()
        
        vaild_loss,vaild_acc,vaild_numbers,total_samples,bg_acc=0,0,0,0,0
        model.eval()
        with torch.no_grad():
            for batch_idx, (rhorho_s,rhorho_t,label_s,label_t) in enumerate(valid_loader):
                total_samples+=len(rhorho_t)
                rhorho_t= rhorho_t[label_t.sum(axis=1)!=0]
                label_t = label_t[label_t.sum(axis=1)!=0]
                valid_labels=label_t.cpu().numpy()
                outputs=model(rhorho_t)
                valid_outputs = model(rhorho_t).detach().cpu()
                valid_outputs=torch.softmax(torch.cat([valid_outputs]),axis=1).numpy()
                
                _, predictions = torch.max(outputs, 1)
                vaild_acc+=(predictions==torch.argmax(label_t,axis=1)).sum().item()
                vaild_numbers+=len(rhorho_t)
                if isinstance(criterion,nn.CrossEntropyLoss):
                    loss=criterion(outputs,torch.argmax(label_t,axis=1))
                else:
                    loss=criterion(output,label_t)
                vaild_loss+=loss.item()*len(rhorho_t)
                
        print()
        print('\r validation loss: %.3f \t valid acc: %.3f \t ' %(vaild_loss/vaild_numbers,vaild_acc/vaild_numbers),end='')
        f.write(('\r validation loss: %.3f \t valid acc: %.3f \t ' %(vaild_loss/vaild_numbers,vaild_acc/vaild_numbers)).encode())
        print()
        validation_loss.append(vaild_loss/vaild_numbers)
        v_pred.extend(valid_outputs)
        v_true.extend(valid_labels)
        #early_stopping(-vaild_acc/vaild_numbers,model)
        #if early_stopping.early_stop:
            #print("Early stopping")
            #f.write(("Early stopping").encode())
            #break;
            # test_loss=0
    # with torch.no_grad():
    #     for inputs, label in test_loader:
    #         outputs=model(inputs)
    #         test_loss+=mse_loss(outputs,label).item()*len(inputs)
    #     print('test loss: %f' %(test_loss/len(test_loader.dataset.tensors[0])))


In [None]:
colors=['skyblue','orange']
plt.plot(training_loss, color=colors[0],label='training loss')
plt.plot(validation_loss, color=colors[1],label='validation_loss')
plt.legend(loc='best')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.title('loss vs epoch (signal background classification) with epoches = 200')
TestResults_path = os.path.join(os.getcwd() + '/Results/TestResults/')
if not os.path.exists(TestResults_path): 
    os.mkdir(os.path.join(TestResults_path))
plt.savefig(TestResults_path+decaymode+'_loss.pdf')
plt.show()

# confusion matrix

## training

In [None]:
import seaborn as sns

In [None]:
Trues_train = np.argmax(tr_true, axis=1)
prediction_train = np.argmax(tr_pred, axis=1)

In [None]:
cf_matrix_train = confusion_matrix(Trues_train, prediction_train)
cf_matrix_train_n = cf_matrix_train.astype('float') / cf_matrix_train.sum(axis=1)[:, np.newaxis]

In [None]:
df_cm_train_n = pd.DataFrame(cf_matrix_train_n)
ax = plt.subplots(figsize=(12, 8), dpi=100)
plt.ylabel('True')
plt.xlabel('Predicted')
sns.heatmap(df_cm_train_n, annot=True, fmt='.2f', square=True, linewidths=.5, cmap="YlGnBu")
TestResults_path = os.path.join(os.getcwd() + '/Results/TestResults/')
if not os.path.exists(TestResults_path):
    os.mkdir(os.path.join(TestResults_path))
plt.savefig(TestResults_path+decaymode+'_train_confusion_matrix_norm.pdf')

## validation

In [None]:
Trues_valid = np.argmax(v_true, axis=1)
prediction_valid = np.argmax(v_pred, axis=1)

In [None]:
cf_matrix_valid = confusion_matrix(Trues_valid, prediction_valid)
cf_matrix_valid_n = cf_matrix_valid.astype('float') / cf_matrix_valid.sum(axis=1)[:, np.newaxis]

In [None]:
df_cm_valid_n = pd.DataFrame(cf_matrix_valid_n)
ax = plt.subplots(figsize=(12, 8), dpi=100)
plt.ylabel('True')
plt.xlabel('Predicted')
sns.heatmap(df_cm_valid_n, annot=True, fmt='.2f', square=True, linewidths=.5, cmap="YlGnBu")
TestResults_path = os.path.join(os.getcwd() + '/Results/TestResults/')
if not os.path.exists(TestResults_path):
    os.mkdir(os.path.join(TestResults_path))
plt.savefig(TestResults_path+decaymode+'_valid_confusion_matrix_norm.pdf')

# Preprocessing singal and bkgd from all the decaymodes

In [None]:
events={'nn_rhorho':'RhoRhoEvent', 'nn_a1rho':'A1RhoEvent', 'nn_a1a1':'A1A1Event'}
if args.REPRO:
    args.Z_NOISE_FRACTION = 1
    args.IN = 'HiggsCP_data/'+decaymode
    args.TYPE = 'nn_'+decaymode
    data, weights, argmaxs, perm, c012s, hits_argmaxs, hits_c012s = preprocess_data(args)
    event = eval(events[args.TYPE])(data, args)
    points = EventDatasets(event, weights, argmaxs, perm, c012s=c012s, hits_argmaxs=hits_argmaxs,  hits_c012s=hits_c012s, miniset=args.MINISET, unweighted=args.UNWEIGHTED)
    pickle.dump(points,open(args.IN+'/events_w_background21_test.pk','wb'))

## signal background classification

In [None]:
events={'nn_rhorho':'RhoRhoEvent', 'nn_a1rho':'A1RhoEvent', 'nn_a1a1':'A1A1Event'}
if args.REPRO:
    args.LABEL_BKGD = False
    args.Z_NOISE_FRACTION = 1
    args.IN = 'HiggsCP_data/'+decaymode+'_bkgd'
    args.TYPE = 'nn_'+decaymode
    data, weights, argmaxs, perm, c012s, hits_argmaxs, hits_c012s = preprocess_data(args)
    event = eval(events[args.TYPE])(data, args)
    points = EventDatasets(event, weights, argmaxs, perm, c012s=c012s, hits_argmaxs=hits_argmaxs,  hits_c012s=hits_c012s, miniset=args.MINISET, unweighted=args.UNWEIGHTED)
    pickle.dump(points,open(args.IN+'/events_w_background_test.pk','wb'))

# Loading bkgd samples

In [None]:
background_points=pickle.load(open(args.IN+'/events_w_background21.pk','rb'))

## signal background classification

In [None]:
background_points=pickle.load(open(args.IN+'/events_w_background_test.pk','rb'))

## loading

In [None]:
background=[]
background.append(background_points.train.x[background_points.train.weights.sum(axis=1)==0])
background.append(background_points.valid.x[background_points.valid.weights.sum(axis=1)==0])
background.append(background_points.test.x[background_points.test.weights.sum(axis=1)==0])

In [None]:
background=np.concatenate(background)
print(background.shape)

In [None]:
background=torch.tensor(background).float().to(device)

# Testing NN w/ bkgd only

In [None]:
model.eval()
with torch.no_grad():
    outputs=[]
    #for i in tqdm(range(0, 400000,batch_size)):
    for i in tqdm(range(500000, 1000000,batch_size)):
        outputs.append(model(background[i:i+batch_size]).detach().cpu())
outputs=torch.cat(outputs)

bg_outputs=torch.argmax(torch.softmax(outputs,axis=1),axis=1).numpy()
bg_labels_counts=np.unique(bg_outputs,return_counts=True)[1]

In [None]:
pickle.dump(bg_outputs,open(args.IN+'/NN_outputs_background_only.pk','wb'))

# Testing NN w/ signal only (Class 0)

In [None]:
model.eval()
with torch.no_grad():
    signal_outputs,signal_labels=[],[]
    for batch_idx, (rhorho_s,rhorho_t,label_s,_) in enumerate(train_loader):
        signal_outputs.append(model(rhorho_s).detach().cpu())
        signal_labels.append(label_s.detach().cpu())
signal_outputs=torch.softmax(torch.cat(signal_outputs),axis=1).numpy()
signal_labels=np.concatenate(signal_labels)

####### Filtering signal outputs that are classified to Class 0
signal_outputs=signal_outputs[np.argmax(signal_labels,axis=1)==0]
signal_labels=np.argmax(signal_outputs,axis=1)

In [None]:
pickle.dump(signal_outputs,open(args.IN+'/NN_outputs_signal_only.pk','wb'))

# Test Results

In [None]:
bg_outputs = pickle.load(open(args.IN+'/NN_outputs_background_only.pk','rb'))
signal_outputs= pickle.load(open(args.IN+'/NN_outputs_signal_only.pk','rb'))

In [None]:
# create the dataframe; enumerate is used to make column names
columns=['Ztt','Signal']
fig,ax=plt.subplots(dpi=150)

df = pd.concat([pd.DataFrame(a, columns=[columns[i]]) for i, a in enumerate([bg_outputs, np.argmax(signal_outputs,axis=1)], 0)], axis=1)
# plot the data
#ax.set_xlim(0,args.NUM_CLASSES-1)
ax.set_xlim(0,args.NUM_CLASSES)
#ax = df.plot.hist(stacked=True, bins=args.NUM_CLASSES-1,ax=ax, color = ['skyblue','red']).get_figure()
ax = df.plot.hist(stacked=True, bins=args.NUM_CLASSES,ax=ax, color = ['skyblue','red']).get_figure()
plt.xlabel("Classes")
plt.ylabel("Events")
# ax.set_xticks(np.arange(args.NUM_CLASSES-1))
# ax.set_xticklabels((np.linspace(0,2,args.NUM_CLASSES-1)*np.pi))
#bars = ax.patches
# hatches = ['/','\\']

# for i in range(2):
#     for j in range(args.NUM_CLASSES-1):
#         bars[i*(args.NUM_CLASSES-1)+j].set_hatch(hatches[i])

In [None]:
TestResults_path = os.path.join(os.getcwd()) + '/Results/TestResults/'
if not os.path.exists(TestResults_path): 
    os.mkdir(os.path.join(TestResults_path))
ax.savefig(TestResults_path+decaymode+'_TestResults.pdf')

In [None]:
signal_df = pd.DataFrame(np.argmax(signal_outputs, axis=1), columns=[columns[1]])
bkgd_df = pd.DataFrame(bg_outputs, columns=[columns[0]])
signal_df = signal_df.groupby('Signal').size().to_frame('SgCounts').reset_index().rename({'Signal':'Class'},axis=1)
bkgd_df = bkgd_df.groupby('Ztt').size().to_frame('BgCounts').reset_index().rename({'Ztt':'Class'},axis=1)
total = pd.concat([signal_df,bkgd_df['BgCounts']],axis = 1) 

In [None]:
import seaborn as sns
import matplotlib.patches as mpatches

fig, axis = plt.subplots(2,1,figsize=(6,10),dpi=150)
fig.suptitle( decaymode + " Test Results", fontsize=15)
axis[0].set_title('Signal + Background Results', fontsize=11)
axis[1].set_title('Background Results with Error Bar', fontsize=11)
# set plot style: grey grid in the background:
sns.set(style="dark")

total[['BgCounts','SgCounts']].plot(kind="bar", ax = axis[0],stacked=True,color = ['skyblue','r']).get_figure()

# add legend
top_bar = mpatches.Patch(color='r', label='Signal')
bottom_bar = mpatches.Patch(color='skyblue', label='Ztt')
axis[0].legend(handles=[top_bar, bottom_bar])

for i in range(len(axis)): 
    axis[i].set_xlabel("Classes",fontsize=10)
    axis[i].set_ylabel("Events",fontsize=10)


Poisson_std = [math.sqrt(i) for i in total['BgCounts'].to_numpy()]
total[['BgCounts']].plot(kind="bar", ax = axis[1],stacked=True,color = ['skyblue','r'], yerr = Poisson_std, alpha = 1)
# ax.bar(x_pos, CTEs, yerr=error, align='center', alpha=0.5, ecolor='black', capsize=10)
axis[1].legend(handles=[bottom_bar])

# show the graph
plt.show()

In [None]:
TestResults_path = os.path.join(os.getcwd() + '/Results/TestResults/')
if not os.path.exists(TestResults_path): 
    os.mkdir(os.path.join(TestResults_path))
fig.savefig(TestResults_path+decaymode+'_TestResults2.pdf')

In [None]:
columns=['Signal']
fig,ax=plt.subplots(dpi=150)

df = pd.concat([pd.DataFrame(a, columns=[columns[i]]) for i, a in enumerate([np.argmax(signal_outputs,axis=1)], 0)], axis=1)
# plot the data
ax.set_xlim(0,args.NUM_CLASSES)
ax = df.plot.hist(stacked=True, bins=args.NUM_CLASSES,ax=ax, color = ['red']).get_figure()
plt.xlabel("Classes")
plt.ylabel("Events")

In [None]:
TestResults_path = os.path.join(os.getcwd()) + '/Results/TestResults/'
if not os.path.exists(TestResults_path): 
    os.mkdir(os.path.join(TestResults_path))
ax.savefig(TestResults_path+decaymode+'_TestResults3.pdf')

In [None]:
columns=['Ztt','Signal']
fig,ax=plt.subplots(dpi=150)

df = pd.concat([pd.DataFrame(a, columns=[columns[i]]) for i, a in enumerate([bg_outputs], 0)], axis=1)
# plot the data
ax.set_xlim(1,args.NUM_CLASSES-1)
ax.set_ylim(0,1300)
ax = df.plot.hist(stacked=True, bins=args.NUM_CLASSES,ax=ax, color = ['skyblue']).get_figure()
plt.xlabel("Classes")
plt.ylabel("Events")

In [None]:
TestResults_path = os.path.join(os.getcwd()) + '/Results/TestResults/'
if not os.path.exists(TestResults_path): 
    os.mkdir(os.path.join(TestResults_path))
ax.savefig(TestResults_path+decaymode+'_TestResults4.pdf')

In [None]:
print(bkgd_df)
print(signal_df)

## confusion matrix for testing

In [None]:
t_pred = []
t_true = []
model.eval()
with torch.no_grad():
    for batch_idx, (rhorho_s,rhorho_t,label_s,label_t) in enumerate(test_loader):
        
        
        outputs = model(rhorho_t).detach().cpu()
        outputs=torch.softmax(torch.cat([outputs]),axis=1).numpy()
        t_pred.extend(outputs)
        
        labels=label_t.cpu().numpy()
        t_true.extend(labels)
     

In [None]:
Trues_test = np.argmax(t_true, axis=1)
prediction_test = np.argmax(t_pred, axis=1)

In [None]:
cf_matrix_test = confusion_matrix(Trues_test, prediction_test)
cf_matrix_test_n = cf_matrix_test.astype('float') / cf_matrix_test.sum(axis=1)[:, np.newaxis]

In [None]:
df_cm_test_n = pd.DataFrame(cf_matrix_test_n)
ax = plt.subplots(figsize=(12, 8), dpi=100)
plt.ylabel('True')
plt.xlabel('Predicted')
sns.heatmap(df_cm_test_n, annot=True, fmt='.2f', square=True, linewidths=.5, cmap="YlGnBu")
TestResults_path = os.path.join(os.getcwd() + '/Results/TestResults/')
if not os.path.exists(TestResults_path):
    os.mkdir(os.path.join(TestResults_path))
plt.savefig(TestResults_path+decaymode+'_test_confusion_matrix_norm.pdf')