In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import os
import sys
import glob
import uproot as ur
import matplotlib.pyplot as plt
import time
import seaborn as sns
import tensorflow as tf
from graph_nets import utils_np
from graph_nets import utils_tf
from graph_nets.graphs import GraphsTuple
import sonnet as snt
import argparse
import yaml
import logging
import tensorflow as tf
from tqdm import tqdm

from gn4pions.modules.data import GraphDataGenerator
from gn4pions.modules.models import MultiOutWeightedRegressModel
from gn4pions.modules.utils import convert_to_tuple

sns.set_context('poster')

In [2]:
# Loading model config
config_file = 'gn4pions/configs/ming_tracks_full.yaml' # for a quick run of the notebook
# config_file = 'gn4pions/configs/baseline.yaml' # for actual training
config = yaml.load(open(config_file), Loader=yaml.FullLoader)

# Data config
data_config = config['data']

data_dir = data_config['data_dir']
num_train_files = data_config['num_train_files']
num_val_files = data_config['num_val_files']
batch_size = data_config['batch_size']
shuffle = data_config['shuffle']
num_procs = data_config['num_procs']
preprocess = data_config['preprocess']
output_dir = data_config['output_dir']
already_preprocessed = data_config['already_preprocessed']  # Set to false when running training for first time

# Model Config
model_config = config['model']

concat_input = model_config['concat_input']


# Traning Config
train_config = config['training']

epochs = train_config['epochs']
learning_rate = train_config['learning_rate']
alpha = train_config['alpha']
os.environ['CUDA_VISIBLE_DEVICES'] = str(train_config['gpu'])
log_freq = train_config['log_freq']
save_dir = train_config['save_dir'] + config_file.replace('.yaml','').split('/')[-1] + '_' + time.strftime("%Y%m%d")

os.makedirs(save_dir, exist_ok=True)
yaml.dump(config, open(save_dir + '/config.yaml', 'w'))


In [None]:
# Read data and create data generators

pi0_files = np.sort(glob.glob(data_dir+'*pi0*/*root'))
pion_files = np.sort(glob.glob(data_dir+'*pion*/*root'))

train_start = 0
train_end = train_start + num_train_files
val_end = train_end + num_val_files

pi0_train_files = pi0_files[train_start:train_end]
pi0_val_files = pi0_files[train_end:val_end]
pion_train_files = pion_files[train_start:train_end]
pion_val_files = pion_files[train_end:val_end]

train_output_dir = None
val_output_dir = None

# Get Data
if preprocess:
    train_output_dir = output_dir + 'train/'
    val_output_dir = output_dir + 'val/'

    if already_preprocessed:
        train_files = np.sort(glob.glob(train_output_dir+'*.p'))[:num_train_files]
        val_files = np.sort(glob.glob(val_output_dir+'*.p'))[:num_val_files]

        pi0_train_files = train_files
        pi0_val_files = val_files
        pion_train_files = None
        pion_val_files = None

        train_output_dir = None
        val_output_dir = None

# Traning Data Generator
# Will preprocess data if it doesnt find pickled files
data_gen_train = GraphDataGenerator(pi0_file_list=pi0_train_files,
                                    pion_file_list=pion_train_files,
                                    cellGeo_file=data_dir+'CellGeo.neighbours.root',
                                    batch_size=batch_size,
                                    shuffle=shuffle,
                                    num_procs=num_procs,
                                    preprocess=preprocess,
                                    output_dir=train_output_dir)

# Validation Data generator
# Will preprocess data if it doesnt find pickled files
data_gen_val = GraphDataGenerator(pi0_file_list=pi0_val_files,
                                  pion_file_list=pion_val_files,
                                  cellGeo_file=data_dir+'CellGeo.neighbours.root',
                                  batch_size=batch_size,
                                  shuffle=shuffle,
                                  num_procs=num_procs,
                                  preprocess=preprocess,
                                  output_dir=val_output_dir)



Preprocessing and saving data to /clusterfs/ml4hep/mfong/ML4Pions/graph_data_tracks/train/
Processing file number 0
Processing file number 1
Processing file number 2
Processing file number 3
Processing file number 4
Processing file number 5
Processing file number 6
Processing file number 7
Processing file number 8
Processing file number 9
Processing file number 10
Processing file number 11
Processing file number 12
Processing file number 13
Processing file number 14
Processing file number 15
Processing file number 16
Processing file number 17
Processing file number 18Processing file number 19
Processing file number 20

Processing file number 21
Processing file number 22
Processing file number 23
Processing file number 24
Processing file number 25
Processing file number 26
Processing file number 27
Processing file number 28Processing file number 29

Processing file number 30
Processing file number 31


In [None]:
# Get batch of data
def get_batch(data_iter):
    for graphs, targets in data_iter:
        graphs = convert_to_tuple(graphs)
        targets = tf.convert_to_tensor(targets)
        yield graphs, targets
        
# Define loss function        
mae_loss = tf.keras.losses.MeanAbsoluteError()
bce_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def loss_fn(targets, regress_preds, class_preds):
    regress_loss = mae_loss(targets[:,:1], regress_preds)
    class_loss = bce_loss(targets[:,1:], class_preds)
    combined_loss = alpha*regress_loss + (1 - alpha)*class_loss 
    return regress_loss, class_loss, combined_loss

In [None]:
# Get a sample graph for tf.function decorator
samp_graph, samp_target = next(get_batch(data_gen_train.generator()))
data_gen_train.kill_procs()
graph_spec = utils_tf.specs_from_graphs_tuple(samp_graph, True, True, True)

# Training set
@tf.function(input_signature=[graph_spec, tf.TensorSpec(shape=[None,2], dtype=tf.float32)])
def train_step(graphs, targets):
    with tf.GradientTape() as tape:
        regress_output, class_output = model(graphs)
        regress_preds = regress_output.globals
        class_preds = class_output.globals
        regress_loss, class_loss, loss = loss_fn(targets, regress_preds, class_preds)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return regress_loss, class_loss, loss

# Validation Step
@tf.function(input_signature=[graph_spec, tf.TensorSpec(shape=[None,2], dtype=tf.float32)])
def val_step(graphs, targets):
    regress_output, class_output = model(graphs)
    regress_preds = regress_output.globals
    class_preds = class_output.globals
    regress_loss, class_loss, loss = loss_fn(targets, regress_preds, class_preds)
    return regress_loss, class_loss, loss, regress_preds, class_preds

In [22]:
samp_graph

GraphsTuple(nodes=<tf.Tensor: shape=(65955, 11), dtype=float32, numpy=
array([[-0.2833688 ,  0.21428572, -2.1990156 , ...,  0.        ,
         0.        ,  0.        ],
       [-0.68640566,  0.21428572, -2.1989949 , ...,  0.        ,
         0.        ,  0.        ],
       [-0.883118  ,  0.21428572, -2.1990368 , ...,  0.        ,
         0.        ,  0.        ],
       ...,
       [-1.2855362 ,  0.        ,  0.74026513, ...,  0.        ,
         0.        ,  0.        ],
       [-2.298425  ,  0.03571429,  0.7353846 , ...,  0.        ,
         0.        ,  0.        ],
       [-1.9210465 ,  0.03571429,  0.73531723, ...,  0.        ,
         0.        ,  0.        ]], dtype=float32)>, edges=<tf.Tensor: shape=(205464, 10), dtype=float32, numpy=
array([[1., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]]

In [25]:
# Model 
model = MultiOutWeightedRegressModel(global_output_size=1, num_outputs=2, model_config=model_config)

# Optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate)

# Average epoch losses
training_loss_epoch = []
training_loss_regress_epoch = []
training_loss_class_epoch = []
val_loss_epoch = []
val_loss_regress_epoch = []
val_loss_class_epoch = []

# Model checkpointing, load latest model if available
checkpoint = tf.train.Checkpoint(module=model)
checkpoint_prefix = os.path.join(save_dir, 'latest_model')
latest = tf.train.latest_checkpoint(save_dir)
if latest is not None:
    checkpoint.restore(latest)
else:
    checkpoint.save(checkpoint_prefix)

In [26]:
# for i, (graph_data_tr, targets_tr) in enumerate(get_batch(data_gen_train.generator())):
#     print("Targets: ",targets_tr[:,0])

In [27]:
# Run training
curr_loss = 1e5

for e in range(epochs):

    print(f'\n\nStarting epoch: {e}')
    epoch_start = time.time()
    
    # Batchwise losses
    training_loss = []
    training_loss_regress = []
    training_loss_class = []
    val_loss = []
    val_loss_regress = []
    val_loss_class = []

    # Train
    print('Training...')
    start = time.time()
    for i, (graph_data_tr, targets_tr) in enumerate(get_batch(data_gen_train.generator())):
        losses_tr_rg, losses_tr_cl, losses_tr = train_step(graph_data_tr, targets_tr)

        training_loss.append(losses_tr.numpy())
        training_loss_regress.append(losses_tr_rg.numpy())
        training_loss_class.append(losses_tr_cl.numpy())

        if not (i-1)%log_freq:
            end = time.time()
            print(f'Iter: {i:04d}, ', end='')
            print(f'Tr_loss_mean: {np.mean(training_loss):.4f}, ', end='')
            print(f'Tr_loss_rg_mean: {np.mean(training_loss_regress):.4f}, ', end='') 
            print(f'Tr_loss_cl_mean: {np.mean(training_loss_class):.4f}, ', end='') 
            print(f'Took {end-start:.4f}secs')
            start = time.time()
                  
    training_loss_epoch.append(training_loss)
    training_loss_regress_epoch.append(training_loss_regress)
    training_loss_class_epoch.append(training_loss_class)
    training_end = time.time()

    # validate
    print('\nValidation...')
    all_targets = []
    all_outputs = []
    all_energies = []
    start = time.time()
    for i, (graph_data_val, targets_val) in enumerate(get_batch(data_gen_val.generator())):
        losses_val_rg, losses_val_cl, losses_val, regress_vals, class_vals = val_step(graph_data_val, targets_val)

        targets_val = targets_val.numpy()
        regress_vals = regress_vals.numpy()
        class_vals = class_vals.numpy()

        targets_val[:,0] = 10**targets_val[:,0]
        regress_vals = 10**regress_vals
        class_vals =  tf.math.sigmoid(class_vals)

        output_vals = np.hstack([regress_vals, class_vals])

        val_loss.append(losses_val.numpy())
        val_loss_regress.append(losses_val_rg.numpy())
        val_loss_class.append(losses_val_cl.numpy())

        all_targets.append(targets_val)
        all_outputs.append(output_vals)

        if not (i-1)%log_freq:
            end = time.time()
            print(f'Iter: {i:04d}, ', end='')
            print(f'Val_loss_mean: {np.mean(val_loss):.4f}, ', end='')
            print(f'Val_loss_rg_mean: {np.mean(val_loss_regress):.4f}, ', end='') 
            print(f'Val_loss_cl_mean: {np.mean(val_loss_class):.4f}, ', end='') 
            print(f'Took {end-start:.4f}secs')
            start = time.time()

    epoch_end = time.time()

    all_targets = np.concatenate(all_targets)
    all_outputs = np.concatenate(all_outputs)

    val_loss_epoch.append(val_loss)
    val_loss_regress_epoch.append(val_loss_regress)
    val_loss_class_epoch.append(val_loss_class)

    
    # Book keeping
    val_mins = int((epoch_end - training_end)/60)
    val_secs = int((epoch_end - training_end)%60)
    training_mins = int((training_end - epoch_start)/60)
    training_secs = int((training_end - epoch_start)%60)
    print(f'\nEpoch {e} ended')
    print(f'Training: {training_mins:2d}:{training_secs:02d}')
    print(f'Validation: {val_mins:2d}:{val_secs:02d}')
    
    
    # Save losses
    np.savez(save_dir+'/losses', 
            training=training_loss_epoch, validation=val_loss_epoch,
            training_regress=training_loss_regress_epoch, validation_regress=val_loss_regress_epoch,
            training_class=training_loss_class_epoch, validation_class=val_loss_class_epoch,
            )

    
    # Checkpoint if validation loss improved
    if np.mean(val_loss)<curr_loss:
        print(f'Loss decreased from {curr_loss:.4f} to {np.mean(val_loss):.4f}')
        print(f'Checkpointing and saving predictions to:\n{save_dir}')
        curr_loss = np.mean(val_loss)
        np.savez(save_dir+'/predictions', 
                targets=all_targets, 
                outputs=all_outputs,
                energies=all_energies)
        checkpoint.save(checkpoint_prefix)
    else: 
        print(f'Loss didnt decrease from {curr_loss:.4f}')
    
    
    # Decrease learning rate every few epochs
    if not (e+1)%2:   #%20:
        optimizer.learning_rate = optimizer.learning_rate/10
        print(f'Learning rate decreased to: {optimizer.learning_rate.value():.3e}')



Starting epoch: 0
Training...


2022-02-25 01:51:16.242679: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
2022-02-25 01:51:17.874576: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)


Iter: 0001, Tr_loss_mean: 1.1912, Tr_loss_rg_mean: 1.3625, Tr_loss_cl_mean: 0.6774, Took 12.8812secs
Iter: 0101, Tr_loss_mean: 0.3749, Tr_loss_rg_mean: 0.2921, Tr_loss_cl_mean: 0.6232, Took 18.3654secs
Iter: 0201, Tr_loss_mean: 0.3174, Tr_loss_rg_mean: 0.2196, Tr_loss_cl_mean: 0.6107, Took 18.9418secs
Iter: 0301, Tr_loss_mean: 0.2811, Tr_loss_rg_mean: 0.1868, Tr_loss_cl_mean: 0.5639, Took 19.0842secs
Iter: 0401, Tr_loss_mean: 0.2584, Tr_loss_rg_mean: 0.1696, Tr_loss_cl_mean: 0.5247, Took 19.2515secs
Iter: 0501, Tr_loss_mean: 0.2436, Tr_loss_rg_mean: 0.1601, Tr_loss_cl_mean: 0.4940, Took 18.8412secs

Validation...
Iter: 0001, Val_loss_mean: 0.1558, Val_loss_rg_mean: 0.0948, Val_loss_cl_mean: 0.3385, Took 4.2126secs
Iter: 0101, Val_loss_mean: 0.1601, Val_loss_rg_mean: 0.0923, Val_loss_cl_mean: 0.3633, Took 11.0951secs
Iter: 0201, Val_loss_mean: 0.1608, Val_loss_rg_mean: 0.0922, Val_loss_cl_mean: 0.3666, Took 11.2257secs
Iter: 0301, Val_loss_mean: 0.1606, Val_loss_rg_mean: 0.0921, Val_los

Iter: 0501, Val_loss_mean: 0.1241, Val_loss_rg_mean: 0.0733, Val_loss_cl_mean: 0.2767, Took 10.7907secs

Epoch 5 ended
Training:  1:41
Validation:  1:03
Loss decreased from 0.1244 to 0.1242
Checkpointing and saving predictions to:
results/test_ming_20220224
Learning rate decreased to: 1.000e-06


Starting epoch: 6
Training...
Iter: 0001, Tr_loss_mean: 0.1272, Tr_loss_rg_mean: 0.0742, Tr_loss_cl_mean: 0.2860, Took 3.3853secs
Iter: 0101, Tr_loss_mean: 0.1240, Tr_loss_rg_mean: 0.0734, Tr_loss_cl_mean: 0.2757, Took 18.3857secs
Iter: 0201, Tr_loss_mean: 0.1236, Tr_loss_rg_mean: 0.0727, Tr_loss_cl_mean: 0.2762, Took 18.7961secs
Iter: 0301, Tr_loss_mean: 0.1237, Tr_loss_rg_mean: 0.0727, Tr_loss_cl_mean: 0.2766, Took 18.6806secs
Iter: 0401, Tr_loss_mean: 0.1236, Tr_loss_rg_mean: 0.0726, Tr_loss_cl_mean: 0.2763, Took 19.2454secs
Iter: 0501, Tr_loss_mean: 0.1235, Tr_loss_rg_mean: 0.0727, Tr_loss_cl_mean: 0.2757, Took 18.4393secs

Validation...
Iter: 0001, Val_loss_mean: 0.1219, Val_loss_rg_mean:

Iter: 0101, Val_loss_mean: 0.1235, Val_loss_rg_mean: 0.0730, Val_loss_cl_mean: 0.2749, Took 10.7699secs
Iter: 0201, Val_loss_mean: 0.1238, Val_loss_rg_mean: 0.0730, Val_loss_cl_mean: 0.2764, Took 10.6784secs
Iter: 0301, Val_loss_mean: 0.1236, Val_loss_rg_mean: 0.0728, Val_loss_cl_mean: 0.2762, Took 10.7143secs
Iter: 0401, Val_loss_mean: 0.1235, Val_loss_rg_mean: 0.0727, Val_loss_cl_mean: 0.2760, Took 10.7344secs
Iter: 0501, Val_loss_mean: 0.1237, Val_loss_rg_mean: 0.0729, Val_loss_cl_mean: 0.2764, Took 9.6659secs

Epoch 11 ended
Training:  1:43
Validation:  1:00
Loss decreased from 0.1238 to 0.1238
Checkpointing and saving predictions to:
results/test_ming_20220224
Learning rate decreased to: 1.000e-09


Starting epoch: 12
Training...
Iter: 0001, Tr_loss_mean: 0.1270, Tr_loss_rg_mean: 0.0705, Tr_loss_cl_mean: 0.2968, Took 3.3840secs
Iter: 0101, Tr_loss_mean: 0.1241, Tr_loss_rg_mean: 0.0734, Tr_loss_cl_mean: 0.2761, Took 18.2638secs
Iter: 0201, Tr_loss_mean: 0.1235, Tr_loss_rg_mean: 0.0


Validation...
Iter: 0001, Val_loss_mean: 0.1234, Val_loss_rg_mean: 0.0785, Val_loss_cl_mean: 0.2584, Took 2.9635secs
Iter: 0101, Val_loss_mean: 0.1236, Val_loss_rg_mean: 0.0730, Val_loss_cl_mean: 0.2754, Took 10.8453secs
Iter: 0201, Val_loss_mean: 0.1239, Val_loss_rg_mean: 0.0729, Val_loss_cl_mean: 0.2769, Took 11.0873secs
Iter: 0301, Val_loss_mean: 0.1236, Val_loss_rg_mean: 0.0728, Val_loss_cl_mean: 0.2761, Took 10.9806secs
Iter: 0401, Val_loss_mean: 0.1236, Val_loss_rg_mean: 0.0727, Val_loss_cl_mean: 0.2760, Took 10.9541secs
Iter: 0501, Val_loss_mean: 0.1238, Val_loss_rg_mean: 0.0728, Val_loss_cl_mean: 0.2765, Took 10.3336secs

Epoch 17 ended
Training:  1:46
Validation:  1:01
Loss didnt decrease from 0.1238
Learning rate decreased to: 1.000e-12


Starting epoch: 18
Training...
Iter: 0001, Tr_loss_mean: 0.1268, Tr_loss_rg_mean: 0.0739, Tr_loss_cl_mean: 0.2857, Took 3.1704secs
Iter: 0101, Tr_loss_mean: 0.1239, Tr_loss_rg_mean: 0.0733, Tr_loss_cl_mean: 0.2757, Took 18.4677secs
Iter: 02

In [28]:
graph_data_tr

GraphsTuple(nodes=<tf.Tensor: shape=(8905, 11), dtype=float32, numpy=
array([[ 2.0124032 ,  0.25      ,  2.9625819 , ...,  0.        ,
         0.        ,  0.        ],
       [ 1.2589308 ,  0.25      ,  2.9625587 , ...,  0.        ,
         0.        ,  0.        ],
       [ 0.87126935,  0.25      ,  2.9625828 , ...,  0.        ,
         0.        ,  0.        ],
       ...,
       [-1.6004568 ,  0.03571429,  0.55865055, ...,  0.        ,
         0.        ,  0.        ],
       [-1.7220358 ,  0.03571429,  0.56489515, ...,  0.        ,
         0.        ,  0.        ],
       [-1.8561385 ,  0.03571429,  0.5742621 , ...,  0.        ,
         0.        ,  0.        ]], dtype=float32)>, edges=<tf.Tensor: shape=(28045, 10), dtype=float32, numpy=
array([[1., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], 

In [29]:
targets_tr

<tf.Tensor: shape=(129, 2), dtype=float32, numpy=
array([[ 2.755323  ,  1.        ],
       [ 3.0317209 ,  0.        ],
       [ 0.8071258 ,  1.        ],
       [ 0.44958335,  1.        ],
       [ 1.6490003 ,  1.        ],
       [ 2.8600674 ,  0.        ],
       [ 1.84583   ,  0.        ],
       [ 2.3885376 ,  1.        ],
       [-0.6348547 ,  0.        ],
       [ 1.8926154 ,  1.        ],
       [ 1.1382633 ,  1.        ],
       [ 2.632147  ,  1.        ],
       [ 0.23887107,  1.        ],
       [ 1.5309668 ,  1.        ],
       [ 0.7487196 ,  1.        ],
       [ 2.7842493 ,  0.        ],
       [ 0.6117757 ,  1.        ],
       [ 2.8833387 ,  1.        ],
       [ 3.156846  ,  1.        ],
       [ 1.1219668 ,  0.        ],
       [-1.6813831 ,  1.        ],
       [ 0.9821644 ,  1.        ],
       [ 2.0277817 ,  0.        ],
       [ 0.294213  ,  1.        ],
       [-0.7414676 ,  1.        ],
       [-0.32059962,  1.        ],
       [ 0.01045107,  1.        ],
     