In [2]:
import os
import uproot
import awkward as ak
import numpy as np
from matplotlib import pyplot as plt
import particle 
from particle import PDGID
from particle import Particle 
from particle.pdgid import is_meson

In [None]:
from torch_geometric.nn import MessagePassing
import torch.nn as nn
from torch.nn import Sequential as Seq, Linear, ReLU
#from torch_geometric.nn import global_mean_pool

class MLP(nn.Module):
    def __init__(self, input_size, output_size):
        super(MLP, self).__init__()

        self.layers = nn.Sequential(
            nn.Linear(input_size, 16),
            nn.ReLU(),
            nn.Linear(16, 16),
            nn.ReLU(),
            nn.Linear(16, output_size),
        )

    def forward(self, x):
        return self.layers(x)
    
class MPLayer(MessagePassing):
    def __init__(self, n_node_feats, n_edge_feats, message_size, output_size):
        super(MPLayer, self).__init__(aggr='mean', 
                                   flow='source_to_target')
        self.phi = MLP(2*n_node_feats + n_edge_feats, message_size)
        self.gamma = MLP(message_size + n_node_feats, output_size)
        
    def forward(self, x, edge_index, edge_attr):
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)
    
    def message(self, x_i, x_j, edge_attr):       
        return self.phi(torch.cat([x_i, x_j, edge_attr], dim=1))

    def update(self, aggr_out, x):
        return self.gamma(torch.cat([x, aggr_out], dim=1))

class GNN(nn.Module):
    def __init__(self):
        super(GNN, self).__init__()
        self.mpl1 = MPLayer(1, 1, 16, 8)
        self.mpl2 = MPLayer(8, 1, 16, 16)
        self.mpl3 = MPLayer(16, 1, 16, 16)
        self.mpl4 = MPLayer(16, 1, 16, 32)
        self.mpl5 = MPLayer(32, 1, 16, 1)
    
    def forward(self, data, batch):
        x = data.x.float()
        edge_attr = data.edge_attr.unsqueeze(1)
        edge_index = data.edge_index
        x = self.mpl1(x, edge_index, edge_attr)
        x = self.mpl2(x, edge_index, edge_attr)
        x = self.mpl3(x, edge_index, edge_attr)
        x = self.mpl4(x, edge_index, edge_attr)
        x = self.mpl5(x, edge_index, edge_attr)
        return global_mean_pool(x, batch=batch)
    
graph = train_dataset[0]
print(graph.x.shape)
mpl = MPLayer(1, 1, 16, 8).to(device)

In [None]:
import numpy as np
from torch import optim
from torch.optim.lr_scheduler import StepLR
from matplotlib import pyplot as plt
import torch.nn.functional as F

def train(model, train_loader, optimizer, loss_fcn, epoch):
    model.train()
    losses = []
    for i, batch in enumerate(train_loader):
        optimizer.zero_grad()
        batch_idx = torch.zeros(batch.x[0].shape[0]).long().to(device)
        output = model(batch, batch_idx)
        y, output = batch.y, output.squeeze(1)
        loss = loss_fcn(output, y)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    #print("...epoch {}: train loss={}".format(epoch, np.mean(losses)))
    return np.mean(losses)

def test(model, test_loader, loss_fcn):
    model.eval()
    losses = []
    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            batch_idx = torch.zeros(batch.x[0].shape[0]).long().to(device)
            output = model(batch, batch_idx)
            y, output = batch.y, output.squeeze(1)
            loss = loss_fcn(output, y).item()
            losses.append(loss)
    return np.mean(losses)

# initialize weights
#for module in model.modules():
#    if isinstance(module, nn.Linear):
#        module.weight.data.normal_(0, 1)

# send weights to optimizer 
lr = 1e-2
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=5, gamma=0.9)

loss_fcn = nn.L1Loss()
epochs = 100
for epoch in range(1, epochs + 1):
    train_loss = train(model, train_loader, optimizer, loss_fcn, epoch)
    test_loss = test(model, test_loader, loss_fcn)
    if epoch%1==0:
        print('epoch={}: train_loss={:.10f}, test_loss={:.10f}'
              .format(epoch, train_loss, test_loss))

In [None]:
import torch
from torch_geometric.data import DataLoader

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
test_dataset = [d.to(device) for d in test_dataset]
train_dataset = [d.to(device) for d in train_dataset]
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)

batch_idx = torch.zeros(train_dataset[0].x[0].shape[0]).long().to(device)
print(model(train_dataset[2], batch_idx))
print(train_dataset[2])