# FRB Type-IV vs Type-IV Classification 

This notebook explores CNN-based classification of simulated
type IV and type V morphology of the FRBs as described in the Kumar et. al. 2025

## Notes
- Training data is generated using simulation scripts in the folder frabjous_sim.
- This notebook is used  to obtain the hyperparameter optimised models.
- the type IV here is refered to as type C1 
- the tyoe V here is refered to as type C2

In [None]:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Sequential
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from sklearn.metrics import precision_recall_curve, roc_curve, auc
from sklearn.metrics import  ConfusionMatrixDisplay

import argparse
import os
import sys
import locale
import time
import datetime
import json 
import glob
from pathlib import Path

import keras
import keras.backend as K
import keras.layers as KL
import keras.models as KM
import keras.optimizers as KO
import keras.callbacks as KC
import keras.utils as KU
import keras.preprocessing.image as KI
from keras.layers import Dense, Dropout, Flatten, Activation, Concatenate,Input
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D 
from keras.models import Model
import keras_tuner

In [None]:
def read_indices(fname):
    indices = []
    with open(SIMFRB_DIR / fname) as f:
        for line in f:
            line = line.strip()
            if line:
                indices.append(int(line))
    return indices

In [None]:
def read_data(type1, type2, snr, remove_indices_C1, remove_indices_C2):
    """
    Load Simulated FRB 2-D dynamic spectra and labels for two archetypes.

    Parameters
    ----------
    type1, type2 : str
        FRB morphology labels (e.g., 'B', 'C1').
    snr : str
        Signal-to-noise ratio label (e.g., '50').
    remove_indices_C1 : list of int
        Indices of simulated FRBs to exclude for type1.
    remove_indices_C2 : list of int
        Indices of simulated FRBs to exclude for type2.

    Returns
    -------
    labels : list
        Binary class labels.
    images : list
        Normalized 2D FRB dynamic spectra.
    """

    base_path = os.path.join(BASE_DIR, "simfrb", "simdata", "type_")

    # ---- Load metadata ----
    with open(f"{base_path}{type1}/SNR_{snr}/frb_header_type_{type1}.json") as f:
        class1_labels = json.load(f)

    with open(f"{base_path}{type2}/SNR_{snr}/frb_header_type_{type2}.json") as f:
        class2_labels = json.load(f)

    # ---- Load image data ----
    frbd1 = np.load(
        f"{base_path}{type1}/SNR_{snr}/simulatefrbs_type_{type1}.npz"
    )["arr_0"]

    frbd2 = np.load(
        f"{base_path}{type2}/SNR_{snr}/simulatefrbs_type_{type2}.npz"
    )["arr_0"]

    frbdn = []
    frbdl = []
    frbdm = []
    frbdi = []

    for i in range(len(frbd1)):
        if i in remove_indices_C1:
            continue

        immax = frbd1[i].max()
        frbdn.append(frbd1[i] / (immax / 255))
        frbdl.append(0)
        frbdm.append(immax)
        frbdi.append(i)

    for i in range(len(frbd2)):
        if i in remove_indices_C2:
            continue

        immax = frbd2[i].max()
        frbdn.append(frbd2[i] / (immax / 255))
        frbdl.append(1)
        frbdm.append(immax)
        frbdi.append(i)

    return frbdl, frbdn

In [None]:
type1 = 'C1'
type2 = 'C2'

## Expected Data Directory Structure


<pre>
&lt;BASE_DIR&gt;/
├── simfrb/
│   ├── simdata/
│   │   ├── type_C1/
│   │   │   ├── SNR_15/
│   │   │   ├── SNR_25/
│   │   │   ├── SNR_35/
│   │   │   ├── SNR_50/
│   │   │   └── SNR_100/
│   │   └── type_C2/
│   │       ├── SNR_15/
│   │       ├── SNR_25/
│   │       ├── SNR_35/
│   │       ├── SNR_50/
│   │       └── SNR_100/
</pre>

### Notes
- `type C1` corresponds to **Type IV FRBs**
- `type C2` corresponds to **Type V FRBs**


In [None]:
tf.keras.backend.experimental.enable_tf_random_generator()
tf.keras.utils.set_random_seed(1334534)

In [None]:
BASE_DIR='/media/akumar/Data/'

In [None]:
SIMFRB_DIR = Path(BASE_DIR) / "simfrb"

#### loading the indices for type C1 samples to exlude from training data #######
remove_indices_C1_100 = read_indices("remove_type_C1_SNR_100.txt")
remove_indices_C1_50  = read_indices("remove_type_C1_SNR_50.txt")
remove_indices_C1_35  = read_indices("remove_type_C1_SNR_35.txt")
remove_indices_C1_25  = read_indices("remove_type_C1_SNR_25.txt")

indices_C1_15 = read_indices("type_C1_SNR_15.txt")
remove_indices_C1_15 = [x for x in np.arange(0, 1000) if x not in indices_C1_15]


#### loading the indices for type C2 samples to exlude from training data #######
remove_indices_C2_100 = read_indices("type_C2_SNR_100.txt")
remove_indices_C2_50  = read_indices("type_C2_SNR_50.txt")
remove_indices_C2_35  = read_indices("type_C2_SNR_35.txt")
remove_indices_C2_25  = read_indices("type_C2_SNR_25.txt")

indices_C2_15 = read_indices("type_C2_SNR_15.txt")
remove_indices_C2_15 = [x for x in np.arange(0, 1000) if x not in indices_C2_15]


In [None]:
frbdl, frbdn = read_data(type1, type2, '100', remove_indices_C1_100, remove_indices_C2_100)

temp_labels, temp_data = read_data(type1, type2, '50', remove_indices_C1_50, remove_indices_C2_50)
frbdl = frbdl + temp_labels
frbdn = frbdn + temp_data

temp_labels, temp_data = read_data(type1, type2, '35', remove_indices_C1_35, remove_indices_C2_35)
frbdl = frbdl + temp_labels
frbdn = frbdn + temp_data

temp_labels, temp_data = read_data(type1, type2, '25', remove_indices_C1_25, remove_indices_C2_25)
frbdl = frbdl + temp_labels
frbdn = frbdn + temp_data

temp_labels, temp_data = read_data(type1, type2, '15', remove_indices_C1_15, remove_indices_C2_15)
frbdl = frbdl + temp_labels
frbdn = frbdn + temp_data

frbdn = np.asarray(frbdn)
frbdn.shape += (1,)
frbdl = np.asarray(frbdl)

In [None]:
input_shape = (frbdn.shape[1] , frbdn.shape[2],1)
split = train_test_split(frbdl, frbdn, test_size=0.15, random_state=42)
(trainAttrX, testingAttrX, trainImagesX, testingImagesX) = split

In [None]:
split = train_test_split(trainAttrX, trainImagesX, test_size=0.2, random_state=42)
(trainAttrX, testAttrX, trainImagesX, testImagesX) = split

In [None]:
len(trainAttrX)

In [None]:
len(testingAttrX)

In [None]:
def model_builder(hp):
    """
    Build a CNN model based in the input hyperparameters

    Parameters
    ----------
    hp : keras_tuner.HyperParameters
        Hyperparameter object for tuning model configuration.

    Returns
    -------
    model : keras.Model
        Compiled Keras model.
    """

    model = keras.Sequential()

    ######    CNN layers  #########
    model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(256, 256, 1)))
    model.add(MaxPooling2D(pool_size=(2, 2)))

    model.add(Conv2D(32, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))

    model.add(Conv2D(32, (5, 5), activation='relu'))
    model.add(MaxPooling2D(pool_size=(3, 3)))

    model.add(Flatten())

    ###### fully connected layers #######
    model.add(
        Dense(
            units=hp.Choice("units_1", values=[16, 32, 64]),
            activation="relu",
        )
    )
    model.add(
        Dropout(
            hp.Choice("dropout1", values=[0.15, 0.2, 0.25, 0.3])
        )
    )

    model.add(
        Dense(
            units=hp.Choice("units_2", values=[4, 8, 16, 32]),
            activation="relu",
        )
    )
    model.add(
        Dropout(
            hp.Choice("dropout2", values=[0.15, 0.2, 0.25, 0.3])
        )
    )

    # ---- Output layer ----
    model.add(Dense(1, activation='sigmoid'))

    ###### Optimizer hyperparameters #####
    learning_rate = hp.Choice(
        "learning_rate", values=[2e-4, 5e-4, 1e-5, 2e-5, 5e-5]
    )

    # Batch size is tuned externally during training
    hp.Choice("batch_size", values=[32, 64, 128])

    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
        loss='binary_crossentropy',
        metrics=['accuracy']
    )

    return model

In [None]:
hp = keras_tuner.HyperParameters()
model_builder(keras_tuner.HyperParameters())

In [None]:
batch_sizes = hp.Choice( 'batch_size' , values=[ 32 , 64, 128 ] )

In [None]:
tuner = keras_tuner.RandomSearch(
    hypermodel=model_builder,
    objective="val_loss",
    max_trials=100,
    executions_per_trial=2,
    overwrite=True,
    directory=os.path.join(BASE_DIR, "model_tuning_c1_c2")
    project_name="classify_c1_c2"
)

In [None]:
tuner.search_space_summary()

In [None]:
stop_early = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=30)

In [None]:
tensorboard = keras.callbacks.TensorBoard(
    log_dir=os.path.join(BASE_DIR, "model_tuning_c1_c2", "tensorboard_logs")
)

In [None]:
tuner.search( x= trainImagesX, y=trainAttrX, epochs = 300, validation_data=(testImagesX, testAttrX)  , callbacks= [stop_early,tensorboard ] )

In [None]:
models = tuner.get_best_models(num_models=10)
best_model = models[0]

# Explicitly build the model to enable summary display
best_model.build(input_shape=(None, 256, 256, 1))
best_model.summary()

In [None]:
tuner.results_summary()

In [None]:
for i in range(0,10):
    models[i].save('/path/to/directory/best_model_' + str(i), save_format ='tf')  

In [None]:
predictions = best_model.predict(testingImagesX)
predictions_bool = (predictions > 0.5)
cm_1 = confusion_matrix( testingAttrX , predictions_bool, labels=[0 , 1])
cm_1 

In [None]:
loss, acc = best_model.evaluate(testingImagesX, testingAttrX, verbose=2 )
f1score = f1_score(testingAttrX, predictions_bool)
print(round(acc ,4), round(f1score,4 ))

In [None]:
fpr, tpr, thresholds = roc_curve(testingAttrX, predictions)
####Plotting the FPR-FNR curve for the best model
plt.rc('font', size=10)
plt.plot(np.flip(thresholds[1:]), np.flip(fpr[1:])*100,label='False Positive Rate')
plt.plot(np.flip(thresholds[1:]), np.flip(1-tpr[1:])*100,label='False Negative Rate')
plt.xlabel('Threshold',fontsize = 20)
plt.ylabel('Cumulative percentage',fontsize = 10)
plt.title('FPR FNR vs threshold ',fontsize = 10)
plt.legend()
plt.yscale('log')
plt.grid(True) 