<a href="https://colab.research.google.com/github/XenoicZ/PointCloud/blob/main/test_script.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!pip install torch_geometric
!pip install uproot
!pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-1.13.0+cpu.html
import torch
import uproot
import torch.nn as nn
import torch.utils.data
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

from google.colab import drive
drive.mount('/content/drive')

# to-do: make new directory
#        assign directory with variable

In [2]:
from torch_geometric.data import Batch

In [2]:
from glob import glob
pi0_files = sorted(glob('/content/drive/MyDrive/ml4pion/data/onetrack_multicluster/pi0_files/'+'*.npy'))
pion_files = sorted(glob('/content/drive/MyDrive/ml4pion/data/onetrack_multicluster/pion_files/'+'*.npy'))

In [3]:
data = uproot.open('/content/drive/MyDrive/ml4pion/RhoDeltaPion/Samp_data/Rho/user.angerami.29450173.OutputStream._000001.root')['CellGeo']
data['cell_geo_prevInPhi'].array(library='np')

array([array([    63,      0,      1, ..., 187641, 187642, 187643], dtype=int32)],
      dtype=object)

In [148]:

import uproot
import os.path as osp
import numpy as np

import torch
from tqdm import tqdm
from torch_geometric.data import Dataset, Data

torch.set_default_dtype(torch.float32)

node_feature_names = ['cluster_cell_E', 'cell_geo_sampling', 'cell_geo_eta', 
                      'cell_geo_phi', 'cell_geo_rPerp', 'cell_geo_deta', 'cell_geo_dphi']
edge_feature_names = ['cell_geo_prevInPhi', 'cell_geo_nextInPhi',
                      'cell_geo_prevInEta', 'cell_geo_nextInEta',
                      'cell_geo_prevInSamp','cell_geo_nextInSamp',
                      'cell_geo_prevSubDet','cell_geo_nextSubDet',
                      'cell_geo_prevSuperCalo','cell_geo_nextSuperCalo']
global_feature_names = ['cluster_E', 'trackPt', 'trackZ0', 'trackEta', 'trackPhi']

cellGeo_data = uproot.open('/content/drive/MyDrive/ml4pion/RhoDeltaPion/Samp_data/Rho/user.angerami.29450173.OutputStream._000001.root')['CellGeo']
cellGeo_ID = cellGeo_data['cell_geo_ID'].array()[0]
sorter = np.argsort(cellGeo_ID)

class GraphDataset(Dataset):
    def __init__(self, root, i_file,
                 transform=None, pre_transform=None, pre_filter=None):
      
        self.i_file = i_file
        self.cellGeo_data = uproot.open('/content/drive/MyDrive/ml4pion/RhoDeltaPion/Samp_data/Rho/user.angerami.29450173.OutputStream._000001.root')['CellGeo']
        self.cellGeo_ID = self.cellGeo_data['cell_geo_ID'].array()[0]
        self.sorter = np.argsort(self.cellGeo_ID)
        self.edgeFeatureNames = self.cellGeo_data.keys()[9:]
        super().__init__(root, transform, pre_transform, pre_filter)
        
        # to-do: index by file

    @property
    def raw_file_names(self):
        return 'user.angerami.29450173.OutputStream._' + format(self.i_file, '06d') + '.root'

    @property
    def processed_file_names(self):
        return 'test.pt'

    def download(self):
        print('raw_file not found warning')
        pass

    def process(self):
        self.N_events = 0
        raw_path = self.raw_paths[0]

        # Read data from `raw_path`.
        
        event_file = uproot.open(raw_path)['EventTree']
        N_events = len(event_file[0].array())
        processed_data = []

        for i_event in tqdm(range(N_events)):
            if event_file['nCluster'].array()[i_event] != 2 or event_file['nTrack'].array()[i_event] != 1:
                continue
            node_features, N_nodes, cell_IDmap = self._get_node_features(event_file, i_event)
            edge_features, edge_index = self._get_edge_features(N_nodes, cell_IDmap)
            global_features = self._get_global_features(event_file, i_event)

            processed_data.append(Data(x=node_features, edge_attr=edge_features,
                                       edge_index=edge_index, y=global_features, dtype=torch.float32))
            
        torch.save(processed_data, osp.join(self.processed_dir, f'data_{self.i_file}.pt'))
        self.N_events = len(processed_data)
        

    def len(self):
        return self.N_events

    def get(self, i_event):
        data = torch.load(osp.join(self.processed_dir, f'data_{self.i_file}.pt'))[i_event]
        return data

    def _get_node_features(self, event_file, i_event):
        
        cell_IDs = event_file['cluster_cell_ID'].array()[i_event][0] ### to-do: genelize to multiple clusters
        cell_IDmap = self.sorter[np.searchsorted(self.cellGeo_ID, cell_IDs, sorter=self.sorter)]
        N_nodes = len(cell_IDs)

        node_features = torch.zeros((N_nodes, len(node_feature_names)))
        
        temp = np.log10(event_file['cluster_cell_E'].array(library='np')[i_event][0])
        node_features[:,0] = torch.from_numpy(temp)
        
        for i, name in enumerate(node_feature_names[1:]):
            temp = cellGeo_data[name].array(library='np')[0][cell_IDmap].astype(float)
            node_features[:,i+1] = torch.from_numpy(temp)
        
        return node_features, N_nodes, cell_IDmap

    def _get_edge_features(self, N_nodes, cell_IDmap):
        edge_index = np.zeros((N_nodes, len(self.edgeFeatureNames)))
        for i, name in enumerate(edge_feature_names):
            edge_index[:,i] = self.cellGeo_data[name].array(library='np')[0][cell_IDmap]
            mask = np.logical_not(np.isin(edge_index[:,i], cell_IDmap))
            edge_index[mask,i] = np.nan

        senders, edge_on_inds = np.nonzero(np.isin(edge_index, cell_IDmap))
        
        N_edges = len(senders)
        edge_features = np.zeros((N_edges, len(self.edgeFeatureNames)))
        edge_features[np.arange(N_edges), edge_on_inds] = 1

        cell_IDmap_sorter = np.argsort(cell_IDmap)
        rank = np.searchsorted(cell_IDmap, edge_index , sorter=cell_IDmap_sorter)
        receivers = cell_IDmap_sorter[rank[rank!=N_nodes]]

        return torch.tensor(edge_features, dtype=torch.float), torch.tensor([senders, receivers], dtype=torch.float)


    def _get_global_features(self, event_file, i_event):
        global_features = []
        for name in global_feature_names:
            global_features.append(event_file[name].array()[i_event][0])
        global_features = torch.tensor(global_features, dtype=torch.float)
        return global_features

In [75]:
data_set = GraphDataset('/content/drive/MyDrive/ml4pion/RhoDeltaPion/Samp_data/Rho', 1)

Processing...
Done!


In [149]:
datas =torch.utils.data.ConcatDataset([
            GraphDataset('/content/drive/MyDrive/ml4pion/RhoDeltaPion/Samp_data/Rho', i)\
            for i in range( 1,2 )            
    ])

Processing...
100%|██████████| 5000/5000 [07:51<00:00, 10.60it/s]
Done!


In [151]:
datas[0].edge_index.dtype

torch.float32

In [154]:
from torch_geometric.loader import DataLoader
loader = DataLoader(datas, batch_size=3, shuffle=False)
for i, samp in enumerate(loader):
    global_features = torch.reshape(samp.y, (3,5))
    print(samp.y)
    print(global_features)
    #print(samp.batch[samp.edge_index[0]])
    #print(samp.batch[samp.edge_index[1]])
    x = samp.batch[samp.edge_index[1]]
    print(model(samp.x, samp.edge_index, samp.edge_attr, global_features, samp.batch))
    break

tensor([ 4.1431e+00,  3.5067e+00, -1.5764e+01, -1.6843e+00, -9.5085e-01,
         1.3381e+03,  1.0181e+03, -3.2734e+01, -5.5517e-01,  1.4673e+00,
         1.2889e+01,  1.2400e+00,  5.0983e+01,  2.1723e+00, -4.0360e-01])
tensor([[ 4.1431e+00,  3.5067e+00, -1.5764e+01, -1.6843e+00, -9.5085e-01],
        [ 1.3381e+03,  1.0181e+03, -3.2734e+01, -5.5517e-01,  1.4673e+00],
        [ 1.2889e+01,  1.2400e+00,  5.0983e+01,  2.1723e+00, -4.0360e-01]])


IndexError: ignored

In [152]:
from torch_geometric.nn import MetaLayer
from torch.nn import Sequential as Seq, Linear as Lin, ReLU, BatchNorm1d
from torch_scatter import scatter_mean

torch.set_default_dtype(torch.float32)

#inputs = np.array([7,10,5])
global_size = 5
edge_size = 10
node_size = 7
latent_size = 64

class MLP(nn.Module):
    def __init__(self, input_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(64, 64)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(64, 64)
    
    def forward(self, x):
        print('x: '+str(x.dtype))
        out = self.fc1(x)
        out = self.relu1(out)
        out = self.fc2(out)
        out = self.relu2(out)
        out = self.fc3(out)
        return out

class EdgeBlock(torch.nn.Module):
    def __init__(self, input_size):
        super(EdgeBlock, self).__init__()
        self.edge_mlp = MLP(input_size*2+edge_size+global_size)

    def forward(self, src, dest, edge_attr, u, batch):
        print(src.size())
        print(dest.size())
        print(len(edge_attr))
        print(u.size())
        print(batch.size())
        out = torch.cat([src, dest, edge_attr, u[batch]], 1)
        print(out)
        return self.edge_mlp(out)


class GlobalBlock(torch.nn.Module):
    def __init__(self, input_size):
        super(GlobalBlock, self).__init__()
        
        self.global_mlp = MLP(input_size+node_size) ## 
        
    def forward(self, x, edge_index, edge_attr, u, batch):

        out = torch.cat([u,scatter_mean(x, batch, dim=0)], dim=1)

        return self.global_mlp(out)

class GraphBlock(torch.nn.Module):
    def __init__(self, input_size):
        super(GraphBlock, self).__init__()
        self.graph_block = MetaLayer(EdgeBlock(7), None, GlobalBlock(input_size))
        #self.bn = BatchNorm1d(inputs)
        
    def forward(self, x, edge_index, edge_attr, u, batch):
        
        #x = self.bn(x)
        x, edge_attr, u = self.graph_block(x, edge_index, edge_attr, u, batch)

        return edge_attr, u

class GraphNetwork(torch.nn.Module):
    def __init__(self):
        super(GraphNetwork, self).__init__()

        self.block0 = GraphBlock(global_size)
        self.block1 = GraphBlock(latent_size)
        self.block2 = GraphBlock(latent_size)
        self.block3 = GraphBlock(latent_size)

    def forward(self, x, edge_index, edge_attr, u, batch):
        edge_attr, u = self.block0(x, edge_index, edge_attr, u, batch)

        edge_attr, u = self.block1(x, edge_index, edge_attr, u, batch)

        edge_attr, u = self.block2(x, edge_index, edge_attr, u, batch)

        edge_attr, u = self.block3(x, edge_index, edge_attr, u, batch)
        '''
        print(u)
        x, edge_attr, u = self.block0(x, edge_index, edge_attr, u, batch)
        print(u)
        x, edge_attr, u = self.block1(x, edge_index, edge_attr, u, batch)
        print(u)
        x, edge_attr, u = self.block2(x, edge_index, edge_attr, u, batch)
        print(u)
        x, edge_attr, u = self.block3(x, edge_index, edge_attr, u, batch)
        '''
        return u

In [153]:
model = GraphNetwork()

In [138]:
for name, param in model.named_parameters():
    print(param.dtype)
    break

torch.float64


In [18]:
dir(model.named_parameters())

['__class__',
 '__del__',
 '__delattr__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__iter__',
 '__le__',
 '__lt__',
 '__name__',
 '__ne__',
 '__new__',
 '__next__',
 '__qualname__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 'close',
 'gi_code',
 'gi_frame',
 'gi_running',
 'gi_yieldfrom',
 'send',
 'throw']