In [6]:
# Code for creating graph data structure using the labeled_data.csv which is obtained after execution of labeling.ipynb 
# Each event is represented as a single graph
# Hit coordinates(i.e x_dig,y_dig,z_dig) are used as node features and time difference is used as edge features.
# The following code snippet can be used in colab
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from scipy.spatial.distance import cdist
import torch
import shutil
import dgl
import  networkx as nx

df_csv_concat=pd.read_csv("labeled_data.csv")
# df_csv_concat


os.environ['DGLBACKEND'] = 'pytorch'



# create a folder for saving the graphs
if not os.path.exists('saved_graphs'):
    os.makedirs('saved_graphs')

# Loop over all unique graph IDs in the DataFrame and plotting graph for each graphID
min_nodes=5 # graphs with nodes less than min_nodes are not created
for graphID in df_csv_concat['graphID'].unique():
    # print(f"Processing event {graphID}")
    
    # Filtering the DataFrame to get rows for the desired event
    graph_df = df_csv_concat[df_csv_concat['graphID'] == graphID][['x_dig', 'y_dig', 'z_dig', 't','graphLabel','graphID','nodeID']]
    num_nodes = len(graph_df)
    # print(num_nodes)
    if num_nodes > min_nodes:
        g = dgl.graph((np.arange(0,num_nodes-1), np.arange(1,num_nodes)))
        g.ndata['pos']=torch.tensor(graph_df[['x_dig', 'y_dig', 'z_dig']].values)
        t = torch.tensor(graph_df['t'].to_numpy())
        g.edata["t_diff"]=t[1:]-t[:-1]
        # save the graph as a DGL file in the 'saved_graphs' folder
        dgl.save_graphs(f"saved_graphs/graph_{graphID}.dgl", g)

        #To visualize the graph
        # nxg = g.to_networkx()
        # pos = nx.kamada_kawai_layout(nxg)
        # nx.draw(nxg, pos, node_color="r", node_size=10)
        # plt.title(f"Event {graphID}")
        # plt.show()
    # else:
    #     print("skipping graphs with one node and zero edges")
    # ng=ng+1
    # if ng==10:
    #     break

# make sure the folder name is correct
folder_name = 'saved_graphs'
shutil.make_archive("saved_graphs", 'zip', folder_name)
#To read a graph
#graph_list, _ = dgl.load_graphs('/content/saved_graphs/graph_10001.dgl')
# # Get the first graph object
# data = graph_list[0]
#print(graph_list)



'/home/sree/Desktop/hemalata/ical_graph_data/saved_graphs.zip'