In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import yaml

from torch_geometric.data import Data

from acorn.stages.edge_classifier import InteractionGNN2
from acorn.stages.edge_classifier import RecurrentInteractionGNN2, ChainedInteractionGNN2
from acorn.utils import handle_edge_features





In [3]:
model_config_file = "../examples/CTD_2023/gnn_train.yaml"
with open(model_config_file, "r") as f:
    model_config = yaml.safe_load(f)

In [4]:
gnnModel = InteractionGNN2(model_config)

In [5]:
filename = "/pscratch/sd/x/xju/ITk/ForFinalPaper/CHEP2024_data_allITk/metric_learning_v3/trainset/event000000027.pyg"
data = torch.load(filename, map_location=torch.device('cpu'))
handle_edge_features(data, model_config["edge_features"])

In [8]:
reduced_num_edges = 1000
data["edge_index"] = data.edge_index[:, :reduced_num_edges]
for key in model_config["edge_features"]:
    data[key] = data[key][:reduced_num_edges]

In [10]:
with torch.no_grad():
    output = gnnModel(data)
    print(output.shape)

torch.Size([1000])


In [15]:
node_features = torch.stack([data[feature] for feature in model_config["node_features"]], dim=-1).float()
mask =  torch.logical_or(data.region == 2, data.region == 6).reshape(-1)
node_features[mask] = torch.cat([node_features[mask, 0:4], node_features[mask, 0:4], node_features[mask, 0:4]], dim=1)

edge_attr = torch.stack([data[feature] for feature in model_config["edge_features"]], dim=-1).float()
edge_index = data.edge_index
input_data = [node_features, edge_index, edge_attr]

In [13]:
is_recurrent = model_config["node_net_recurrent"] and model_config["edge_net_recurrent"]
print(f"Is recurrent: {is_recurrent}")

is recurrent: False


In [16]:
if is_recurrent:
    new_gnn = RecurrentInteractionGNN2(model_config)
else:
    new_gnn = ChainedInteractionGNN2(model_config)

In [19]:
# copy the weights from the original model to the new model
new_gnn.load_state_dict(gnnModel.state_dict())

<All keys matched successfully>

In [23]:
input_data = [node_features, edge_index, edge_attr]

In [20]:
with torch.no_grad():
    new_output = new_gnn(*input_data)
assert new_output.equal(output)

In [24]:
with torch.jit.optimized_execution(True):
    script = new_gnn.to_torchscript(example_inputs=[input_data])

with torch.no_grad():
    script_output = script(*input_data)
torch.jit.freeze(script)
assert script_output.equal(output)

In [26]:
torch_script_path = "test_gnn_model.pt"
print(f"Saving model to {torch_script_path}")
torch.jit.save(script, torch_script_path)
print(f"Done saving model to {torch_script_path}")

Saving model to test_gnn_model.pt
Done saving model to test_gnn_model.pt
