In [1]:
import torch
import torch_geometric
device =  torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from tqdm.notebook import tqdm
import numpy as np
import os
import sys

root_file = "/home/dmngo/test-train-hep"

In [2]:
import yaml

with open(root_file + '/reg_defs.yml') as file:
    # The FullLoader parameter handles the conversion from YAML
    # scalar values to Python the dictionary format
    definitions = yaml.load(file, Loader=yaml.FullLoader)
    
features = definitions['features']
spectators = definitions['spectators']
labels = definitions['labels']

nfeatures = definitions['nfeatures']
nspectators = definitions['nspectators']
nlabels = definitions['nlabels']
ntracks = definitions['ntracks']

In [3]:
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.nn import EdgeConv, global_mean_pool
from torch.nn import Sequential as Seq, Linear as Lin, ReLU, BatchNorm1d, Conv2d
from torch_scatter import scatter_mean
from torch_geometric.nn import MetaLayer

inputs = 19
hidden = 100
outputs = 1

class EdgeBlock(torch.nn.Module):
    def __init__(self):
        super(EdgeBlock, self).__init__()
        self.edge_mlp = Seq(Lin(inputs*2, hidden), 
                            BatchNorm1d(hidden),
                            ReLU(),
                            Lin(hidden, hidden),
                            BatchNorm1d(hidden),
                            ReLU(),
                            Lin(hidden, hidden))

    def forward(self, src, dest, edge_attr, u, batch):
        out = torch.cat([src, dest], 1)
        return self.edge_mlp(out)

class NodeBlock(torch.nn.Module):
    def __init__(self):
        super(NodeBlock, self).__init__()
        self.node_mlp_1 = Seq(Lin(inputs+hidden, hidden), 
                              BatchNorm1d(hidden),
                              ReLU(), 
                              Lin(hidden, hidden),
                              BatchNorm1d(hidden),
                              ReLU(),
                              Lin(hidden, hidden))
        self.node_mlp_2 = Seq(Lin(inputs+hidden, hidden), 
                              BatchNorm1d(hidden),
                              ReLU(), 
                              Lin(hidden, hidden),
                              BatchNorm1d(hidden),
                              ReLU(),
                              Lin(hidden, hidden))

    def forward(self, x, edge_index, edge_attr, u, batch):
        row, col = edge_index
        out = torch.cat([x[row], edge_attr], dim=1)
        out = self.node_mlp_1(out)
        out = scatter_mean(out, col, dim=0, dim_size=x.size(0))
        out = torch.cat([x, out], dim=1)
        return self.node_mlp_2(out)

    
class GlobalBlock(torch.nn.Module):
    def __init__(self):
        super(GlobalBlock, self).__init__()
        self.global_mlp = Seq(Lin(hidden, hidden),                               
                              BatchNorm1d(hidden),
                              ReLU(),
                              Lin(hidden, hidden),
                              BatchNorm1d(hidden),
                              ReLU(),
                              Lin(hidden, outputs))

    def forward(self, x, edge_index, edge_attr, u, batch):
        out = scatter_mean(x, batch, dim=0)
        return self.global_mlp(out)


class InteractionNetwork(torch.nn.Module):
    def __init__(self):
        super(InteractionNetwork, self).__init__()
        self.interactionnetwork = MetaLayer(EdgeBlock(), NodeBlock(), GlobalBlock())
        self.bn = BatchNorm1d(inputs)
        
    def forward(self, x, edge_index, batch):
        
        x = self.bn(x)
        x, edge_attr, u = self.interactionnetwork(x, edge_index, None, None, batch)
        return u
    
# model = InteractionNetwork().to(device)
# optimizer = torch.optim.Adam(model.parameters(), lr = 1e-2)

In [4]:
model = InteractionNetwork().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = 2e-2)

In [5]:
#ported from GNN notebook test and train based on GNN MAE Loss
@torch.no_grad()
def test(model, loader, total, batch_size, leave=False):
    model.eval()
    
    msel = nn.L1Loss()

    sum_loss = 0.
    t = tqdm(enumerate(loader), total=total/batch_size, leave=leave)
    for i, data in t:
        data = data.to(device)
        y = data.y
        batch_output = model(data.x, data.edge_index, data.batch)
        batch_loss_item = msel(batch_output, y).item()
        sum_loss += batch_loss_item
        t.set_description("loss = %.5f" % (batch_loss_item))
        t.refresh() # to show immediately the update

    return sum_loss/i+1
def train(model, optimizer, loader, total, batch_size, leave=False):
    model.train()
    
    msel = nn.L1Loss()

    sum_loss = 0.
    t = tqdm(enumerate(loader), total=total/batch_size, leave=leave)
    for i, data in t:
        data = data.to(device)
        y = data.y
        optimizer.zero_grad()
        batch_output = model(data.x, data.edge_index, data.batch)
        batch_loss = msel(batch_output.float(), y.float())
        batch_loss.backward()
        batch_loss_item = batch_loss.item()
        t.set_description("loss = %.5f" % batch_loss_item)
        t.refresh() # to show immediately the update
        sum_loss += batch_loss_item
        optimizer.step()
    
    return sum_loss/i+1

In [6]:
from torch_geometric.data import Data, DataListLoader, Batch
from torch.utils.data import random_split
import os.path as osp
from GraphDataset import GraphDataset
import glob
import shutil
data_dir = '/teams/DSC180A_FA21_A00/a11'

import random

def collate(items):
    l = sum(items, [])
    return Batch.from_data_list(l)

In [60]:
#train_files = ['/teams/DSC180A_FA21_A00/a11/train_mass_hbb/BulkGravitonToHHTo4Q_MX-600to6000_MH-15to250_part1_TuneCP5_13TeV-madgraph_pythia8/nano_mc2017_1-1_Skim.root']
#all_files = glob.glob('/home/dmngo/teams/DSC180A_FA21_A00/a11/train/*.root')#retrieve all data files and sort them
p1_train_hbb = glob.glob(data_dir + '/train_mass_hbb/BulkGravitonToHHTo4Q_MX-600to6000_MH-15to250_part1_TuneCP5_13TeV-madgraph_pythia8/*.root')
p2_train_hbb = glob.glob(data_dir + '/train_mass_hbb/BulkGravitonToHHTo4Q_MX-600to6000_MH-15to250_part2_TuneCP5_13TeV-madgraph_pythia8/*.root')
p1_train_qcd = glob.glob(data_dir + '/train_mass_qcd/QCD_HT500to700_TuneCP5_13TeV-madgraph-pythia8/*.root')
p2_train_qcd = glob.glob(data_dir + '/train_mass_qcd/QCD_HT700to1000_TuneCP5_13TeV-madgraph-pythia8/*.root')
p3_train_qcd = glob.glob(data_dir + '/train_mass_qcd/QCD_HT1000to1500_TuneCP5_13TeV-madgraph-pythia8/*.root')
p4_train_qcd = glob.glob(data_dir + '/train_mass_qcd/QCD_HT1500to2000_TuneCP5_13TeV-madgraph-pythia8/*.root')
p5_train_qcd = glob.glob(data_dir + '/train_mass_qcd/QCD_HT2000toInf_TuneCP5_13TeV-madgraph-pythia8/*.root')
#all_files.sort()

all_train_hbb = p1_train_hbb + p2_train_hbb
all_train_qcd = p1_train_qcd + p2_train_qcd + p3_train_qcd + p4_train_qcd + p5_train_qcd

some_hbb = random.sample(all_train_hbb, 1)
some_qcd = random.sample(all_train_qcd, 2)

#data_split = .8 
#elem_split = int(len(all_files) * data_split)

#train_files = ['/home/dmngo/teams/DSC180A_FA21_A00/a11/ntuple_merged_10.root']
train_files = some_hbb + some_qcd
random.shuffle(train_files)
#train_files = ['/teams/DSC180A_FA21_A00/a11/train_mass_hbb/BulkGravitonToHHTo4Q_MX-600to6000_MH-15to250_part1_TuneCP5_13TeV-madgraph_pythia8/nano_mc2017_1-1_Skim.root']


#parameters for training per dataset
n_epochs = 10 
best_valid_loss = 9999999
patience = 5
t = tqdm(range(0, n_epochs))
dataset = 0
bump_size = 3 #how many datasets at once

#train validation split
valid_frac = 0.20
batch_size = 32

# graph_dataset = GraphDataset('gdata_train', features, labels, spectators, n_events=1000, n_events_merge=1, 
#                          file_names= train_files)

#run on each dataset
for i in range(0,len(train_files), bump_size):
    #delete old processed files
    dir_path = '/home/dmngo/test-train-hep/data/gdata_train'
    try:
        shutil.rmtree(dir_path)
    except OSError as e:
        print("Error: %s : %s" % (dir_path, e.strerror))
        
    #what dataset is it on? 
    print('Dataset:' + str(dataset))
    
    #reset stale epocs to not skip a dataset
    stale_epochs = 0
    
    #load in next dataset
    graph_dataset = GraphDataset(dir_path, features, labels, spectators, n_events=3000, n_events_merge=1, 
                             file_names= train_files[i:i+bump_size])
    full_length = len(graph_dataset)
    valid_num = int(valid_frac*full_length)
    
    #pretty sure this can be outside the loop
    torch.manual_seed(0)

    #split based on elements declared above
    train_dataset, valid_dataset = random_split(graph_dataset, [full_length-valid_num,valid_num])
    
    #train and collate
    train_loader = DataListLoader(train_dataset, batch_size=batch_size, pin_memory=True, shuffle=True)
    train_loader.collate_fn = collate
    valid_loader = DataListLoader(valid_dataset, batch_size=batch_size, pin_memory=True, shuffle=False)
    valid_loader.collate_fn = collate

    train_samples = len(train_dataset)
    valid_samples = len(valid_dataset)
    
    for epoch in t:
        if stale_epochs >= patience:
            print('next dataset after %i stale epochs'%patience)
            break
        loss = train(model, optimizer, train_loader, train_samples, batch_size, leave=bool(epoch==n_epochs-1))
        valid_loss = test(model, valid_loader, valid_samples, batch_size, leave=bool(epoch==n_epochs-1))
        print('Epoch: {:02d}, Training Loss:   {:.4f}'.format(epoch, loss))
        print('           Validation Loss: {:.4f}'.format(valid_loss))

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            modpath = osp.join('interactionnetwork_best.pth')
            print('New best model saved to:',modpath)
            torch.save(model.state_dict(),modpath)
            stale_epochs = 0
        else:
            print('Stale epoch')
            stale_epochs += 1
    dataset += 1

  0%|          | 0/10 [00:00<?, ?it/s]

Dataset:0


Processing...


  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

  0%|          | 0/3000 [00:00<?, ?it/s]

Done!


  0%|          | 0/75.0 [00:00<?, ?it/s]

  0%|          | 0/18.75 [00:00<?, ?it/s]

Epoch: 00, Training Loss:   83.8011
           Validation Loss: 80.3594
New best model saved to: interactionnetwork_best.pth


  0%|          | 0/75.0 [00:00<?, ?it/s]

  0%|          | 0/18.75 [00:00<?, ?it/s]

Epoch: 01, Training Loss:   80.7809
           Validation Loss: 80.4567
Stale epoch


  0%|          | 0/75.0 [00:00<?, ?it/s]

  0%|          | 0/18.75 [00:00<?, ?it/s]

Epoch: 02, Training Loss:   80.2659
           Validation Loss: 80.3335
New best model saved to: interactionnetwork_best.pth


  0%|          | 0/75.0 [00:00<?, ?it/s]

  0%|          | 0/18.75 [00:00<?, ?it/s]

Epoch: 03, Training Loss:   80.3882
           Validation Loss: 80.4992
Stale epoch


  0%|          | 0/75.0 [00:00<?, ?it/s]

  0%|          | 0/18.75 [00:00<?, ?it/s]

Epoch: 04, Training Loss:   80.4516
           Validation Loss: 80.7770
Stale epoch


  0%|          | 0/75.0 [00:00<?, ?it/s]

  0%|          | 0/18.75 [00:00<?, ?it/s]

Epoch: 05, Training Loss:   80.2473
           Validation Loss: 80.2721
New best model saved to: interactionnetwork_best.pth


  0%|          | 0/75.0 [00:00<?, ?it/s]

  0%|          | 0/18.75 [00:00<?, ?it/s]

Epoch: 06, Training Loss:   80.4599
           Validation Loss: 91.1666
Stale epoch


  0%|          | 0/75.0 [00:00<?, ?it/s]

  0%|          | 0/18.75 [00:00<?, ?it/s]

Epoch: 07, Training Loss:   80.0750
           Validation Loss: 79.7950
New best model saved to: interactionnetwork_best.pth


  0%|          | 0/75.0 [00:00<?, ?it/s]

  0%|          | 0/18.75 [00:00<?, ?it/s]

Epoch: 08, Training Loss:   80.0464
           Validation Loss: 80.0551
Stale epoch


  0%|          | 0/75.0 [00:00<?, ?it/s]

  0%|          | 0/18.75 [00:00<?, ?it/s]

Epoch: 09, Training Loss:   80.2344
           Validation Loss: 80.3181
Stale epoch


In [7]:
modpath = osp.join('interactionnetwork_best.pth')
model.load_state_dict(torch.load(modpath))
model.eval()
batch_size = 32

In [8]:
from GraphDataset import GraphDataset
test_files = ["/home/dmngo/teams/DSC180A_FA21_A00/a11/test_mass_hbb/BulkGravitonToHHTo4Q_MX-600to6000_MH-15to250_part1_TuneCP5_13TeV-madgraph_pythia8/nano_mc2017_3-10_Skim.root"]
#delete old processed files
dir_path = '/home/dmngo/test-train-hep/src/data/gdata_test'
try:
    shutil.rmtree(dir_path)
except OSError as e:
    print("Error: %s : %s" % (dir_path, e.strerror))
test_dataset = GraphDataset(dir_path, features, labels, spectators, n_events=2000, n_events_merge=1, 
                            file_names=test_files)

test_loader = DataListLoader(test_dataset, batch_size=batch_size, pin_memory=True, shuffle=False)
test_loader.collate_fn = collate
test_samples = len(test_dataset)

Processing...


  0%|          | 0/2000 [00:00<?, ?it/s]

Done!


In [9]:
model.eval()
t = tqdm(enumerate(test_loader),total=test_samples/batch_size)
y_test = []
y_predict = []
for i,data in t:
    data = data.to(device)    
    batch_output = model(data.x, data.edge_index, data.batch)    
    y_predict.append(batch_output.detach().cpu().numpy())
    y_test.append(data.y.cpu().numpy())
y_test = np.concatenate(y_test)
y_predict = np.concatenate(y_predict)

  0%|          | 0/62.5 [00:00<?, ?it/s]

In [10]:
y_test

array([[193],
       [127],
       [581],
       ...,
       [104],
       [130],
       [149]])

In [11]:
y_predict

array([[52.947575],
       [62.013107],
       [46.89854 ],
       ...,
       [46.3341  ],
       [57.547234],
       [44.35313 ]], dtype=float32)

In [12]:
np.sqrt(np.mean((y_test - y_predict) ** 2))

272.82726686076774

In [40]:
y_test

array([[171],
       [182],
       [119],
       ...,
       [170],
       [189],
       [159]])

In [41]:
y_predict

array([[45.082127 ],
       [ 5.6559086],
       [71.140076 ],
       ...,
       [19.494946 ],
       [38.086517 ],
       [18.608244 ]], dtype=float32)

In [15]:
y_test #1 hidden layer

array([[150],
       [ 97],
       [150],
       ...,
       [120],
       [  3],
       [107]])

In [17]:
y_predict

array([[67.05142 ],
       [76.40093 ],
       [87.14396 ],
       ...,
       [84.81449 ],
       [84.3029  ],
       [84.458755]], dtype=float32)