In [1]:
import os
import sys
from pathlib import Path
import glob
PATH_TO_PARTGP = os.getenv("PATH_TO_PARTGP")
sys.path.append(os.path.abspath(PATH_TO_PARTGP))

import numpy
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import uproot

from weaver.utils.import_tools import import_module
from weaver.utils.dataset import SimpleIterDataset
from weaver.train import test_load
from weaver.utils.nn.tools import evaluate_classification

from utils.nn_utils.hook_handler import register_forward_hooks, remove_all_forward_hooks

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
training = False
load_model = True
datasets_path = os.getenv("PART_DATA")

In [None]:
# Arguments for weaver
class Args:
    def __init__(self, **kwargs):
        # defaults
        self.data_test = []
        self.num_workers = 0
        self.data_config = ''
        self.extra_test_selection = None
        self.data_fraction = 0
        self.batch_size = 0
        
        for key, value in kwargs.items():
            setattr(self, key, value)

In [None]:
# Set up paths

data_types = ['full', 'kinpid', 'kin']

weight_paths = {
    'full': 'part_models/ParT_full.pt',
    'kinpid': 'part_models/ParT_kinpid.pt',
    'kin': 'part_models/ParT_kin.pt'
}

config_paths = {
    'full': 'data_config/JetClass/JetClass_full.yaml',
    'kinpid': 'data_config/JetClass/JetClass_kinpid.yaml',
    'kin': 'data_config/JetClass/JetClass_kin.yaml'
}

jetclass_data = {
    'train': 'JetClass/Pythia/train_100M',
    'validation': 'JetClass/Pythia/val_5M',
    'test': 'JetClass/Pythia/test_20M'
}

In [None]:
path_to_network = (Path(PATH_TO_PARTGP) / 'part_models/part_wrapper.py')

network_module = import_module(path_to_network, name='_network_module')

In [None]:
def initialize_models(types, weights, configs):

    models = {}

    for t in types:
        
        # Initialize Model Template

        config_path = (Path(PATH_TO_PARTGP) / configs[t])
        data_config = SimpleIterDataset({}, str(config_path), for_training=training).config
        model, model_info = network_module.get_model(data_config)
    
        # Load Weights
        if load_model:
            weights_path = (Path(PATH_TO_PARTGP) / weights[t])
            wts = torch.load(str(weights_path), map_location = device, weights_only = True)
            model.load_state_dict(wts)
            
        # Let's organize the model and extra info
        models[t] = {
            'model': model,
            'info': model_info,
            'loss': network_module.get_loss(data_config)
        }
        
    return models

models = initialize_models(data_types, weight_paths, config_paths)

In [None]:
full = models['full']['model']

part_full = full.to(device)

In [None]:
model = part_full

In [2]:
for name, module in model.named_modules():
    print(name)

NameError: name 'model' is not defined

In [None]:
interesting_layers = {
    'post_layer_embed': model.mod.embed.embed,
    'post_pair_embed': model.mod.pair_embed.embed,
    'first_layer_attn': model.mod.blocks[0].attn,
    'first_layer_block': model.mod.blocks[0],
    'final_layer_attn': model.mod.blocks[7].attn,
    'final_layer_block': model.mod.blocks[7],
    'first_cls_attn': model.mod.cls_blocks[0].attn,
    'first_cls_block': model.mod.cls_blocks[0],
    'final_cls_attn': model.mod.cls_blocks[1].attn,
    'final_cls_block': model.mod.cls_blocks[1],
    'logits': model.mod.fc
}

In [None]:
outputs = {}

remove_all_forward_hooks(model) # Safety Precaution

handles = register_forward_hooks(location_dict=interesting_layers, outputs=outputs)

In [None]:
# Check to see if hooks exist

print(interesting_layers['logits']._forward_hooks)
print(handles['logits'])

In [None]:
# Paths

training_set = str(Path(datasets_path) / jetclass_data['train'])
demo_files = glob.glob(training_set + '/*_000.root')
config_path = str(Path(PATH_TO_PARTGP) / config_paths['full'])

# Set arguments
args = Args(data_test = demo_files, data_config = config_path, data_fraction = 0.005, batch_size = 32) # Run an incredibly small demo

test_loaders, data_config = test_load(args)

for name, get_test_loader in test_loaders.items():
    
    test_loader = get_test_loader()
    
    test_metric, scores, labels, observers = evaluate_classification(model, test_loader, device, epoch=None, for_training=False)
    
    del test_loader

In [None]:
observers['jet_pt']

In [None]:
scores # logits

In [None]:
# Iteration 1 Outputs

print(outputs['logits'][0])

In [None]:
processed = torch.cat(outputs['logits']).cpu() # Concatenate batches

In [None]:
processed.numpy()