In [10]:
#!/usr/bin/env python
import sys
if "../" not in sys.path:
    sys.path.append("../")

import os
import shutil
import glob
import argparse
import numpy as np
import torch

from torch.utils.data import DataLoader
from importlib import import_module
import ast
from utils.logger import _logger
from utils.dataset import SimpleIterDataset
from utils.nn.tools import train, evaluate

sys.path

['/usr/share/DJC',
 '/usr/local/lib',
 '/storage/user/abao/weaver2/notebooks',
 '/usr/lib/python36.zip',
 '/usr/lib/python3.6',
 '/usr/lib/python3.6/lib-dynload',
 '',
 '/storage/user/abao/.local/lib/python3.6/site-packages',
 '/usr/local/lib/python3.6/dist-packages',
 '/usr/lib/python3/dist-packages',
 '/usr/local/lib/python3.6/dist-packages/IPython/extensions',
 '/tmp/tmpopl8xlcn',
 '../',
 '../',
 '../']

### Set Hyperparameters and File Paths

In [11]:
data_config = "../data/ak8_points_pf_sv_hww_ptmasswgt_h5.yaml"
data_train = "/data/shared/abao/DNNTuples/train/*.h5"
data_test = []
data_fraction = 1
file_fraction = 1
fetch_by_files = True
fetch_step = 10
train_val_split = 0.8
demo = False
lr_finder = None
network_config = "../networks/particle_net_pf_sv.py"
network_option = []
model_prefix = "../models/testh5"
num_epochs = 20
optimizer = "ranger"
load_epoch = None
start_lr = 2e-2
lr_steps = '10,20'
batch_size = 64
use_amp = False
gpus = "0,2,3,4,5"
num_workers = 4
predict = False
# predict_output = 
export_onnx = None
# io_test = False

# training/testing mode
training_mode = not predict

### Set Device (GPU if Possible)

In [None]:
# device
if gpus:
    gpus = [int(i) for i in gpus.split(',')]
    print(gpus,gpus[0])
    dev = torch.device(gpus[0])
else:
    gpus = None
    dev = torch.device('cpu')

### Data Loaders

In [None]:
# load data
if training_mode:
    filelist = sorted(sum([glob.glob(f) for f in data_train], []))
    # np.random.seed(1)
    np.random.shuffle(filelist)
    if demo:
        filelist = filelist[:20]
        print(filelist)
        data_fraction = 0.1
        fetch_step = 0.002
    num_workers = min(num_workers, int(len(filelist) * file_fraction))
    train_data = SimpleIterDataset(filelist, data_config, for_training=True, load_range_and_fraction=((0, train_val_split), data_fraction),
                                   file_fraction=file_fraction, fetch_by_files=fetch_by_files, fetch_step=fetch_step)
    val_data = SimpleIterDataset(filelist, data_config, for_training=True, load_range_and_fraction=((train_val_split, 1), data_fraction),
                                 file_fraction=file_fraction, fetch_by_files=fetch_by_files, fetch_step=fetch_step)
    train_loader = DataLoader(train_data, num_workers=num_workers, batch_size=batch_size, drop_last=True, pin_memory=True)
    val_loader = DataLoader(val_data, num_workers=num_workers, batch_size=batch_size, drop_last=True, pin_memory=True)
    data_config = train_data.config
else:
    filelist = sorted(sum([glob.glob(f) for f in data_test], []))
    num_workers = min(num_workers, len(filelist))
    test_data = SimpleIterDataset(filelist, data_config, for_training=False,
                                  load_range_and_fraction=((0, 1), data_fraction),
                                  fetch_by_files=True, fetch_step=1)
    test_loader = DataLoader(test_data, num_workers=num_workers, batch_size=batch_size, drop_last=False, pin_memory=True)
    data_config = test_data.config
    print('test loader done ')
    print(data_config)

### Define Model (Load Model File)

In [None]:
# model
network_module = import_module(network_config.replace('.py', '').replace('/', '.'))
network_options = {k:ast.literal_eval(v) for k, v in network_option}
if export_onnx:
    network_options['for_inference'] = True
if use_amp:
    network_options['use_amp'] = True
model, model_info = network_module.get_model(data_config, **network_options)
print(model)

### Option to Export to ONNX

In [None]:
# export to ONNX
if export_onnx:
    assert(export_onnx.endswith('.onnx'))
    model_path = model_prefix
    print('Exporting model %s to ONNX' % model_path)
    model.load_state_dict(torch.load(model_path, map_location='cpu'))
    model = model.cpu()
    model.eval()

    os.makedirs(os.path.dirname(export_onnx), exist_ok=True)
    inputs = tuple(torch.ones(model_info['input_shapes'][k], dtype=torch.float32) for k in model_info['input_names'])
    torch.onnx.export(model, inputs, export_onnx,
                      input_names=model_info['input_names'],
                      output_names=model_info['output_names'],
                      dynamic_axes=model_info.get('dynamic_axes', None),
                      opset_version=11)
    print('ONNX model saved to %s', export_onnx)

    preprocessing_json = os.path.join(os.path.dirname(export_onnx), 'preprocess.json')
    data_config.export_json(preprocessing_json)
    print('Preprocessing parameters saved to %s', preprocessing_json)
    return

## Training

In [None]:
# note: we should always save/load the state_dict of the original model, not the one wrapped by nn.DataParallel
# so we do not convert it to nn.DataParallel now
model = model.to(dev)

if training_mode:
    # loss function
    try:
        loss_func = network_module.get_loss(data_config, **network_options)
        print(loss_func)
    except AttributeError:
        loss_func = torch.nn.CrossEntropyLoss()
        print('Loss function not defined in %s. Will use `torch.nn.CrossEntropyLoss()` by default.', network_config)

    # optimizer & learning rate
    if optimizer == 'adam':
        opt = torch.optim.Adam(model.parameters(), lr=start_lr)
        if lr_finder is None:
            lr_steps = [int(x) for x in lr_steps.split(',')]
            scheduler = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=lr_steps, gamma=0.1)
    else:
        from utils.nn.optimizer.ranger import Ranger
        opt = Ranger(model.parameters(), lr=start_lr)
        if lr_finder is None:
            lr_decay_epochs = max(1, int(num_epochs * 0.3))
            lr_decay_rate = 0.01 ** (1. / lr_decay_epochs)
            scheduler = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=list(range(num_epochs - lr_decay_epochs, num_epochs)), gamma=lr_decay_rate)

    # load previous training and resume if `--load-epoch` is set
    if load_epoch is not None:
        print('Resume training from epoch %d' % load_epoch)
        model_state = torch.load(model_prefix + '_epoch-%d_state.pt' % load_epoch, map_location=dev)
        model.load_state_dict(model_state)
        opt_state = torch.load(model_prefix + '_epoch-%d_optimizer.pt' % load_epoch, map_location=dev)
        opt.load_state_dict(opt_state)

    # mutli-gpu
    if gpus is not None and len(gpus) > 1:
        model = torch.nn.DataParallel(model, device_ids=gpus)  # model becomes `torch.nn.DataParallel` w/ model.module being the orignal `torch.nn.Module`
    model = model.to(dev)

    # lr finder: keep it after all other setups
    if lr_finder is not None:
        start_lr, end_lr, num_iter = lr_finder.replace(' ', '').split(',')
        from utils.lr_finder import LRFinder
        lr_finder = LRFinder(model, opt, loss_func, device=dev, input_names=train_data.config.input_names, label_names=train_data.config.label_names)
        lr_finder.range_test(train_loader, start_lr=float(start_lr), end_lr=float(end_lr), num_iter=int(num_iter))
        lr_finder.plot(output='lr_finder.png')  # to inspect the loss-learning rate graph
        return

    if use_amp:
        from torch.cuda.amp import GradScaler
        scaler = GradScaler()
    else:
        scaler = None

    # training loop
    best_valid_acc = 0
    acc_vals_validation = np.zeros(num_epochs)
    loss_vals_training = np.zeros(num_epochs)
    loss_std_training = np.zeros(num_epochs)
    loss_vals_validation = np.zeros(num_epochs)
    loss_std_validation = np.zeros(num_epochs)
    for epoch in range(num_epochs):
        if load_epoch is not None:
            if epoch <= load_epoch:
                continue
        print('-' * 50)
        print('Epoch #%d training' % epoch)
        loss_mean,loss_std = train(model, loss_func, opt, scheduler, train_loader, dev, grad_scaler=scaler)
        loss_vals_training[epoch] =loss_mean
        loss_std_training[epoch] = loss_std

        if model_prefix:
            dirname = os.path.dirname(model_prefix)
            if dirname and not os.path.exists(dirname):
                os.makedirs(dirname)
            state_dict = model.module.state_dict() if isinstance(model, torch.nn.DataParallel) else model.state_dict()
            torch.save(state_dict, model_prefix + '_epoch-%d_state.pt' % epoch)
            torch.save(opt.state_dict(), model_prefix + '_epoch-%d_optimizer.pt' % epoch)

        print('Epoch #%d validating' % epoch)
        valid_acc,loss_mean,loss_std = evaluate(model, val_loader, dev, loss_func=loss_func)
        loss_vals_validation[epoch] =loss_mean
        loss_std_validation[epoch] = loss_std
        acc_vals_validation[epoch] = valid_acc
        if valid_acc > best_valid_acc:
            best_valid_acc = valid_acc
            if model_prefix:
                shutil.copy2(model_prefix + '_epoch-%d_state.pt' % epoch, model_prefix + '_best_acc_state.pt')
                torch.save(model, model_prefix + '_best_acc_full.pt')
        print('Epoch #%d: Current validation acc: %.5f (best: %.5f)' % (epoch, valid_acc, best_valid_acc))

    dirname = os.path.dirname('%s_history/'%model_prefix)
    if dirname and not os.path.exists(dirname):
        os.makedirs(dirname)

    np.save('%s_history/acc_vals_validation.npy'%(model_prefix),acc_vals_validation)
    np.save('%s_history/loss_vals_training.npy'%(model_prefix),loss_vals_training)
    np.save('%s_history/loss_vals_validation.npy'%(model_prefix),loss_vals_validation)
    np.save('%s_history/loss_std_validation.npy'%(model_prefix),loss_std_validation)
    np.save('%s_history/loss_std_training.npy'%(model_prefix),loss_std_training)

## Prediction

In [None]:
if not training_mode:
    # run prediction
    if model_prefix.endswith('.onnx'):
        print('Loading model %s for eval' % model_prefix)
        from utils.nn.tools import evaluate_onnx
        test_acc, scores, labels, observers = evaluate_onnx(model_prefix, test_loader)
    else:
        model_path = model_prefix if model_prefix.endswith('.pt') else model_prefix + '_best_acc_state.pt'
        print('Loading model %s for eval' % model_path)
        model.load_state_dict(torch.load(model_path, map_location=dev))
        if gpus is not None and len(gpus) > 1:
            model = torch.nn.DataParallel(model, device_ids=gpus)
        model = model.to(dev)
        test_acc, scores, labels, observers = evaluate(model, test_loader, dev, for_training=False)
    print('Test acc %.5f' % test_acc)

    if predict_output:
        os.makedirs(os.path.dirname(predict_output), exist_ok=True)
        if predict_output.endswith('.root'):
            from utils.data.fileio import _write_root
            output = {}
            for idx, label_name in enumerate(data_config.label_value):
                output[label_name] = (labels[data_config.label_names[0]] == idx)
                output['score_' + label_name] = scores[:, idx]
            for k, v in labels.items():
                if k == data_config.label_names[0]:
                    continue
                if v.ndim > 1:
                    print('Ignoring %s, not a 1d array.', k)
                    continue
                output[k] = v
            for k, v in observers.items():
                if v.ndim > 1:
                    print('Ignoring %s, not a 1d array.', k)
                    continue
                output[k] = v
            _write_root(predict_output, output)
        else:
            import awkward
            output = {'scores':scores}
            output.update(labels)
            output.update(observers)

            name_remap = {}
            arraynames = list(output)
            for i in range(len(arraynames)):
                for j in range(i + 1, len(arraynames)):
                    if arraynames[i].startswith(arraynames[j]):
                        name_remap[arraynames[j]] = '%s_%d' % (arraynames[j], len(name_remap))
                    if arraynames[j].startswith(arraynames[i]):
                        name_remap[arraynames[i]] = '%s_%d' % (arraynames[i], len(name_remap))
            print('Renamed the following variables in the output file: %s', str(name_remap))
            output = {name_remap[k] if k in name_remap else k: v for k, v in output.items()}

            awkward.save(predict_output, output, mode='w')

        print('Written output to %s' % predict_output)
