- FIXME: Copied from stttrkx-iml

In [1]:
import sys, os, glob, yaml
import pprint, pkg_resources

In [2]:
# environment variables first
with open('stt/configs/prep_stt.yaml') as f:
    config = yaml.safe_load(f)
    
input_dir = os.path.expandvars(config['input_dir'])
os.environ['TRKXINPUTDIR'] = input_dir
os.environ['TRKXOUTPUTDIR'] = "stt_output"

In [3]:
# deep learning imports
import torch
import tensorflow as tf
import pytorch_lightning as pl
from pytorch_lightning import callbacks as pl_callbacks
from pytorch_lightning import loggers as pl_loggers
import trackml.dataset

In [4]:
# local imports
from stt import config_dict             # for accessing predefined configuration files
from stt import outdir_dict             # for accessing predefined output directories
from stt.src import utils_dir           # for accessing directory sturcture

In [5]:
# for preprocessing
from stt import FeatureStore

# for embedding
from stt import LayerlessEmbedding
from stt import EmbeddingInferenceCallback

# for filtering
from stt import VanillaFilter
from stt import FilterInferenceCallback

In [6]:
# print(os.environ['TRKXINPUTDIR'])

In [7]:
# print(os.environ['TRKXOUTPUTDIR'])

In [8]:
# print("tf: {}, pytorch: {}".format(tf.__version__, torch.__version__))

In [9]:
pp = pprint.PrettyPrinter(indent=2)

## _Loading Data from Processing_

In [10]:
!ls stt_output/feature_store

0   13	18  22	27  31	36  40	45  5	54  59	63  68	72  77	81  86	90  95
1   14	19  23	28  32	37  41	46  50	55  6	64  69	73  78	82  87	91  96
10  15	2   24	29  33	38  42	47  51	56  60	65  7	74  79	83  88	92  97
11  16	20  25	3   34	39  43	48  52	57  61	66  70	75  8	84  89	93  98
12  17	21  26	30  35	4   44	49  53	58  62	67  71	76  80	85  9	94  99


In [11]:
evtid = 1
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [12]:
# feature_data = torch.load(os.path.join(utils_dir.feature_outdir, str(evtid))).to(device)
feature_data = torch.load(os.path.join(utils_dir.feature_outdir, str(evtid)), map_location=device)

print("Length of Data: {}".format(len(feature_data)))
pp.pprint(feature_data)

Length of Data: 7
Data(event_file="/home/adeel/current/data_sets/sttiml/data-3-7-GeV/event0000000001", hid=[255], layerless_true_edges=[2, 251], layers=[255], pid=[255], pt=[255], x=[255, 3])


In [13]:
# feature_data.x

In [14]:
# feature_data.pt

## _Loading Data from Embedding_

In [15]:
!ls stt_output/embedding_output/train/

0000  0001  0011  0012	0013  0015  0016  0017


In [16]:
embedding_data = torch.load(os.path.join(utils_dir.embedding_outdir, "train", "0002"), map_location=device)

print("Length of Data: {}".format(len(embedding_data)))

FileNotFoundError: [Errno 2] No such file or directory: 'stt_output/embedding_output/train/0002'

In [None]:
embedding_data

In [None]:
# embedding_data.x

In [None]:
# embedding_data.pt

## _Filtering_

In [None]:
action = 'filtering'

config_file = pkg_resources.resource_filename("stt", os.path.join('configs', config_dict[action]))
with open(config_file) as f:
    f_config = yaml.load(f, Loader=yaml.FullLoader)

In [None]:
pp.pprint(f_config)

In [None]:
# Let change the config for STT
f_config['train_split'] = [1, 1, 1]
f_config['in_channels'] = 3
f_config['regime'] = ['non-ci']

In [None]:
pp.pprint(f_config)

In [None]:
f_model = VanillaFilter(f_config)

In [None]:
print(f_model)

### _(+) Loggers_

- Logging from a LightningModule (pl.Trainer(logger=true)
- TensorBoardLogger (default)
- TestTubeLogger (Test Tube is a TensorBoard logger but with nicer file structure.)

In [None]:
# from pytorch_lightning.loggers import TensorBoardLogger, TestTubeLogger
# from pytorch_lightning import loggers as pl_loggers

In [None]:
# To use TensorBoard as your logger do the following.
# tb_logger = pl_loggers.TensorBoardLogger('tb_logs', name='filtering')
# Then pass it to pl.Trainer(logger=tb_logger)

In [None]:
# To use TestTube as your logger do the following.
# tt_logger = pl_loggers.TestTubeLogger('tt_logs', name='filtering')
# Then pass it to pl.Trainer(logger=tt_logger)

In [None]:
# Lightning supports the use of multiple loggers. 
# Then pass a list to pl.Trainer(logger=[tb_logger, tt_logger])

### _(+) Callbacks_

- ModelCheckpoint() Callback
- FilterInferenceCallback() Custom Callback

In [None]:
# ModelCheckpoint Callback
f_checkpoint_callback = pl.callbacks.ModelCheckpoint(
    filepath=os.path.join(utils_dir.filtering_outdir,'ckpt-{epoch:02d}-{val_loss:.2f}'),
    
    # NOTE: The 'filepath' is deprecated since version 1.0. Use 'dirpath' + 'filename' instead.  
    # dirpath=os.path.join(utils_dir.feature_outdir),
    # filename='ckpt-{epoch:02d}-{val_loss:.2f}-{other_metric:.2f}',
    
    monitor='val_loss',
    save_top_k=3,
    mode='min'
    )

# Inference Callback
f_callback_list = [FilterInferenceCallback()]

In [None]:
tt_logger = pl.loggers.TestTubeLogger('tt_logs', name='filtering')

In [None]:
# Lightning Trainer
f_trainer = pl.Trainer(
    #logger=True,                                # Gives Default (True) TensorBoardLogger.
    logger=tt_logger,                            # Can Pass Loggers : TensorBoard, TestTube, Neptune, etc.

    checkpoint_callback=f_checkpoint_callback,
    # NOTE: Passing a ModelCheckpoint instance to this argument is deprecated since 
    # v1.1 and will be unsupported from v1.3. Use to callbacks argument instead. Do
    # checkpoint_callback=True,
    # callbacks=[FilterInferenceCallback(), f_checkpoint_callback],
    callbacks=f_callback_list,
    
    # default_root_dir='/your/path/to/save/checkpoints',
    
    max_epochs=10, 
    min_epochs=1, 
    limit_train_batches=1.0, 
    limit_val_batches=1.0,
    process_position=0, 
    num_nodes=1, 
    num_processes=1, 
    gpus=None,
    progress_bar_refresh_rate=1       # custom progress callback disable this option.
    )

In [None]:
# Run Training
f_trainer.fit(f_model)

### _(+) Loading Checkpoints_

After training finishes, use `checkpoint_callback.best_model_path` to retrieve the path to the best checkpoint file and `checkpoint_callback.best_model_score` to retrieve its score.

In [None]:
# manually save checkpoints
# trainer.save_checkpoint("example.ckpt")
# new_model = f_model.load_from_checkpoint(checkpoint_path="example.ckpt")

In [None]:
f_checkpoint_callback.best_model_path

In [None]:
new_model = VanillaFilter(f_config)
new_model = new_model.load_from_checkpoint(checkpoint_path=f_checkpoint_callback.best_model_path, hparams=f_config)

In [None]:
new_model.eval()

In [None]:
# model loaded from best checkpoint
# new_model(filtering_data)

### _(+) View TensorBoard Logs_

In [None]:
# first load tensorboard extension
# %load_ext tensorboard

In [None]:
# let check the test tube logs
# %tensorboard --logdir lightning_logs

In [None]:
# let check the test tube logs
# %tensorboard --logdir tt_logs/filtering

### _(+) Efficiency_

### _(+) Loading Data from Filtering_

In [None]:
# list training output
!ls stt_output/filtering_output/train/

In [None]:
filtering_data = torch.load(os.path.join(utils_dir.filtering_outdir, "train", "0002"), map_location=device)

print("Length of Data: {}".format(len(filtering_data)))

In [None]:
filtering_data

In [None]:
filtering_data.pt

In [None]:
filtering_data.x