### FRB classification of all types of morphology 

This notebook explores CNN-based classification of multiple classes:
type A (I), type B (II), type C (III), type C1 (IV) and type C2 (V) morphology of the FRBs as described in Kumar et al. 2025

### Notes
- Training data is generated using simulation scripts in the folder frabjous_sim.
- This notebook is used  to obtain the hyperparameter optimised models.
- the type I here is refered to as type A 
- the type II here is refered to as type B
- the type II here is refered to as type C
- the type IV here is refered to as type C1
- the type V here is refered to as type C2

In [None]:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Sequential
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from sklearn.metrics import precision_recall_curve, roc_curve, auc
from sklearn.metrics import  ConfusionMatrixDisplay

import argparse
import os
import sys
import locale
import time
import datetime
import json 
import glob
from pathlib import Path

import keras
import keras.backend as K
import keras.layers as KL
import keras.models as KM
import keras.optimizers as KO
import keras.callbacks as KC
import keras.utils as KU
import keras.preprocessing.image as KI
from keras.layers import Dense, Dropout, Flatten, Activation, Concatenate,Input
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D 
from keras.models import Model
import keras_tuner

You can either generate your own simulated dataset using the scripts in `frabjous_sim/` or download a prepared archive:

Download: https://drive.google.com/file/d/1gABgy59nyHcNOffODCyzDRfNvvt4t29B/view?usp=drive_link

After downloading, unzip the archive so the `simfrb/` directory appears under `BASE_DIR`.

### Expected Data Directory Structure

<pre>
&lt;BASE_DIR&gt;/
├── simfrb/
│   ├── simdata/
│   │   ├── type_A/
│   │   │   ├── SNR_15/
│   │   │   ├── SNR_25/
│   │   │   ├── SNR_35/
│   │   │   ├── SNR_50/
│   │   │   └── SNR_100/
│   │   └── type_B/
│   │       ├── SNR_15/
│   │       ├── SNR_25/
│   │       ├── SNR_35/
│   │       ├── SNR_50/
│   │       └── SNR_100/
|   |   .
|   |   .
|   |   .
</pre>

### Notes
- If you don't have simulated data, you can download the prepared dataset (see the next cell).
- Place the unzipped `simfrb` folder under `BASE_DIR` so the data-loading helpers find files automatically.

In [None]:
# type1, type2, type3, type4, type5 define the five morphology labels used throughout the notebook (e.g., type1 = 'A', type2 = 'B', etc.).
# These variables are used by the data-loading functions to construct file paths to locate the corresponding `.npz` and `.json` files.
# Changing these values will change which classes are loaded for training, validation, and testing.
type1 = 'A'
type2 = 'B'
type3 = 'C'
type4 = 'C1'
type5 = 'C2'

In [None]:
tf.keras.backend.experimental.enable_tf_random_generator()
tf.keras.utils.set_random_seed(1354534)

In [None]:
# Path to the simulation data directory
SIMFRB_DIR = Path(BASE_DIR) / "simfrb"

In [None]:
# function to read indices from a file
def read_indices(fname):
    indices = []
    with open(SIMFRB_DIR / fname) as f:
        for line in f:
            line = line.strip()
            if line:
                indices.append(int(line))
    return indices

In [None]:
# function to read simulated FRB images for all five classes for a given SNR and apply removal indices
def read_data(snr, remove_indices_C1, remove_indices_C2):
    """
    Load simulated FRB 2-D dynamic spectra and labels for multiple archetypes.

    Parameters
    ----------
    snr : str
        Signal-to-noise ratio label (e.g., '50').
    remove_indices_C1 : list of int
        Indices of simulated FRBs to exclude for class C1.
    remove_indices_C2 : list of int
        Indices of simulated FRBs to exclude for class C2.

    Returns
    -------
    labels : list
        One-hot labels for the multiclass problem.
    images : list
        Normalized 2D FRB dynamic spectra.
    """

    base_path = os.path.join(BASE_DIR, "simfrb", "simdata", "type_")

    # ---- Load metadata ----
    with open(f"{base_path}{type1}/SNR_{snr}/frb_header_type_{type1}.json") as f:
        class1_labels = json.load(f)

    with open(f"{base_path}{type2}/SNR_{snr}/frb_header_type_{type2}.json") as f:
        class2_labels = json.load(f)

    # ---- Load image data ----
    frbd1 = np.load(
        f"{base_path}{type1}/SNR_{snr}/simulatefrbs_type_{type1}.npz"
    )["arr_0"]

    frbd2 = np.load(
        f"{base_path}{type2}/SNR_{snr}/simulatefrbs_type_{type2}.npz"
    )["arr_0"]

    frbd3 = np.load(
        f"{base_path}{type3}/SNR_{snr}/simulatefrbs_type_{type3}.npz"
    )["arr_0"]

    frbd4 = np.load(
        f"{base_path}{type4}/SNR_{snr}/simulatefrbs_type_{type4}.npz"
    )["arr_0"]

    frbd5 = np.load(
        f"{base_path}{type5}/SNR_{snr}/simulatefrbs_type_{type5}.npz"
    )["arr_0"]

    frbdn = []
    frbdl = []

    for i in range(len(frbd1)):
        if i in remove_indices_C1:
            continue
        else:
            immax = frbd1[i].max() 
            frbdn.append(frbd1[i]/(immax/255))
            frbdl.append([1.,0.,0.,0.,0.])

            immax = frbd2[i].max()
            frbdn.append(frbd2[i]/(immax/255))
            frbdl.append([0.,1.,0.,0.,0.])

            immax = frbd3[i].max()
            frbdn.append(frbd3[i]/(immax/255))
            frbdl.append([0.,0.,1.,0.,0.])
                            
            immax = frbd4[i].max()
            frbdn.append(frbd4[i]/(immax/255))
            frbdl.append([0.,0.,0.,1.,0.])

    for i in range(len(frbd5)):
        if i in remove_indices_C2:
            continue
        else:
            immax = frbd5[i].max()
            frbdn.append(frbd5[i]/(immax/255))
            frbdl.append([0.,0.,0.,0.,1.])

    return frbdl, frbdn

In [None]:
# function to  Load CHIME FRB images and correct labels (interpolated examples).
def read_chime_data(data_file, label_file):
    """
    Reads interpolated CHIME FRB dynamic spectra and label file,
    normalizes images and returns per-class train/test splits.

    Parameters
    ----------
    data_file : str
        Path to the NumPy `.npz` file containing FRB images (arr_0).
    label_file : str
        Path to a text file containing morphology labels (one per line: A, B, C, C1, or C2).

    Returns
    -------
    labels_train, images_train, labels_test, images_test
    """

    # Containers for each morphology class
    frbdn1, frbdn2, frbdn3, frbdn4, frbdn5 = [], [], [], [], []
    frbdl1, frbdl2, frbdl3, frbdl4, frbdl5 = [], [], [], [], []

    # ---- Load data ----
    with np.load(data_file) as frbdata:
        frbd = frbdata["arr_0"]

    with open(label_file) as f:
        correct_type = f.readlines()

    # ---- Assign images and labels ----
    for i in range(len(correct_type)):
        immax = frbd[i].max()
        norm_img = frbd[i] / (immax / 255)

        if correct_type[i] == "A\n":
            frbdn1.append(norm_img)
            frbdl1.append([1., 0., 0., 0., 0.])

        if correct_type[i] == "B\n":
            frbdn2.append(norm_img)
            frbdl2.append([0., 1., 0., 0., 0.])

        if correct_type[i] == "C\n":
            frbdn3.append(norm_img)
            frbdl3.append([0., 0., 1., 0., 0.])

        if correct_type[i] == "C1\n":
            frbdn4.append(norm_img)
            frbdl4.append([0., 0., 0., 1., 0.])

        if correct_type[i] == "C2\n":
            frbdn5.append(norm_img)
            frbdl5.append([0., 0., 0., 0., 1.])

    # ---- Train / test split (per class) ----
    lenA  = int(len(frbdn1) / 2)
    lenB  = int(len(frbdn2) / 2)
    lenC  = int(len(frbdn3) / 2)
    lenC1 = int(len(frbdn4) / 2)
    lenC2 = int(len(frbdn5) / 2)

    images_train = (
        frbdn1[:lenA] +
        frbdn2[:lenB] +
        frbdn3[:lenC] +
        frbdn4[:lenC1] +
        frbdn5[:lenC2]
    )

    labels_train = (
        frbdl1[:lenA] +
        frbdl2[:lenB] +
        frbdl3[:lenC] +
        frbdl4[:lenC1] +
        frbdl5[:lenC2]
    )

    images_test = (
        frbdn1[lenA:] +
        frbdn2[lenB:] +
        frbdn3[lenC:] +
        frbdn4[lenC1:] +
        frbdn5[lenC2:]
    )

    labels_test = (
        frbdl1[lenA:] +
        frbdl2[lenB:] +
        frbdl3[lenC:] +
        frbdl4[lenC1:] +
        frbdl5[lenC2:]
    )

    return labels_train, images_train,  labels_test, images_test


In [None]:

# Load removal index files for C1 (type IV) examples that should be excluded from training
remove_indices_C1_100 = read_indices("remove_type_C1_SNR_100.txt")
remove_indices_C1_50  = read_indices("remove_type_C1_SNR_50.txt")
remove_indices_C1_35  = read_indices("remove_type_C1_SNR_35.txt")
remove_indices_C1_25  = read_indices("remove_type_C1_SNR_25.txt")
indices_C1_15 = read_indices("type_C1_SNR_15.txt")
remove_indices_C1_15 = [x for x in np.arange(0, 1000) if x not in indices_C1_15]

# Load indices for C2 (type V) examples to exclude
remove_indices_C2_100 = read_indices("type_C2_SNR_100.txt")
remove_indices_C2_50  = read_indices("type_C2_SNR_50.txt")
remove_indices_C2_35  = read_indices("type_C2_SNR_35.txt")
remove_indices_C2_25  = read_indices("type_C2_SNR_25.txt")
indices_C2_15 = read_indices("type_C2_SNR_15.txt")
remove_indices_C2_15 = [x for x in np.arange(0, 1000) if x not in indices_C2_15]

Loading the training and testing dataset

In [None]:
# Load simulated data at multiple SNR levels and combine them
# SNR_100
frbdl, frbdn = read_data('100', remove_indices_C1_100, remove_indices_C2_100)

# SNR_50
temp_labels, temp_data = read_data('50', remove_indices_C1_50, remove_indices_C2_50)
frbdl = frbdl + temp_labels[:]
frbdn = frbdn + temp_data[:]

# SNR_35
temp_labels, temp_data = read_data('35', remove_indices_C1_35, remove_indices_C2_35)
frbdl = frbdl + temp_labels[:]
frbdn = frbdn + temp_data[:]

# SNR_25
temp_labels, temp_data = read_data('25', remove_indices_C1_25, remove_indices_C2_25)
frbdl = frbdl + temp_labels[:]
frbdn = frbdn + temp_data[:]

# SNR_15
temp_labels, temp_data = read_data('15', remove_indices_C1_15, remove_indices_C2_15)
frbdl = frbdl + temp_labels[:]
frbdn = frbdn + temp_data[:]

# reading the chime data
## chime bursts with interpolation and their label file are provided in the directory
chime_labels, chime_data, chime_test_labels, chime_test = read_chime_data('files/chime_interp_frbs.npz','files/chime_labels.txt')
frbdl = frbdl + chime_labels[:]
frbdn = frbdn + chime_data[:]

In [None]:
frbdn = np.asarray(frbdn)     
frbdn.shape += 1,
frbdl = np.asarray(frbdl)

In [None]:
# split the dataset into train and test sets
split = train_test_split(frbdl, frbdn, test_size=0.20, random_state=42)
(trainAttrX, testingAttrX, trainImagesX, testingImagesX) = split

In [None]:
#split the training set further into train and validation sets
split = train_test_split(trainAttrX, trainImagesX, test_size=0.2, random_state=42)
(trainAttrX, testAttrX, trainImagesX, testImagesX) = split

In [None]:
# fuction to build model which is used for training and hyperparameter tuning

def model_builder(hp):
    """
    Build a 2D CNN classifier for FRB morphology classification.

    Parameters
    ----------
    hp : keras_tuner.HyperParameters
        Hyperparameter object used by Keras Tuner.

    Returns
    -------
    model : keras.Model
        Compiled Keras model.
    """

    # Input
    input_2d = Input(shape=(256, 256, 1))

    # CNN layers 
    x = Conv2D(filters=32, kernel_size=3, activation='relu')(input_2d)
    x = MaxPooling2D(pool_size=2)(x)

    x = Conv2D(filters=32, kernel_size=3, activation='relu')(x)
    x = MaxPooling2D(pool_size=2)(x)

    x = Conv2D(filters=32, kernel_size=3, activation='relu')(x)
    x = MaxPooling2D(pool_size=5)(x)

    x = Flatten()(x)

    # Fully connected layers 
    x = Dense(
        units=hp.Choice("units_1", values=[64, 128, 256]),
        activation="relu",
    )(x)

    x = Dropout(
        hp.Choice("dropout1", values=[0.1, 0.2, 0.3])
    )(x)

    x = Dense(
        units=hp.Choice("units_2", values=[16, 32, 64]),
        activation="relu",
    )(x)

    x = Dropout(
        hp.Choice("dropout2", values=[0.1, 0.2, 0.3])
    )(x)

    # Final layer
    output = Dense(units=5, activation='softmax')(x)

    # ---- Model ----
    model = Model(inputs=input_2d, outputs=output)
    learning_rate = hp.Choice(
        "learning_rate",
        values=[2e-4, 5e-4, 1e-5, 2e-5, 5e-5]
    )

    # Batch size tuned externally during training
    hp.Choice("batch_size", values=[32, 64, 128])

    optimizer = keras.optimizers.Adam(
        learning_rate=learning_rate,
        decay=0
    )

    model.compile(
        optimizer=optimizer,
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )

    return model


In [None]:
# define batch sizes for hyperparameter tuning
batch_sizes = hp.Choice( 'batch_size' , values=[ 32 , 64, 128 ] )

In [None]:
## Create a HyperParameters object and build the model to test
## set up the tuner using RandomSearch
hp = keras_tuner.HyperParameters()
model_builder( keras_tuner.HyperParameters() )


tuner = keras_tuner.RandomSearch(
    hypermodel=model_builder,
    objective="val_accuracy",
    max_trials=100,
    executions_per_trial=2,
    overwrite=True,
    directory="tuning_model_multiclass",
    project_name="classify_all_types"
)


In [None]:
tuner.search_space_summary()

In [None]:
## Early stopping callback 
# TensorBoard callback for training reports
# Run the hyperparameter search and storing the results in tuner_results
stop_early = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=30)

tensorboard = keras.callbacks.TensorBoard('tuning_model_multiclass/tb_logs/')

tuner_results = tuner.search( x= trainImagesX , y=trainAttrX,  epochs=300 , validation_data=(testImagesX ,testAttrX) ,
                    callbacks= [stop_early,tensorboard ]   ) 


In [None]:
# display the tuner results summary
tuner.results_summary()

In [None]:
models = tuner.get_best_models(num_models=10)
best_model = models[0]
                    
# Save the best models]
for i in range(0,10):
    models[i].save('path/to/directory/best_model_' + str(i), save_format ='tf')

For testing with the best model

In [None]:
#testing with the simulated dataset
loss, acc = best_model.evaluate(testingImagesX, testingAttrX, verbose=2 )

In [None]:
## testing with the chime data
predictions = best_model.predict( chime_test ) 

predicted_classes = np.argmax(predictions , axis =1 )
test_labels = np.argmax( chime_test_labels, axis = 1 )

In [None]:
## classification report 
report= classification_report(test_labels, predicted_classes, target_names=['A' ,'B', 'C', 'C1', 'C2'] )
print(report)