# Load Configuration File
I'm using the ymal type configuration file to treat the hyper parameters that will be used in the training

In [1]:
import yaml

with open('example_config.yaml') as f:
    config = yaml.load(f, Loader=yaml.CLoader)

branches = config['branches']
hyper_params = config['hyper_params']

You can easily check the hyper parameters or configuration of the training in this file 

In [2]:
def read_config(config):
    for key in config:
        if isinstance(config[key], dict):
            print(f"{key:_^40}")
            read_config(config[key])
        else:
            print(f" {key} : {config[key]}")

read_config(config)

______________hyper_params______________
 dim_output : 3
 dim_ffnn : 64
 num_blocks : 2
 num_heads : 2
 depth : 4
 batch_size : 128
 learning_rate : 0.0003
________________branches________________
 jet_branches : ['jet_pt', 'jet_eta', 'jet_phi', 'jet_mass', 'jet_b_tag']
 lep_branches : ['lep_pt', 'lep_eta', 'lep_phi', 'lep_mass', 'lep_charge', 'lep_isMuon']
 met_branches : ['met', 'met_phi']
 target_branch : jet_parton_match_detail
 reco_branches : ['weight']
______________loader_args_______________
 pin_memory : True
 shuffle : True
 n_epoch : 10


# Load Dataset
the file that loaded in this example is analyzed ttbar dileptonic channel ($t\bar{t}\rightarrow bWbW$)

In [3]:
import uproot

example_rootfile = "example_dilepton.root"
f = uproot.open(example_rootfile)
print(f.keys())

['delphes;2', 'delphes;1', 'unmatched;2', 'unmatched;1', 'genWeight;1', 'cutflow;1']


In this file you can see 'delphes' and 'unmatched' tree, 'delphes' tree contains jet parton matched events otherwise 'unmatched' contains unmatched events

In [4]:
from saja import TTbarDileptonDataset

tree_path = 'delphes'
dataset = TTbarDileptonDataset(example_rootfile,
                               tree_path,
                               **config['branches'])

[TTbarDileptonDataset] 781 / 781 (100.00 %): : 1it [00:00, 17.85it/s]


Brief view of dataset in the following cell

In [5]:
for evt_num, evt in enumerate(dataset[:5]):
    print(f"{'event'+str(evt_num):_^50}")
    for key in evt.keys():
        print(f"{key}: {evt[key]}")

______________________event0______________________
jet: tensor([[86.2265, -0.2301, -1.6474,  9.3401,  1.0000],
        [33.6623,  0.1983,  1.0714,  7.3076,  0.0000]])
lepton: tensor([[ 5.4675e+01,  4.6465e-01, -2.5411e+00,  0.0000e+00, -1.0000e+00,
          1.0000e+00],
        [ 5.0344e+01, -7.5076e-01, -3.9229e-01,  9.5367e-07,  1.0000e+00,
          1.0000e+00]])
met: tensor([77.8212,  1.4392])
target: tensor([1, 1])
reco: tensor([1.])
______________________event1______________________
jet: tensor([[154.3780,  -0.4807,  -2.0845,  21.8055,   1.0000],
        [ 31.3351,   0.3568,   0.3271,   6.0551,   0.0000]])
lepton: tensor([[ 1.0395e+02,  9.9853e-03,  1.5579e+00, -1.3487e-06, -1.0000e+00,
          0.0000e+00],
        [ 4.5997e+01, -7.6262e-01,  2.6669e+00,  0.0000e+00,  1.0000e+00,
          1.0000e+00]])
met: tensor([ 7.3291e+01, -4.0187e-02])
target: tensor([1, 1])
reco: tensor([1.])
______________________event2______________________
jet: tensor([[57.9330,  0.7417, -1.3356,  6

## train and validation dataset
In this tutorial I used the same dataset for the train and validation,
but you have to use splited dataset for validation in real training

In [6]:
train_dataset = dataset
valid_dataset = dataset

To train the model we usually scale the variables in the same range 

E.G. jet pt range [0, 700] --> [0, 1], jet eta range [-2.4, 2.4] --> [0, 1] ...

and we have to save the scaler's value after fitting for the training dataset since the scale values must be fitted in $\textit{"train"}$ dataset (Fixed parameter)

In [7]:
import os
import torch
from saja import MinMaxScaler

save_path = "model_output/tutorial_model"
if not os.path.isdir(save_path):
    os.makedirs(save_path)

scaler = MinMaxScaler(
        branches['jet_branches'],
        branches['lep_branches'],
        branches['met_branches'],
        )
scaler.fit(train_dataset)
torch.save(scaler, f'{save_path}/scaler.pt')

After fit the scaler in train dataset, we scale the datasets 

In [8]:
# scaler = torch.load(f'{save_path}/scaler.pt')  # you must use the scaler that fitted in the training dataset
scaler.transform(train_dataset)
scaler.transform(valid_dataset)

DataLoader is used for training such as splitting the dataset into batches

In [9]:
from torch.utils.data import DataLoader

train_loader = DataLoader(dataset=train_dataset,
                          batch_size=hyper_params['batch_size'],
                          collate_fn=train_dataset.collate,
                          **config['loader_args'],
                          )

valid_loader = DataLoader(dataset=valid_dataset,
                          batch_size=512,
                          collate_fn=valid_dataset.collate,
                          )

# Model and optimizer

before define the model and optimizer, all the training stuffs (batch, model...) should be in the same device

In [10]:
device = torch.device("cuda" if (torch.cuda.is_available()) else 'cpu')

Now we define the model with hyper parameters

In [11]:
from saja import TTbarDileptonSAJA

model = TTbarDileptonSAJA(dim_jet=len(branches['jet_branches']),
                          dim_lepton=len(branches['lep_branches']),
                          dim_met=len(branches['met_branches']),
                          dim_output=2,  # 0: other, 1: b-parton matched jet
                          dim_ffnn=hyper_params['dim_ffnn'],
                          num_blocks=hyper_params['num_blocks'],
                          num_heads=hyper_params['num_heads'],
                          depth=hyper_params['depth'],
        ).to(device)  # Send to devcie (GPU or CPU..., whatever you defined)

In [12]:
opt = torch.optim.Adam(model.parameters(),
                       lr=hyper_params['learning_rate'])

# Training

In [13]:
from saja import object_wise_cross_entropy

In [14]:
def train(model, train_loader, opt, device):
    torch.set_grad_enabled(True)
    model.train()
    train_loss = 0
    for batch in train_loader:
        opt.zero_grad()
        batch = batch.to(device)
        logits = model(batch)
        # Using custom loss
        loss = object_wise_cross_entropy(logits,
                                         batch.target,
                                         torch.logical_not(
                                             batch.jet_data_mask
                                             ),
                                         batch.jet_lengths)
        loss.backward()
        opt.step()
        train_loss += loss.item() * len(batch.target)
    return train_loss

In [15]:
def validation(model, valid_loader, device):
    torch.set_grad_enabled(False)
    model.eval()
    valid_loss = 0
    for batch in valid_loader:
        batch = batch.to(device)
        logits = model(batch)
        loss = object_wise_cross_entropy(logits,
                                         batch.target,
                                         torch.logical_not(
                                             batch.jet_data_mask
                                             ),
                                         batch.jet_lengths,
                                         reduction='none').sum()
        valid_loss += loss.item()
    return valid_loss

Now the train and validation functions are defined, and used in the following cell

In [16]:
for epoch in range(config['n_epoch']):
    train_loss = train(model, train_loader, opt, device)
    valid_loss = validation(model, valid_loader, device)
    print(f"Epoch {epoch} Done")
    print(f"  {train_loss = :.4f}\t{valid_loss = :.4f}")

Epoch 0 Done
  train_loss = 682.0090	valid_loss = 644.4925
Epoch 1 Done
  train_loss = 590.8404	valid_loss = 537.7847
Epoch 2 Done
  train_loss = 527.0134	valid_loss = 482.5629
Epoch 3 Done
  train_loss = 487.2957	valid_loss = 446.5849
Epoch 4 Done
  train_loss = 457.8244	valid_loss = 428.2918
Epoch 5 Done
  train_loss = 437.1365	valid_loss = 420.3568
Epoch 6 Done
  train_loss = 430.8327	valid_loss = 416.9252
Epoch 7 Done
  train_loss = 425.9379	valid_loss = 415.6815
Epoch 8 Done
  train_loss = 422.8520	valid_loss = 415.6044
Epoch 9 Done
  train_loss = 421.4663	valid_loss = 415.0670
