In [None]:
import torch
import torch_geometric

from torch import nn
import torch.nn.functional as F

from torch_geometric.nn.conv import GraphConv, EdgeConv, GCNConv

from torch_cluster import radius_graph, knn_graph

In [None]:
class GraphMETNetwork(nn.Module):
    def __init__ (self, continuous_dim, cat_dim, output_dim=1, hidden_dim=32, conv_depth=1):
        super(GraphMETNetwork, self).__init__()
        
        self.embed_continuous = nn.Sequential(nn.Linear(continuous_dim,hidden_dim),
                                             nn.ReLU(),
                                             nn.Linear(hidden_dim, hidden_dim),
                                             nn.ReLU(),
                                             nn.Linear(hidden_dim, hidden_dim),
                                             # nn.BatchNorm1d(hidden_dim) # uncomment if it starts overtraining
                                            )

        self.embed_categorical = nn.Sequential(nn.Linear(cat_dim,hidden_dim),
                                               nn.ReLU(),
                                               nn.Linear(hidden_dim, hidden_dim),
                                               nn.ReLU(),
                                               nn.Linear(hidden_dim, hidden_dim),
                                               # nn.BatchNorm1d(hidden_dim)
                                              )

        self.conv_continuous = nn.ModuleList()        
        for i in range(conv_depth):
            mesg = nn.Sequential(nn.Linear(2*hidden_dim, 3*hidden_dim//2),
                                 nn.ReLU(),
                                 nn.Linear(3*hidden_dim//2, hidden_dim),
                                 # nn.BatchNorm1d(hidden_dim)
                                )

            self.conv_continuous.append(
                EdgeConv(nn=mesg).jittable()
                #GCNConv(hidden_dim, hidden_dim).jittable()
            )
            
        self.conv_categorical = nn.ModuleList()        
        for i in range(conv_depth):
            mesg = nn.Sequential(nn.Linear(2*hidden_dim, 3*hidden_dim//2),
                                 nn.ReLU(),
                                 nn.Linear(3*hidden_dim//2, hidden_dim),
                                 # nn.BatchNorm1d(hidden_dim)
                                )
            self.conv_categorical.append(
                EdgeConv(nn=mesg).jittable()
                #GCNConv(hidden_dim, hidden_dim).jittable()
            )
        
        self.output = nn.Sequential(nn.Linear(2*hidden_dim, hidden_dim),
                                    nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim//2),
                                    nn.ReLU(),
                                    nn.Linear(hidden_dim//2, output_dim)
                                   )

    def forward(self, x_cont, x_cat, edge_index, batch):
        emb_cont = self.embed_continuous(x_cont)
        emb_cat = self.embed_categorical(x_cat)
        
        # graph convolution for continuous variables
        for co_conv in self.conv_continuous:
            emb_cont = co_conv(emb_cont, edge_index)

        # graph convolution for discrete variables
        for ca_conv in self.conv_categorical:
            emb_cat = ca_conv(emb_cat, edge_index)
                              
        # concatenate embeddings together to make description of weight inputs
        emb = torch.cat([emb_cont,emb_cat], dim=1)
        
        out = self.output(emb)
        
        return out

class GraphMET(nn.Module):
    def __init__(self, continuous_dim, categorical_dim):
        super(GraphMET, self).__init__()
        self.graphnet = GraphMETNetwork(continuous_dim, categorical_dim,
                                        output_dim=1, hidden_dim=32,
                                        conv_depth=1)
    
    def forward(self, x_cont, x_cat, edge_index, batch):
        weights = self.graphnet(x_cont, x_cat, edge_index, batch)
        return torch.sigmoid(weights)

In [None]:
from torch_geometric.data import (Data, Dataset)
import glob
import os.path as osp

# dummy dataloader since I don't have the npz
class METDataset(Dataset):
    """PyTorch geometric dataset from processed hit information"""
    
    def __init__(self, root):
        super(METDataset, self).__init__(root)
    
    def download(self):
        pass #download from xrootd or something later
    
    @property
    def raw_file_names(self):
        if not hasattr(self,'input_files'):
            self.input_files = glob.glob(self.processed_dir+'/*.pt')
        return [f.split('/')[-1] for f in self.input_files]
    
    @property
    def processed_file_names(self):
        if not hasattr(self,'processed_files'):
            self.processed_files = self.input_files
        return self.processed_files
    
    def __len__(self):
        return len(self.processed_file_names)
    
    def get(self, idx):
        data = torch.load(self.processed_files[idx])
        return data
    
    def process(self):
        pass

dataset = METDataset(root='/home/lagray/graphmet/data')

from torch_geometric.data import DataLoader
train_data = DataLoader(dataset, batch_size=1, shuffle=True)

In [None]:
model = torch.jit.script(GraphMET(7, 3)).to('cuda') # [px, py, pt, eta, d0, dz, mass], [pdgid, charge, fromPV]

In [None]:
from torch_geometric.utils import to_undirected

def f_loss(weights, px, py, true_px, true_py):
    return 0.5*( ( (weights*px).sum() + true_px)**2 + ( (weights*py).sum() + true_py)**2 )

deltaR = 0.4

opt = torch.optim.AdamW(model.parameters(),lr=0.001)

for e in range(0, 100):
    avg_loss = 0
    for data in train_data:
        if isinstance(data, Data):
            opt.zero_grad()
            data.to('cuda')
            x_cont = data.x[:,:7]
            x_cat = data.x[:,8:]
        
            phi = torch.atan2(data.x[:,1], data.x[:,0])
            etaphi = torch.cat([data.x[:,3][:,None], phi[:,None]], dim=1)        
        
            # NB: there is a problem right now for comparing hits at the +/- pi boundary
            edge_index = radius_graph(etaphi, r=deltaR, batch=data.batch, loop=True, max_num_neighbors=255)
        
            out = model(x_cont, x_cat, edge_index, data.batch)
        
            true_MET = data.y[:,0]
            true_METphi = data.y[:,1]
            loss = f_loss(out, data.x[:,0], data.x[:,1], true_MET*torch.cos(true_METphi), true_MET*torch.sin(true_METphi))
        
            loss.backward()
            avg_loss += loss.item()
            opt.step()
    print(avg_loss, len(train_data))
    print(e, ':',avg_loss/len(train_data))
        
    