In [1]:
import numpy as np
import h5py, os

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
plt.rcParams["font.family"] = "serif"
plt.style.use('classic')
font = font_manager.FontProperties(family='serif', size=16)

import torch
import torch.nn as nn
from torch.utils.data import DataLoader 

In [None]:
def h5_loader(file_path: str='background_for_training.h5', min_pt=None, max_pt=None, mean_eta=None, std_eta=None, mean_phi=None, std_phi=None, min_class=None, max_class=None):
    ff = h5py.File(file_path, 'r')
    particles = np.asarray(ff.get('Particles'))
    particles = normalize_features(particles, min_pt, max_pt, mean_eta, std_eta, mean_phi, std_phi, min_class, max_class).reshape((-1, 19*4))
    particles = torch.from_numpy(particles)
    print(particles.shape)
    return particles

def summary_statistics(file_path: str='background_for_training.h5'):
    ff = h5py.File(file_path, 'r')
    particles = np.asarray(ff.get('Particles'))
    idx_pt, idx_eta, idx_phi, idx_class = range(4)
    min_pt    = np.min(particles[:,:,idx_pt])
    max_pt    = np.max(particles[:,:,idx_pt])
    mean_eta  = np.mean(particles[:,:,idx_eta])
    std_eta   = np.std(particles[:,:,idx_eta])
    mean_phi  = np.mean(particles[:,:,idx_phi])
    std_phi   = np.std(particles[:,:,idx_phi])
    min_class = np.min(particles[:,:,idx_class])
    max_class = np.max(particles[:,:,idx_class])
    return min_pt, max_pt, mean_eta, std_eta, mean_phi, std_phi, min_class, max_class

def normalize_features(particles, min_pt=None, max_pt=None, mean_eta=None, std_eta=None, mean_phi=None, std_phi=None, min_class=None, max_class=None):
    idx_pt, idx_eta, idx_phi, idx_class = range(4)
    if min_pt==None: min_pt    = np.min(particles[:,:,idx_pt])
    if max_pt==None: max_pt    = np.max(particles[:,:,idx_pt])
    if mean_eta==None: mean_eta  = np.mean(particles[:,:,idx_eta])
    if std_eta==None: std_eta   = np.std(particles[:,:,idx_eta])
    if mean_phi==None: mean_phi  = np.mean(particles[:,:,idx_phi])
    if std_phi==None: std_phi   = np.std(particles[:,:,idx_phi])
    if min_class==None: min_class = np.min(particles[:,:,idx_class])
    if max_class==None: max_class = np.max(particles[:,:,idx_class])
    # min-max normalize pt
    particles[:,:,idx_pt] = (particles[:,:,idx_pt] - min_pt) / (max_pt-min_pt)
    # standard normalize angles
    particles[:,:,idx_eta] = (particles[:,:,idx_eta] - mean_eta)/std_eta
    particles[:,:,idx_phi] = (particles[:,:,idx_phi] - mean_phi)/std_phi
    # min-max normalize class label
    particles[:,:,idx_class] = (particles[:,:,idx_class] - min_class) / (max_class-min_class)
    return particles

In [None]:
class Encoder(nn.Module):
    '''
    encoder produces mean and log of variance 
    (i.e., parateters of simple tractable normal distribution "q"
    '''
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        self.FC_input = nn.Linear(input_dim, hidden_dim[0])
        self.FC_hidden = nn.ModuleList([nn.Linear(hidden_dim[i], hidden_dim[i+1]) for i in range(len(hidden_dim)-1)])
        #self.FC_input2 = nn.Linear(hidden_dim, hidden_dim)
        self.FC_mean  = nn.Linear(hidden_dim[-1], latent_dim)
        self.FC_var   = nn.Linear (hidden_dim[-1], latent_dim)
        self.LeakyReLU = nn.LeakyReLU(0.2)
        self.training = True
        
    def forward(self, x):
        h = self.LeakyReLU(self.FC_input(x))
        for i, FC in enumerate(self.FC_hidden):
            h = self.LeakyReLU(FC(h))
        mean = self.FC_mean(h)
        log_var = self.FC_var(h)
        return mean, log_var
        
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.FC_latent = nn.Linear(latent_dim, hidden_dim[0])
        self.FC_hidden = nn.ModuleList([nn.Linear(hidden_dim[i], hidden_dim[i+1]) for i in range(len(hidden_dim)-1)])
        self.FC_output = nn.Linear(hidden_dim[-1], output_dim)
        self.LeakyReLU = nn.LeakyReLU(0.2)
        
    def forward(self, x):
        h = self.LeakyReLU(self.FC_latent(x))
        for i, FC in enumerate(self.FC_hidden):
            h = self.LeakyReLU(FC(h))
        x_hat = torch.sigmoid(self.FC_output(h))
        return x_hat

class Model(nn.Module):
    def __init__(self, input_dim, latent_dim, hidden_dim_encoder, hidden_dim_decoder):
        super(Model, self).__init__()
        self.Encoder = Encoder(input_dim=input_dim, hidden_dim=hidden_dim_encoder, latent_dim=latent_dim)
        self.Decoder = Decoder(latent_dim=latent_dim, hidden_dim = hidden_dim_decoder, output_dim=input_dim)
        
    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(DEVICE)  # sampling epsilon        
        z = mean + var*epsilon       # reparameterization trick
        return z
        
    def forward(self, x):
        mean, log_var = self.Encoder(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var)) # takes exponential function (log var -> var)
        x_hat = self.Decoder(z)
        return x_hat, mean, log_var

def AllLoss(x, x_hat, mean, log_var, lambda_kld):
    Reco = RecoLoss(x_hat, x)
    KLD  = KLDLoss(mean, log_var)
    return Reco + lambda_kld*KLD, Reco, lambda_kld*KLD

def RecoLoss(x_hat, x, reduction='sum'):
    return nn.functional.binary_cross_entropy(x_hat, x, reduction=reduction)
    
def KLDLoss(mean, log_var, reduction='sum'):
    if reduction=='sum':
        return - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())
    elif reduction=='none':
        return - 0.5 * (1+ log_var - mean.pow(2) - log_var.exp())

# Visualise Training Data

In [None]:
xlabel = [r'$p_{\rm T}$', r'$\eta$', r'$\phi$', 'class']
obj_label = ['MET', 
             'lepton 1', 'lepton 2', 'lepton 3', 'lepton 4', 
             'electron 1', 'electron 2', 'electron 3', 'electron 4',
             'jet 1', 'jet 2', 'jet 3', 'jet 4', 'jet 5', 'jet 6', 'jet 7', 'jet 8', 'jet 9', 'jet 10'
        ]
bins = [
    np.linspace(0, 0.5, 30),
    np.linspace(-6, 6, 30),
    np.linspace(-6, 6, 30),
    np.linspace(0, 1, 30),
       ]
sm= test_dataset[0]
sm=sm.view(-1, sm.shape[1]).float()
sm=sm.cpu()
for i in range(19):
    fig = plt.figure(figsize=(15, 4))
    fig.patch.set_facecolor('white')
    for j in range(4):
        ax = fig.add_axes([0.05+0.24*j+(i*4), 0.2, 0.21, 0.75])
        h = plt.hist(sm[:, j+(i*4)], density=True, color='black', bins=bins[j], lw=2, histtype='step', label='SM')
        ax.set_yscale('log')
        plt.xlabel(obj_label[i]+' '+xlabel[j], fontsize=16, fontname='serif')
        plt.xticks(fontsize=12, fontname='serif')
        plt.yticks(fontsize=12, fontname='serif')
        
        if j==0:
            plt.ylabel('probability', fontsize=16, fontname='serif')
            font = font_manager.FontProperties(family='serif', size=14)
            plt.xscale('log')
        if j==3:
            ax.legend(prop=font, loc='best', frameon=False)
    #plt.savefig('./input_features_obj%i.pdf'%(i))
    plt.show()