In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
! pip install uproot
! pip install plotly
! pip install "notebook>=5.3" "ipywidgets>=7.5"
! pip install hdf5plugin
! pip install awkward
! pip install numba
! pip install vector
! pip install scikit-tda



In [None]:
%ls

[0m[01;34mdrive[0m/  [01;34msample_data[0m/


In [None]:
%cd drive/MyDrive/

/content/drive/MyDrive


In [None]:
%cd reference_data/

/content/drive/MyDrive/reference_data


In [None]:
%ls

[0m[01;34mconverted[0m/  train.h5  val.h5


In [None]:
import h5py as hp
import numpy as np
import pandas as pd
import os
import requests
import functools
import pathlib
import shutil
import logging

import awkward as ak
import torch
import tqdm.auto as tqdm

In [None]:
filename = 'train.h5'
f = hp.File(filename, 'r')

In [None]:
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def plot_jets_interactive(coordinates_array):
    # Ensure the input is a NumPy array
    if not isinstance(coordinates_array, np.ndarray):
        raise ValueError("Input must be a NumPy array.")

    # Extract the first 5 jets
    jets = coordinates_array[:5]

    # Initialize an empty list to hold individual figures
    figs = []

    for i, jet in enumerate(jets):
        x, y, z, intensities = jet[..., 1], jet[..., 2], jet[..., 3], jet[..., 0]  # Using index 0 for intensity

        # Create a scatter plot for each jet
        fig = go.Figure(data=[go.Scatter3d(
            x=x,
            y=y,
            z=z,
            mode='markers',
            marker=dict(
                size=2,
                color=intensities,  # set color to intensity
                colorscale='Viridis',  # choose a colorscale
                opacity=0.8
            ),
            name=f'Jet {i+1}'
        )])

        # Update layout for better visibility
        fig.update_layout(scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z'
        ))

        # Add the figure to the list
        figs.append(fig)

    # Combine all figures into one with a dropdown menu
    combined_fig = go.Figure()
    for fig in figs:
        combined_fig.add_trace(fig.data[0])  # Add the first (and only) trace from each figure

    # Create a dropdown menu for selecting the jet
    dropdown = dict(
        active=0,
        values=[f"Jet {i+1}" for i in range(len(figs))],
        labels=[f"Jet {i+1}" for i in range(len(figs))]
    )
    combined_fig.update_layout(
        updatemenus=[
            dict(
                type="dropdown",
                showactive=False,
                buttons=list([
                dict(label=jet_name, method="update", args=[{"visible": [True if j == i else False for j in range(len(figs))]}])
                    for i, jet_name in enumerate(dropdown["values"])
                ]),
                pad={"r": 10, "t": 10},
#                showactive=True,
                x=0.1,
                xanchor="left",
                y=1.1,
                yanchor="top"
            )
        ],
        autosize=False,
        width=500,
        height=400,
        margin=dict(l=50, r=50, b=100, t=100),
        paper_bgcolor="LightSteelBlue",
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z'
        )
    )

    # Show the combined plot with dropdown
    combined_fig.show()




In [None]:
import vector
import numba as nb
from vector.backends import awkward_constructors as awk
from vector._compute import lorentz as lz
"""
        Takes a DataFrame and converts it into a Awkward array representation
        with features relevant to our model.

        :param df: Pandas DataFrame, The DataFrame with all the momenta-energy coordinates for all the particles
        :param start: int, First element of the DataFrame
        :param stop: int, Last element of the DataFrame
        :return v: OrderedDict, A Ordered Dictionary with all properties of interest



        Here the function is just computing 4 quantities of interest:
        * Eta value relative to the jet
        * Phi value relative to the jet
        * Transverse Momentum of the Particle (log of it)
        * Energy of the Particle (log of it)
"""


# TODO: Compute as many properties as required, keep scope for more always. Maybe use some of it to transform between latent spaces and some for Message Passing.
# TODO: Initially, simply use the low-level features for everything and see how the model trained on low level features compares to model with domain knowledge.

# @nb.jit(nopython=True)
# def compute_features(v, label, mom_objects):
#     v['label'] = np.stack((label, 1-label), axis=-1)
#     v['part_pt_log'] = np.log(jet_p4.pt())
#     v['part_e_log'] = np.log(energy)
#     v['part_etarel'] = vector._compute.spatial.deltaeta()
#     v['part_phirel'] = jet_p4.pseudorapidity(jet_p4)


def _transform(df, start=0, stop=-1):
    from collections import OrderedDict
    v = OrderedDict()

    # generate the column list to be extracted
    def _col_list(prefix, max_particles=200):
        return ['%s_%d'%(prefix,i) for i in range(max_particles)]

    df = df.iloc[start:stop]
    # We take the values in the dataframe for all particles of a single event in each row
    # px, py, pz, e are in separate arrays

    _px = df[_col_list(prefix = 'PX')].values
    _py = df[_col_list(prefix = 'PY')].values
    _pz = df[_col_list(prefix = 'PZ')].values
    _e = df[_col_list(prefix = 'E')].values
    # # We filter out the non-0 non-negative energy particles
    mask = _e > 0
    n_particles = np.sum(mask, axis=1) # Number of particles for each event where energy is greater than 0


    # # _p[mask] filters out the >0 energy particles, and flattens them, so that they can be recollected for each event from counts array.
    px = ak.Array(_px[mask])
    py = ak.Array(_py[mask])
    pz = ak.Array(_pz[mask])
    energy = ak.Array(_e[mask])
    # # These are jagged arrays with each row for 1 event, and all particles in the row

    # p4_lz = vector.array({"x": px, "y": py, "z": pz, "t": energy})
        # Calculate jet parameters
    jet_px = ak.sum(px, axis=0)
    jet_py = ak.sum(py, axis=0)
    jet_pz = ak.sum(pz, axis=0)
    jet_energy = ak.sum(energy, axis=0)

    jet_p4 = vector.obj(x = jet_px, y = jet_py, z = jet_pz, t = jet_energy)

    # Transverse momentum (p_T)
    jet_pt = np.sqrt(jet_px**2 + jet_py**2)
    log_jet_pt = np.log(jet_pt)

    # Energy
    log_jet_energy = np.log(jet_energy)

    # Pseudorapidity difference
    jet_eta = 0.5 * np.log((np.sqrt(jet_px**2 + jet_py**2 + jet_pz**2) + jet_pz) / (np.sqrt(jet_px**2 + jet_py**2 + jet_pz**2) - jet_pz))
    part_eta = 0.5 * np.log((np.sqrt(px**2 + py**2 + pz**2) + pz) / (np.sqrt(px**2 + py**2 + pz**2) - pz))
    eta_rel = part_eta - jet_eta

    # Storing the calculated parameters in the OrderedDict
    v['jet_pt'] = jet_pt
    v['jet_log_pt'] = log_jet_pt
    v['jet_eta'] = jet_eta
    v['part_ptrel'] = jet_pt/v['jet_pt']
    v['log_jet_energy'] = log_jet_energy
    v['eta_rel'] = eta_rel

    # outputs
    _label = df['is_signal_new'].values
    v['label'] = np.stack((_label, 1-_label), axis=-1)
    v['train_val_test'] = df['ttv'].values
    v['n_parts'] = n_particles

    del px, py, pz, energy, _px, _py, _pz, _e
    del jet_px, jet_py, jet_pz, jet_energy, jet_p4, jet_pt, log_jet_pt, log_jet_energy, part_eta, eta_rel

    return v

In [None]:
def convert(source, destdir, basename, step=None, limit=None):
    """
    Converts the DataFrame into an Awkward array and performs the read-write
    operations for the same. Also performs Batching of the file into smaller
    Awkward files.

    :param source: str, The location to the H5 file with the dataframe
    :param destdir: str, The location we need to write to
    :param basename: str, Prefix for all the output file names
    :param step: int, Number of rows per awkward file, None for all rows in 1 file
    :param limit: int, Number of rows to read.
    """
    df = pd.read_hdf(source, key='table')
    logging.info('Total events: %s' % str(df.shape[0]))
    if limit is not None:
        df = df.iloc[0:limit]
        logging.info('Restricting to the first %s events:' % str(df.shape[0]))
    if step is None:
        step = df.shape[0]

    # Initialize an empty DataFrame to accumulate transformed data
    accumulated_df = pd.DataFrame()

    idx = 0
    # Generate files as batches based on step size, only 1 batch is default
    for start in range(0, df.shape[0], step):
        if not os.path.exists(destdir):
            os.makedirs(destdir)
        output = os.path.join(destdir, '%s_%d.parquet'%(basename, idx))  # Changed to .parquet
        logging.info(output)
        if os.path.exists(output):
            logging.warning('... file already exists: continue ...')
            continue
        v = _transform(df, start=start, stop=start+step)  # Convert Awkward array to pandas DataFrame
        # Convert the ordered dictionary to a DataFrame
        print(v)
        batch_df = pd.DataFrame(v)

        # Append the batch DataFrame to the accumulated DataFrame
        accumulated_df = pd.concat([accumulated_df, batch_df], ignore_index=True)

        # Write the batch DataFrame to a Parquet file
        batch_df.to_parquet(output)
        idx += 1

    del batch_df, v, df
    return accumulated_df

In [None]:
PROJECT_DIR = os.getcwd()

In [None]:
v = convert(source = os.path.join(PROJECT_DIR, 'train.h5'), destdir = os.path.join(PROJECT_DIR, 'converted'), basename = 'train-file', limit = 5)

In [None]:
import math
from typing import Callable, Optional
import numpy as np
from numpy.typing import ArrayLike, NDArray
from rich.table import Table
from rich.highlighter import ReprHighlighter
from rich import box
from tabulate import tabulate


def dict2table(input_dict: dict, num_cols: int = 4, title: Optional[str] = None) -> Table:
    num_items = len(input_dict)
    num_rows = math.ceil(num_items / num_cols)
    col = 0
    data = {}
    keys = []
    vals = []

    for i, (key, val) in enumerate(input_dict.items()):
        keys.append(f'{key}:')

        vals.append(val)
        if (i + 1) % num_rows == 0:
            data[col] = keys
            data[col+1] = vals
            keys = []
            vals = []
            col += 2

    data[col] = keys
    data[col+1] = vals

    highlighter = ReprHighlighter()
    message = tabulate(data, tablefmt='plain')
    table = Table(title=title, show_header=False, box=box.HORIZONTALS)
    table.add_row(highlighter(message))
    return table

In [None]:
from rich.console import Console as RichConsole
from rich.logging import RichHandler
from rich.spinner import Spinner
from rich.table import Table
from rich.status import Status
from rich.live import Live
from rich._log_render import LogRender
from time import time
import logging


class LogStatus(Status):
    def __init__(self,
        status,
        console: RichConsole,
        level: int = logging.INFO,
        enabled: bool = True,
        speed: float = 1.0,
        refresh_per_second: float = 12.5,
    ):
        super().__init__(status,
            console=console,
            spinner='simpleDots',
            speed=speed,
            refresh_per_second=refresh_per_second
        )

        self.status = status
        self.level = level
        self.enabled = enabled
        spinner = Spinner('simpleDots', style='status.spinner', speed=speed)
        record = logging.LogRecord(name=None, level=level, pathname=None, lineno=None, msg=None, args=None, exc_info=None)
        handler = RichHandler(console=console)
        table = Table.grid()
        table.add_row(self.status, spinner)

        self._spinner = LogRender(show_level=True, time_format='[%X]')(
            console=console,
            level=handler.get_level_text(record),
            renderables=[table]
        )
        self._live = Live(
            self.renderable,
            console=console,
            refresh_per_second=refresh_per_second,
            transient=True,
        )

    def __enter__(self):
        if self.enabled:
            self._start_time = time()
            return super().__enter__()

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.enabled:
            super().__exit__(exc_type, exc_val, exc_tb)
            self._end_time = time()
            self.console.log(f'{self.status}...done in {self._end_time - self._start_time:.2f} s', level=self.level)

In [None]:
from torch_geometric.transforms import BaseTransform
from torch_geometric.data import Data


class RemoveIsolatedNodes(BaseTransform):
    def __call__(self, data: Data) -> Data:
        mask = data.y.new_zeros(data.num_nodes, dtype=bool)
        mask[data.edge_index[0]] = True
        mask[data.edge_index[1]] = True
        data = data.subgraph(mask)
        return data

In [None]:
from torch_geometric.utils import remove_self_loops
from torch_geometric.transforms import BaseTransform
from torch_geometric.data import Data


class RemoveSelfLoops(BaseTransform):
    def __call__(self, data: Data) -> Data:
        if hasattr(data, 'edge_index') and data.edge_index is not None:
            data.edge_index, _ = remove_self_loops(data.edge_index)
        if hasattr(data, 'adj_t'):
            data.adj_t = data.adj_t.remove_diag()
        return data

In [None]:
from typing import Iterable
from rich.console import Group
from rich.padding import Padding
from rich.table import Column, Table
from rich.progress import Progress, SpinnerColumn, BarColumn, TimeElapsedColumn, Task
from rich.highlighter import ReprHighlighter





class TrainerProgress(Progress):
    def __init__(self,
                 num_epochs: int,
                 **kwargs
                 ):

        progress_bar = [
            SpinnerColumn(),
            "{task.description}",
            "[cyan]{task.completed:>3}[/cyan]/[cyan]{task.total}[/cyan]",
            "{task.fields[unit]}",
            BarColumn(),
            "[cyan]{task.percentage:>3.0f}[/cyan]%",
            TimeElapsedColumn(),
            # "{task.fields[metrics]}"
        ]

        console = Console()

        super().__init__(*progress_bar, console=console, **kwargs)

        self.trainer_tasks = {
            'epoch': self.add_task(total=num_epochs, metrics='', unit='epochs', description='overal progress'),
            'train': self.add_task(metrics='', unit='steps', description='training', visible=False),
            'val':   self.add_task(metrics='', unit='steps', description='validation', visible=False),
            'test':  self.add_task(metrics='', unit='steps', description='testing', visible=False),
        }

        self.max_rows = 0

    def update(self, task: Task, **kwargs):
        if 'metrics' in kwargs:
            kwargs['metrics'] = self.render_metrics(kwargs['metrics'])

        super().update(self.trainer_tasks[task], **kwargs)

    def reset(self, task: Task, **kwargs):
        super().reset(self.trainer_tasks[task], **kwargs)

    def render_metrics(self, metrics: Metrics) -> str:
        out = []
        for split in ['train', 'val', 'test']:
            metric_str = ' '.join(f'{k}: {v:.3f}' for k, v in metrics.items() if f'{split}/' in k)
            out.append(metric_str)

        return '  '.join(out)

    def make_tasks_table(self, tasks: Iterable[Task]) -> Table:
        """Get a table to render the Progress display.

        Args:
            tasks (Iterable[Task]): An iterable of Task instances, one per row of the table.

        Returns:
            Table: A table instance.
        """
        table_columns = (
            (
                Column(no_wrap=True)
                if isinstance(_column, str)
                else _column.get_table_column().copy()
            )
            for _column in self.columns
        )

        highlighter = ReprHighlighter()
        table = Table.grid(*table_columns, padding=(0, 1), expand=self.expand)

        if tasks:
            epoch_task = tasks[0]
            metrics = epoch_task.fields['metrics']

            for task in tasks:
                if task.visible:
                    table.add_row(
                        *(
                            (
                                column.format(task=task)
                                if isinstance(column, str)
                                else column(task)
                            )
                            for column in self.columns
                        )
                    )

            self.max_rows = max(self.max_rows, table.row_count)
            pad_top = 0 if epoch_task.finished else self.max_rows - table.row_count
            group = Group(table, Padding(highlighter(metrics), pad=(pad_top,0,0,2)))
            return Padding(group, pad=(0,0,1,18))

        else:
            return table

In [None]:
from typing import Optional
from torch import Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
from torch_geometric.data import Data
from abc import ABC, abstractmethod
from class_resolver.contrib.torch import optimizer_resolver

# Defined as a subclass of both Module and ABC, inheriting attributes from both.

# Serves as a base class for all the trainable modules in the project (baseline{i.e. without persistence}, Persistent Mod, etc.)
class TrainableModule(Module, ABC):
    def __init__(self, optimizer: str, learning_rate: float, weight_decay: float):
        super().__init__()
        self.optimizer_name = optimizer
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay

    @abstractmethod
    def forward(self, *args, **kwargs): pass

    @abstractmethod
    def step(self, data: Data, phase: Phase) -> tuple[Optional[Tensor], Metrics]: pass

    @abstractmethod
    def predict(self, data: Data) -> Tensor: pass

    @abstractmethod
    def reset_parameters(self): pass

    def configure_optimizers(self):
        optimizer = optimizer_resolver.make(
            query=self.optimizer_name,
            params=self.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay
        )
        return optimizer

In [None]:
from copy import deepcopy
from typing import Iterable, Optional, Annotated, Literal
from torch.optim import Optimizer
from torch.types import Number
from torchmetrics import MeanMetric


class Trainer:
    def __init__(self,
                 monitor:       str = 'val/acc',
                 monitor_mode:  Literal['min', 'max'] = 'max',
                 epochs:        Annotated[int,  ArgInfo(help='number of epochs for training')] = 100,
                 device:        Annotated[str,  ArgInfo(help='device to use for training', choices=['cpu', 'cuda', 'auto'])] = 'auto',
                 verbose:       Annotated[bool, ArgInfo(help='display progress')] = True,
                 logger:        Logger = None,
                 ):

        self.epochs = epochs
        self.monitor = monitor
        self.monitor_mode = monitor_mode
        self.verbose = verbose
        self.logger = logger or DummyLogger()

        # setup device
        if device == 'auto':
            self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        else:
            self.device = torch.device(device)

        # trainer internal state
        self.model: TrainableModule = None
        self.metrics: dict[str, MeanMetric] = {}

    def reset(self) -> None:
        self.model = None
        self.metrics.clear()

    def update_metrics(self, metric_name: str, metric_value: object, batch_size: int = 1) -> None:
        # if this is a new metric, add it to self.metrics
        device = metric_value.device if torch.is_tensor(metric_value) else 'cpu'
        if metric_name not in self.metrics:
            self.metrics[metric_name] = MeanMetric().to(device)

        # update the metric
        self.metrics[metric_name].update(metric_value, weight=batch_size)

    def aggregate_metrics(self, phase: Phase='train') -> Metrics:
        metrics = {}

        for metric_name, metric_value in self.metrics.items():
            if phase in metric_name.split('/'):
                value = metric_value.compute()
                metric_value.reset()
                metrics[metric_name] = value

        return metrics

    def is_better(self, current_metric: Number, previous_metric: Number) -> bool:
        assert self.monitor_mode in ['min', 'max'], f'Unknown metric mode: {self.monitor_mode}'
        if self.monitor_mode == 'max':
            return current_metric > previous_metric
        elif self.monitor_mode == 'min':
            return current_metric < previous_metric

    def fit(self,
            model: TrainableModule,
            train_dataloader: Iterable,
            val_dataloader: Optional[Iterable]=None,
            test_dataloader: Optional[Iterable]=None,
            ) -> Metrics:

        self.model = model.to(self.device)
        self.optimizer: Optimizer = self.model.configure_optimizers()

        self.progress = TrainerProgress(
            num_epochs=self.epochs,
            disable=not self.verbose,
        )

        best_state_dict = None
        best_metrics = None

        with self.progress:
            for epoch in range(1, self.epochs + 1):
                metrics = {f'epoch': epoch}

                # train loop
                train_metrics = self.loop(train_dataloader, phase='train')
                metrics.update(train_metrics)

                # validation loop
                if val_dataloader:
                    val_metrics = self.loop(val_dataloader, phase='val')
                    metrics.update(val_metrics)

                    if best_metrics is None or self.is_better(
                        metrics[self.monitor], best_metrics[self.monitor]
                        ):
                        best_metrics = metrics
                        best_state_dict = deepcopy(self.model.state_dict())

                # test loop
                if test_dataloader:
                    test_metrics = self.loop(test_dataloader, phase='test')
                    metrics.update(test_metrics)

                # log and update progress
                self.progress.update(task='epoch', metrics=metrics, advance=1)
                self.logger.log(metrics)

        if best_metrics is None:
            best_metrics = metrics
        else:
            self.model.load_state_dict(best_state_dict)

        # log and return best metrics
        self.logger.log_summary(best_metrics)

        return best_metrics

    def test(self, dataloader: Iterable) -> Metrics:
        self.metrics.clear()
        metrics = self.loop(dataloader, phase='test')
        return metrics

    def predict(self, dataloader: Iterable, move_to_cpu: bool=False) -> Metrics:
        preds = []
        self.model.eval()
        with torch.no_grad():
            for batch in dataloader:
                batch = self.to_device(batch)
                # out might be a tuple of predictions
                out = self.model.predict(batch)
                if move_to_cpu:
                    out = out.cpu()
                preds.append(out)

        # concatenate predictions, check if they are tuples
        if isinstance(preds[0], tuple):
            preds = tuple(torch.cat([p[i] for p in preds]) for i in range(len(preds[0])))
        else:
            preds = torch.cat(preds)

        return preds

    def loop(self, dataloader: Iterable, phase: Phase) -> Metrics:
        self.model.train(phase == 'train')
        grad_state = torch.is_grad_enabled()
        torch.set_grad_enabled(phase == 'train')
        self.progress.update(phase, visible=len(dataloader) > 1, total=len(dataloader))

        for batch in dataloader:
            batch = self.to_device(batch)
            metrics = self.step(batch, phase)
            for item in metrics:
                self.update_metrics(item, metrics[item], batch_size=batch.batch_nodes.size(0))
            self.progress.update(phase, advance=1)

        self.progress.reset(phase, visible=False)
        torch.set_grad_enabled(grad_state)
        return self.aggregate_metrics(phase)

    def step(self, batch, phase: Phase) -> Metrics:
        if phase == 'train':
            self.optimizer.zero_grad(set_to_none=True)

        loss, metrics = self.model.step(batch, phase=phase)

        if phase == 'train':
            loss.backward()
            self.optimizer.step()

        return metrics

    def to_device(self, batch):
        if isinstance(batch, tuple):
            return tuple(item.to(self.device) for item in batch)
        return batch.to(self.device)

In [None]:
from torch.nn import Module, MultiheadAttention
from torch_geometric.nn import JumpingKnowledge as JK, Linear


class SelfAttention(MultiheadAttention):
    def forward(self, xs: Tensor) -> Tensor:
        """forward propagation

        Args:
            xs (Tensor): input with shape (batch_size, hidden_dim, num_phases)

        Returns:
            Tensor: output tensor with size (num_nodes, hidden_dim)
        """
        x = xs.transpose(2, int(self.batch_first))
        out: Tensor = super().forward(x, x, x, need_weights=True)[0]
        return out.mean(dim=int(self.batch_first))

    def reset_parameters(self):
        super()._reset_parameters()


class WeightedSum(Module):
    def __init__(self, hidden_dim: int, num_heads: int):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.Q = Linear(in_channels=hidden_dim, out_channels=num_heads, bias=False)

        if num_heads > 1:
            self.fc = Linear(in_channels=num_heads, out_channels=1, bias=False)

    def forward(self, xs: Tensor) -> Tensor:
        """forward propagation

        Args:
            xs (Tensor): input with shape (batch_size, hidden_dim, num_phases)

        Returns:
            Tensor: output tensor with size (num_nodes, hidden_dim)
        """
        H = xs.transpose(1, 2)  # (node, hop, dim)
        W = self.Q(H).softmax(dim=1)  # (node, hop, head)
        Z = H.transpose(1, 2).matmul(W)

        if self.num_heads > 1:
            Z = self.fc(Z)

        return Z.squeeze(-1)

    def reset_parameters(self):
        self.Q.reset_parameters()
        if self.num_heads > 1:
            self.fc.reset_parameters()


class JumpingKnowledge(Module):
    supported_modes = ['cat', 'max', 'lstm', 'sum', 'mean', 'attn', 'wsum']
    def __init__(self, mode: str, **kwargs):
        super().__init__()
        self.mode = mode
        if mode == 'attn':
            self.hidden_dim = kwargs['hidden_dim']
            self.num_heads = kwargs['num_heads']
            self.attn = SelfAttention(self.hidden_dim, num_heads=self.num_heads, batch_first=True)
        elif mode == 'wsum':
            self.hidden_dim = kwargs['hidden_dim']
            self.num_heads = kwargs['num_heads']
            self.wsum = WeightedSum(self.hidden_dim, num_heads=self.num_heads)
        elif mode == 'lstm':
            self.lstm = JK(mode='lstm', **kwargs)

    def forward(self, xs: Tensor) -> Tensor:
        """forward propagation

        Args:
            xs (Tensor): input with shape (batch_size, hidden_dim, num_phases)

        Returns:
            Tensor: aggregated output with shape (batch_size, hidden_dim)
        """
        if self.mode == 'cat':
            return xs.transpose(1,2).reshape(xs.size(0), -1)
        elif self.mode == 'sum':
            return xs.sum(dim=-1)
        elif self.mode == 'mean':
            return xs.mean(dim=-1)
        elif self.mode == 'max':
            return xs.max(dim=-1)[0]
        elif self.mode == 'attn':
            return self.attn(xs)
        elif self.mode == 'wsum':
            return self.wsum(xs)
        elif self.mode == 'lstm':
            return self.lstm(xs.unbind(dim=-1))
        else:
            raise NotImplementedError(f'Unsupported JK mode: {self.mode}')

    def reset_parameters(self):
        for module in self.children():
            module.reset_parameters()

## Building the Persistent Homology Module

`~`

`Algebraic Topology` works to describe the shape of a `continuous manifold`. However, real-world datasets are typically given as point clouds, a discrete set of points sampled from an underlying manifold. In this setting, true homologies are trivial, as there is one connected component per point and no holes whatsoever; instead, `persistent homology` can be used to find holes in point clouds and to assign an importance score called persistence to each. Holes with high persistence are indicative of holes in the underlying manifold.

* Persistent homology is a tool that computes topologially-informed features (or topological invariants) for a dataspace at different scales.

* In simple terms, persistent homology keeps track of births and deaths of k-dimensional simplices. (vertices, edges, triangles, tetrahedra and so on)

* This scale grows from local, and extends upto a global level. (upto infinity in theory)

* Persistent features are computed for entities known as `Abstract Simplicial Complexes`

* The simplicial complexes can be determined by one, or more parameters - `d`.

* Here, varying the parameter(s) `d` in an increasing manner creates supersets of abstract simplicial complexes created before.


In [None]:
! pip install torch-geometric

Collecting torch-geometric
  Downloading torch_geometric-2.5.3-py3-none-any.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch-geometric
Successfully installed torch-geometric-2.5.3


`Ripser` and `Persim` are some of the libraries that are based out of C++ and are used to compute filtrations and persistent homology of given point cloud data.

Persistent homology can work pretty well for low-dimensional data, and our purpose of applying this idea here is to just obtain useful filtrations/globally informed features that can be further used to ease the task of GNN learning.

If this approach is successful in our task of quark-gluon classification, it might be indicative of how different instances of particle jets might be just the outcome of sampling different low-dimensional manifolds in a higher-dimensional ambient space.

In [None]:
! pip install ripser persim

Collecting ripser
  Downloading ripser-0.6.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (834 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m834.5/834.5 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting persim
  Downloading persim-0.3.7-py3-none-any.whl (48 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m48.6/48.6 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
Collecting deprecated (from persim)
  Downloading Deprecated-1.2.14-py2.py3-none-any.whl (9.6 kB)
Collecting hopcroftkarp (from persim)
  Downloading hopcroftkarp-1.2.5.tar.gz (16 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: hopcroftkarp
  Building wheel for hopcroftkarp (setup.py) ... [?25l[?25hdone
  Created wheel for hopcroftkarp: filename=hopcroftkarp-1.2.5-py2.py3-none-any.whl size=18101 sha256=680cb8350170d1459dd7fb80baf6b5062950a9e6a18b50f63bd5c30584e5f7a4
  Stored in directory: /root/.cache/pip/wheels

## Beyond "d" : Using Zig-Zag or Multiparameter Persistence

Persistent homology is wel suited for detecting structure in high-dimensional datasets, so it is no surprise that the technique has mostly been applied in cosmology to either constrain non-Gaussianity in the CMB or find cosmic voids or filament loops in the large-scale structure of matter.
<br>

Never before, has persistent homology been used in particle physics.

In [None]:
!pip install chart-studio

Collecting chart-studio
  Downloading chart_studio-1.1.0-py3-none-any.whl (64 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/64.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.4/64.4 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
Collecting retrying>=1.3.3 (from chart-studio)
  Downloading retrying-1.3.4-py3-none-any.whl (11 kB)
Installing collected packages: retrying, chart-studio
Successfully installed chart-studio-1.1.0 retrying-1.3.4


In [None]:
import chart_studio
username='DarthRevan07'
api_key='oUtdAdgKoP0P8GJpLWiP'
chart_studio.tools.set_credentials_file(username=username,
                                        api_key=api_key)

In [None]:
import chart_studio.plotly as py
import chart_studio.tools as tls
import plotly.express as px

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
! pip install --upgrade hepml

Collecting hepml
  Downloading hepml-0.0.12-py3-none-any.whl (25 kB)
Collecting black (from hepml)
  Downloading black-24.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m33.1 MB/s[0m eta [36m0:00:00[0m
Collecting wget (from hepml)
  Downloading wget-3.2.zip (10 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nbdev (from hepml)
  Downloading nbdev-2.3.25-py3-none-any.whl (67 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.1/67.1 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
Collecting giotto-tda (from hepml)
  Downloading giotto_tda-0.6.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m29.0 MB/s[0m eta [36m0:00:00[0m
Collecting mypy-extensions>=0.4.3 (from black->hepml)
  Downloading mypy_extensions-1.0.0-py3-none-any.whl (4.7 kB)

In [None]:
# data wrangling
import numpy as np
import pandas as pd
from pathlib import Path
import pickle
from typing import List
from PIL import Image
from hepml.core import download_dataset
from scipy import ndimage

# tda magic
from gtda.homology import VietorisRipsPersistence, CubicalPersistence
from gtda.diagrams import PersistenceEntropy
from gtda.plotting import plot_heatmap, plot_point_cloud, plot_diagram
from gtda.pipeline import Pipeline
from hepml.core import make_point_clouds, load_shapes
from gtda.graphs import GraphGeodesicDistance

# ml tools
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder

# dataviz
import matplotlib.pyplot as plt

In [None]:
point_clouds_basic, labels_basic = make_point_clouds(n_samples_per_shape=100, n_points=20, noise=0.5)
point_clouds_basic.shape, labels_basic.shape

((300, 400, 3), (300,))

In [None]:
type(point_clouds_basic[0])

numpy.ndarray

In [None]:
fig = px.scatter_3d(point_clouds_basic[0])
fig.show()

In [None]:
py.plot(fig, filename="plotly_scatter", auto_open = True)

'https://plotly.com/~DarthRevan07/1/'

In [None]:
plot_point_cloud(point_clouds_basic[0])

In [None]:
plot_point_cloud(point_clouds_basic[100])

In [None]:
plot_point_cloud(point_clouds_basic[-1])

In [None]:
from gtda.point_clouds import ConsistentRescaling, ConsecutiveRescaling
from gtda.graphs import TransitionGraph, KNeighborsGraph

def adjust_density(point_cloud, density_factor):
    """Adjust point cloud density."""
    return point_cloud * density_factor

def adjust_distance(point_cloud, *args):
    """Adjust point cloud distances."""
    cr = ConsistentRescaling()
    point_cloud = cr.fit_transform(point_cloud, *args)
    return point_cloud

In [None]:
from gtda.diagrams import PersistenceLandscape
from gtda.homology import VietorisRipsPersistence
from scipy.sparse.csgraph import connected_components
import networkx as nx
from gtda.plotting import plot_diagram, plot_betti_curves

class VietorisPersistenceModule:
  def __init__(self, max_edge_length = np.inf, homology_dim = (0, 1, 2, 3)):
    self.point_clouds_basic, self.labels_basic = make_point_clouds(n_samples_per_shape=100, n_points=40, noise=0.5)
    self.point_cloud = self.point_clouds_basic[0]
    self.max_edge_length = max_edge_length
    self.homology_dim = homology_dim
    self.persistence_diagram = None
    self.betti_numbers = None
    self.adj_graph = None
    self.rips_complex = None

  def vietoris_rips_complex(self):
    self.rips_complex = VietorisRipsPersistence(metric = 'euclidean',
                               max_edge_length = self.max_edge_length,
                               homology_dimensions = self.homology_dim)

    self.persistence_diagram = self.rips_complex.fit_transform([self.point_cloud])[0]


  def compute_betti(self):
        landscape = PersistenceLandscape(n_layers=1, n_bins=100, n_jobs=6)
        landscapes = landscape.fit_transform([self.persistence_diagram])

        # Compute Betti numbers
        self.betti_numbers = [
            np.sum(landscape[i]) for landscape, i in zip(landscapes, self.homology_dim)
        ]

        # Plot Betti curves
        fig = plot_betti_curves(
            landscapes[0],
            samplings=landscape.sampling_range_[0],
            homology_dimensions=self.homology_dim
        )
        fig.show()

  def plot_persistence_diagram(self):
        # Plot the persistence diagram
        fig, ax = plt.subplots(figsize=(8, 6))
        plot_diagram(self.persistence_diagram, ax=ax, show=False)
        plt.title("Persistence Diagram")
        plt.show()

  def create_persistent_graph(self):
      # Get the adjacency matrix from the persistence module (note: fit_transform returns the diagram)
      adjacency_matrix = self.rips_complex.fit_transform(self.point_cloud.reshape(1, *self.point_cloud.shape))

      # Convert the persistence diagram to a graph structure
      self.adj_graph = nx.Graph()

      # Here, we extract edges from the adjacency matrix that correspond to a certain filtration value
      for i in range(len(self.point_cloud)):
          for j in range(i+1, len(self.point_cloud)):
              if adjacency_matrix[0][i][j] < self.max_edge_length:
                  self.adj_graph.add_edge(i, j, weight=adjacency_matrix[0][i][j])

      # Optionally, visualize the graph
      pos = {i: self.point_cloud[i] for i in range(len(self.point_cloud))}
      nx.draw(self.adj_graph, pos, with_labels=True, node_size=50)
      plt.title("Graph from Vietoris-Rips Complex")
      plt.show()

  def preprocess(self):
    self.vietoris_rips_complex()
    self.compute_betti()
    self.plot_persistence_diagram()
    self.create_persistent_graph()

In [None]:

# Initialize and run the analysis
topology_analyzer = VietorisPersistenceModule(max_edge_length=1.5)
topology_analyzer.preprocess()



AttributeError: 'PersistenceLandscape' object has no attribute 'sampling_range_'

### Approaching Multiparameter Persistence - `RIVET`

* Simply using a single parameter such as distance `d` will simply result in a k-Nearest Neighbor situation all over again, where the cost of applying persistent homology will become too significant and applying it would be impractical.

* I would like to add another parameter that might be indicative of the presence of a global phenomenon in the vicinity of the particles present in a jet.

* I have implemented a persistent module for 2 features here (distance and density for now).


For this purpose, RIVET is a tool for topological data analysis, and more specifically, for the visualization and analysis of two-parameter persistent homology.

A python API for RIVET is provided as `pyrivet`.

In [None]:
! pip install pyrivet

[31mERROR: Could not find a version that satisfies the requirement pyrivet (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for pyrivet[0m[31m
[0m

In [None]:
import numpy as np
from ripser import ripser
from persim import plot_diagrams
import torch
from collections import defaultdict
import networkx as nx
from torch_geometric.data import Data


In [None]:
class PersistentHomologyFeatureExtractor:
    def __init__(self, data_entry):
        self.data_entry = data_entry

    def compute_persistence_diagram(self):
        # Compute persistence diagram for the single data entry
        diagram = ripser(self.data_entry.numpy(), maxdim=2)['dgms']
        return diagram

    def create_graph_from_diagram(self, diagram):
        # Create a PyG Data object from the persistence diagram
        # Here, we simplify the process by assuming nodes represent features
        # and edges are determined by some criterion (e.g., proximity in filtration values)

        # Example: Using the 0-dimension persistence diagram to create nodes
        node_features = torch.tensor([feature for feature in diagram[0]], dtype=torch.float)
        edge_indices = torch.tensor([[i, j] for i in range(len(diagram[0])) for j in range(i+1, len(diagram[0]))], dtype=torch.long)

        # Create a PyG Data object
        graph = Data(x=node_features, edge_index=edge_indices)

        return graph

    def preprocess(self):
        persistence_diagram = self.compute_persistence_diagram()
        graph_object = self.create_graph_from_diagram(persistence_diagram)
        return graph_object


In [None]:
# Generate a random dataset with 1 entry having 140 points with 4 features
data_entry = torch.tensor(np.random.rand(140, 4), dtype=torch.float)

# Instantiate the PersistentHomologyFeatureExtractor with the generated data entry
extractor = PersistentHomologyFeatureExtractor(data_entry)

# Preprocess the data entry to compute persistent homology features and create a graph object
graph_object = extractor.preprocess()

# Print information about the generated graph object
print(f"Graph Object:")
print(f"Node Features Shape: {graph_object.x.shape}")
print(f"Edge Indices Shape: {graph_object.edge_index.shape}\n")

Graph Object:
Node Features Shape: torch.Size([140, 2])
Edge Indices Shape: torch.Size([9730, 2])

