In [14]:
import os
import sys
from pathlib import Path
import glob
PATH_TO_PARTGP = os.getenv("PATH_TO_PARTGP")
sys.path.append(os.path.abspath(PATH_TO_PARTGP))

import numpy
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import uproot

from weaver.nn.model.ParticleTransformer import ParticleTransformer
from weaver.utils.import_tools import import_module
from weaver.utils.dataset import SimpleIterDataset
from weaver.train import test_load
from weaver.utils.nn.tools import evaluate_classification

from utils.nn_utils.hook_handler import register_forward_hooks, remove_all_forward_hooks

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
training = False
load_model = True
datasets_path = os.getenv("PART_DATA")

In [16]:
# Arguments for weaver
class Args:
    def __init__(self, **kwargs):
        # defaults
        self.data_test = []
        self.num_workers = 0
        self.data_config = ''
        self.extra_test_selection = None
        self.data_fraction = 0
        self.batch_size = 0
        
        for key, value in kwargs.items():
            setattr(self, key, value)

In [4]:
# Set up paths

data_types = ['full', 'kinpid', 'kin']

weight_paths = {
    'full': 'part_models/ParT_full.pt',
    'kinpid': 'part_models/ParT_kinpid.pt',
    'kin': 'part_models/ParT_kin.pt'
}

config_paths = {
    'full': 'data_config/JetClass/JetClass_full.yaml',
    'kinpid': 'data_config/JetClass/JetClass_kinpid.yaml',
    'kin': 'data_config/JetClass/JetClass_kin.yaml'
}

jetclass_data = {
    'train': 'JetClass/Pythia/train_100M',
    'validation': 'JetClass/Pythia/val_5M',
    'test': 'JetClass/Pythia/test_20M'
}

In [5]:
path_to_network = (Path(PATH_TO_PARTGP) / 'part_models/ParticleTransformer.py')

network_module = import_module(path_to_network, name='_network_module')

In [6]:
def initialize_models(types, weights, configs):

    models = {}

    for t in types:
        
        # Initialize Model Template

        config_path = (Path(PATH_TO_PARTGP) / configs[t])
        data_config = SimpleIterDataset({}, str(config_path), for_training=training).config
        model, model_info = network_module.get_model(data_config)
    
        # Load Weights
        if load_model:
            weights_path = (Path(PATH_TO_PARTGP) / weights[t])
            wts = torch.load(str(weights_path), map_location = device, weights_only = True)
            model.load_state_dict(wts)
            
        # Let's organize the model and extra info
        models[t] = {
            'model': model,
            'info': model_info,
            'loss': network_module.get_loss(data_config)
        }
        
    return models

models = initialize_models(data_types, weight_paths, config_paths)

In [7]:
full = models['full']['model']

part_full = full.to(device)

In [8]:
model = part_full

In [9]:
for name, module in model.named_modules():
    print(name)


mod
mod.trimmer
mod.embed
mod.embed.input_bn
mod.embed.embed
mod.embed.embed.0
mod.embed.embed.1
mod.embed.embed.2
mod.embed.embed.3
mod.embed.embed.4
mod.embed.embed.5
mod.embed.embed.6
mod.embed.embed.7
mod.embed.embed.8
mod.pair_embed
mod.pair_embed.embed
mod.pair_embed.embed.0
mod.pair_embed.embed.1
mod.pair_embed.embed.2
mod.pair_embed.embed.3
mod.pair_embed.embed.4
mod.pair_embed.embed.5
mod.pair_embed.embed.6
mod.pair_embed.embed.7
mod.pair_embed.embed.8
mod.pair_embed.embed.9
mod.pair_embed.embed.10
mod.pair_embed.embed.11
mod.pair_embed.embed.12
mod.blocks
mod.blocks.0
mod.blocks.0.pre_attn_norm
mod.blocks.0.attn
mod.blocks.0.attn.out_proj
mod.blocks.0.post_attn_norm
mod.blocks.0.dropout
mod.blocks.0.pre_fc_norm
mod.blocks.0.fc1
mod.blocks.0.act
mod.blocks.0.act_dropout
mod.blocks.0.post_fc_norm
mod.blocks.0.fc2
mod.blocks.1
mod.blocks.1.pre_attn_norm
mod.blocks.1.attn
mod.blocks.1.attn.out_proj
mod.blocks.1.post_attn_norm
mod.blocks.1.dropout
mod.blocks.1.pre_fc_norm
mod.blo

In [17]:
interesting_layers = {
    'post_layer_embed': model.mod.embed.embed,
    'post_pair_embed': model.mod.pair_embed.embed,
    'first_layer_attn': model.mod.blocks[0].attn,
    'first_layer_block': model.mod.blocks[0],
    'final_layer_attn': model.mod.blocks[7].attn,
    'final_layer_block': model.mod.blocks[7],
    'first_cls_attn': model.mod.cls_blocks[0].attn,
    'first_cls_block': model.mod.cls_blocks[0],
    'final_cls_attn': model.mod.cls_blocks[1].attn,
    'final_cls_block': model.mod.cls_blocks[1],
    'final_logits': model.mod.fc
}

In [11]:
outputs = {name: [] for name in interesting_layers.keys()}

# Safety Precaution

remove_all_forward_hooks(model)

handles = register_forward_hooks(location_dict=interesting_layers, outputs=outputs)

Removing all hooks within the model
Forward Hook Registered: post_layer_embed
Forward Hook Registered: post_pair_embed
Forward Hook Registered: first_layer_attn
Forward Hook Registered: first_layer_block
Forward Hook Registered: final_layer_attn
Forward Hook Registered: final_layer_block
Forward Hook Registered: first_cls_attn
Forward Hook Registered: first_cls_block
Forward Hook Registered: final_cls_attn
Forward Hook Registered: final_cls_block
Forward Hook Registered: final_logits


In [12]:
# Check to see if hooks exist

print(interesting_layers['final_logits']._forward_hooks)
print(handles['final_logits'])

OrderedDict([(10, <function save_outputs.<locals>.hook at 0x7c53f37c04c0>)])
<torch.utils.hooks.RemovableHandle object at 0x7c53f5149630>


In [14]:
# Paths

training_set = str(Path(datasets_path) / jetclass_data['train'])
demo_files = glob.glob(training_set + '/*_000.root')
config_path = str(Path(PATH_TO_PARTGP) / config_paths['full'])

# Set arguments
args = Args(data_test = demo_files, data_config = config_path, data_fraction = 0.005, batch_size = 32) # Run an incredibly small demo

test_loaders, data_config = test_load(args)

for name, get_test_loader in test_loaders.items():
    
    test_loader = get_test_loader()
    
    test_metric, scores, labels, observers = evaluate_classification(model, test_loader, device, epoch=None, for_training=False)
    
    del test_loader

0it [00:00, ?it/s]

=== Restarting DataIter test_, seed=None ===


157it [00:11, 14.03it/s, Loss=0.00000, AvgLoss=0.00000, Acc=0.75000, AvgAcc=0.85840]


In [24]:
observers['jet_pt']

In [26]:
# Iteration 1 Outputs

print(outputs['final_logits'][0])

tensor([[ 1.2881,  8.6138, -1.9069,  1.8779, -4.5168, -8.0102,  3.4128, -5.4139,
         -3.9643, -5.0155],
        [-0.2050,  1.2807, -0.5661, -1.1434, -2.0372, -0.8013,  1.1555,  2.0963,
         -1.2802, -0.7061],
        [-0.8722,  2.5090,  2.1592,  4.2234,  2.8238, -1.4500, -1.9209, -3.1409,
         -3.4871, -2.6612],
        [ 3.8417, 10.3349, -0.9703, -0.8127, -5.9063, -6.3644,  2.8431, -6.6620,
         -1.7865, -0.4859],
        [ 2.4278, 11.6852, -2.4629,  0.4811, -6.8531, -8.7825,  4.4419, -6.4352,
         -3.9579, -2.8979],
        [ 1.3520,  6.2128, -1.4011, -0.6865, -0.1083, -6.2026,  0.9606, -3.0825,
          1.5226, -2.0052],
        [-0.3114,  8.2566, -4.1326, -2.0885, -8.8568, -5.1508,  7.9680, -3.0312,
         -6.3298, -3.8241],
        [-0.2341,  8.4554, -1.9621,  4.0125, -1.8350, -9.6541,  2.4948, -6.2103,
         -2.1334, -5.1094],
        [ 2.9507,  8.2878, -3.2807,  4.1379, -5.9809, -8.4298,  2.9307, -5.4795,
         -2.6054, -2.4350],
        [-0.5854,  