# Photon Reconstructin in the Belle II Calorimeter Using Graph Neural Networks
Code and examples to the paper on reconstructing Photons in the Belle II Calorimeter using Graph Neural Networks

### Loading model
Loading the model hyperparameters from a yaml config and initializing the model.

In [None]:
import yaml
from torch.cuda import is_available
from model import GNNmodel

with open("configs/two_photon_train_config.yml") as cfg_path:
    config = yaml.safe_load(cfg_path)
config["device"] = "cuda" if is_available() else "cpu"

model = GNNmodel(
    features=config["features"],
    n_photons=config["n_photons"],
    dense_layer_dim=config["dense_layer_dim"],
    feature_space_dim=config["feature_space_dim"],
    spatial_information_dim=config["spatial_information_dim"],
    k=config["k"],
    n_gravblocks=config["n_gravblocks"],
    batch_norm_momentum=config["batch_norm_momentum"],
).to(config["device"])

### Loading the dataset

In [None]:

from datasets import ECLDataset
from train_loop import get_datasets
from torch_geometric.data import DataLoader

full_dataset = ECLDataset(
    root = "./data/",
    raw_filename="two_photon_data.parquet",
    processed_filename="two_photon_data_processed.pt",
    n_photons=config["n_photons"],
    n_events=config["n_events"],
    features=config["features"],
)

train_dataset, val_dataset = get_datasets(config["val_ratio"], full_dataset)

trainloader = DataLoader(
    train_dataset,
    batch_size=config["batch_size"],
    shuffle=True,
    num_workers=config["num_workers"],
    pin_memory=True,
)
valloader = DataLoader(
    val_dataset,
    batch_size=config["val_batch_size"],
    shuffle=False,
    num_workers=config["num_workers"],
    pin_memory=True,
)

### Setting up training

In [None]:
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

optimizer = Adam(model.parameters(), lr=config["lr"])
lr_scheduler = ReduceLROnPlateau(
    optimizer, patience=5, factor = 0.25, 
)


### Running the training

In [None]:
from train_loop import train
from torch import save

model, last_epoch = train(
    config=config,
    model=model,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    trainloader=trainloader,
    valloader=valloader,
)

### Display Resolution

Resolution is defined as (pred - true) / true

In [None]:
from utils import resolution_plot
import matplotlib.pyplot as plt

batch = next(iter(valloader))
batch = batch.to(config["device"])
pred = model(batch).cpu().detach().numpy()
true = batch.y.cpu().numpy()
resolution_plot(true[:,0], pred[:,0], label="gravnet")
plt.title("Resolution of photon 1")
plt.show()

In [None]:
resolution_plot(true[:,1], pred[:,1], label="gravnet")
plt.title("Resolution of photon 2")
plt.show()