# FRB Type-II vs Type-IV Classification 

This notebook explores CNN-based classification of simulated
type II and type IV morphology of the FRBs as described in the Kumar et. al. 2025

## Notes
- Training data is generated using simulation scripts in the folder frabjous_sim.
- This notebook is used  to obtain the hyperparameter optimised models.
- the type II here is refered to as type B 
- the tyoe IV here is refered to as type C1

In [None]:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Sequential
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from sklearn.metrics import precision_recall_curve, roc_curve, auc
from sklearn.metrics import  ConfusionMatrixDisplay

import argparse
import os
import sys
import locale
import time
import datetime
import json 
import glob

import keras
import keras.backend as K
import keras.layers as KL
import keras.models as KM
import keras.optimizers as KO
import keras.callbacks as KC
import keras.utils as KU
import keras.preprocessing.image as KI
from keras.layers import Dense, Dropout, Flatten, Activation, Concatenate,Input
from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D 
from keras.models import Model
import keras_tuner

In [None]:
type1 = 'B'
type2 = 'C1'

In [None]:
tf.keras.backend.experimental.enable_tf_random_generator()
tf.keras.utils.set_random_seed(1334534)

In [None]:
##This is base directory where you will keep the generated FRB samples

BASE_DIR = os.getenv("base_directory", "/DATA/ajay/")

print(f"Using BASE_DIR = {BASE_DIR}")

## Expected Data Directory Structure

<BASE_DIR>/
├── simfrb/
│ ├── simdata/
│ │ ├── type_B/
│ │ │ ├── SNR_15/
│ │ │ ├── SNR_25/
│ │ │ ├── SNR_35/
│ │ │ ├── SNR_50/
│ │ │ └── SNR_100/
│ │ └── type_C1/
│ │ ├── SNR_15/
│ │ ├── SNR_25/
│ │ ├── SNR_35/
│ │ ├── SNR_50/
│ │ └── SNR_100/
│ ├── remove_SNR_100/
│ ├── remove_type_C1_SNR_50.txt
│ ├── remove_type_C1_SNR_35.txt
│ ├── remove_type_C1_SNR_25.txt
│ └── type_C1_SNR_15.txt

Each 'SNR_xx/' directory contains:
- `simulatefrbs_type_<TYPE>.npz` : simulated FRB dynamic spectra
- `frb_header_type_<TYPE>.json` : corresponding metadata

In [None]:
### We identified several examples in type C1 which look like single component
### These examples are then filtered out from the training dataset 

SIMFRB_BASEDIR = os.path.join(BASE_DIR, "simfrb")

def load_indices_from_txt(filepath):
    """Load integer indices from a text file."""
    with open(filepath, "r") as f:
        return [int(line.strip()) for line in f if line.strip()]

# SNR = 100
remove_indices_100 = [
    int(os.path.basename(f).split("_")[-1].split(".")[0])
    for f in glob.glob(os.path.join(SIMFRB_BASEDIR, "remove_SNR_100", "*"))
]

# SNR = 50
remove_indices_50 = load_indices_from_txt(
    os.path.join(SIMFRB_BASEDIR, "remove_type_C1_SNR_50.txt")
)

# SNR = 35
remove_indices_35 = load_indices_from_txt(
    os.path.join(SIMFRB_BASEDIR, "remove_type_C1_SNR_35.txt")
)

# SNR = 25
remove_indices_25 = load_indices_from_txt(
    os.path.join(SIMFRB_BASEDIR, "remove_type_C1_SNR_25.txt")
)

# SNR = 15 
indices_15 = load_indices_from_txt(
    os.path.join(SIMFRB_BASEDIR, "type_C1_SNR_15.txt")
)
remove_indices_15 = list(set(np.arange(0, 1000)) - set(indices_15))

In [None]:
def read_data(type1,type2,snr,remove_indices):
    """
    Load simulated FRB images and labels for two morphology classes.

    Parameters
    ----------
    type1, type2 : str
        FRB morphology labels (e.g., 'B', 'C1').
    snr : str
        Signal-to-noise ratio label (e.g., '50').
    remove_indices : list of int
        Indices of simulated FRBs to exclude (applied to type2 only).

    Returns
    -------
    labels : list
        Binary class labels.
    images : list
        Normalized 2D FRB dynamic spectra.
    """

    base_path = os.path.join(BASE_DIR, "simfrb", "simdata", "type_")

    # ---- Load metadata ----
    with open(f"{base_path}{type1}/SNR_{snr}/frb_header_type_{type1}.json") as f:
        class1_labels = json.load(f)

    with open(f"{base_path}{type2}/SNR_{snr}/frb_header_type_{type2}.json") as f:
        class2_labels = json.load(f)

    # ---- Load image data ----
    frbd1 = np.load(
        f"{base_path}{type1}/SNR_{snr}/simulatefrbs_type_{type1}.npz"
    )["arr_0"]

    frbd2 = np.load(
        f"{base_path}{type2}/SNR_{snr}/simulatefrbs_type_{type2}.npz"
    )["arr_0"]
           
    frbdn = []
    frbdl = []
    frbdm = []
    frbdi = []
    for i in range(0,len(frbd2)):
            if i in remove_indices : 
                continue
            else :
                immax = frbd1[i].max() 
                frbdn.append(frbd1[i]/(immax/255))
                frbdl.append(0)
                frbdm.append(immax)
                frbdi.append(i)

                immax = frbd2[i].max()
                frbdn.append(frbd2[i]/(immax/255))
                frbdl.append(1)
                frbdm.append(immax)
                frbdi.append(i)

    return frbdl,frbdn 


In [None]:
    
frbdl,frbdn ,test_data_100, test_label_100 ,frb_params_100= read_data(type1,type2,'100',remove_indices_100)
frbdl = frbdl[:]
frbdn = frbdn[:]
temp1,temp2 ,test_data_50, test_label_50 ,frb_params_50= read_data(type1,type2,'50',remove_indices_50)
frbdl = frbdl + temp1[:]
frbdn = frbdn + temp2[:]
temp1,temp2 ,test_data_35, test_label_35 ,frb_params_35= read_data(type1,type2,'35',remove_indices_35)
frbdl = frbdl + temp1[:]
frbdn = frbdn + temp2[:]
temp1,temp2 ,test_data_25, test_label_25 , frb_params_25= read_data(type1,type2,'25',remove_indices_25)
frbdl = frbdl + temp1[:]
frbdn = frbdn + temp2[:]
temp1, temp2, test_data_15, test_label_15, frb_params_15 = read_data(type1,type2,'15',remove_indices_15)
frbdl = frbdl + temp1[:]
frbdn = frbdn + temp2[:]

frbdn = np.asarray(frbdn)     
frbdn.shape += 1,
frbdl = np.asarray(frbdl)

In [None]:

input_shape = (frbdn.shape[1] , frbdn.shape[2],1)
split = train_test_split(frbdl, frbdn, test_size=0.15, random_state=42)
(trainAttrX, testingAttrX, trainImagesX, testingImagesX) = split

In [None]:
split = train_test_split(trainAttrX, trainImagesX, test_size=0.2, random_state=42)
(trainAttrX, testAttrX, trainImagesX, testImagesX) = split

In [None]:
def model_builder(hp):
    """
    CNN based model builder for hyperparameter tuning
    """

    model = keras.Sequential()

    ######    CNN layers  #########
    model.add(Conv2D(32, (3, 3), activation="relu",
                     input_shape=(256, 256, 1)))
    model.add(MaxPooling2D(pool_size=(2, 2)))

    model.add(Conv2D(32, (3, 3), activation="relu"))
    model.add(MaxPooling2D(pool_size=(2, 2)))

    model.add(Conv2D(32, (5, 5), activation="relu"))
    model.add(MaxPooling2D(pool_size=(3, 3)))

    model.add(Flatten())

    ###### fully connected layers #######
    model.add(
        Dense(
            units=hp.Choice("units_1", values=[16, 32, 64]),
            activation="relu",
        )
    )

    model.add(
        Dropout(
            hp.Choice("dropout1", values=[0.15, 0.2, 0.25, 0.3])
        )
    )

    model.add(
        Dense(
            units=hp.Choice("units_2", values=[4, 8, 16, 32]),
            activation="relu",
        )
    )

    model.add(
        Dropout(
            hp.Choice("dropout2", values=[0.15, 0.2, 0.25, 0.3])
        )
    )

    
    model.add(Dense(1, activation="sigmoid"))

    ###### Optimizer hyperparameters #####
    hp_learning_rate = hp.Choice(
        "learning_rate",
        values=[2e-4, 5e-4, 1e-5, 2e-5, 5e-5],
    )

    # Batch size is tuned at the training stage, not inside the model
    hp.Choice("batch_size", values=[32, 64, 128])

    model.compile(
        optimizer=keras.optimizers.Adam(
            learning_rate=hp_learning_rate
        ),
        loss="binary_crossentropy",
        metrics=["accuracy"],
    )

    return model


In [None]:
hp = keras_tuner.HyperParameters()
model_builder(keras_tuner.HyperParameters())

In [None]:
batch_sizes = hp.Choice( 'batch_size' , values=[ 32 , 64, 128 ] )

In [None]:
tuner = keras_tuner.RandomSearch(
    hypermodel=model_builder,
    objective="val_loss",
    max_trials=100,
    executions_per_trial=1,
    overwrite=True,
    directory="/DATA/ajay/ML_training/tuning_model_B_C1",
    project_name="version3_equal_weights"
)

In [None]:
tuner.search_space_summary()

In [None]:
stop_early = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=30)

In [None]:
tensorboard  = keras.callbacks.TensorBoard('/DATA/ajay/ML_training/tuning_model_B_C1/version3_equal_weights/tb_logs')

In [None]:
tuner.search( x= trainImagesX, y=trainAttrX, epochs = 200, validation_data=(testImagesX, testAttrX)  , callbacks= [stop_early,tensorboard ] )

In [None]:
models = tuner.get_best_models(num_models=10)
best_model = models[0]
# Build the model.
best_model.build(input_shape=(None, 28, 28)) 
best_model.summary()

In [None]:
tuner.results_summary()

In [None]:
models = tuner.get_best_models(num_models=10)
#print(type(models[0]))
#models[0].save('/DATA/ajay/ML_training/tuning_model/best_model_' + str(0) )
best_model = models[0]
for i in range(0,10):
    models[i].save('/path/to/directory/best_model_' + str(i), save_format ='tf')

In [None]:
best_hps = tuner.get_best_hyperparameters(5)
##
model = model_builder(best_hps[0])
model.summary()

Testing with the best model

In [None]:
split = train_test_split(frbdl, frbdn, test_size=0.15, random_state=42)
(trainAttrX, testingAttrX, trainImagesX, testingImagesX) = split

In [None]:
predictions = best_model.predict(testingImagesX)
predictions_bool = (predictions > 0.5)
cm_1 = confusion_matrix( testingAttrX , predictions_bool, labels=[0 , 1])
cm_1 

In [None]:
loss, acc = best_model.evaluate(testingImagesX, testingAttrX, verbose=2 )
f1score = f1_score(testingAttrX, predictions_bool)
print(round(acc ,4), round(f1score,4 ))

In [None]:
fpr, tpr, thresholds = roc_curve(testingAttrX, predictions)
####Plotting the FPR-FNR curve for the best model
plt.rc('font', size=10)
plt.plot(np.flip(thresholds[1:]), np.flip(fpr[1:])*100,label='False Positive Rate')
plt.plot(np.flip(thresholds[1:]), np.flip(1-tpr[1:])*100,label='False Negative Rate')
plt.xlabel('Threshold',fontsize = 20)
plt.ylabel('Cumulative percentage',fontsize = 10)
plt.title('FPR FNR vs threshold ',fontsize = 10)
plt.legend()
plt.yscale('log')
#plt.setp(ax4.get_xticklabels(), fontsize=15)
#plt.setp(ax4.get_yticklabels(), fontsize=15)    
plt.grid(True) 