In [9]:
import torch
from torch.nn import Linear, ReLU
from torch_geometric.data import Data
from torch_geometric.nn import GATConv, GCNConv, global_max_pool, global_mean_pool
# from torch.utils.data import DataLoader
from torch_geometric.loader import DataLoader

from sklearn.model_selection import train_test_split

import torch.nn.functional as F
from torch import nn
import pandas as pd
import numpy as np
from itertools import combinations
import yaml
from copy import deepcopy
import pickle
import os
from tqdm import tqdm
from collections import OrderedDict
from sklearn.preprocessing import StandardScaler



In [10]:
local=False
filename = "data_dict_removed_outliersAug4.pkl" #name of fake data generate pickle file

In [11]:
# Define the base path and the specific folder where your file is saved
base_path = os.path.dirname(os.getcwd())


# filename = "data_dict_removed_outliersAug3.pkl"


if local:
    local_path = "/Users/dimademler/Desktop/UCSD/Labs/EPFL2023/local_gnn_trying/"
    full_file_path = os.path.join(local_path, filename)
else:
    folder = "saved_files/fake_data"
    full_file_path = os.path.join(base_path, folder, filename)

# Load the data dictionary from the pickle file
with open(full_file_path, 'rb') as f:
    clean_data_dict = pickle.load(f)
print(clean_data_dict.keys())
print(clean_data_dict['2_phi'].shape)
data_dict=deepcopy(clean_data_dict)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

dict_keys(['event', 'genWeight', 'MET_phi', '1_phi', '1_genPartFlav', '2_phi', '2_genPartFlav', '3_phi', '3_genPartFlav', 'charge_1', 'charge_2', 'charge_3', 'pt_1', 'pt_2', 'pt_3', 'pt_MET', 'eta_1', 'eta_2', 'eta_3', 'mass_1', 'mass_2', 'mass_3', 'deltaphi_12', 'deltaphi_13', 'deltaphi_23', 'deltaphi_1MET', 'deltaphi_2MET', 'deltaphi_3MET', 'deltaphi_1(23)', 'deltaphi_2(13)', 'deltaphi_3(12)', 'deltaphi_MET(12)', 'deltaphi_MET(13)', 'deltaphi_MET(23)', 'deltaphi_1(2MET)', 'deltaphi_1(3MET)', 'deltaphi_2(1MET)', 'deltaphi_2(3MET)', 'deltaphi_3(1MET)', 'deltaphi_3(2MET)', 'deltaeta_12', 'deltaeta_13', 'deltaeta_23', 'deltaeta_1(23)', 'deltaeta_2(13)', 'deltaeta_3(12)', 'deltaR_12', 'deltaR_13', 'deltaR_23', 'deltaR_1(23)', 'deltaR_2(13)', 'deltaR_3(12)', 'pt_123', 'mt_12', 'mt_13', 'mt_23', 'mt_1MET', 'mt_2MET', 'mt_3MET', 'mt_1(23)', 'mt_2(13)', 'mt_3(12)', 'mt_MET(12)', 'mt_MET(13)', 'mt_MET(23)', 'mt_1(2MET)', 'mt_1(3MET)', 'mt_2(1MET)', 'mt_2(3MET)', 'mt_3(1MET)', 'mt_3(2MET)', 'ma

In [12]:
# Define the input data and labels according to the notebook
input_data_names_ordered = [
    ['MET_phi', 'pt_MET'], 
    ['1_phi', 'charge_1', 'pt_1', 'eta_1', 'mass_1'], 
    ['2_phi', 'charge_2', 'pt_2', 'eta_2', 'mass_2'], 
    ['3_phi', 'charge_3', 'pt_3', 'eta_3', 'mass_3']
]
input_data_particle_order = ['MET', '1', '2', '3']
used_labels2 = [
    ['deltaphi_1MET', 'mt_1MET'], 
    ['deltaphi_2MET', 'mt_2MET'], 
    ['deltaphi_3MET', 'mt_3MET'], 
    ['deltaphi_12', 'deltaeta_12', 'deltaR_12', 'mt_12', 'norm_mt_12'], 
    ['deltaphi_13', 'deltaeta_13', 'deltaR_13', 'mt_13', 'norm_mt_13'], 
    ['deltaphi_23', 'deltaeta_23', 'deltaR_23', 'mt_23', 'norm_mt_23']
]
edge_order = ["MET_1", "MET_2", "MET_3", "1_2", "1_3", "2_3"]

# Create dictionaries for input data and labels
input_data = {}
for particle, features in zip(input_data_particle_order, input_data_names_ordered):
    input_data[particle] = {feature: data_dict[feature] for feature in features if feature in data_dict}

labels = {}
for edge, features in zip(edge_order, used_labels2):
    labels[edge] = {feature: data_dict[feature] for feature in features if feature in data_dict}

print(labels.keys())
print(labels['MET_1'].keys())

print("----------------------------------")

def normalize_data(data):
    mean = np.mean(data)
    std = np.std(data)
    return (data - mean) / std

inputdatanamess_nocharge=deepcopy(input_data_names_ordered)
for i, list in enumerate(input_data_names_ordered):
    if "charge_"+str(i+1) in list:
        inputdatanamess_nocharge[i].remove("charge_"+str(i+1))  

# for particle, features in zip(input_data_particle_order, inputdatanamess_nocharge):
#     input_data[particle] = {feature: normalize_data(data_dict[feature]) for feature in features if feature in data_dict}

# for edge, features in zip(edge_order, used_labels2):
#     labels[edge] = {feature: normalize_data(data_dict[feature]) for feature in features if feature in data_dict}



dict_keys(['MET_1', 'MET_2', 'MET_3', '1_2', '1_3', '2_3'])
dict_keys(['deltaphi_1MET', 'mt_1MET'])
----------------------------------


## Normalization ##

In [13]:
scalers = {}

for particle, features in zip(input_data_particle_order, input_data_names_ordered):
    for feature in features:
        if feature in data_dict:
            scaler = StandardScaler()
            data = data_dict[feature].reshape(-1, 1)
            scaler.fit(data)
            data_dict[feature] = scaler.transform(data).flatten()
            scalers[(particle, feature)] = scaler

for edge, features in zip(edge_order, used_labels2):
    for feature in features:
        if feature in data_dict:
            scaler = StandardScaler()
            data = data_dict[feature].reshape(-1, 1)
            scaler.fit(data)
            data_dict[feature] = scaler.transform(data).flatten()
            scalers[(edge, feature)] = scaler

for particle, features in zip(input_data_particle_order, input_data_names_ordered):
    input_data[particle] = {feature: data_dict[feature] for feature in features if feature in data_dict}

for edge, features in zip(edge_order, used_labels2):
    labels[edge] = {feature: data_dict[feature] for feature in features if feature in data_dict}


In [14]:
edge_types = OrderedDict(zip(edge_order, used_labels2))
edge_mapping = {edge: i for i, edge in enumerate(edge_types.keys())}

print("edge_types:",edge_types)
print("edge_mapping:",edge_mapping)

edge_type_mapping = {0: 'MET_1', 1: 'MET_2', 2: 'MET_3', 3: '1_2', 4: '1_3', 5: '2_3'}

class NodeDNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(NodeDNN, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.network(x)


class EdgeModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, edge_types):
        # for edge, edge_attributes in edge_types.items():
        #     print(2*input_dim + len(edge_attributes))
        print(4*input_dim, hidden_dim)
        super(EdgeModel, self).__init__()
        self.edge_networks = nn.ModuleDict({
            edge: nn.Sequential(
                nn.Linear(4*input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, len(edge_attributes)),
                nn.ReLU()
            )
            
            for edge, edge_attributes in edge_types.items()
        })


    def forward(self, src, dest, edge_attr, edge_type):
        outputs = []
        for i, edge_type_i in enumerate(edge_type):
            src_i, dest_i = src[i], dest[i]
            edge_attr_i = edge_attr[i]

            # print(f"src_i shape: {src_i.shape}")
            # print(f"dest_i shape: {dest_i.shape}")
            # print(f"edge_attr_i shape: {edge_attr_i.shape}")

            # edge_input = torch.cat([src_i, dest_i, edge_attr_i], dim=0)  
            edge_input = torch.cat([src_i, dest_i, edge_attr_i], dim=-1) 
            # print("Size of edge_input: ", edge_input.shape) 
            # print("edge_networks:",self.edge_networks.keys())
            output = self.edge_networks[edge_type_mapping[edge_type_i.item()]](edge_input)
            outputs.append(output)
            # print(" /output shape:",output.shape)
        # print("outputs shape:",outputs)
        return outputs
        # return torch.cat(outputs, dim=0)




class GNN(nn.Module):
    def __init__(self, hidden_dim, edge_types, node_input_dim):
        super(GNN, self).__init__()
        self.node_dnn_MET = NodeDNN(2, hidden_dim, node_input_dim)
        self.node_dnn = NodeDNN(5, hidden_dim, node_input_dim)
        self.gcn1 = GCNConv(node_input_dim, 15)
        self.gcn2 = GCNConv(15, node_input_dim)
        self.edge_model = EdgeModel(node_input_dim, hidden_dim, edge_types)

    def forward(self, data):
        x, edge_index, edge_attr, edge_type = data.x, data.edge_index, data.edge_attr, data.edge_type
        # print("edge_type:",edge_type)
        x = self.gcn1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.gcn2(x, edge_index)
        # Transform the edge attributes
        row, col = edge_index
        out = self.edge_model(x[row], x[col], edge_attr, edge_type)

        return out


node_input_dim=5
nodednn_hidden_dim=10

node_dnn = NodeDNN(5, nodednn_hidden_dim, node_input_dim)
node_dnn_MET = NodeDNN(2, nodednn_hidden_dim, node_input_dim)
edge_model = EdgeModel(node_input_dim, nodednn_hidden_dim, edge_types)
gnn = GNN(nodednn_hidden_dim, edge_types, node_input_dim)  # change input size to 5



edge_types: OrderedDict([('MET_1', ['deltaphi_1MET', 'mt_1MET']), ('MET_2', ['deltaphi_2MET', 'mt_2MET']), ('MET_3', ['deltaphi_3MET', 'mt_3MET']), ('1_2', ['deltaphi_12', 'deltaeta_12', 'deltaR_12', 'mt_12', 'norm_mt_12']), ('1_3', ['deltaphi_13', 'deltaeta_13', 'deltaR_13', 'mt_13', 'norm_mt_13']), ('2_3', ['deltaphi_23', 'deltaeta_23', 'deltaR_23', 'mt_23', 'norm_mt_23'])])
edge_mapping: {'MET_1': 0, 'MET_2': 1, 'MET_3': 2, '1_2': 3, '1_3': 4, '2_3': 5}
20 10
20 10


In [15]:
already_saved = True # Set to true if already preprocessed the data and saved it
processed_name="gnn_processed_100k.pt"


In [16]:
gnn_input_data = []
max_label_len = max(len(labels) for labels in used_labels2)
# processed_name="gnn_processed_1M.pt"
folder2 = os.path.join(base_path, 'saved_files', 'gnnfeatregr_processed_data')

if not os.path.exists(folder2):
    already_saved = False


os.makedirs(folder2, exist_ok=True)

savepath= os.path.join(folder2, processed_name)


if not already_saved:
    for i in tqdm(range(len(data_dict['1_phi'])), desc='events'):
        x = []
        for j, particle in enumerate(input_data_particle_order):
            nodedata = [data_dict[feature][i] for feature in input_data_names_ordered[j]]
            nodedata = torch.tensor(nodedata).float()

            # Transform the node here
            if j == 0:  # This is the MET node
                node = node_dnn_MET(nodedata)
            else:  # This is one of the l123 nodes
                node = node_dnn(nodedata)
            x.append(node)
            
        # Now all node features should have the same size, and you can stack them
        x = torch.stack(x, dim=0)
        edges = []
        edge_attrs = []
        edge_labels = []
        edge_types = []
        for j, edge in enumerate(edge_order):
            particle_mapping = {'MET': 0, '1': 1, '2': 2, '3': 3}
            edge_index = [particle_mapping[particle] for particle in edge.split('_')]
            edges.append(edge_index)

            # Create the edge attributes by concatenating the node features
            node1_features = x[edge_index[0]]
            node2_features = x[edge_index[1]]
            edge_attr = torch.cat((node1_features, node2_features))
            edge_attrs.append(edge_attr)
            
            edge_label = [data_dict[feature][i] for feature in used_labels2[j]]
            edge_label += [-1] * (max_label_len - len(edge_label))  # Padding with -1
            edge_label = torch.tensor(edge_label)
            edge_labels.append(edge_label)
            # edge_types.append(edge) 
            edge_types.append(edge_mapping[edge])
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
        edge_attr = torch.stack(edge_attrs)
        y = torch.stack(edge_labels)
        edge_type = torch.tensor(edge_types, dtype=torch.long)

        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, edge_type=edge_type)
        gnn_input_data.append(data)

    torch.save(gnn_input_data, savepath)

else:
    gnn_input_data = torch.load(savepath)


In [17]:
# Print out the actual tensor data for the first Data object
# data0 = gnn_input_data[0]
# print("x:", data0.x)
# print("edge_index:", data0.edge_index)
# print("edge_attr:", data0.edge_attr)
# print("y:", data0.y)
# print("edge_type:", data0.edge_type)


# Split the data into a training set and a validation set
train_data, val_data = train_test_split(gnn_input_data, test_size=0.2, random_state=42)

# Convert the lists to PyTorch Geometric DataLoader objects
train_loader = DataLoader(train_data, batch_size=640, shuffle=True)
val_loader = DataLoader(val_data, batch_size=320, shuffle=False)


# Define the loss function and the optimizer
criterion = nn.MSELoss()


In [18]:
#TODO: change loss function to log(tanh)
def custom_loss(outs, targets):
    diffs = []
    for i in range(0, len(outs), 6):
        for j in range(6):
            out = outs[i+j]
            target = targets[i+j]
            mask = target != -1  # Mask to ignore padding values
            filtered_target = target[mask]
            filtered_out = out[:len(filtered_target)]  # Adjust out to match the size of filtered_target
            diff = (filtered_out - filtered_target)**2  # Square differences
            diffs.append(diff)
    return torch.cat(diffs).mean()  # Compute the mean of all diffs

def train(model, data_loader, optimizer, device):
    model.train()
    total_loss = 0
    for batch in tqdm(data_loader, desc="Training batch"):
        batch = batch.to(device)  # Move the batch data to the GPU
        optimizer.zero_grad()
        outs = model(batch)
        loss = custom_loss(outs, batch.y)
        loss.backward(retain_graph=True)
        optimizer.step()
        total_loss += loss.item()
    return total_loss

def validate(model, data_loader, device):
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Validation batch", disable=True):
            batch = batch.to(device)  # Move the batch data to the GPU
            outs = model(batch)
            val_loss = custom_loss(outs, batch.y)
            total_val_loss += val_loss.item()
    return total_val_loss

In [19]:
gnn = gnn.to(device)

optimizer = torch.optim.Adam(gnn.parameters(), lr=0.001)

n_epochs = 5

for epoch in tqdm(range(n_epochs), desc='Epoch', disable=True):
    train_loss = train(gnn, train_loader, optimizer, device)
    val_loss = validate(gnn, val_loader, device)
    print('Epoch: {:03d}, Training Loss: {:.4f}, Validation Loss: {:.4f}'.format(epoch, train_loss, val_loss))


Training batch: 100%|██████████| 108/108 [03:26<00:00,  1.91s/it]


Epoch: 000, Training Loss: 106.9818, Validation Loss: 51.6501


Training batch: 100%|██████████| 108/108 [03:29<00:00,  1.94s/it]


Epoch: 001, Training Loss: 101.6927, Validation Loss: 50.0146


Training batch: 100%|██████████| 108/108 [03:29<00:00,  1.94s/it]


Epoch: 002, Training Loss: 99.7730, Validation Loss: 49.4698


Training batch: 100%|██████████| 108/108 [03:27<00:00,  1.92s/it]


Epoch: 003, Training Loss: 98.9140, Validation Loss: 49.1445


Training batch: 100%|██████████| 108/108 [03:27<00:00,  1.92s/it]


Epoch: 004, Training Loss: 98.3380, Validation Loss: 48.9048
