![deelema](Deelema.png)

In [None]:
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F
from torch.distributions import Uniform
from torch.optim.lr_scheduler import ReduceLROnPlateau, ExponentialLR, OneCycleLR
from torch_poly_lr_decay import PolynomialLRDecay

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
from sklearn.model_selection import train_test_split

import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
import pandas as pd
import matplotlib.colors as colors

import wandb

PATH_DATASETS = "."
AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 1024 if AVAIL_GPUS else 64
# BATCH_SIZE=1

plt.rc('font', size=20)        # 기본 폰트 크기
plt.rc('axes', labelsize=20)   # x,y축 label 폰트 크기
plt.rc('xtick', labelsize=20)  # x축 눈금 폰트 크기 
plt.rc('ytick', labelsize=20)  # y축 눈금 폰트 크기
plt.rc('legend', fontsize=20)  # 범례 폰트 크기
plt.rc('figure', titlesize=20) # figure title 폰트 크기

In [None]:
pl.seed_everything(8407)

In [None]:
class ToyData(Dataset):
    def __init__(self, p_A, p_a, p_B, p_b, q_C, q_c):
        self.X = torch.column_stack([p_A, p_a, p_B, p_b])
        self.q_C = q_C
        self.q_c = q_c
    
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx], self.q_C[idx], self.q_c[idx]

## Data Loading

In [None]:
process = 'toy'
np_data = np.load(process+'_array_preproc.npz')

In [None]:
pa1 = np_data['b1']
pa2 = np_data['b2']
pb1 = np_data['l2']
pb2 = np_data['l1']
qc1 = np_data['nu1']
qc2 = np_data['nu2']

In [None]:
def np_mass_sq(p):
    return p[:,0]**2 - p[:,1]**2 - p[:,2]**2 - p[:,3]**2

g = np.array([1,-1,-1,-1])
def Mass(p, ax=1):
    return torch.sqrt(p[:,0]**2 - p[:,1]**2 - p[:,2]**2 - p[:,3]**2)

def npMass(p, ax=1):
    return np.sqrt(p[:,0]**2 - p[:,1]**2 - p[:,2]**2 - p[:,3]**2)

In [None]:
pa1.shape, pb2.shape

## Into PyTorch

In [None]:
np_momenta = [pa1, pa2, pb1, pb2, qc1, qc2]

X = np.concatenate((pa1,pa2,pb1,pb2,qc1,qc2), axis=1)
X_train, X_test = train_test_split(X, test_size=0.2, random_state=42)

scaler = 1000

torch_momenta_train = list(map(lambda x: torch.tensor(x/scaler, dtype=torch.float32), np.array_split(X_train, 6, axis=1)))
torch_momenta_test  = list(map(lambda x: torch.tensor(x/scaler, dtype=torch.float32), np.array_split(X_test, 6, axis=1)))
# torch_momenta = list(map(lambda x: torch.tensor(x / 1000, dtype=torch.float32), np_momenta))

In [None]:
# np.savetxt('test_sample_toy_new.txt', X_test[:,:16], fmt = '%5f', delimiter = ' ')

In [None]:
# ds       = ToyData(*torch_momenta_train)
ds_test  = ToyData(*torch_momenta_test)

## Network

In [None]:
class ToyNet(pl.LightningModule):
    def __init__(self, hparams=None):
        super().__init__()
        
        
        hidden_layer       = hparams["hidden_layer"]
        hidden_depth       = hparams["hidden_depth"]
        learning_rate      = hparams["learning_rate"]
        batch_size         = hparams["batch_size"]
        
        self.hidden_layer  = hidden_layer
        self.hidden_depth  = hidden_depth
        self.learning_rate = learning_rate
        self.batch_size    = batch_size
        self.epochs        = hparams["epochs"]
        self.gamma        = hparams["gamma"]
        self.max_lr        = hparams['max_lr']
        self.learn_mode    = hparams['learn_mode'] # for pT mC loss on off
        self.learn_mode_sq = hparams['learn_mode_sq'] # sq or sqrt
        
        m_C = torch.tensor(hparams["m_C_init"])
        m_B = m_C + torch.tensor(hparams["m_B_add"])
        m_A = m_B + torch.tensor(hparams["m_A_add"])

        
        if self.learn_mode_sq == 'sq':
            m_C = m_C ** 2
            m_B = m_B ** 2
            m_A = m_A ** 2
        elif self.learn_mode_sq == 'sqrt':        
            m_C = m_C 
            m_B = m_B 
            m_A = m_A 
        
#         self.m_C = nn.Parameter(m_C, requires_grad=True)
        self.m_C = m_C
        self.m_B = nn.Parameter(m_B, requires_grad=True)
        self.m_A = nn.Parameter(m_A, requires_grad=True)
        
        layers = [nn.Linear(16, hidden_layer), nn.ReLU(inplace=True), nn.BatchNorm1d(hidden_layer)]
        for i in range(hidden_depth):
            layers.extend([
                nn.Linear(hidden_layer, hidden_layer),
                nn.ReLU(inplace=True),
                nn.BatchNorm1d(hidden_layer)
        ])
            

        if self.learn_mode == 'pt_mc':
            layers.append(nn.Linear(hidden_layer, 8))
        elif self.learn_mode in ['pt', 'mc']:
            layers.append(nn.Linear(hidden_layer, 6))
        elif self.learn_mode == None:
            layers.append(nn.Linear(hidden_layer, 4))
                
        self.net = nn.Sequential(*layers)
        
        self.save_hyperparameters(hparams)
        
        self.ds = None
        
    def forward(self, x):
        return self.net(x)
            
            
    def training_step(self, batch, batch_idx):
        x, _, _ = batch
        pa1 = x[:,0:4]
        pa2 = x[:,4:8]
        pb1 = x[:,8:12]
        pb2 = x[:,12:16]
        
        q = self(x)
        if self.learn_mode == 'pt_mc':
            qc1 = q[:,0:4]
            qc2 = q[:,4:8]
        elif self.learn_mode == 'pt':
            qx1 = q[:,0:1] 
            qy1 = q[:,1:2]
            qx2 = q[:,2:3] 
            qy2 = q[:,3:4]
            qz1 = q[:,4:5]   
            qz2 = q[:,5:6]          
            
            Eq1 = torch.sqrt(self.m_C**2 + qx1**2 + qy1**2 + qz1**2)
            Eq2 = torch.sqrt(self.m_C**2 + qx2**2 + qy2**2 + qz2**2)    
            
            qc1  = torch.cat([Eq1,qx1,qy1,qz1], 1)
            qc2  = torch.cat([Eq2,qx2,qy2,qz2], 1) 
            
        elif self.learn_mode == 'mc':
            qx1 = q[:,0:1] 
            qy2 = q[:,1:2]
            qz1 = q[:,2:3]   
            qz2 = q[:,3:4]
            Eq1 = q[:,4:5]
            Eq2 = q[:,5:6]            

            pTx = x[:,1:2]+x[:,5:6]+x[:,9:10]+x[:,13:14]
            pTy = x[:,2:3]+x[:,6:7]+x[:,10:11]+x[:,14:15]

            qx2 = -pTx-qx1
            qy1 = -pTy-qy2

            qc1  = torch.cat([Eq1,qx1,qy1,qz1], 1)
            qc2  = torch.cat([Eq2,qx2,qy2,qz2], 1)    
            
        elif self.learn_mode == None:
            qx1 = q[:,0:1] 
            qy2 = q[:,1:2]
            qz1 = q[:,2:3]   
            qz2 = q[:,3:4]

            pTx = x[:,1:2]+x[:,5:6]+x[:,9:10]+x[:,13:14]
            pTy = x[:,2:3]+x[:,6:7]+x[:,10:11]+x[:,14:15]

            qx2 = -pTx-qx1
            qy1 = -pTy-qy2

            Eq1 = torch.sqrt(self.m_C**2 + qx1**2 + qy1**2 + qz1**2)
            Eq2 = torch.sqrt(self.m_C**2 + qx2**2 + qy2**2 + qz2**2)

            qc1  = torch.cat([Eq1,qx1,qy1,qz1], 1)
            qc2  = torch.cat([Eq2,qx2,qy2,qz2], 1)        
        
        pB1 = pb1 + qc1
        pB2 = pb2 + qc2
        pA1 = pa1 + pB1
        pA2 = pa2 + pB2
        pT = (pA1 + pA2)[:,1:3]

        if self.learn_mode_sq == 'sq':
            mC1_sq = np_mass_sq(qc1)
            mC2_sq = np_mass_sq(qc2)
            mB1_sq = np_mass_sq(pB1)
            mB2_sq = np_mass_sq(pB2)
            mA1_sq = np_mass_sq(pA1)
            mA2_sq = np_mass_sq(pA2)

        elif self.learn_mode_sq == 'sqrt':
            mC1_sq = Mass(qc1)
            mC2_sq = Mass(qc2)
            mB1_sq = Mass(pB1)
            mB2_sq = Mass(pB2)
            mA1_sq = Mass(pA1)
            mA2_sq = Mass(pA2)

        mCs = self.m_C * torch.ones_like(mC1_sq)
        mBs = self.m_B * torch.ones_like(mB1_sq)
        mAs = self.m_A * torch.ones_like(mA1_sq)
        
  

        loss_C = torch.abs(mC1_sq - mC2_sq) + torch.abs(mC1_sq - mCs) + torch.abs(mC2_sq - mCs)
        loss_B = torch.abs(mB1_sq - mB2_sq) + torch.abs(mB1_sq - mBs) + torch.abs(mB2_sq - mBs)
        loss_A = torch.abs(mA1_sq - mA2_sq) + torch.abs(mA1_sq - mAs) + torch.abs(mA2_sq - mAs)
        
        loss_pT = pT[:,0]**2 + pT[:,1]**2            
        
        loss_C = loss_C 
        loss_B = loss_B 
        loss_A = loss_A 
        
        
        if self.learn_mode == 'pt_mc':
            loss = (loss_A + loss_B + loss_C).mean() + loss_pT.mean()
        elif self.learn_mode == 'pt':
            loss = (loss_A + loss_B).mean() + loss_pT.mean()
        elif self.learn_mode == 'mc':
            loss = (loss_A + loss_B + loss_C).mean() 
        elif self.learn_mode == None:        
            loss = (loss_A + loss_B).mean() 
                
        return loss
    
    
    def validation_step(self, batch, batch_idx):
        x, _, _ = batch
        pa1 = x[:,0:4]
        pa2 = x[:,4:8]
        pb1 = x[:,8:12]
        pb2 = x[:,12:16]
        
        q = self(x)
        if self.learn_mode == 'pt_mc':
            qc1 = q[:,0:4]
            qc2 = q[:,4:8]
        elif self.learn_mode == 'pt':
            qx1 = q[:,0:1] 
            qy1 = q[:,1:2]
            qx2 = q[:,2:3] 
            qy2 = q[:,3:4]
            qz1 = q[:,4:5]   
            qz2 = q[:,5:6]          
            
            Eq1 = torch.sqrt(self.m_C**2 + qx1**2 + qy1**2 + qz1**2)
            Eq2 = torch.sqrt(self.m_C**2 + qx2**2 + qy2**2 + qz2**2)    
            
            qc1  = torch.cat([Eq1,qx1,qy1,qz1], 1)
            qc2  = torch.cat([Eq2,qx2,qy2,qz2], 1) 
            
        elif self.learn_mode == 'mc':
            qx1 = q[:,0:1] 
            qy2 = q[:,1:2]
            qz1 = q[:,2:3]   
            qz2 = q[:,3:4]
            Eq1 = q[:,4:5]
            Eq2 = q[:,5:6]            

            pTx = x[:,1:2]+x[:,5:6]+x[:,9:10]+x[:,13:14]
            pTy = x[:,2:3]+x[:,6:7]+x[:,10:11]+x[:,14:15]

            qx2 = -pTx-qx1
            qy1 = -pTy-qy2

            qc1  = torch.cat([Eq1,qx1,qy1,qz1], 1)
            qc2  = torch.cat([Eq2,qx2,qy2,qz2], 1)    
            
        elif self.learn_mode == None:
            qx1 = q[:,0:1] 
            qy2 = q[:,1:2]
            qz1 = q[:,2:3]   
            qz2 = q[:,3:4]

            pTx = x[:,1:2]+x[:,5:6]+x[:,9:10]+x[:,13:14]
            pTy = x[:,2:3]+x[:,6:7]+x[:,10:11]+x[:,14:15]

            qx2 = -pTx-qx1
            qy1 = -pTy-qy2

            Eq1 = torch.sqrt(self.m_C**2 + qx1**2 + qy1**2 + qz1**2)
            Eq2 = torch.sqrt(self.m_C**2 + qx2**2 + qy2**2 + qz2**2)

            qc1  = torch.cat([Eq1,qx1,qy1,qz1], 1)
            qc2  = torch.cat([Eq2,qx2,qy2,qz2], 1)        
        
        pB1 = pb1 + qc1
        pB2 = pb2 + qc2
        pA1 = pa1 + pB1
        pA2 = pa2 + pB2
        pT = (pA1 + pA2)[:,1:3]

        if self.learn_mode_sq == 'sq':
            mC1_sq = np_mass_sq(qc1)
            mC2_sq = np_mass_sq(qc2)
            mB1_sq = np_mass_sq(pB1)
            mB2_sq = np_mass_sq(pB2)
            mA1_sq = np_mass_sq(pA1)
            mA2_sq = np_mass_sq(pA2)

        elif self.learn_mode_sq == 'sqrt':
            mC1_sq = Mass(qc1)
            mC2_sq = Mass(qc2)
            mB1_sq = Mass(pB1)
            mB2_sq = Mass(pB2)
            mA1_sq = Mass(pA1)
            mA2_sq = Mass(pA2)

        mCs = self.m_C * torch.ones_like(mC1_sq)
        mBs = self.m_B * torch.ones_like(mB1_sq)
        mAs = self.m_A * torch.ones_like(mA1_sq)
        
  

        loss_C = torch.abs(mC1_sq - mC2_sq) + torch.abs(mC1_sq - mCs) + torch.abs(mC2_sq - mCs)
        loss_B = torch.abs(mB1_sq - mB2_sq) + torch.abs(mB1_sq - mBs) + torch.abs(mB2_sq - mBs)
        loss_A = torch.abs(mA1_sq - mA2_sq) + torch.abs(mA1_sq - mAs) + torch.abs(mA2_sq - mAs)
        
        loss_pT = pT[:,0]**2 + pT[:,1]**2            
        
        loss_C = loss_C 
        loss_B = loss_B 
        loss_A = loss_A 
        
        
        if self.learn_mode == 'pt_mc':
            loss = (loss_A + loss_B + loss_C).mean() + loss_pT.mean()
        elif self.learn_mode == 'pt':
            loss = (loss_A + loss_B).mean() + loss_pT.mean()
        elif self.learn_mode == 'mc':
            loss = (loss_A + loss_B + loss_C).mean() 
        elif self.learn_mode == None:        
            loss = (loss_A + loss_B).mean() 
            
        
        self.log('val_loss', loss)
        self.log('loss_A', loss_A)
        self.log('loss_B', loss_B)
        self.log('loss_C', loss_C)
        self.log('loss_pT', loss_pT)
        self.log('m_A', self.m_A)
        self.log('m_B', self.m_B)
        self.log('m_C', self.m_C)
        self.log('m_A1', mA1_sq)
        self.log('m_A2', mA2_sq)
        self.log('m_B1', mB1_sq)
        self.log('m_B2', mB2_sq)
        self.log('m_C1', mC1_sq)
        self.log('m_C2', mC2_sq)
        
        return loss
    
    
    
    def configure_optimizers(self):
        optimizer = optim.Adam(
            self.parameters(),
            lr=self.learning_rate,
            betas=(0.99, 0.9999),
            weight_decay=0.1
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": OneCycleLR(
                    optimizer, 
                    max_lr=self.max_lr,
                    steps_per_epoch=len(self.ds_train) // self.batch_size + 1,
                    epochs = self.epochs,
                ),
                "interval": "step",
                "monitor": "val_loss",
                "strict": True,
            }
        } 

        
    def prepare_data(self):
        self.ds = ToyData(*torch_momenta_train)
        self.N = len(self.ds)
        
    def setup(self, stage=None):
        N_train = self.N // 10 * 7
        N_val = self.N - N_train
        if stage == "fit" or stage is None:
            self.ds_train, self.ds_val = random_split(self.ds, [N_train, N_val])
        if stage == "test" or stage is None:
            _, self.ds_test = random_split(self.ds, [N_train, N_val])
    
    def train_dataloader(self):
        return DataLoader(self.ds_train, batch_size=self.batch_size)
    
    def val_dataloader(self):
        return DataLoader(self.ds_val, batch_size=self.batch_size)
    
    def test_dataloader(self):
        return DataLoader(self.ds_test, batch_size=self.batch_size)

## Hyper-parameter Setting

In [None]:
mc = 700/scaler 

hparams = {
    "hidden_layer": 256,
    "hidden_depth": 5,
    "learning_rate": 1e-4,
    "batch_size": BATCH_SIZE,
    "m_C_init": mc,
    "m_B_add": 0.3,
    "m_A_add": 0.3,
    "max_lr": 1e-4,
    "epochs": 50,
    "gamma": 0.9,
    "learn_mode": None, # 'pt_mc', 'mc', 'pt', None
    "learn_mode_sq":'sqrt', # 'sq' , 'sqrt' // Note: For 'sqrt', ONLY 'pt' and None are available for the physical reason.
}

In [None]:
model = ToyNet(
    hparams=hparams
)

wandb_logger = WandbLogger(
    project='Auxiliary_Mass_Exp_boostaug'
)

trainer = Trainer(
    logger=wandb_logger,
    max_epochs=hparams["epochs"],
    gpus=AVAIL_GPUS,
    enable_progress_bar=False,
    callbacks=[
#         EarlyStopping(monitor="val_loss", patience=20, mode="min"),
        LearningRateMonitor(logging_interval="step")
    ]
)

## Learning

In [None]:
trainer.fit(model)

In [None]:
trainer.save_checkpoint('Auxiliary.pth')
wandb.save('Auxiliary.pth')

In [None]:
wandb.finish()