In [163]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

In [164]:
import pyarrow.parquet as pq
import numpy as np

import pennylane.numpy as pnp
import pennylane as qml
import jax
from tqdm import tqdm
import jax.numpy as jnp
import optax

from sklearn.svm import SVC
from jax_utils import square_kernel_matrix_jax, kernel_matrix_jax, target_alignment_jax
from pathlib import Path
from torch_geometric.nn import knn_graph
from torch_geometric.utils import to_undirected, k_hop_subgraph
import pandas as pd
import pyarrow as pa
from itertools import zip_longest
import networkx as nx
from jax.config import config
config.update("jax_enable_x64", True)

import torch
seed = 42
np.random.seed(seed)
pnp.random.seed(seed)

jax.devices()

[CpuDevice(id=0)]

In [165]:
def generate(pf, path, iter_batch_size, max_batches, zero_idx, one_idx, prefix):
    record_batch = pf.iter_batches(batch_size=iter_batch_size)
    count = 0
    
    while True:
        try:
            batch = next(record_batch)
            zero_idx, one_idx = transform_to_graph(batch, path, zero_idx, one_idx, prefix)
            count += 1
            break
        except StopIteration as e:
            print(e)
            return zero_idx, one_idx

        if count == max_batches:
            break

    return zero_idx, one_idx

In [166]:
def transform_to_graph(batch, path, zero_idx, one_idx, prefix):
    p = batch.to_pandas()
    im = np.array(np.array(np.array(p.iloc[:, 0].tolist()).tolist()).tolist())
    meta = np.array(p.iloc[:, 3])
    return saver(im, meta, path, zero_idx, one_idx, prefix)

In [167]:
from torch.nn.functional import pad
from torch_geometric.utils import to_dense_batch, to_dense_adj
from time import time

max_nodes = 1024
max_ego_nodes = 0

def saver(im, meta, path, zero_idx, one_idx, prefix):

    im[im < 1.e-3] = 0 #Zero_suppression
    # im[:,0,:,:] = (im[:,0,:,:] - im[:,0,:,:].mean())/(im[:,0,:,:].std())
    # im[:,1,:,:] = (im[:,1,:,:] - im[:,1,:,:].mean())/(im[:,1,:,:].std())
    # im[:,2,:,:] = (im[:,2,:,:] - im[:,2,:,:].mean())/(im[:,2,:,:].std())

    new_file = True
    
    with tqdm(range(meta.shape[0]), unit='datum') as tbatch:   
        for i in tbatch:
            img = im[i,:,:,:]
            label = int(meta[i])

#             channel1 = img[0,:,:]
#             channel2 = img[1,:,:]
#             channel3 = img[2,:,:]

#             channel1 = np.clip(channel1, 0, 500*channel1.std())
#             channel2 = np.clip(channel2, 0, 500*channel2.std())
#             channel3 = np.clip(channel3, 0, 500*channel3.std())

#             p = channel1.max() == 0.0
#             q = channel2.max() == 0.0
#             r = channel3.max() == 0.0

#             if p | q | r:
#                 continue

#             channel1 = channel1/channel1.max()
#             channel2 = channel2/channel2.max()
#             channel3 = channel3/channel3.max()

#             img[0,:,:] = channel1
#             img[1,:,:] = channel2
#             img[2,:,:] = channel3

            img = img.T

            # graph conversion
            img = torch.Tensor(img)
            xhit, yhit, zhit = torch.nonzero(img, as_tuple=True)
            
            # print(xhit.shape, yhit.shape, zhit.shape)
            
            # indices
            chs = [(zhit == 0).nonzero(as_tuple=True)[0], 
                   (zhit == 1).nonzero(as_tuple=True)[0], 
                   (zhit == 2).nonzero(as_tuple=True)[0]]

            hcal_coordinates = torch.stack((xhit[chs[2]], yhit[chs[2]])).T
            hcal_values =  img[xhit[chs[2]], yhit[chs[2]], 2]
            
            assert hcal_coordinates.shape[0]%5 == 0
            
            # print(hcal_indices.shape)
            # Reshape the padded array into groups of 5 elements
            grouped_coordinates = hcal_coordinates.reshape(-1, 5, 2).to(torch.float32)
            grouped_values = hcal_values.reshape(-1, 5)
#             print(grouped_indices)
#             print(grouped_values)
            
            # Calculate the mean along the second axis (axis=1)
            mean_coordinates = torch.mean(grouped_coordinates, axis=1).numpy()
#             print(mean_indices)
            mean_values = torch.mean(grouped_values, axis=1)
#             print(mean_values)
            
            mean_coordinates = mean_coordinates[np.lexsort((mean_coordinates[:, 0], mean_coordinates[:, 1]))]
            
#             print(mean_indices)
            
            mean_coordinates = torch.tensor(mean_coordinates)
        
            grouped_coordinates = mean_coordinates.reshape(-1, 5, 2).to(torch.float32)
            grouped_values = mean_values.reshape(-1, 5)
#             print(grouped_indices)
#             print(grouped_values)
            
            mean_coordinates = torch.mean(grouped_coordinates, axis=1).to(torch.int)
            mean_values = torch.mean(grouped_values, axis=1)
            # print(mean_indices.shape)
            # print(mean_values.shape)
            
            xhit = torch.cat((xhit[chs[0]], xhit[chs[1]], mean_coordinates[:,0])).to(torch.int)
            yhit = torch.cat((yhit[chs[0]], yhit[chs[1]], mean_coordinates[:,1])).to(torch.int)
            zhit = torch.cat((zhit[chs[0]], zhit[chs[1]], torch.ones(mean_coordinates.shape[0])*2)).to(torch.int)
            
            non_zero_values = img[xhit, yhit, zhit]
            non_zero_values *= 50
            
            xhit = xhit.to(torch.float32)
            yhit = yhit.to(torch.float32)
            zhit = zhit.to(torch.float32)
            
            zhit[zhit == 0] = 3
            zhit[zhit == 1] = 5.5
            zhit[zhit == 2] = 8.5
            
            node_feats = torch.stack((xhit, yhit, zhit, non_zero_values), dim=1)
            node_feats = node_feats[:max_nodes]
            coords = node_feats[:, [0, 1, 2]]
            
            if len(coords) == 0:
                continue

            # Create knn graph adjacency matrix
            edge_index = knn_graph(coords,
                                   k=6,
                                   batch=None,
                                   loop=True,
                                   num_workers=16)
            
            
            # generate ego graphs
            n_nodes = node_feats.shape[0]
            ego_nodes = []
            k = 5
            
#             edge_index = to_undirected(edge_index)
            global max_ego_nodes
            for node in range(n_nodes):
                try:
                    subset, sub_edge_index, _, _ = k_hop_subgraph(node,
                                                              k,
                                                              edge_index,
                                                              directed=False)
                except:
                    subset = []
                    sub_edge_index = torch.tensor([])
                    
                n_subset_nodes = len(subset)
                
                if n_subset_nodes:
                    G = nx.Graph()
                    G.add_edges_from(sub_edge_index.numpy().T)
                    paths = nx.single_source_shortest_path_length(G, node, cutoff=k)

                    nodes = np.array(list(paths.keys()))
                    dists = np.array(list(paths.values()))

                    hop_nodes = [
                        [node] + list(nodes[np.where(dists == hop)[0]]) for hop in range(1, k + 1)
                    ]

                    hop_nodes = np.array(list(zip_longest(*hop_nodes, fillvalue=max_nodes+1))).T
                else:
                    dists = np.array([])
                    hop_nodes = np.array([np.array([])]*k)
                    
                ego_nodes.append(hop_nodes)

                max_ego_nodes = max(max_ego_nodes, hop_nodes.shape[-1])
            
                
#             print(node_feats.shape)
#             print(edge_index.shape)

#             dense_node_feats, node_mask = to_dense_batch(node_feats, max_num_nodes=max_nodes)
#             dense_adj = to_dense_adj(edge_index, edge_attr=None, max_num_nodes=max_nodes)
            
            # print(node_feats.shape)
            # print(dense_node_feats.shape)
            
            # print(dense_node_feats[0].shape, node_mask[0].shape)
            # print(dense_adj[0].shape)
#             print('-'*50)
            
            # num_nodes.append(non_zero_values.shape[0])
            # num_edges.append(edge_index.shape[1])
            
            # if label == 0:
            #     save_path = path / 'quark'
            #     # np.savez_compressed(save_path / str(zero_idx), 
            #     #                     node_feats=node_feats, 
            #     #                     edge_index=edge_index,
            #     #                    )
            #     np.savez_compressed(save_path / str(zero_idx), 
            #                         node_feats=dense_node_feats[0], 
            #                         node_mask=node_mask[0],
            #                         adj=dense_adj[0]
            #                        )
            #     zero_idx += 1
            # else:
            #     save_path = path / 'gluon'
            #     np.savez_compressed(save_path / str(one_idx), 
            #                         node_feats=dense_node_feats[0], 
            #                         node_mask=node_mask[0],
            #                         adj=dense_adj[0]
            #                        )
            #     one_idx += 1
            
            
            # sparse
#             parquet_df = pd.DataFrame({
#                 'x': [node_feats[:, 0].numpy()],
#                 'y': [node_feats[:, 1].numpy()],
#                 'detector': [node_feats[:, 2].numpy()],
#                 'energy': [node_feats[:, 3].numpy()],
#                 'edge_index_from': [edge_index[0, :].numpy()],
#                 'edge_index_to': [edge_index[1, :].numpy()],
#                 'y': [label],
#             })

#             table = pa.Table.from_pandas(parquet_df)
#             if new_file:
#                 output_filename = Path(path) / f'{prefix}_{zero_idx}.parquet'
#                 pqwriter = pq.ParquetWriter(output_filename, table.schema, compression='snappy')
#                 new_file = False
#                 zero_idx += 1
#             pqwriter.write_table(table)
    
            
    return zero_idx, one_idx

In [168]:
processed_dir = Path('./processed_qg_parquets/')

In [169]:
os.listdir(processed_dir)

['qg_train_0.parquet']

In [170]:
train_files = ['QCDToGGQQ_IMGjet_RH1all_jet0_run0_n36272.test.snappy.parquet']

In [171]:
zero_idx = 0
one_idx = 0
prefix = 'qg_train'
iter_batch_size = 1024
max_batches = -1

for raw_path in train_files:
    print("Processing file:", raw_path, zero_idx, one_idx)
    zero_idx, one_idx = generate(pq.ParquetFile(raw_path), processed_dir,
                                 iter_batch_size, max_batches, zero_idx,
                                 one_idx, prefix)

print("The files were successfully generated")

Processing file: QCDToGGQQ_IMGjet_RH1all_jet0_run0_n36272.test.snappy.parquet 0 0


100%|████████████████████████████████████| 1024/1024 [04:44<00:00,  3.60datum/s]


The files were successfully generated


In [172]:
max_ego_nodes

54

Tested with 1000 qg samples

k qubits
1 6

2 23

3 41

4 47

5 54