In [None]:
import os.path as osp
import time

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

import torch
from sklearn.cluster import DBSCAN
from torch_geometric.data import Data
import torch_geometric.transforms as ttr

from astropy.io import fits
from astropy.table import Table
from astropy import table as astropy_table
from astropy import units as u
from astropy.coordinates import SkyCoord
from astropy.time import Time

from mptd.reader import get_raw_data
from mptd.simple_message import SimpleMessage
from mptd.plotter import plot_data, plot_clusters, plot_fits_data


In [None]:
DEVICE="cpu"#"cuda:0"

In [None]:
from typing import Any

class MPTDData(Data):
    """
    Represents the data for Message-Passing Transient Detection (MPTD).

    Parameters
    ----------
    x : None
        Feature data (not used in this context).
    edge_index : None
        Graph edge indices (not used in this context).
    edge_attr : None
        Graph edge attributes (not used in this context).
    y : None
        Target labels (not used in this context).
    pos : torch.Tensor
        Tensor representing the node positions (coordinates).
    **kwargs
        Additional keyword arguments.

    Returns
    -------
    MPTDData
        A new instance of MPTDData with the given node positions.

    Raises
    ------
    ValueError
        If the data object has no valid tensor for indexing.
    """

    def __init__(self, x=None, edge_index=None, edge_attr=None, y=None, pos=None, **kwargs):
        assert x is None
        super().__init__(x, edge_index, edge_attr, y, pos, **kwargs)

    @property
    def x(self):
        """
        Get the node positions (coordinates) stored in 'pos' attribute.

        Returns
        -------
        torch.Tensor
            Node positions (coordinates).
        """
        return self.pos
    
    def append(self, other):
        """
        Append the data from another MPTDData object to the current instance.

        Parameters
        ----------
        other : MPTDData
            Another MPTDData object to append.

        Returns
        -------
        MPTDData
            A new instance of MPTDData with the combined data.
        """
        new_edge_index = torch.hstack([self.edge_index, other.edge_index]) \
                         if other.edge_index is not None and self.edge_index is not None else None
        new_edge_attr = torch.hstack([self.edge_attr, other.edge_attr]) \
                        if other.edge_attr is not None and self.edge_attr is not None else None
        
        return MPTDData(x=torch.vstack([self.x, other.x]), 
                        y=torch.hstack([self.y, other.y]), 
                        pos=torch.vstack([self.y, other.y]), 
                        edge_index=new_edge_index,
                        edge_attr=new_edge_attr
                        )    
 
    def getsplice(self, index):
        """
        Get a new MPTDData object by selecting specific nodes based on the given index.

        Parameters
        ----------
        index : torch.Tensor
            The index tensor used for selecting nodes.

        Returns
        -------
        MPTDData
            A new instance of MPTDData with selected nodes.
        """
        y = self.y[index] if self.y is not None else None
        pos = self.pos[index] if self.pos is not None else None

        return MPTDData(y=y, pos=pos)
    
    def __len__(self):
        """
        Get the number of nodes in the data object.

        Returns
        -------
        int
            Number of nodes in the data object.

        Raises
        ------
        ValueError
            If the data object has no valid tensor for indexing.
        """
        if self.y is not None:
            return self.y.size(0)
        else:
            raise ValueError("The data object has no valid tensor for indexing.")   


In [None]:

class MPTDDataset:
    """
    Represents a dataset for Message-Passing Transient Detection (MPTD).

    Parameters
    ----------
    filenames : str or List[str]
        File name(s) of the dataset.
    keys : List[str]
        List of keys used to extract data from the dataset files.
    filters : dict
        A dictionary containing filters for the dataset.
    withsim : bool, optional
        Whether to include simulated data, by default True.

    Returns
    -------
    None
    """

    def __init__(self, filenames, keys, filters:dict, withsim=True) -> None:
        if type(filenames) is str:
            filenames = [filenames]
        filename = filenames[0]
        ismos = filename.endswith("MIEVLF0000.FTZ") or filename.endswith("MIEVLI0000.FTZ")
        lastcolname = "PHA" if ismos else "TIME_RAW"
        keys_plus = keys + [lastcolname]
        raw_data = get_raw_data(filename, keys_plus, filters)
        # with fits.open(filename) as hdul:
        #     header = hdul[0].header
        # frame, date = header["RADECSYS"].lower(), Time(header["DATE-OBS"], format='isot')
        # coords = SkyCoord(*(raw_data.to_pandas()[["X","Y"]].to_numpy().T * 0.05), unit=u.arcsec, frame=frame, obstime=date).transform_to("icrs")
        # raw_data["X"] = coords.to_table()["ra"]
        # raw_data["Y"] = coords.to_table()["dec"]
        for filename in filenames[1:]:
            ismos = filename.endswith("MIEVLF0000.FTZ") or filename.endswith("MIEVLI0000.FTZ")
            lastcolname = "PHA" if ismos else "TIME_RAW"
            keys_plus = keys + [lastcolname]
            raw_data = astropy_table.vstack([raw_data, get_raw_data(filename, keys_plus, filters)])
            # with fits.open(filename) as hdul:
            #     header = hdul[0].header
            # frame, date = header["RADECSYS"].lower(), Time(header["DATE-OBS"], format='isot')
            # coords = SkyCoord(*(raw_data.to_pandas()[["X","Y"]].to_numpy().T * 0.05), unit=u.arcsec, frame=frame, obstime=date).transform_to("icrs")
            # raw_data["X"] = coords.to_table()["ra"]
            # raw_data["Y"] = coords.to_table()["dec"]

        issimulated = torch.from_numpy(np.array(raw_data["ISSIMULATED"])).bool()
        self.groups = torch.from_numpy(np.array(raw_data[lastcolname]))
        self.groups[~issimulated] = -1

        if withsim:
            self.data = MPTDData(pos=torch.from_numpy(np.array([raw_data[key] for key in keys]).T).float(),
                                 y=issimulated.long()).to(DEVICE)
        else:
            self.data = MPTDData(pos=torch.from_numpy(np.array([raw_data[key] for key in keys]).T[~issimulated]).float(),
                                 y=issimulated[~issimulated].long()).to(DEVICE)
            
        self.keys = keys

    def get_group(self, group):
        """
        Get nodes belonging to a specific group.

        Parameters
        ----------
        group : int
            Group index.

        Returns
        -------
        torch.Tensor
            Node positions (coordinates) belonging to the specified group.
        """
        indices = self.groups == group
        return self.data.pos[indices]
    
    def list_groups(self):
        """
        Get a tensor containing unique group indices.

        Returns
        -------
        torch.Tensor
            A tensor containing unique group indices.
        """
        return torch.unique(self.groups)


In [None]:

class MPTDElaborator:
    """
    Performs data elaboration for Message-Passing Transient Detection (MPTD).

    Parameters
    ----------
    dataset : MPTDDataset
        The dataset used for elaboration.
    transformer : Any
        The transformer used to transform the dataset data.
    keys : List[str]
        List of keys used for data transformation.
    model : Any
        The model used for data elaboration.

    Returns
    -------
    None
    """

    def __init__(self, dataset:MPTDDataset, transformer, keys, model) -> None:
        self.dataset = dataset
        self.net_data = transformer(dataset.data)
        self.keys = keys
        self.model = model
        self.iterations = 0
        self.elaborated_data = torch.ones_like(self.net_data.x[:, 0].unsqueeze(-1))  # Initialize elaborated data

    def sizes(self):
        """
        Get the sizes of the elaborated data.

        Returns
        -------
        torch.Tensor
            Sizes of the elaborated data.
        """
        return self.elaborated_data.squeeze()
    
    def distances(self):
        """
        Compute the distances between nodes in the elaborated data.

        Returns
        -------
        torch.Tensor
            Distances between nodes in the elaborated data.
        """
        return torch.norm(self.net_data.pos[self.net_data.edge_index[0]] - self.net_data.pos[self.net_data.edge_index[1]], dim=1)
    
    def forward(self, iterations=1):
        """
        Perform the forward pass on the model for a specified number of iterations.

        Parameters
        ----------
        iterations : int, optional
            The number of iterations to run the forward pass, by default 1.

        Returns
        -------
        torch.Tensor
            Sizes of the elaborated data after the forward pass.
        """
        for _ in range(iterations):
            self.elaborated_data += self.model.forward(self.elaborated_data, self.net_data.edge_index)
            self.iterations += 1
            self.elaborated_data /= self.elaborated_data.max()
        return self.sizes()
    
    def forward_plot(self, iterations, plot_every=1, plot_after=0, max_threshold=0.5):
        """
        Perform forward passes on the model and create plot data.

        Parameters
        ----------
        iterations : int
            Total number of iterations to run the forward pass.
        plot_every : int, optional
            Number of iterations to plot data after, by default 1.
        plot_after : int, optional
            Number of iterations to run before starting the plot, by default 0.
        max_threshold : float, optional
            Maximum threshold value, by default 0.5.

        Returns
        -------
        torch.Tensor
            Sizes of the elaborated data after the forward pass.
        """
        sizes = self.forward(iterations=plot_after)
        threshold = min(sizes.mean().item(), max_threshold)
        mask = sizes >= threshold
        while self.iterations < iterations:
            plot_data(self.net_data.pos[mask].cpu(), sizes[mask].cpu(), issimulated=self.net_data.y[mask].cpu().bool(), 
                      keys=self.keys, title=f"iteration {self.iterations}", 
                      outfile=osp.join("video_frames", f"frame_{self.iterations:02}.png"))
            sizes = self.forward(iterations=min(iterations - self.iterations, plot_every))
            threshold = min(sizes.mean().item(), max_threshold)
            mask = sizes >= threshold
        
        plot_data(self.net_data.pos[mask].cpu(), sizes[mask].cpu(), issimulated=self.net_data.y[mask].cpu().bool(), 
                  keys=self.keys, title=f"iteration {self.iterations}", 
                  outfile=osp.join("video_frames", f"frame_{self.iterations:02}.png"))

        return sizes

    def auto_forward(self, rej_threshold, max_threshold):
        """
        Perform an automatic forward pass on the model.

        Parameters
        ----------
        rej_threshold : float
            Threshold for rejection.
        max_threshold : float
            Maximum threshold value.

        Returns
        -------
        torch.Tensor
            Sizes of the elaborated data after the forward pass.
        """
        sizes = self.sizes()
        threshold = min(sizes.mean().item(), max_threshold)
        mask = sizes >= threshold
        old_len = mask.sum()
        self.elaborated_data += self.model.forward(self.elaborated_data, self.net_data.edge_index)
        self.iterations += 1
        self.elaborated_data /= self.elaborated_data.max()
        sizes = self.sizes()
        threshold = min(sizes.mean().item(), max_threshold)
        mask = sizes >= threshold
        new_len = mask.sum()
        while old_len - new_len > rej_threshold:
            old_len = new_len
            self.elaborated_data += self.model.forward(self.elaborated_data, self.net_data.edge_index)
            self.iterations += 1
            self.elaborated_data /= self.elaborated_data.max()
            sizes = self.sizes()
            threshold = min(sizes.mean().item(), max_threshold)
            mask = sizes >= threshold
            new_len = mask.sum()
        return self.sizes()


In [None]:

class MPTDClusterer:
    """
    Cluster data using a specified algorithm for Message-Passing Transient Detection (MPTD).

    Parameters
    ----------
    algorithm : Any
        The clustering algorithm to be used.
    max_threshold : float
        Maximum threshold value.

    Returns
    -------
    None
    """

    def __init__(self, algorithm, max_threshold):
        self.max_threshold = max_threshold
        self.algorithm = algorithm

    def mask(self, elaborator:MPTDElaborator):
        """
        Generate a mask for the data using the clustering algorithm.

        Parameters
        ----------
        elaborator : MPTDElaborator
            The MPTDElaborator instance.

        Returns
        -------
        torch.Tensor
            A mask indicating which nodes are part of the cluster.
        """
        sizes = elaborator.sizes()
        threshold = min(sizes.mean().item(), self.max_threshold)
        mask = sizes >= threshold
        return mask

    def mask_data(self, elaborator:MPTDElaborator):
        """
        Get the data with nodes masked using the clustering algorithm.

        Parameters
        ----------
        elaborator : MPTDElaborator
            The MPTDElaborator instance.

        Returns
        -------
        torch.Tensor
            Masked data with nodes removed.
        """
        mask = self.mask(elaborator)
        masked_data = elaborator.net_data.getsplice(mask).cpu()
        return masked_data
    
    def cluster(self, elaborator:MPTDElaborator):
        """
        Cluster the data using the specified algorithm.

        Parameters
        ----------
        elaborator : MPTDElaborator
            The MPTDElaborator instance.

        Returns
        -------
        np.ndarray
            An array containing cluster labels for each node.
        """
        mask = self.mask(elaborator)
        masked_data = elaborator.net_data.getsplice(mask).cpu()
        labels = self.algorithm.fit_predict(masked_data.pos)
        labels_full = np.full((elaborator.net_data.pos.shape[0],), -1)
        labels_full[mask.cpu()] = labels

        return labels_full
   

In [None]:
 
class MPTDScorer:
    """
    Evaluate the performance of the MPTD model using various metrics.

    Parameters
    ----------
    elaborator : MPTDElaborator
        The MPTDElaborator instance.

    Returns
    -------
    None
    """

    def __init__(self, elaborator:MPTDElaborator):
        self.dataset    = elaborator.dataset
        self.elaborator = elaborator
        self.labels = None
        self.l2g_table = None

    def predict_labels(self, clusterer:MPTDClusterer):
        """
        Predict cluster labels using the specified clusterer.

        Parameters
        ----------
        clusterer : MPTDClusterer
            The MPTDClusterer instance.

        Returns
        -------
        np.ndarray
            An array containing cluster labels for each node.
        """
        self.labels = clusterer.cluster(self.elaborator)
        self.l2g_table = self.label_to_group_table()
        return self.labels

    def fluence_vs_success(self):
        """
        Compute the fluence vs. success rate metric.

        Returns
        -------
        torch.tensor
            A tensor containing fluence vs. success rate values.
        """
        assert self.labels is not None
        result = []
        group_table = self.l2g_table.set_index("Group")#.loc[group]
        for group in self.dataset.list_groups():
            if group < 0: continue
            mask = self.dataset.groups == group
            assert mask.dtype == torch.bool
            fluence = float(mask.sum())
            success = not (group_table.loc[group.item()].to_numpy() < 0).all()
            result.append([group, fluence, success])
        return torch.tensor(result)
        return count_and_check_coordinates_grouped(self.dataset.data.pos[self.dataset.data.y.cpu().bool()], self.dataset.groups[self.dataset.data.y.cpu().bool()], self.dataset.data.pos[self.labels >= 0])
    
    def cluster_accuracy(self):
        """
        Compute the cluster accuracy metric.

        Returns
        -------
        float
            The cluster accuracy score.
        """
        assert self.labels is not None
        return self.num_true_positives()/len(np.unique(self.labels[self.labels>=0]))
        return count_and_check_coordinates_grouped(self.dataset.data.pos[self.labels >= 0], self.labels[self.labels >= 0], self.dataset.data.pos[self.dataset.data.y.bool()]).T[1].float().mean()

    def num_true_positives(self):
        """
        Compute the number of true positives.

        Returns
        -------
        int
            The number of true positives.
        """
        assert self.labels is not None
        return (self.l2g(np.unique(self.labels[self.labels>=0])) >= 0).sum()
        # num = 0
        # for label in np.unique(self.labels[self.labels>=0]):
        #     mask = self.labels == label
        #     num += self.dataset.data.y[mask].any().item()
        # return num
        # return count_and_check_coordinates_grouped(self.dataset.data.pos[self.labels >= 0], self.labels[self.labels >= 0], self.dataset.data.pos[self.dataset.data.y.bool()]).T[1].sum()

    def num_false_positives(self):
        """
        Compute the number of false positives.

        Returns
        -------
        int
            The number of false positives.
        """
        assert self.labels is not None
        return (self.l2g(np.unique(self.labels[self.labels>=0])) < 0).sum()
        # num = 0
        # for label in np.unique(self.labels[self.labels>=0]):
        #     mask = self.labels == label
        #     num += not self.dataset.data.y[mask].any().item()
        # return num
        # return (1 - count_and_check_coordinates_grouped(self.dataset.data.pos[self.labels >= 0], self.labels[self.labels >= 0], self.dataset.data.pos[self.dataset.data.y.bool()]).T[1]).sum()
    
    def label_to_group_table(self):
        assert self.labels is not None
        result = pd.DataFrame()
        for label in np.unique(self.labels[self.labels>=0]):
            mask = self.labels == label
            masked_groups = self.dataset.groups[mask]
            groups = np.unique(masked_groups[masked_groups >= 0])
            if len(groups) == 1:
                group = groups[0]
            else:
                group = -1
            new_row = pd.DataFrame({"Label": label, "Group": group}, index=[0])
            result = pd.concat([result, new_row], ignore_index=True)
        unlabeled_groups = np.setdiff1d(self.dataset.list_groups(), np.unique(result["Group"].to_numpy()), assume_unique=True)
        if len(unlabeled_groups) > 0:
            leftovers = pd.DataFrame({"Group": unlabeled_groups,
                                    "Label": np.full_like(unlabeled_groups, -1)})
            result = pd.concat([result, leftovers], ignore_index=True)
        return result
    
    def l2g(self, label):
        assert self.l2g_table is not None
        return self.l2g_table.set_index("Label").loc[label, "Group"]
    
    def g2l(self, group):
        assert self.l2g_table is not None
        return self.l2g_table.set_index("Group").loc[group, "Label"]


In [None]:
def find_common_parts(strings):
    """
    Find the common characters at each position among a list of strings.

    Given a list of strings, this function returns a new string that contains
    the characters that appear at the same position in all the input strings.

    Parameters
    ----------
    strings : List[str]
        A list of strings for which the common characters at each position need to be found.

    Returns
    -------
    str
        A new string containing the common characters at each position among the input strings.
        If the input list is empty or the strings have different lengths, an empty string is returned.

    Examples
    --------
    >>> find_common_parts(["apple", "apricot", "apartment"])
    'ap'
    >>> find_common_parts(["cat", "bat", "rat"])
    'a'
    >>> find_common_parts(["apple", "apricot", "banana"])
    ''
    """
    if not strings:
        return ""

    common_parts = ""
    for chars_at_position in zip(*strings):
        if all(char == chars_at_position[0] for char in chars_at_position):
            common_parts += chars_at_position[0]

    return common_parts


In [None]:
def single_file_scoring(filenames, keys, k, rej_threshold, max_threshold, filters, min_samples, withsim=True, max_time_interval=np.inf):
    dataset = MPTDDataset(filenames, keys, filters, withsim=withsim)
    transformer = ttr.KNNGraph(k=k, force_undirected=True)
    model = SimpleMessage()
    elaborator = MPTDElaborator(dataset, transformer, keys, model)

    eps = elaborator.distances().max().item()
    dbscan = DBSCAN(eps=eps, min_samples=min_samples)
    clusterer = MPTDClusterer(algorithm=dbscan, max_threshold=max_threshold)
    scorer = MPTDScorer(elaborator)

    if save_every is None:
        save_every = np.inf

    groups = pd.DataFrame()
    start_time = time.time()

    scorer.elaborator.auto_forward(rej_threshold, max_threshold)
    labels = scorer.predict_labels(clusterer)

    for label in np.unique(scorer.labels):
        if label < 0:
            continue
        mask = scorer.labels == label
        group_times = scorer.elaborator.dataset.data.pos[mask, 1].cpu()
        time_interval = (group_times.max() - group_times.min()).item()
        if time_interval > max_time_interval:
            scorer.labels[mask] = -1
            continue
        group_xs = scorer.elaborator.dataset.data.pos[mask, 0].cpu()
        group_ys = scorer.elaborator.dataset.data.pos[mask, 2].cpu()
        det_grp = scorer.l2g(label)
        # det_grp = scorer.l2g(label).to_numpy()
        # assert len(det_grp) == 1
        # det_grp = det_grp[0]
        det_grp_fl = len(scorer.elaborator.dataset.groups[mask][scorer.elaborator.dataset.groups[mask] == det_grp]) if det_grp >= 0 else 0
        new_row = pd.DataFrame({"Mean Time": group_times.mean().item(),
                                "Duration": time_interval, 
                                "Mean X": group_xs.mean().item(),
                                "Std X": group_xs.std().item(),
                                "Mean Y": group_ys.mean().item(),
                                "Std Y": group_ys.mean().item(),
                                "Counts": len(group_times),
                                "File": str(filenames),
                                "Group ID": label,
                                "Rejection": rej_threshold,
                                "Iteration": scorer.elaborator.iterations,
                                "Detects": det_grp >= 0, #scorer.elaborator.dataset.data.y[mask].any().item(),
                                "Detected Group": det_grp,
                                "Detected Group Fluence": det_grp_fl,
                                "Original Group Fluence": len(scorer.elaborator.dataset.get_group(det_grp)) if det_grp >= 0 else 0
                                }, index=[0])
        groups = pd.concat([groups, new_row], ignore_index=True)

    results = pd.DataFrame(scorer.fluence_vs_success().float(), columns=["ID", "Fluence", "Detected"])
    results["File"] = np.full(results["Fluence"].shape, filenames)
    results["Accuracy"] = np.full_like(results["Fluence"], scorer.cluster_accuracy())
    results["Rejection"] = np.full_like(results["Fluence"], rej_threshold)
    results["Iteration"] = np.full_like(results["Fluence"], scorer.elaborator.iterations)
    results["True Pos"] = np.full_like(results["Fluence"], scorer.num_true_positives())
    results["False Pos"] = np.full_like(results["Fluence"], scorer.num_false_positives())
    end_time = time.time()
    elapsed_time = end_time - start_time
    results["Time"] = elapsed_time

    return results, groups


In [None]:
def objective(trial, keys, ks, rej_thresholds, max_thresholds, filters, min_samples_range, withsim, max_time_intervals):
    k = trial.suggest_int("k", *ks)
    rej_threshold = trial.suggest_int("rej_threshold", *rej_threshols)
    min_samples = trial.suggest_int("min_samples", *min_samples_range)