# Jupyter notebook to visualize our results after training PyTorch model

Please run this notebook from the notebooks directory after running `python src/visualizer.py` from the root folder of this project.

Some useful links:
- https://brandonrozek.com/blog/jupyterwithpyenv/

- https://github.com/microsoft/vscode-jupyter/wiki/Setting-Up-Run-by-Line-and-Debugging-for-Notebooks

### Imports

In [1]:
!python --version

Python 3.10.6


In [2]:
import sys

sys.path.append('../')
sys.path.append('../src')

%load_ext autoreload
%autoreload 2

In [3]:
import gc
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import re
import torch
import warnings

from hydra import initialize, compose
from ipywidgets import interact, FloatSlider, IntSlider
from pathlib import Path
from src import *
from src.trainer import Trainer
from src.visualizer import plot_val_residual, plot_val_pull



In [4]:
%matplotlib widget

### Load trained model & more

In [5]:
with initialize(version_base=None, config_path="../config"):
    cfg = compose(config_name='trainer.yaml')
    cfg.wandb.mode = "disabled"

    root_filename = "../" + cfg.dataset.filename
    
    # The following assumes that the user already
    # ran ../src/visualizer.py prior to this.
    # Otherwise, we need to ensure that the way we create the dataset
    # did not change since last time the .pkl file was created
    p = Path(root_filename)
    filename =  f"{str(p.parent)}/{p.stem}_dataset.pkl"
    if Path(filename).is_file():  # if exists and is a file
        cfg.dataset.filename = filename
    else:
        cfg.dataset.filename = root_filename
        cfg.dataset.save_format = "pkl"  # to save dataset
    
    trainer = Trainer(cfg)

    # Loading checkpoint
    general_checkpoint = torch.load("../checkpoints/last_general_checkpoint.pth")
    trainer.model.load_state_dict(general_checkpoint["model_state_dict"])
    trainer.optimizer.load_state_dict(general_checkpoint["optimizer_state_dict"])

    trainer.epoch = general_checkpoint["epoch"]
    trainer.train_loss = general_checkpoint["train_loss"]
    trainer.val_loss = general_checkpoint["val_loss"]
    
    trainer.model.eval()
    torch.set_grad_enabled(False)
    # it helps with memory-related issues:
    # https://stackoverflow.com/questions/69007342/disable-grad-and-backward-globally
    # https://discuss.pytorch.org/t/how-to-delete-a-tensor-in-gpu-to-free-up-memory/48879/15

[array([      0,       1,       2, ..., 3930742, 3930743, 3930744]), array([    250,     251,     506, ..., 3930363, 3930618, 3930619]), array([    252,     253,     508, ..., 3930365, 3930620, 3930621])]


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


In [6]:
trainer.dataset_full.data_df

Unnamed: 0,unix_time,glat,glon,altitude,temperature,fe_cosmic,corrected,correrr,config,raz,...,1/rate_err[6]**2,1/rate_err[10]**2,1/rate_err[12]**2,rate[0]/rate_err[0],rate[1]/rate_err[1],rate[5]/rate_err[5],rate[6]/rate_err[6],rate[10]/rate_err[10],rate[12]/rate_err[12],(unix_time-1474004181.5460)//5535.4+10
0,1.483525e+09,41.273513,64.930544,376.565432,33.500000,1177.0,1111.225454,0.0,42,335.977594,...,0.000364,0.000260,0.000270,28.621580,22.109620,54.603536,49.826243,58.944945,57.878622,1730.0
1,1.483525e+09,41.301336,65.091966,376.553905,33.500000,1185.0,1077.260702,0.0,42,336.125659,...,0.000348,0.000256,0.000257,27.929632,21.788451,54.791671,50.310019,58.843234,58.802215,1730.0
2,1.483525e+09,41.301336,65.091966,376.553905,33.500000,1185.0,1067.501924,0.0,42,336.125659,...,0.000348,0.000252,0.000253,27.636904,21.803819,54.403405,50.046106,58.948036,58.898361,1730.0
3,1.483525e+09,41.328899,65.253475,376.536191,33.500000,2154.0,1108.812904,0.0,42,336.273850,...,0.000343,0.000246,0.000247,28.166676,21.966212,55.363248,50.761011,60.010819,59.940325,1730.0
4,1.483525e+09,41.328899,65.253475,376.536191,33.500000,2154.0,1110.524550,0.0,42,336.273850,...,0.000349,0.000251,0.000252,28.515913,22.481051,55.182131,50.641369,59.763522,59.708677,1730.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3930740,1.489745e+09,-42.757777,126.189354,386.590141,34.400002,24460.0,1946.386405,0.0,42,91.466481,...,0.000197,0.000132,0.000132,39.073977,29.911732,73.389050,66.531368,80.931946,80.931946,2853.0
3930741,1.489745e+09,-42.758912,126.358379,386.578593,34.400002,24778.0,2020.254206,0.0,42,91.614903,...,0.000191,0.000130,0.000130,40.109691,31.212337,74.295432,67.827269,81.857269,81.857269,2853.0
3930742,1.489745e+09,-42.758912,126.358379,386.578593,34.400002,24778.0,1960.692600,0.0,42,91.614903,...,0.000197,0.000133,0.000133,39.401131,30.389685,73.269094,66.799552,81.149886,81.139297,2853.0
3930743,1.489745e+09,-42.759779,126.527354,386.562382,34.400002,24223.0,2001.103329,0.0,42,91.763169,...,0.000202,0.000133,0.000133,39.691491,30.651826,72.450085,65.728584,80.722957,80.711165,2853.0


### Recreate dataset that includes the GRBs we've removed + Apply model

In [7]:
# init trainer with a dataset that doesn't filter out the
# GRBs -> we only use it to obtain the dataset with GRBs
# may take some time as it recreates the dataset
cfg.dataset.filename = root_filename
cfg.dataset.save_format = None
cfg.dataset.filter_conditions = ["rate[0]/rate_err[0] > 20"]
trainer_with_GRBs = Trainer(cfg)
dataset_full_GRBs = trainer_with_GRBs.dataset_full



[array([      0,       1,       2, ..., 3935515, 3935516, 3935517]), array([    250,     251,     506, ..., 3935227, 3935482, 3935483]), array([    252,     253,     508, ..., 3935229, 3935484, 3935485])]


Apply our model on that dataset (do predictions)

In [8]:
from torch.utils.data import Subset
## Prediction on full dataset with GRBs (e.g rate[0])
# Need to transform before inputting the whole set into the model
X = dataset_full_GRBs.X_cpu
dataset_tensor = dataset_full_GRBs.transform(X).to(device=cfg.common.device)

# Apply the model trained without GRBs on whole dataset including GRBs.
pred = trainer.model(dataset_tensor).detach().to("cpu")

# Remove unused tensors that are on GPU
dataset_tensor = dataset_tensor.to(device="cpu")
del dataset_tensor
torch.cuda.empty_cache()
gc.collect()

# Create a PyTorch Subset with all data
data_df = dataset_full_GRBs.data_df
subset_dataset_full_GRBs = Subset(dataset_full_GRBs,
                                  indices=range(dataset_full_GRBs.n_examples))
# Note: We do it because our functions require PyTorch Subsets as input

### Show residuals, pulls or normalized pulls

In [9]:
plt.close("all")

# Create plots but don't show them now, we will show them later
# via our interactive plots.
with plt.ioff():
    rate_err_names = []
    gcfs = []
    target_id_dict = {target_name: i for i, target_name in enumerate(cfg.dataset.target_names)}
    for target_name, i in target_id_dict.items():
        # Depending on the target_name, we plot different things,
        # Either the residual: target-prediction or
        # what we call normalized pull: pull=(target-prediction)/err_rate,
        #                               normalized_pull=pull/std
        # In both cases, we also plot their normalized histograms ('density').
        if target_name not in [f"rate[{i}]" for i in range(13)]:
            fig, fig2 = plot_val_residual(subset_dataset_full_GRBs, pred, target_name=target_name,
                                          save_path=None, save_path_hist=None)
        else:
            # rate[i] != rate[j] in general except if target_names included all the rates.
            j = re.findall("[0-9]+", target_name)[0]
            rate_err_names.append(f"rate_err[{j}]")
            fig, fig2 = plot_val_pull(subset_dataset_full_GRBs, pred, target_name=target_name,
                                      rate_err_name=rate_err_names[-1], save_path=None, save_path_hist=None,
                                      normalized=True)
        gcfs += [fig, fig2]
        # Note: Even though we called the function that has as name "plot_val_residual"
        # we apply it to the whole dataset including GRBs. Same thought process
        # for "plot_val_pull".

In [10]:
mpl.rcParams['agg.path.chunksize'] = 10000

Interactive plots using ipywidgets. See these links:
-  https://stackoverflow.com/questions/72271574/interactive-plot-of-dataframe-by-index-with-ipywidgets
- https://ipywidgets.readthedocs.io/en/stable/examples/Using%20Interact.html

In [11]:
@interact(target_name=list(cfg.dataset.target_names))
def residual_plot(target_name: str) -> mpl.figure.Figure:
    fig = gcfs[2*target_id_dict[target_name]]
    fig.set_figwidth(10)
    fig.set_figheight(5)
    fig.canvas.header_visible = False
    return fig

interactive(children=(Dropdown(description='target_name', options=('rate[0]', 'rate[1]', 'rate[5]', 'rate[6]',…

In [12]:
@interact(target_name=list(cfg.dataset.target_names))
def residual_hist(target_name: str) -> mpl.figure.Figure:
    fig = gcfs[2*target_id_dict[target_name]+1]
    fig.set_figwidth(10)
    fig.set_figheight(5)
    fig.canvas.header_visible = False
    return fig

interactive(children=(Dropdown(description='target_name', options=('rate[0]', 'rate[1]', 'rate[5]', 'rate[6]',…

We used `residual_plot` and `residual_hist` but what we actually plot are either the residuals, the pulls or normalized pulls.

### Data points of interest (in red)

In [13]:
from src.visualizer import get_all_time_y_y_hat, find_moments, get_columns

# We do the following instead of picking the target
# from data_df because of what was used in our model
# The model used PyTorch float tensors (single precision)
# Although the sorting isn't necessary in this particular case,
# our function also converts PyTorch tensors into NumPy arrays
sorted_time, sorted_y, sorted_y_hat = get_all_time_y_y_hat(subset_dataset_full_GRBs, pred)

# Target - prediction
residuals = sorted_y-sorted_y_hat
# We use the variable "var" to either talk about residuals or pulls
# By default, it's "residual"
var = residuals
var_name = "residual"

# Are our target_names of the form rate[i] ?
target_of_rate_form = np.isin(cfg.dataset.target_names, [f"rate[{i}]" for i in range(13)])

if ~np.any(target_of_rate_form):
    print(f"Thresholding using residuals", end="")
elif np.all(target_of_rate_form):
    rate_errs = get_columns(subset_dataset_full_GRBs, rate_err_names)
    pulls = residuals/rate_errs
    var = pulls
    var_name = "pull"
    print("Thresholding using residuals/rate_errs (pull)", end="")
else:
    raise NotImplementedError("Did not implement the case in which the targets are a mix of rate[i] and other")

# Modified gaussian fit
new_mean, new_std = list(zip(*[find_moments(var[:, j]) for j in range(var.shape[1])]))
new_mean, new_std = np.array(new_mean), np.array(new_std)

k = 5
print(", N° points of interest:\n", np.sum(var > k*new_std, axis=0))

Thresholding using residuals/rate_errs (pull), N° points of interest:
 [ 5139  5956 21411 22027 15531 15673]


In [14]:
new_std

array([1.3375583, 1.196147 , 1.3169341, 1.2481014, 1.5024753, 1.4970227],
      dtype=float32)

In [15]:
@interact(target_name=list(cfg.dataset.target_names),
          k=FloatSlider(value=k, min=1, max=7, step=0.5, continuous_update=False))
def points_of_interest(target_name: str, k: float) -> None:
    warnings.filterwarnings("ignore")
    id = target_id_dict[target_name]
    mask = var[:, id] > k*new_std[id]
    
    fig = plt.figure()
    fig.canvas.header_visible = False
    # Whole dataset in blue
    plt.plot(data_df["unix_time"], data_df[target_name],
         linewidth=0.05, label="whole data")
    # Points of interest in red
    plt.scatter(data_df[mask]["unix_time"],
            data_df[mask][target_name],
            color='r', s=0.1,
            label=fr"data s.t {var_name} > ${k:.1f}\sigma$")
    plt.title(f"{mask.sum()}"+\
                f" data points such that {var_name} "+\
                fr"> ${k:.1f}\sigma$, $\sigma\approx{new_std[id]:.4f}$")
    plt.legend()
    plt.show()
    warnings.filterwarnings("default")

interactive(children=(Dropdown(description='target_name', options=('rate[0]', 'rate[1]', 'rate[5]', 'rate[6]',…

Remark: it's maybe a better idea to use `sorted_y`.

#### Comparison with known GRBs

In [16]:
# Import 55 GRBs extracted from:
# https://www.researchgate.net/publication/326811280_Overview_of_the_GRB_observation_by_POLAR
GRBs = pd.read_csv("../data/GRBs.csv")
# Show 5 GRBs
GRBs.head()

Unnamed: 0,Number,GRB_Name,Trigger_time_UTC,unix_time
0,1,GRB_160924A,2016-09-24T06:04:09.040,1474697000.0
1,2,GRB_160928A,2016-09-28T19:48:05.000,1475092000.0
2,3,GRB_161009651,2016-10-09T15:38:07.190,1476027000.0
3,4,GRB_161011217,2016-10-11T05:13:44.420,1476163000.0
4,5,GRB_161012989,2016-10-12T23:45:11.380,1476316000.0


Let's only focus on GRBs happening in the same time range as our data

In [17]:
GRB_mask = data_df["unix_time"].min() <= GRBs["unix_time"].values
GRB_mask &= GRBs["unix_time"].values <= data_df["unix_time"].max()
print(f"Out of 55, there are: {GRB_mask.sum()} GRBs within our time range")

Out of 55, there are: 25 GRBs within our time range


Zooming into the GRBs with a window of `w` seconds before and `w` seconds after (`w=100` by default)

In [18]:
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt

# blue_line = mlines.Line2D([], [], color='blue', marker='*',
#                           markersize=15, label='Blue stars')

In [19]:
@interact(target_name=list(cfg.dataset.target_names),
          k=FloatSlider(value=k, min=1, max=7, step=0.5, continuous_update=False),
          w=IntSlider(value=100, min=2, max=1000, step=1, continuous_update=False))
def show_grbs(target_name: str, k: float, w: int) -> None:
     # Show 25 GRBs as well as our predictions and the points of interest
     warnings.filterwarnings("ignore")
     id = target_id_dict[target_name]

     # Labels used in the legend
     labels = ["original", "predicted", "GRB trigger time",
               fr"data s.t {var_name} > ${k:.1f}\sigma$"]

     with plt.ioff():
          fig, axs = plt.subplots(5, 5, figsize=(6, 8), sharey=True)
     fig.canvas.header_visible = False
     mpl.rcParams.update({'font.size': 5})  # 10 is default
     
     for i, GRB_name, GRB_tunix in zip(range(GRB_mask.sum()),
                         GRBs["GRB_Name"][GRB_mask],
                         GRBs["unix_time"][GRB_mask]):
          # Mask for GRB interval
          m = data_df["unix_time"] >= GRB_tunix-w
          m &= data_df["unix_time"] <= GRB_tunix+w
          m = m.values

          # Mask for GRB interval & red points
          mvar = m & (var[:, id] > k*new_std[id])
          
          if m.sum() > 0:
               t = data_df[m]["unix_time"]
               y = data_df[m][target_name]
               y_hat = pred[:, id][m]

               # Original
               h1, = axs[i//5, i%5].plot(t, y, 'blue', linewidth=0.4)
               # Prediction
               h2, = axs[i//5, i%5].plot(t, y_hat, 'black', linewidth=0.4)
               # GRB trigger time
               h3 = axs[i//5, i%5].vlines(GRB_tunix, 0, y.max(), 'g', linestyle="--")
               axs[i//5, i%5].set_title(GRB_name)

          if mvar.sum() > 0:
               t = data_df[mvar]["unix_time"]
               y = data_df[mvar][target_name]

               # Points of interest in red
               axs[i//5, i%5].scatter(t, y, color='r', alpha=0.25)
          
          axs[i//5, i%5].set_box_aspect(1)
          axs[i//5, i%5].tick_params(axis='both', which='major', labelsize=5)
          axs[i//5, i%5].tick_params(axis='both', which='minor', labelsize=4)
          axs[i//5, i%5].xaxis.offsetText.set_fontsize(5)
     
     h4 = mpatches.Circle([], [], color='r', alpha=0.25)
     plt.legend([h1, h2, h3, h4], labels, bbox_to_anchor=(1, -0.3), loc="upper right")
     plt.tight_layout()
     plt.show()
     mpl.rcParams.update({'font.size': 10})  # 10 is default
     warnings.filterwarnings("default")

interactive(children=(Dropdown(description='target_name', options=('rate[0]', 'rate[1]', 'rate[5]', 'rate[6]',…

Looking at the **normalized** (division by standard deviation) residuals or pulls

In [20]:
@interact(target_name=list(cfg.dataset.target_names),
          k=FloatSlider(value=k, min=1, max=7, step=0.5, continuous_update=False),
          w=IntSlider(value=100, min=2, max=1000, step=1, continuous_update=False))
def show_var(target_name: str, k: float, w: int) -> None:
     warnings.filterwarnings("ignore")
     id = target_id_dict[target_name]

     # Labels used in the legend
     labels = [f"{var_name}", "GRB trigger time", f"threshold {k}",
               fr"data s.t {var_name} > ${k:.1f}\sigma$"]

     with plt.ioff():
          fig, axs = plt.subplots(5, 5, figsize=(6, 8), sharey=True)
     fig.canvas.header_visible = False
     mpl.rcParams.update({'font.size': 5})  # 10 is default
     
     for i, GRB_name, GRB_tunix in zip(range(GRB_mask.sum()),
                         GRBs["GRB_Name"][GRB_mask],
                         GRBs["unix_time"][GRB_mask]):
          # Mask for GRB interval
          m = data_df["unix_time"] >= GRB_tunix-w
          m &= data_df["unix_time"] <= GRB_tunix+w
          m = m.values

          # Mask for GRB interval & red points
          mvar = m & (var[:, id] > k*new_std[id])
          
          if m.sum() > 0:
               t = data_df[m]["unix_time"]
               v_norm = var[:, id][m]/new_std[id]

               # Normalized residuals or pulls
               h1, = axs[i//5, i%5].plot(t, v_norm, 'k', linewidth=0.4)
               # GRB trigger time
               h2 = axs[i//5, i%5].vlines(GRB_tunix, v_norm.min(), v_norm.max(), 'g', linestyle="--")
               # 'Threshold line'
               h3, = axs[i//5, i%5].plot(t, k*np.ones_like(t),
                                   linewidth=0.2, color='gray', alpha=0.5,
                                   linestyle="--", zorder=0)
               axs[i//5, i%5].set_title(GRB_name)
               
          if mvar.sum() > 0:
               t = data_df[mvar]["unix_time"]
               v_norm = var[:, id][mvar]/new_std[id]

               # Points of interest in red
               axs[i//5, i%5].scatter(t, v_norm, color='r', alpha=0.25)
          
          axs[i//5, i%5].set_box_aspect(1)
          axs[i//5, i%5].tick_params(axis='both', which='major', labelsize=5)
          axs[i//5, i%5].tick_params(axis='both', which='minor', labelsize=4)
          axs[i//5, i%5].xaxis.offsetText.set_fontsize(5)
     
     h4 = mpatches.Circle([], [], color='r', alpha=0.25)
     plt.legend([h1, h2, h3, h4], labels, bbox_to_anchor=(1, -0.3), loc="upper right")
     plt.tight_layout()
     plt.show()
     mpl.rcParams.update({'font.size': 10})  # 10 is default
     warnings.filterwarnings("default")

interactive(children=(Dropdown(description='target_name', options=('rate[0]', 'rate[1]', 'rate[5]', 'rate[6]',…

### Clusters

In [23]:
import mpl_interactions.ipyplot as iplt
from mpl_interactions import indexer
from mpl_interactions.widgets import scatter_selector_index

In [250]:
def get_clusters(target_name, id, k, pred_below=True, discard_w=None):
    t = data_df["unix_time"].values
    t_flip = np.flip(t)

    ## Red points of interest. Id is used to select a particular output
    if pred_below:
        mask = var[:, id] > k*new_std[id]
    else:
        mask = var[:, id] < -k*new_std[id]

    # Time difference for masked data
    t_mask = data_df[mask]["unix_time"].values
    idx_mask = np.flatnonzero(mask)  # index from original timeline
    delta = (t_mask[1:]-t_mask[:-1]).astype(int)
    delta = np.concatenate([np.array([np.inf]), delta])
    
    # Time difference for masked flipped data
    t_flip_mask = np.flip(t_mask)
    idx_flip_mask = np.flatnonzero(np.flip(mask))  # index from flipped timeline
    delta_flip = (t_flip_mask[1:]-t_flip_mask[:-1]).astype(int)
    delta_flip = np.concatenate([np.array([-np.inf]), delta_flip])

    ## Starting points of clusters
    # Indices of the start points according to original timeline
    old_idx_points_starts = idx_mask[delta > 1]
    if discard_w:
        # Filter: if discard_w data points before are still within discard_w seconds before, keep the cluster
        t_sec_before_start = t[old_idx_points_starts]-discard_w
        t_data_before_start = np.concatenate([np.array([t[0]-i for i in range(discard_w, 0, -1)]), t])[old_idx_points_starts]
        idx_points_starts = np.intersect1d(old_idx_points_starts[t_data_before_start.astype(int) >= t_sec_before_start.astype(int)],\
                                        old_idx_points_starts)
        cluster_idxs_kept = np.flatnonzero(np.isin(old_idx_points_starts, idx_points_starts))
    else:
        idx_points_starts = old_idx_points_starts

    ## Ending points of clusters
    # Indices of the end points according to original timeline
    old_idx_points_ends = np.flip((t.size - 1) - idx_flip_mask[delta_flip < -1])
    if discard_w:
        # Keep the ends if the starts were kept
        old_idx_points_ends = old_idx_points_ends[cluster_idxs_kept]

        # Do something similar to "starts" but on flipped timeline.
        # Filter: if discard_w data points after are still within discard_w seconds after, keep the cluster
        t_sec_after_end = t_flip[old_idx_points_ends]+discard_w
        t_data_after_end = np.concatenate([np.array([t[0]+i for i in range(discard_w, 0, -1)]), t_flip])[old_idx_points_ends]
        idx_points_ends = np.intersect1d(old_idx_points_ends[t_data_after_end.astype(int) <= t_sec_after_end.astype(int)],\
                                        old_idx_points_ends)
        cluster_idxs_kept = np.flatnonzero(np.isin(old_idx_points_ends, idx_points_ends))
    
        # Keep the starts if the ends were kept
        idx_points_starts = idx_points_starts[cluster_idxs_kept]
    else:
        idx_points_ends = old_idx_points_ends
         
    # Get the start and end times !
    starts = data_df["unix_time"].iloc[idx_points_starts]
    ends = data_df["unix_time"].iloc[idx_points_ends]
    
    # Recreate mask because of removed clusters
    new_mask = np.zeros_like(mask, dtype=bool)
    for begin, end in zip(idx_points_starts, idx_points_ends):
        new_mask[begin:end+1] = mask[begin:end+1]
    mask = new_mask

    # Recreate times because of removed clusters
    times = data_df[mask]["unix_time"]
    times = times.reset_index(drop=True)

    # Adding a label to each group/cluster
    # idx_starts = starts.index.values
    idx_starts = idx_points_starts
    idx_ends = idx_points_ends
    groups = np.repeat(np.arange(len(idx_starts)), idx_ends - idx_starts + 1)
    
    # XXX: only for the red points of interest !
    data = data_df[mask]
    data["group"] = groups
    data["target_id"] = id

    # Two properties of clusters
    integrals = data.groupby("group")[target_name].sum().values
    lengths = ends.values-starts.values
    data.drop(columns=["group"], inplace=True)
    
    return mask, data, times, starts, ends, groups, integrals, lengths

In [251]:
def get_cluster_data(t_points, y_points, v_points,
                     starts, ends, cluster_idx):
    start, end = starts[cluster_idx], ends[cluster_idx]
    m_points = (t_points>=start) & (t_points<=end)
    return (t_points[m_points],
            y_points[m_points],
            v_points[m_points])

In [252]:
@interact(target_name=list(cfg.dataset.target_names),
          k=FloatSlider(value=5, min=1, max=7, step=0.5, continuous_update=False),
          w=IntSlider(value=100, min=0, max=int(7e6), step=1, continuous_update=False),
          pred_below=IntSlider(value=1, min=0, max=1, step=1, continuous_update=False),
          discard_w=IntSlider(value=0, min=0, max=500, step=1, continuous_update=False))
def plot_clusters(target_name: str, k: float, w: int, pred_below: bool, discard_w: int) -> None:
     warnings.filterwarnings("ignore")
     id = target_id_dict[target_name]

     out = get_clusters(target_name, id, k,
                        pred_below=bool(pred_below), discard_w=discard_w)
     m_points, data, times, starts, ends, groups, integrals, lengths = out
     starts = starts.values.reshape(-1, 1)
     ends = ends.values.reshape(-1, 1)
     
     t = data_df["unix_time"].values.reshape(-1, 1)
     y = data_df[target_name].values.reshape(-1, 1)
     y_hat = pred[:, id].numpy().reshape(-1, 1)
     v = var[:, id].reshape(-1, 1)
     t_points = times.values.reshape(-1, 1)
     y_points = data[target_name].values.reshape(-1, 1)
     v_points = var[m_points, id].reshape(-1, 1)
     
     all_clusters_data = [get_cluster_data(t_points, y_points, v_points,
                                           starts, ends, idx) for idx in range(len(starts))]
     def get_cluster_times(idx):
          m = (t>=starts[idx]-w) & (t<=ends[idx]+w)
          min_y, max_y, s_y = y[m].min(), y[m].max(), y[m].std()
          min_v, max_v, s_v = v[m].min(), v[m].max(), v[m].std()

          # Also updates the axis limits
          axs[0, 1].set_xlim([starts[idx]-w-t[0], ends[idx]+w-t[0]])
          axs[0, 1].set_ylim([min_y-s_y, max_y+s_y])
          axs[1, 1].set_xlim([starts[idx]-w-t[0], ends[idx]+w-t[0]])
          axs[1, 1].set_ylim([(min_v-s_v)/new_std[id], (max_v+s_v)/new_std[id]])
          
          return all_clusters_data[idx][0]-t[0]

     def get_cluster_targets(x, idx):
          return all_clusters_data[idx][1]

     def get_cluster_normalized_vars(x, idx):
          return all_clusters_data[idx][2]/new_std[id]

     def get_cluster_length(idx):
          return lengths[idx].reshape(-1, 1)
          
     def get_cluster_integral(x, idx):
          return integrals[idx].reshape(-1, 1)

     # Let's plot
     with plt.ioff():
          fig, axs = plt.subplots(2, 2, figsize=(6, 6))  # fig = plt.figure()
     fig.canvas.header_visible = False
     mpl.rcParams.update({'font.size': 5})  # 10 is default

     # https://mpl-interactions.readthedocs.io/en/stable/examples/scatter-selector.html
     fig.canvas.header_visible = False

     # Original
     axs[0, 1].plot(t-t[0], y, linewidth=1, color="b")
     # Prediction
     axs[0, 1].plot(t-t[0], y_hat, linewidth=1, color="k")
     # Normalized Residuals or Pulls (division by std)
     axs[1, 1].plot(t-t[0], v/new_std[id], linewidth=1, color="k")
     # GRB trigger times
     for GRB_tunix in GRBs["unix_time"][GRB_mask]:
          axs[0, 1].vlines(GRB_tunix-t[0], y.min(), y.max(), 'g', linestyle="--")
          # 'Threshold line'
          axs[1, 1].plot(t-t[0], (2*pred_below-1)*k*np.ones_like(t), linewidth=2, color='gray', alpha=0.5,
                         linestyle="--", zorder=0)

     ## Interactive part:
     # Scatter plot of clusters using as coordinates (length, integral)
     index = scatter_selector_index(axs[0, 0], lengths, integrals, color='b', s=5)
     # Points of interest in red
     controls = iplt.scatter(get_cluster_times, get_cluster_targets, idx=index,
                    color='r', alpha=0.25, xlim="auto", ylim="auto", ax=axs[0, 1])
     
     with controls:
          # See: https://mpl-interactions.readthedocs.io/en/stable/examples/plot.html#styling-of-plot
          # Interactive title
          iplt.title(indexer([f"Cluster n°{i}:\n"+\
                              f"begin={(starts[i]-t[0]).item()}+t[0],\n"+\
                              f"end={(ends[i]-t[0]).item()}+t[0],\nt[0]={t[0].item()}\n"+\
                              fr"length={lengths[i]}, integral$\approx${integrals[i]:.5f}" for i in range(len(starts))]),
                         fontsize=5, ax=axs[0, 1])
          # Points of interest in red
          iplt.scatter(get_cluster_times, get_cluster_normalized_vars,
                         color='r', alpha=0.25, xlim="auto", ylim="auto", ax=axs[1, 1])

          # Show selected cluster "point"
          iplt.scatter(get_cluster_length, get_cluster_integral, 
                         color="k", s=5, ax=axs[0, 0])

     # Common settings for all axes
     for ax in [axs[i, j] for i in range(2) for j in range(2)]:
          ax.tick_params(axis='both', which='major', labelsize=5)
          ax.tick_params(axis='both', which='minor', labelsize=4)
          ax.xaxis.offsetText.set_fontsize(5)
          ax.yaxis.offsetText.set_fontsize(5)
          ax.grid("on")

     # Particular settings for different axes
     axs[0, 0].set_xlabel("Length [s]", fontsize=5)
     axs[0, 0].set_ylabel("Integral [Hz]", fontsize=5)

     axs[0, 1].set_xlabel("Unix time - t[0] [s]", fontsize=5)
     axs[0, 1].set_ylabel(target_name, fontsize=5)

     axs[1, 1].set_xlabel("Unix time - t[0] [s]", fontsize=5)
     axs[1, 1].set_ylabel("Normalized " + var_name + ": "+ target_name, fontsize=5)

     # Ignore lower left plot
     axs[1, 0].axis('off')

     # 'Fixed' titles
     axs[0, 0].set_title(f"{len(starts)} clusters obtained from data having {var_name}"+\
                         fr" {'> ' if bool(pred_below) else '< -'}${k:.1f}\sigma$"+\
               fr", $\sigma\approx{new_std[id]:.4f}$")
     axs[1, 1].set_title(f"Normalized {var_name} : {target_name}")

     # Displacing the scientific notation text of y axis
     axs[0, 0].yaxis.get_offset_text().set_position((-0.05,0))
     plt.tight_layout()
     plt.show()
     mpl.rcParams.update({'font.size': 10})  # 10 is default
     warnings.filterwarnings("default")

interactive(children=(Dropdown(description='target_name', options=('rate[0]', 'rate[1]', 'rate[5]', 'rate[6]',…

Note that the red points are only shown for the particular cluster

### Intersection of clusters between different `target_name`

In [254]:
out = [get_clusters(target_name, id, 5, pred_below=1, discard_w=500) for id, target in enumerate(cfg.dataset.target_names)]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data["group"] = groups
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data["target_id"] = id


In [255]:
m_points, data, times, starts, ends, groups, integrals, lengths = zip(*out)

In [290]:
a = pd.DataFrame([[0, 1, 2], [1, 2, 3], [2, 3, 4]], columns=["0", "1", "2"])
a

Unnamed: 0,0,1,2
0,0,1,2
1,1,2,3
2,2,3,4


In [301]:
from functools import reduce
# https://stackoverflow.com/questions/42940507/merging-dataframes-keeping-all-items-pandas
# https://stackoverflow.com/questions/38978214/merge-a-list-of-dataframes-to-create-one-dataframe
df = reduce(lambda df1,df2: pd.merge(df1, df2, on=list(data[0].keys())[:-1], how="outer"), data)

  df = reduce(lambda df1,df2: pd.merge(df1, df2, on=list(data[0].keys())[:-1], how="outer"), data)


In [302]:
df["# inter"] = df.drop(columns=list(data[0].keys())[:-1]).count(axis=1)

See https://numpy.org/devdocs/release/1.25.0-notes.html and the docs for more information.  (Deprecated NumPy 1.25)
  common = np.find_common_type(


In [303]:
df

Unnamed: 0,unix_time,glat,glon,altitude,temperature,fe_cosmic,corrected,correrr,config,raz,...,rate[10]/rate_err[10],rate[12]/rate_err[12],(unix_time-1474004181.5460)//5535.4+10,target_id_x,target_id_y,target_id_x.1,target_id_y.1,target_id_x.2,target_id_y.2,# inter
0,1.483590e+09,-32.258670,37.436699,387.654012,29.299999,9667.0,2222.749497,0.0,42,271.839385,...,89.412637,89.407936,1741.0,0.0,1.0,2.0,3.0,4.0,5.0,6
1,1.483590e+09,-32.064056,37.811630,387.668537,29.299999,9032.0,2149.282588,0.0,42,271.839385,...,87.523686,87.518081,1741.0,0.0,1.0,2.0,3.0,4.0,5.0,6
2,1.483600e+09,-41.076786,-72.475306,383.902241,30.900000,5304.0,4070.938945,0.0,42,271.839385,...,117.763691,117.763691,1743.0,0.0,1.0,2.0,,4.0,5.0,5
3,1.483633e+09,-36.897398,131.610246,382.268808,31.000000,14658.0,2408.535828,0.0,42,133.537122,...,89.106238,89.106238,1749.0,0.0,1.0,2.0,3.0,4.0,5.0,6
4,1.483633e+09,-36.948991,131.751618,382.276959,31.000000,15392.0,2446.716802,0.0,42,133.683225,...,89.446288,89.446288,1749.0,0.0,1.0,2.0,3.0,4.0,5.0,6
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7525,1.485691e+09,-41.705302,81.701082,380.848342,35.799999,20550.0,2540.006361,0.0,42,26.418447,...,91.826573,91.826573,2121.0,,,,,,5.0,1
7526,1.485850e+09,-41.919246,94.423579,383.570904,36.000000,22255.0,2670.001972,0.0,42,350.586855,...,94.930722,94.930722,2150.0,,,,,,5.0,1
7527,1.485938e+09,-39.003290,65.397160,384.959610,35.599998,15201.0,2665.098570,0.0,42,331.582421,...,98.324240,98.324240,2165.0,,,,,,5.0,1
7528,1.486623e+09,-42.678619,64.184940,386.947034,35.500000,18994.0,1857.865371,0.0,42,304.996987,...,80.854961,80.848786,2289.0,,,,,,5.0,1


In [191]:
idxs = ["4", "1", "2", "3", "2", "1", "3"]
arr = np.array([[0, 1, 1.5, 2.6, 2.9, 3, 3.5],
                [6, 2, 2.5, 2.7, 4.1, 4, 4.6]])
                
# for i in range(1, arr.shape[0]):
i = 1
while i < arr.shape[1]:
    print(i, arr.shape[1], len(idxs))
    print(pd.DataFrame(arr, columns=idxs))
    if i > 15:
        break
    start_new = arr[0, i]
    end_new = arr[1, i]
    
    end_prev = arr[1, i-1]
    if start_new < end_prev:
        # starts.append(start_new)
        # ends.append(min(end_prev, end_new))
        # 10*arr[0, i-1] + arr[0, i]
        arr = np.insert(arr, i, np.array([start_new, min(end_prev, end_new)]), axis=1)
        idxs.insert(i, "".join(set(idxs[i-1]).union(set(idxs[i]))))
        # print(arr.shape[1], len(idxs))
        i += 1
        if end_new > end_prev:
            # arr = np.insert(arr, , np.array([arr[0, i], end_prev, end_new]), axis=1)
            arr = np.hstack([arr, np.array([[end_prev, end_new]]).T])
            idxs.append(idxs[i])

            argsort = np.argsort(arr[1, i:])
            arr[:, i:] = arr[:, i:][:, argsort]
            idxs[i:] = list(np.array(idxs[i:])[argsort])
            i += 1
        else:
            # arr = np.insert(arr, , np.array([arr[0, i-1], end_new, end_prev]), axis=1)
            arr = np.hstack([arr, np.array([[end_new, end_prev]]).T])
            idxs.append(idxs[i-1])
        
            argsort = np.argsort(arr[1, i:])
            arr[:, i:] = arr[:, i:][:, argsort]
            idxs[i:] = list(np.array(idxs[i:])[argsort])
            i += 1
        # ids.append((10*arr[0, i-1] + arr[0, i]))
    i += 1
    # print(arr)
    # print(i)

1 7 7
     4    1    2    3    2    1    3
0  0.0  1.0  1.5  2.6  2.9  3.0  3.5
1  6.0  2.0  2.5  2.7  4.1  4.0  4.6
4 9 9
     4   14    1    2    3    1    2    3   14
0  0.0  1.0  1.0  1.5  2.6  3.0  2.9  3.5  2.0
1  6.0  2.0  2.0  2.5  2.7  4.0  4.1  4.6  6.0
5 9 9
     4   14    1    2    3    1    2    3   14
0  0.0  1.0  1.0  1.5  2.6  3.0  2.9  3.5  2.0
1  6.0  2.0  2.0  2.5  2.7  4.0  4.1  4.6  6.0
6 9 9
     4   14    1    2    3    1    2    3   14
0  0.0  1.0  1.0  1.5  2.6  3.0  2.9  3.5  2.0
1  6.0  2.0  2.0  2.5  2.7  4.0  4.1  4.6  6.0
9 11 11
     4   14    1    2    3    1   12    2    2    3   14
0  0.0  1.0  1.0  1.5  2.6  3.0  2.9  2.9  4.0  3.5  2.0
1  6.0  2.0  2.0  2.5  2.7  4.0  4.0  4.1  4.1  4.6  6.0
12 13 13
     4   14    1    2    3    1   12    2    2   23    3    3   14
0  0.0  1.0  1.0  1.5  2.6  3.0  2.9  2.9  4.0  3.5  3.5  4.1  2.0
1  6.0  2.0  2.0  2.5  2.7  4.0  4.0  4.1  4.1  4.1  4.6  4.6  6.0


In [193]:
pd.DataFrame(arr, columns=idxs)

Unnamed: 0,4,14,1,2,3,1.1,12,2.1,2.2,23,3.1,3.2,143,14.1,14.2
0,0.0,1.0,1.0,1.5,2.6,3.0,2.9,2.9,4.0,3.5,3.5,4.1,2.0,2.0,4.6
1,6.0,2.0,2.0,2.5,2.7,4.0,4.0,4.1,4.1,4.1,4.6,4.6,4.6,6.0,6.0


In [195]:
pd.DataFrame(np.array([[0, 1, 1.5, 2.6, 2.9, 3, 3.5],
            [6, 2, 2.5, 2.7, 4.1, 4, 4.6]]), columns=["4", "1", "2", "3", "2", "1", "3"])

Unnamed: 0,4,1,2,3,2.1,1.1,3.1
0,0.0,1.0,1.5,2.6,2.9,3.0,3.5
1,6.0,2.0,2.5,2.7,4.1,4.0,4.6


## Playing with the clusters manually:

In [106]:
def get_clusters_and_more(target_name, k, pred_below):
    id = target_id_dict[target_name]

    out = get_clusters(target_name, id, k, pred_below=bool(pred_below))
    m_points, data, times, starts, ends, groups, integrals, lengths = out
    starts = starts.values.reshape(-1, 1)
    ends = ends.values.reshape(-1, 1)

    t = data_df["unix_time"].values.reshape(-1, 1)
    y = data_df[target_name].values.reshape(-1, 1)
    y_hat = pred[:, id].numpy().reshape(-1, 1)
    v = var[:, id].reshape(-1, 1)
    t_points = times.values.reshape(-1, 1)
    y_points = data[target_name].values.reshape(-1, 1)
    v_points = var[m_points, id].reshape(-1, 1)
    return (m_points, data, times, starts, ends, groups, integrals, lengths,
            t, y, y_hat, v, t_points, y_points, v_points)

In [107]:
out = get_clusters_and_more(target_name="rate[0]", k=5, pred_below=1)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data["group"] = groups


In [108]:
(m_points, data, times, starts, ends, groups, integrals, lengths,
            t, y, y_hat, v, t_points, y_points, v_points) = out

In [109]:
m_points.sum()

5139

In [111]:
data

Unnamed: 0,unix_time,glat,glon,altitude,temperature,fe_cosmic,corrected,correrr,config,raz,...,1/rate_err[10]**2,1/rate_err[12]**2,rate[0]/rate_err[0],rate[1]/rate_err[1],rate[5]/rate_err[5],rate[6]/rate_err[6],rate[10]/rate_err[10],rate[12]/rate_err[12],(unix_time-1474004181.5460)//5535.4+10,group
14013,1.483564e+09,41.664079,-96.653040,376.699524,33.900002,4137.0,2054.359438,0.0,42,335.959749,...,0.000135,0.000135,40.587224,31.687531,73.342335,66.562366,80.298947,80.293807,1737.0,0
16830,1.483570e+09,41.731025,-119.623278,376.726384,33.900002,12918.0,5198.924021,0.0,42,336.060392,...,0.000043,0.000043,63.482672,55.632905,126.989008,120.027890,129.577507,129.577507,1738.0,1
16831,1.483570e+09,41.731025,-119.623278,376.726384,33.900002,12918.0,5208.265132,0.0,42,336.060392,...,0.000044,0.000044,64.399478,55.930942,129.397695,121.757208,132.287308,132.287308,1738.0,1
16865,1.483570e+09,42.092027,-116.822549,376.477686,33.900002,12402.0,5730.374674,0.0,42,338.613419,...,0.000038,0.000038,66.830960,57.496569,134.286194,125.766520,138.022530,138.022530,1738.0,2
19017,1.483575e+09,42.750140,-130.034266,375.686263,34.099998,13284.0,3125.119631,0.0,42,347.467124,...,0.000082,0.000082,50.658004,41.385459,96.732952,88.783724,102.149203,102.149203,1739.0,3
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3882091,1.489684e+09,-42.591516,30.926685,385.922039,32.400002,15814.0,2975.573251,0.0,42,101.644437,...,0.000088,0.000088,49.663783,41.128444,93.834698,86.497816,99.373675,99.373675,2842.0,1936
3882092,1.489684e+09,-42.591516,30.926685,385.922039,32.400002,15814.0,2771.838998,0.0,42,101.644437,...,0.000092,0.000092,48.144650,39.767866,91.870832,84.806185,97.267368,97.267368,2842.0,1936
3882093,1.489684e+09,-42.581735,31.094718,385.911426,32.400002,14959.0,2789.120030,0.0,42,101.793397,...,0.000092,0.000092,48.258012,39.878870,91.589832,84.325020,96.819318,96.819318,2842.0,1936
3882098,1.489684e+09,-42.561338,31.430590,385.891415,32.400002,14379.0,2658.154941,0.0,42,102.092133,...,0.000098,0.000098,46.901229,38.271249,87.711620,80.822242,93.523063,93.523063,2842.0,1937


**TODO:** Use https://omegaconf.readthedocs.io/en/latest/usage.html#save-load-yaml-file to save new filters