In [2]:
import os
import sys
import datetime 
import numpy as np
sys.path.append("..")
# from tqdm.rich import tqdm, trange
from tqdm import tqdm, trange
from copy import deepcopy
import h5py
import pandas as pd 

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor
from torchmetrics.regression import MeanAbsolutePercentageError

import seaborn as sns
import matplotlib.pylab as plt

from data_loader import get_database_path, get_h5_files, read_h5_file



In [3]:
random_seed = 114514
torch.manual_seed(random_seed)
np.random.seed(random_seed)

# Set default tensor type
if sys.platform == "darwin":
    # Mac OS
    default_precision = torch.float32
else:
    # Linux or Windows
    default_precision = torch.float64
torch.set_default_dtype(default_precision)

# Set device
#device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
device = "cuda" if torch.cuda.is_available() else "mps" if sys.platform == "darwin" else "cpu"
print(f"Using {device} device")

Using cuda device


# Load dataset

In [3]:
database_path = get_database_path()
bkg_files, sig_files = get_h5_files()


# SM processes
bkg = read_h5_file(database_path, bkg_files[0]['file'])
bkg_pairs, bkg_emds = read_h5_file(database_path, "bkg_emds.h5", datatype='EMD')

In [4]:
class PairedEventsDataset(Dataset):
    def __init__(self, events, pairs, emds):
        assert len(emds) == len(pairs)
        assert len(pairs.shape) == 2 and pairs.shape[1] == 2
        self.events = torch.from_numpy(events)
        self.pairs = torch.from_numpy(pairs)
        self.emds = torch.from_numpy(emds)
        
    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        return self.events[self.pairs[idx][0]], self.events[self.pairs[idx][1]], self.emds[idx]

In [5]:
bkg_dataset = PairedEventsDataset(bkg, bkg_pairs, bkg_emds)

# Model

In [None]:
class MLP(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(MLP, self).__init__()
        def make_layer(in_size, out_size):
            return nn.Sequential(
                nn.Linear(in_size, out_size),
                nn.LeakyReLU(),
                nn.Dropout(0.1)
            )
        self.layers = nn.Sequential(
            make_layer(input_size, hidden_sizes[0]),
            *[make_layer(hidden_sizes[i], hidden_sizes[i+1]) for i in range(len(hidden_sizes)-1)],
            nn.Linear(hidden_sizes[-1], output_size)
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.layers(x)

# Loss & Metric

In [4]:
def Euclidean_distance(y1, y2):
    return torch.norm(y1 - y2, dim=1)

def hyperbolic_distance(y1, y2):
    return torch.acosh(1 + 2 * torch.sum((y1 - y2)**2) / ((1 - torch.sum(y1**2)) * (1 - torch.sum(y2**2))))

In [10]:
a = torch.tensor([[ 1, 3], [-1, 4], [2, 2]] , dtype=torch.float)
b = torch.tensor([[ 1, 4], [-1, 3], [2, 2]] , dtype=torch.float)
print(torch.norm(a-b, dim=1))
print(Euclidean_distance(a, b))
print(hyperbolic_distance(a, b))

tensor([1., 1., 0.], dtype=torch.float32)
tensor([1., 1., 0.], dtype=torch.float32)
tensor([0.0588, 0.0588, 0.0000], dtype=torch.float32)


In [None]:
def MAPELoss(y_pred, y_true):
        return MeanAbsolutePercentageError(y_pred, y_true)

def MAPEonVAR(y_pred, y_true):
        normed_var_pred = torch.var(y_pred, dim=0) / torch.mean(y_pred, dim=0)
        normed_var_true = torch.var(y_true, dim=0) / torch.mean(y_true, dim=0)
        return MeanAbsolutePercentageError(normed_var_pred, normed_var_true)

# Train Pipeline

In [None]:
def printlog(info):
    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("\n"+"=========="*8 + "%s"%nowtime)
    print(str(info)+"\n")

class StepRunner:
    def __init__(self, net, loss_fn,
                 stage = "train", metrics_dict = None, 
                 optimizer = None
                 ):
        self.model,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage
        self.optimizer = optimizer
            
    def step(self, source_event, target_event, emd):
        #loss
        source_event, target_event, emd = source_event.to(device), target_event.to(device), emd.to(device)

        source_emebedding = self.model(source_event)
        target_embedding = self.model(target_event)
        loss = self.loss_fn(source_emebedding, target_embedding, emd)
        
        #backward()
        if self.optimizer is not None and self.stage=="train": 
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
            
        #metrics
        step_metrics = {self.stage+"_"+name:metric_fn(source_emebedding, target_embedding, emd).item() 
                        for name,metric_fn in self.metrics_dict.items()}
        return loss.item(), step_metrics
    
    def train_step(self, source_event, target_event, emd):
        self.model.train() 
        return self.step(source_event, target_event, emd)
    
    @torch.no_grad()
    def eval_step(self, source_event, target_event, emd):
        self.model.eval() 
        return self.step(source_event, target_event, emd)
    
    def __call__(self, source_event, target_event, emd):
        if self.stage=="train":
            return self.train_step(source_event, target_event, emd) 
        else:
            return self.eval_step(source_event, target_event, emd)
        
class EpochRunner:
    def __init__(self,steprunner):
        self.steprunner = steprunner
        self.stage = steprunner.stage
        
    def __call__(self,dataloader):
        total_loss,step = 0,0
        loop = tqdm(enumerate(dataloader), total =len(dataloader), file = sys.stdout)
        for batch_idx, (source_event, target_event, emd) in loop: 
            loss, step_metrics = self.steprunner(source_event, target_event, emd)
            step_log = dict({self.stage+"_loss":loss},**step_metrics)
            total_loss += loss
            step+=1
            if batch_idx!=len(dataloader)-1:
                loop.set_postfix(**step_log)
            else:
                epoch_loss = total_loss/step
                epoch_metrics = {self.stage+"_"+name:metric_fn.compute()
                                 for name,metric_fn in self.steprunner.metrics_dict.items()}
                epoch_log = dict({self.stage+"_loss":epoch_loss},**epoch_metrics)
                loop.set_postfix(**epoch_log)

                for name,metric_fn in self.steprunner.metrics_dict.items():
                    metric_fn.reset()
        return epoch_log

@torch.compile()
def train_model(net, optimizer, loss_fn, metrics_dict, 
                train_dataloader, val_dataloader=None, 
                epochs=10, ckpt_path='checkpoint.pt',
                patience=5, monitor="MAPE", mode="min"):
    
    history = {}

    for epoch in range(1, epochs+1):
        printlog("Epoch {0} / {1}".format(epoch, epochs))

        # 1，train -------------------------------------------------  
        train_step_runner = StepRunner(net = net,stage="train",
                loss_fn = loss_fn,metrics_dict=deepcopy(metrics_dict),
                optimizer = optimizer)
        train_epoch_runner = EpochRunner(train_step_runner)
        train_metrics = train_epoch_runner(train_dataloader)

        for name, metric in train_metrics.items():
            history[name] = history.get(name, []) + [metric]

        # 2，validate -------------------------------------------------
        if val_dataloader:
            val_step_runner = StepRunner(net = net,stage="val",
                loss_fn = loss_fn,metrics_dict=deepcopy(metrics_dict))
            val_epoch_runner = EpochRunner(val_step_runner)
            with torch.no_grad():
                val_metrics = val_epoch_runner(val_dataloader)
            val_metrics["epoch"] = epoch
            for name, metric in val_metrics.items():
                history[name] = history.get(name, []) + [metric]

        # 3，early-stopping -------------------------------------------------
        arr_scores = history[monitor]
        best_score_idx = np.argmax(arr_scores) if mode=="max" else np.argmin(arr_scores)
        if best_score_idx==len(arr_scores)-1:
            torch.save(net.state_dict(),ckpt_path)
            print("<<<<<< reach best {0} : {1} >>>>>>".format(monitor,
                 arr_scores[best_score_idx]))
        if len(arr_scores)-best_score_idx>patience:
            print("<<<<<< {} without improvement in {} epoch, early stopping >>>>>>".format(
                monitor,patience))
            break 
    net.load_state_dict(torch.load(ckpt_path))

    return pd.DataFrame(history)

# Train!