In [None]:
from IPython.display import clear_output
!pip install torch==2.2.0
!pip install torch_geometric
!pip install particle
!pip install pennylane
!pip install torchdata==0.7.1
!pip install torchvision==0.17.0
!pip install qiskit==0.46.0
!pip install torchquantum
!pip install qiskit-ibm-runtime==0.18.0
!pip install qiskit-aer==0.13.2
!pip install dgl -f https://data.dgl.ai/wheels/cu121/repo.html
!pip install dglgo -f https://data.dgl.ai/wheels-test/repo.html
!pip install energyflow
clear_output()

In [None]:
!pip list

In [None]:
import torch
import torchdata
import torch.nn as nn
import torch.nn.functional as F
import torchquantum as tq
from torchquantum.layer.entanglement.op2_layer import Op2QAllLayer
from torchquantum.layer.layers.layers import Op1QAllLayer, Op2QAllLayer
from torchquantum.measurement import measure

import numpy as np
import os
from tqdm import tqdm
import scipy
import warnings
import dgl
from dgl.data import DGLDataset
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler

from sklearn import metrics
from sklearn.preprocessing import normalize

import scipy.sparse as sp
import csv
import time
import pandas as pd
from collections import OrderedDict
from functools import partial
import pickle
import multiprocessing
import joblib

import torch_geometric
from torch_geometric.nn import global_mean_pool
from torch_geometric.utils import add_self_loops, degree, softmax
import torch.optim as optim

from copy import deepcopy
import gc

from particle import Particle
import pennylane as qml
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
### Download the jets along with storing them in a folder for future use, eliminating the need to download them again
# main_dir = ''

# import energyflow
# data = energyflow.qg_jets.load(num_data=20000, pad=True, ncol=4, generator='pythia',
#                         with_bc=False, cache_dir=main_dir+'/energyflow')

In [None]:
jet_file_path = ''
data = np.load(jet_file_path)

In [None]:
### Reference - https://github.com/ML4SCI/QMLHEP/blob/main/Quantum_GNN_for_HEP_Roy_Forestano/utils/preprocess.py

def preprocess_fixed_nodes(x_data,y_data,nodes_per_graph=10): #,masses):
    print('--- Finding All Unique Particles ---')
    unique_particles = np.unique(x_data[:,:,3])
    x_data = torch.tensor(x_data)
    y_data = torch.tensor(y_data)
    print()
    print('--- Inserting Masses ---')
    masses = torch.zeros((x_data.shape[0],x_data.shape[1]))
    for i,particle in tqdm(enumerate(unique_particles)):
        if particle!=0:
            mass = Particle.from_pdgid(particle).mass/1000
            inds = torch.where(particle==x_data[:,:,3])
            masses[inds]=mass # GeV
    print()
    print('--- Calculating Momenta and Energies ---')
    #theta = torch.arctan(torch.exp(-X[:,:,1]))*2 # polar angle
    pt        = x_data[:,:,0]     # transverse momentum
    rapidity  = x_data[:,:,1]     # rapidity
    phi       = x_data[:,:,2]     # azimuthal angle

    mt        = (pt**2+masses**2).sqrt() # Transverse mass
    energy    = mt*torch.cosh(rapidity) # Energy per multiplicity bin
    e_per_jet = energy.sum(axis=1)  # total energy per jet summed across multiplicity bins

    px = pt*torch.cos(phi)  # momentum in x
    py = pt*torch.sin(phi)  # momentum in y
    pz = mt*torch.sinh(rapidity)  # momentum in z

    # three momentum
    p  = torch.cat(( px[:,:,None],
                     py[:,:,None],
                     pz[:,:,None]), dim=2 )

    p_per_jet        = (p).sum(axis=1)  # total componet momentum per jet
    pt_per_Mbin      = (p_per_jet[:,:2]**2).sum(axis=1).sqrt()  # transverse momentum per jet
    mass_per_jet     = (e_per_jet**2-(p_per_jet**2).sum(axis=1)).sqrt() # mass per jet
    rapidity_per_jet = torch.log( (e_per_jet+p_per_jet[:,2])/(e_per_jet-p_per_jet[:,2]) )/2  # rapidity per jet from analytical formula
    end_multiplicity_indx_per_jet = (pt!=0).sum(axis=1).int() # see where the jet (graph) ends

    x_data = torch.cat( ( x_data[:,:,:3],
                          x_data[:,:,4:],
                          masses[:,:,None],
                          energy[:,:,None],
                          p), dim=2)

    x_data_max = (x_data.max(dim=1).values).max(dim=0).values
    x_data = x_data/x_data_max

    print()
    print('--- Calculating Edge Tensors ---')
    N = x_data[:,0,3].shape[0]  # number of jets (graphs)
    M = nodes_per_graph #x_data[0,:,3].shape[0]  # number of max multiplicty
    connections = nodes_per_graph
    edge_tensor = torch.zeros((N,M,M))
    edge_indx_tensor = torch.zeros((N,2,connections*(connections-1) )) # M*(connections-1) is the max number of edges we allow per jet
    edge_attr_matrix = torch.zeros((N,connections*(connections-1),1))
#     fixed_edges_list = torch.tensor([ [i,j] for i in range(connections) for j in range(connections) if i!=j]).reshape(2,90)

    for jet in tqdm(range(N)):
        stop_indx = end_multiplicity_indx_per_jet[jet] #connections # stop finding edges once we hit zeros -> when we hit 10
        if end_multiplicity_indx_per_jet[jet]>=connections:
            for m in range(connections):
#                 inds_edge = np.argsort((energy[jet,m]+energy[jet,:stop_indx])**2-torch.sum((p[jet,m,:stop_indx]+p[jet,:stop_indx,:])**2,axis=1))[:connections]
#                 edge_tensor[jet,m,:] = (energy[jet,m]+energy[jet,:connections])**2-torch.sum((p[jet,m,:]+p[jet,:connections,:])**2,axis=1)
#                 edge_tensor[jet,m,m] = 0.
#                 edge_tensor[jet,m,m]=((energy[jet,m]+energy[jet,m])**2-torch.sum((p[jet,m,:]+p[jet,m,:])**2,axis=0))
                # inds_edge = torch.sqrt( (phi[jet,m]-phi[jet,:])**2 + (rapidity[jet,m]-rapidity[jet,:])**2 ).argsort()[:connections]
                # edge_tensor[jet,m,:] = torch.sqrt( (phi[jet,m]-phi[jet,inds_edge])**2 + (rapidity[jet,m]-rapidity[jet,inds_edge])**2 )
                edge_tensor[jet,m,:] = torch.sqrt( (phi[jet,m]-phi[jet,:connections])**2 + (rapidity[jet,m]-rapidity[jet,:connections])**2 )
#                 inds_edge = np.argsort( (energy[jet,m]+energy[jet,:stop_indx])**2-torch.sum((p[jet,m,:stop_indx]+p[jet,:stop_indx,:])**2,axis=1) )[:connections]
#                 edge_tensor[jet,m,inds_edge] = (energy[jet,m]+energy[jet,inds_edge])**2-torch.sum((p[jet,m,:]+p[jet,inds_edge,:])**2,axis=1)
            edges_exist_at = torch.where(edge_tensor[jet,:,:].abs()>0)

#             edge_indx_tensor[jet,:,:(edge_tensor[jet,:,:].abs()>0).sum()] = fixed_edges_list
            edge_indx_tensor[jet,:,:(edge_tensor[jet,:,:].abs()>0).sum()] = torch.cat((edges_exist_at[0][None,:],edges_exist_at[1][None,:]),dim=0).reshape((2,edges_exist_at[0].shape[0]))
            edge_attr_matrix[jet,:(edge_tensor[jet,:,:].abs()>0).sum(),0]  =  edge_tensor[jet,edges_exist_at[0],edges_exist_at[1]].flatten()

    end_edges_indx_per_jet = (edge_attr_matrix!=0).sum(axis=1).int()
    keep_inds =  torch.where(end_edges_indx_per_jet>=connections)[0]

    edge_tensor = edge_tensor/edge_tensor.max()
    edge_attr_matrix = edge_attr_matrix/edge_attr_matrix.max()

    graph_help = torch.cat( ( (energy.max(axis=1).values/e_per_jet).reshape(x_data[:,0,3].shape[0],1),
                              (mass_per_jet).reshape(x_data[:,0,3].shape[0],1),
                              (end_multiplicity_indx_per_jet).reshape(x_data[:,0,3].shape[0],1).int(),
                              (end_edges_indx_per_jet).reshape(x_data[:,0,3].shape[0],1).int() ), dim=1)

    return x_data[keep_inds,:nodes_per_graph], y_data[keep_inds].long(), edge_tensor[keep_inds], edge_indx_tensor[keep_inds].long(), edge_attr_matrix[keep_inds], graph_help[keep_inds], masses

In [None]:
### Reference - https://github.com/bmdillon/JetCLR/blob/main/scripts/modules/jet_augs.py

def distort_jets( batch, strength=0.1, pT_clip_min=0.1 ):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    dim 1 ordering: (pT, eta, phi)
    Output: batch of jets with each constituents position shifted independently, shifts drawn from normal with mean 0, std strength/pT, same shape as input
    '''
    pT = batch[:,0]   # (batchsize, n_constit)
    shift_eta = np.nan_to_num( strength * np.random.randn(batch.shape[0], batch.shape[2]) / pT.clip(min=pT_clip_min), posinf = 0.0, neginf = 0.0 )# * mask
    shift_phi = np.nan_to_num( strength * np.random.randn(batch.shape[0], batch.shape[2]) / pT.clip(min=pT_clip_min), posinf = 0.0, neginf = 0.0 )# * mask
    shift = np.stack( [ np.zeros( (batch.shape[0], batch.shape[2]) ), shift_eta, shift_phi ], 1)
    return batch + shift

def collinear_fill_jets( batch ):
    '''
    Input: batch of jets, shape (batchsize, 3, n_constit)
    dim 1 ordering: (pT, eta, phi)
    Output: batch of jets with collinear splittings, the function attempts to fill as many of the zero-padded args.nconstit
    entries with collinear splittings of the constituents by splitting each constituent at most once, same shape as input
    '''
    batchb = batch.copy()
    nc = batch.shape[2]
    nzs = np.array( [ np.where( batch[:,0,:][i]>0.0)[0].shape[0] for i in range(len(batch)) ] )

    for k in range(len(batch)):
        nzs1 = np.max( [ nzs[k], int(nc/2) ] )
        zs1 = int(nc-nzs1)
        els = np.random.choice( np.linspace(0,nzs1-1,nzs1), size=zs1, replace=False )
        rs = np.random.uniform( size=zs1 )
        for j in range(zs1):
            batchb[k,0,int(els[j])] = rs[j]*batch[k,0,int(els[j])]
            batchb[k,0,int(nzs[k]+j)] = (1-rs[j])*batch[k,0,int(els[j])]
            batchb[k,1,int(nzs[k]+j)] = batch[k,1,int(els[j])]
            batchb[k,2,int(nzs[k]+j)] = batch[k,2,int(els[j])]

    return batchb

In [None]:
class QuarkGluonGraphDataset(dgl.data.dgl_dataset.DGLDataset):

  def __init__(self, dataset_name, raw_dir, save_dir, data_folder_name, datafile_name, labelsfile_name, datatype='particles', dataset_size=12500,
               nodes_per_graph = 5, spectral_augmentation=False, irc_safety_aug=False, url=None, hash_key=..., force_reload=False, verbose=False, transform=None,
              device='cpu'):
    self.data_folder = data_folder_name
    self.datafile_name = datafile_name
    self.labelsfile_name = labelsfile_name
    self.datatype = datatype
    self.nodes_per_graph = nodes_per_graph
    self.spectral_augmentation = spectral_augmentation
    self.drop_ra_nodes = False
    self.drop_cp_nodes = False
    self.aug_ratio = None
    self.irc_safety_aug = irc_safety_aug
    self.device = device
    self.dataset_size = dataset_size
    self.augment = False
    self.nodes_per_aug_graph = None
    super().__init__(dataset_name, url, raw_dir, save_dir, hash_key, force_reload, verbose, transform)

  @property
  def data_folder_name(self):
    return self.data_folder

  @property
  def raw_path(self):
    return os.path.join(self.raw_dir, self.data_folder_name)

  @property
  def save_path(self):
    return os.path.join(self.save_dir, self.data_folder_name)

  @property
  def graph_path(self):
    return os.path.join(self.save_path, 'graphs_and_labels')

  @property
  def info_path(self):
    return os.path.join(self.save_path, 'graphs_and_labels')

  def load(self):
    graphs, label_dict = dgl.load_graphs(str(self.graph_path))
    info_dict = dgl.data.utils.load_info(str(self.info_path))

    self.graph_lists = graphs
    self.graph_labels = label_dict["labels"]
    self.max_num_node = info_dict["max_num_node"]
    self.num_labels = info_dict["num_labels"]

  # def save(self,):
  #   label_dict = {"labels": self.graph_labels}
  #   info_dict = {
  #           "max_num_node": self.max_num_node,
  #           "num_labels": self.num_labels,
  #       }
  #   dgl.save_graphs(str(self.graph_path), self.graph_lists, label_dict)
  #   dgl.data.utils.save_info(str(self.info_path), info_dict)

  def process(self,):
    data = np.load(os.path.join(self.raw_path, self.datafile_name))
    X = data['X']
    y = data['y']
    X_l, y_l = [], []
    i = 0

    while len(X_l)!=self.dataset_size:
        if np.unique(X[i].sum(axis=1).nonzero()).shape[0] >= self.nodes_per_graph:
            sorted_inds = np.argsort(X[i,:,0])[::-1]
            x = X[i][sorted_inds]
            X_l.append(x[:self.nodes_per_graph, :])
            y_l.append(y[i])
        i += 1
    X = np.array(X_l)
    y = np.array(y_l)


    if self.datatype == 'particles':
      self.graph_lists = []
      self.rationale_augmented_graph_lists_1 = []
      self.rationale_augmented_graph_lists_2 = []
      self.complement_augmented_graph_lists = []
      x_data_proc, y_data_proc, edge_tensor, edge_indx_tensor, edge_attr_matrix, graph_help, masses = preprocess_fixed_nodes(X,y,nodes_per_graph = self.nodes_per_graph) #,masses[:N])
      self.max_num_node = x_data_proc.shape[1]
      self.graph_labels = y_data_proc
      self.num_labels = y_data_proc.shape[0]

      print('--- Creating graphs ---')
      for i in tqdm(range(x_data_proc.shape[0])):
        g = dgl.graph((edge_indx_tensor[i][0], edge_indx_tensor[i][1]))
        g.ndata['node_attr'] = x_data_proc[i]
        g.ndata['node_indices'] = torch.arange(x_data_proc[i].shape[0]).reshape(-1,1)
        g.ndata['node_mass'] = masses[i][:self.nodes_per_graph]
        g.edata['edge_attr'] = edge_attr_matrix[i].view(-1,)
        g.to(self.device)
        self.graph_lists.append(g)
        self.rationale_augmented_graph_lists_1.append(g)
        self.rationale_augmented_graph_lists_2.append(g)
        self.complement_augmented_graph_lists.append(g)

      if self.spectral_augmentation:
        self.spectral_graph_lists = []
        print('--- Creating spectral graphs ---')
        for i in tqdm(range(x_data_proc.shape[0])):
          g = SpectralGraph((edge_indx_tensor[i][0], edge_indx_tensor[i][1]), theta=0.1, delta_origin=0.05, edge_weights_matrix=edge_tensor[i])
          g.ndata['node_attr'] = x_data_proc[i]
          g.edata['edge_attr'] = edge_attr_matrix[i].view(-1,)
          self.spectral_graph_lists.append(g)
        # print(self.graph_lists)

      if self.irc_safety_aug:
        for idx in range(len(self.graph_lists)):
          g = self.graph_lists[idx]
          g.ndata['node_attr_irc'] = g.ndata['node_attr'].clone()
          if self.device=='cuda':
            g.ndata['node_attr_irc'][:,:3] = torch.Tensor(distort_jets(collinear_fill_jets(g.ndata['node_attr'][:,:3].T.unsqueeze(0).cpu().numpy()))).squeeze(0).T.cuda()
          else:
            g.ndata['node_attr_irc'][:,:3] = torch.Tensor(distort_jets(collinear_fill_jets(g.ndata['node_attr'][:,:3].T.unsqueeze(0).numpy()))).squeeze(0).T
          pt, rapidity, phi = g.ndata['node_attr_irc'][:, 0], g.ndata['node_attr_irc'][:, 1], g.ndata['node_attr_irc'][:, 2]
          mt = (pt**2+g.ndata['node_mass']**2).sqrt()
          energy = mt*torch.cosh(rapidity)
          px, py, pz = pt*torch.cos(phi), pt*torch.sin(phi), mt*torch.sinh(rapidity)
          g.ndata['node_attr_irc'][:,3] =  mt
          g.ndata['node_attr_irc'][:,4] = energy
          g.ndata['node_attr_irc'][:,5] = px
          g.ndata['node_attr_irc'][:,6] = py
          g.ndata['node_attr_irc'][:,7] = pz

  def has_cache(self):
    if os.path.exists(self.graph_path) and os.path.exists(self.info_path):
      return True
    return False

  def __len__(self,):
    return len(self.graph_lists)

  def augment_dataset(self, type, batched_graph, batch_size):
    self.augment = True

    if type == 'rationale':
      return drop_nodes_prob_batch(batched_graph, batch_size), drop_nodes_prob_batch(batched_graph, batch_size)

    if type == 'complement':
      return drop_nodes_cp_batch(batched_graph, batch_size)

  def __getitem__(self, idx):
    if self.spectral_augmentation:
      g1 = self.graph_lists[idx]
      g2 = self.spectral_graph_lists[idx]
      if self._transform is not None:
        g1 = self._transform(g1)
        g2 = self._transform(g2)
      return g1, g2, self.graph_labels[idx]

    else:
      g = self.graph_lists[idx]
      if self._transform is not None:
        g = self._transform(g)
      return g, self.graph_labels[idx]

  @property
  def num_classes(self):
    return int(self.num_labels)

In [None]:
main_dir = ''
jet_folder_path = ''

qg_dataset = QuarkGluonGraphDataset(dataset_name='Quark Gluon', raw_dir=main_dir, save_dir='/content',
                                    data_folder_name=jet_folder_path, datafile_name=jet_file_path, labelsfile_name=jet_file_path,
                                    datatype='particles', dataset_size=10000, nodes_per_graph=10, spectral_augmentation=False, irc_safety_aug=True,
                                   device='cuda')

In [None]:
import run
import torch.multiprocessing as mp

if __name__ == "__main__":
    world_size = 4
    mp.start_processes(run.run, args=(world_size, 50, 'ParticleNet', 10), nprocs=world_size, start_method="spawn")

In [None]:
import torch
checkpoint = torch.load(main_dir+'/particle_net_model.checkpoint', weights_only=True)

trained_gnn_state = dict()
for k in checkpoint.keys():
    if 'module.gnn.pn' in k:
        new_k = k.replace('module.gnn.pn.', 'gnn.')
        trained_gnn_state[new_k] = checkpoint[k]

In [None]:
from run import ParticleNetTagger1Path

gnn = ParticleNetTagger1Path(8,2)
gnn.load_state_dict(trained_gnn_state)

In [None]:
nodes_per_graph_original = 10
test_dataset = QuarkGluonGraphDataset(dataset_name='Quark Gluon', raw_dir=main_dir, save_dir='/content',
                                    data_folder_name=jet_folder_path, datafile_name=jet_file_path, labelsfile_name=jet_file_path,
                                    datatype='particles', dataset_size=7000, nodes_per_graph=nodes_per_graph_original, spectral_augmentation=False)

In [None]:
test_samples = torch.tensor(np.arange(2000,7000).astype('int32'))
test_sampler = SubsetRandomSampler(test_samples)

test_dataloader = test_dataloader = GraphDataLoader(
    test_dataset, sampler=test_sampler, batch_size=500, drop_last=False
)

In [None]:
cls_embds = torch.Tensor([])
cls_labels = torch.Tensor([])

for batched_graph, labels in test_dataloader:
  graphs = []
  unbatched_graph = dgl.unbatch(batched_graph)
  for graph in unbatched_graph:
    graphs.append(dgl.add_self_loop(graph))
  batched_graph = dgl.batch(graphs)
  batch_t = torch.arange(0, batched_graph.batch_size).reshape(-1,1).expand(batched_graph.batch_size, test_dataset.nodes_per_graph).reshape(-1,)

  ## For custom GNN
  # cls_emb = gnn.forward(batched_graph.ndata["node_attr"].float(), torch.stack(batched_graph.edges()), batched_graph.edata["edge_attr"].float())

  ## For ParticleNet
  pf_feats = batched_graph.ndata["node_attr"].reshape(len(unbatched_graph), nodes_per_graph_original, -1).float()
  points = pf_feats[:,:,1:3]
  cls_emb = gnn.forward(points.reshape(points.shape[0], points.shape[2], points.shape[1])
                     , pf_feats.reshape(pf_feats.shape[0], pf_feats.shape[2], pf_feats.shape[1]), None)
  cls_emb = cls_emb.reshape(cls_emb.shape[0], cls_emb.shape[2], cls_emb.shape[1])
  cls_emb = cls_emb.reshape(cls_emb.shape[0]*cls_emb.shape[1], cls_emb.shape[2])

  # cls_emb = batched_graph.ndata["node_attr"].float()
  cls_emb = global_mean_pool(cls_emb, batch_t)
  cls_embds = torch.cat((cls_embds, cls_emb.detach()), 0)     #cls_emb
  cls_labels = torch.cat((cls_labels, labels))

In [None]:
cls_epochs = 1000
cls_train_data = cls_embds[ : int(0.8*len(cls_embds))]
targets = cls_labels[ : int(0.8*len(cls_embds))]
cls_test_data = cls_embds[int(0.8*len(cls_embds)) : ]
testtargets = cls_labels[int(0.8*len(cls_embds)) : ]

In [None]:
### Reference - https://github.com/sdogsq/LorentzNet-release/blob/main/scripts/QGTaggingROC/ROC.py

# Function that takes the labels and score of the positive class
# (top class) and returns a ROC curve, as well as the signal efficiency
# and background rejection at a given targe signal efficiency, defaults
# to 0.3

def buildROC(labels, score, targetEff=[0.3,0.5]):
    if not isinstance(targetEff, list):
        targetEff = [targetEff]
    fpr, tpr, threshold = metrics.roc_curve(labels, score)
    idx = [np.argmin(np.abs(tpr - Eff)) for Eff in targetEff]
    eB, eS = fpr[idx], tpr[idx]
    return fpr, tpr, threshold, eB, eS

### Reference - https://github.com/bmdillon/JetCLR/blob/main/scripts/modules/perf_eval.py

def find_nearest( array, value ):
    array = np.asarray( array )
    idx = ( np.abs( array-value ) ).argmin()
    return array[idx]

def get_perf_stats( labels, measures ):
    measures = np.nan_to_num( measures )
    auc = metrics.roc_auc_score( labels, measures )
    fpr,tpr,thresholds = metrics.roc_curve( labels, measures )
    fpr2 = [ fpr[i] for i in range( len( fpr ) ) if tpr[i]>=0.5]
    tpr2 = [ tpr[i] for i in range( len( tpr ) ) if tpr[i]>=0.5]
    try:
        imtafe = np.nan_to_num( 1 / fpr2[ list( tpr2 ).index( find_nearest( list( tpr2 ), 0.5 ) ) ] )
    except:
        imtafe = 1
    return auc, imtafe

In [None]:
classifier = torch.nn.Sequential(
    torch.nn.Linear(128,1),
    torch.nn.Sigmoid()
)


def train_classifier(cls_epochs, classifier, cls_train_data, labels, cls_test_data, testlabels):
  optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08)
  cls_loss = []
  cls_accuracy = []
  test_cls_loss = []
  test_cls_accuracy = []

  for ce in range(cls_epochs):

    correct = 0
    total = 0
    t_correct = 0
    t_total = 0
    optimizer.zero_grad()
    outputs = classifier(cls_train_data)
    loss = torch.nn.BCELoss()(outputs,labels.view(-1,1))
    loss.backward()
    optimizer.step()
    predicted = np.round(outputs.cpu().detach().numpy())
    total += labels.size(0)
    correct += np.sum(torch.eq(torch.Tensor(predicted), labels.view(-1,1)).cpu().detach().numpy())
    accuracy = 100 * correct / total
    cls_loss.append(loss.cpu().detach().numpy())
    cls_accuracy.append(accuracy)
    testoutputs = classifier(cls_test_data)
    testloss = torch.nn.BCELoss()(testoutputs, testlabels.view(-1,1))
    testpredicted = np.round(testoutputs.cpu().detach().numpy())
    t_total += testlabels.size(0)
    t_correct += np.sum(torch.eq(torch.Tensor(testpredicted), testlabels.view(-1,1)).cpu().detach().numpy())
    testaccuracy = 100 * t_correct / t_total
    test_cls_loss.append(testloss.cpu().detach().numpy())  #np.mean(e_loss)
    test_cls_accuracy.append(testaccuracy)
    if epochs % 50 == 0:
      print(f'Epochs : {ce} ; Loss : {loss.cpu().detach().numpy()} ; Accuracy : {accuracy} ; Test Loss : {testloss} ; Test accuracy : {testaccuracy}' )   #np.mean(e_loss)

  testoutputs = classifier(cls_test_data)
  # fpr, tpr, thresholds = metrics.roc_curve(testlabels.cpu(), testoutputs.cpu().detach().numpy())
  fpr, tpr, threshold, eB, eS = buildROC(testlabels.cpu(), testoutputs.cpu().detach().numpy())
  auc = metrics.auc(fpr, tpr)
  f1_score = metrics.f1_score(testlabels.cpu(), np.round(testoutputs.cpu().detach().numpy()), average='macro')
  _ , imtafe = get_perf_stats(testlabels.cpu(), testoutputs.cpu().detach().numpy())
  return cls_loss, cls_accuracy, test_cls_loss, test_cls_accuracy, fpr, tpr, auc, eB, eS, f1_score, imtafe


cls_loss, cls_accuracy, test_cls_loss, test_cls_accuracy, fpr, tpr, auc, eB, eS, f1_score, imtafe = train_classifier(cls_epochs, classifier, cls_train_data, targets, cls_test_data, testtargets)

In [None]:
plt.plot(cls_loss, label='Train Loss')
plt.plot(test_cls_loss, label='Val Loss')
plt.legend()

In [None]:
plt.plot(cls_accuracy, label='Train Accuracy')
plt.plot(test_cls_accuracy, label='Val Accuracy')
plt.legend()

In [None]:
import matplotlib.patches as mpl_patches

font = {'family': 'serif',
        'color':  'darkblue',
        'weight': 'normal',
        'size': 12,
        }

plt.plot(fpr, tpr)
plt.xlabel('FPR', fontsize=14)
plt.ylabel('TPR', fontsize=14)
# plt.text(0.9, 0.9, 'AUC = '+str(auc.round(6)), fontdict=font, wrap=True)
# create a list with two empty handles (or more if needed)
handles = [mpl_patches.Rectangle((0, 0), 1, 1, fc="white", ec="white",
                                 lw=0, alpha=0)]

# create the corresponding number of labels (= the text you want to display)
labels = []
labels.append("AUC = "+str(auc.round(6)))

# create the legend, supressing the blank space of the empty line symbol and the
# padding between symbol and label by setting handlelenght and handletextpad
plt.legend(handles, labels, loc='best', fontsize='large',
          fancybox=True, framealpha=0.7,
          handlelength=0, handletextpad=0)

In [None]:
print('Accuracy : ', test_cls_accuracy[-1])
print('AUC : ', auc)
print('F1 score : ', f1_score)