### Set autoreloading
This extension will automatically update with any changes to packages in real time

In [100]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Import packages
We'll need the `pytorch_lightning` and `nugraph` packages imported in order to train

In [101]:
import os
from pathlib import Path
import nugraph as ng
import pytorch_lightning as pl

### Set model and data to use

This allows the user to switch out different model architectures and datasets

In [102]:
Data = ng.data.NuGraphDataModule
Model = ng.models.NuGraph3

### Configure data module
Declare a data module. Depending on where you're working, you should edit the data path below to point to a valid data location.

In [103]:
nudata = Data(model=Model, data_path="/exp/sbnd/app/users/yuhw/nugraph/pywcml/sample/2level.h5")
# nudata = Data(model=Model, data_path="/scratch/7DayLifetime/yuhw/wirecell/nugraph/data/23334072.h5")

### Configure network
Declare a model. You can edit the arguments below to change the network configuration.

In [104]:
nugraph = Model(
    in_features=8,
    hit_features=128,
    nexus_features=32,
    instance_features=32,
    interaction_features=32,
    semantic_classes=nudata.semantic_classes,
    event_classes=nudata.event_classes,
    num_iters=5,
    event_head=False,
    semantic_head=True,
    filter_head=True,
    vertex_head=False,
    instance_head=True,
    use_checkpointing=True,
    lr=0.001)

In [105]:
import os

# Set environment variable for NUGRAPH_LOG
os.environ["NUGRAPH_LOG"] = "/exp/sbnd/app/users/yuhw/nugraph/log"

# Confirm the variable has been set
print(f"NUGRAPH_LOG is now set to: {os.environ['NUGRAPH_LOG']}")

NUGRAPH_LOG is now set to: /exp/sbnd/app/users/yuhw/nugraph/log


### Configure logger and callbacks
Declare a tensorboard logger and define the output directory, so we can monitor network training. Also define a callback so we can monitor learning rate evolution.

In [106]:
name = "test-10-03"
logdir = Path(os.environ["NUGRAPH_LOG"])/name
logdir.mkdir(parents=True, exist_ok=True)
logger = pl.loggers.WandbLogger(save_dir=logdir, project="nugraph3", name=name,
                                log_model="all")
callbacks = [
    pl.callbacks.LearningRateMonitor(logging_interval="step"),
    pl.callbacks.ModelCheckpoint(monitor="loss/val", mode="min"),
]

### Declare trainer and run training
First we set the training device. To train with a GPU, pass an integer  otherwise, it defaults to CPU training. We then instantiate a PyTorch Lightning trainer that we'll use for training, and then run the training stage, which iterates over all batches in the train and validation datasets to optimise model parameters, writing output metrics to tensorboard.

In [107]:
accelerator, devices = ng.util.configure_device(0)
print(f"Using accelerator={accelerator}, devices={devices}")
trainer = pl.Trainer(accelerator=accelerator,
                     devices=devices,
                     max_epochs=10,
                     logger=logger,
                     callbacks=callbacks)
trainer.fit(nugraph, datamodule=nudata)
trainer.test(datamodule=nudata)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Using accelerator=gpu, devices=[0]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
/exp/sbnd/app/users/yuhw/dl-clustering/venv_eaf/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.

  | Name             | Type            | Params | Mode 
-------------------------------------------------------------
0 | encoder          | Encoder         | 1.2 K  | train
1 | core_net         | NuGraphCore     | 99.7 K | train
2 | semantic_decoder | SemanticDecoder | 259    | train
3 | filter_decoder   | FilterDecoder   | 130    | train
4 | instance_decoder | InstanceDecoder | 4.3 K  | train
-------------------------------------------------------------
105 K     Trainable params
17        Non-trainable params
105 K     Total params
0.422     Total estimated model params s

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.
Restoring states from the checkpoint path at /exp/sbnd/app/users/yuhw/nugraph/log/test-10-03/nugraph3/y9guviry/checkpoints/epoch=9-step=10.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /exp/sbnd/app/users/yuhw/nugraph/log/test-10-03/nugraph3/y9guviry/checkpoints/epoch=9-step=10.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

[{'loss/test': 10.125808715820312,
  'semantic/loss-test': 0.18141819536685944,
  'semantic/recall-test': 0.8682976961135864,
  'semantic/precision-test': 0.8682976961135864,
  'filter/loss-test': 1.2939022779464722,
  'filter/recall-test': 0.0,
  'filter/precision-test': 0.0,
  'instance/loss-test': 8.650487899780273,
  'instance/bkg-loss-test': 0.5245781540870667,
  'instance/potential-loss-test': 8.121854782104492,
  'instance/adjusted-rand-test': 1.0}]

In [108]:
import wandb
while wandb.run is not None:
    wandb.finish(quiet=True)

0,1
epoch,▁▂▂▃▄▅▅▆▇▇█
filter/loss-test,▁
filter/loss-val,█▇▆▅▄▂▂▁▁▁
filter/precision-test,▁
filter/precision-val,▁▁▁▁▁▁▁▁▁▁
filter/recall-test,▁
filter/recall-val,▁▁▁▁▁▁▁▁▁▁
instance/adjusted-rand-test,▁
instance/adjusted-rand-val,▁▁▁▁▁▁▁▁▁▁
instance/bkg-loss-test,▁

0,1
epoch,10
filter/loss-test,1.2939
filter/loss-val,1.29387
filter/precision-test,0
filter/precision-val,0
filter/recall-test,0
filter/recall-val,0
instance/adjusted-rand-test,1
instance/adjusted-rand-val,1
instance/bkg-loss-test,0.52458
