# Light Curve Parameter Inference Using LFI -- Tutorial

# Imports

## Python Packages

In [None]:
# general modules
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import matplotlib.lines as mlines
import os, sys, time, glob
import json
import copy
import scipy
import warnings
from tqdm import tqdm

In [None]:
# pytorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

In [None]:
# bilby
import bilby
from bilby.core.prior import Uniform, DeltaFunction
from bilby.core.likelihood import GaussianLikelihood

In [None]:
# nflows
from nflows.nn.nets.resnet import ResidualNet
from nflows import transforms, distributions, flows
from nflows.distributions import StandardNormal
from nflows.flows import Flow
from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform
from nflows.transforms import CompositeTransform, RandomPermutation
import nflows.utils as torchutils

In [None]:
# extras
from IPython.display import clear_output
from time import time
from time import sleep
import corner
import torchvision
import torchvision.transforms as transforms
from os.path import exists

## Personal Functions

In [None]:
# importing dataloading functions

# these functions are used to open the data and assign data id's
from model.data_processing import open_json, get_names, json_to_df, add_batch_sim_nums_all, get_test_names

# these functions are used to ensure the data are the same length (121 points)
from model.data_processing import pad_the_data, pad_all_dfs

# these functions are used to read in data when it is in csv format
from model.data_processing import load_in_data, match_fix_to_var, matched

# these functions are used to convert csv files to tensors and create a dataset
from model.data_processing import repeated_df_to_tensor, Paper_data

In [None]:
# importing similarity embedding functions

from model.embedding import VICRegLoss, ConvResidualBlock, ConvResidualNet, SimilarityEmbedding, train_one_epoch_se, val_one_epoch_se

# importing resnet from ML4GW pe

from model.resnet import ResNet

In [None]:
# importing normalizing flow functions

from model.normalizingflows import Flow_data, EmbeddingNet, normflow_params, train_one_epoch, val_one_epoch

In [None]:
# importing inference functions

from model.inference import cast_as_bilby_result, live_plot_samples, ppplot, comparison_plot

## Setting Device

In [None]:
# checking gpu status, ensures tensors are stored on the same device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

# Data Specific Parameters and Priors

Light curve generation is done through the nmma package on github: https://github.com/nuclear-multimessenger-astronomy/nmma/tree/main. To properly encode the effects of changing $t$ and $d_L$, each combination of our physical parameters $\log_{10}(M_{ej})$, $\log_{10}(V_{ej})$, and $\log_{10}(X_{lan})$ are repeated 50 times to produce light curves with unique noise instances. A second set of 50 light curves is created by adjusting $t$ and $d_L$ according to the priors shown below. Our training dataset contains 8,729 unique combinations of $\log_{10}(M_{ej})$, $\log_{10}(V_{ej})$, and $\log_{10}(X_{lan})$, resulting in $8,729 \times 50 \times 2 = 872,900$ total light curves.

In [None]:
# json specific parameters, adjust this cell with commands when generating light curves -- MANDATORY

bands = ['ztfg', 'ztfr', 'ztfi']
detection_limit = 22.0
num_repeats = 50
num_channels = 3
num_points = 23
in_features = num_points
data_dir = '/nobackup/users/mmdesai/new_csv/'

In [None]:
# time shift

t_zero = 44242.00021937881
t_min = 44240.0012975024
t_max = 44269.99958898723
days = int(round(t_max - t_min))
time_step = 0.25

In [None]:
# priors

priors = dict()
priors['log10_mej'] = Uniform(-1.9, -1, name='log10_mej', latex_label='$\log(M_{{ej}})$')
priors['log10_vej'] = Uniform(-1.52, -0.53, name='log10_vej', latex_label='$\log(V_{{ej}})$')
priors['log10_Xlan'] = Uniform(-9, -3, name='log10_Xlan', latex_label='$\log(X_{{lan}})$')
priors['timeshift'] = Uniform(-2, 6, name='timeshift', latex_label='$\Delta\;t$')
priors['distance'] = Uniform(50, 200, name='luminosity distance', latex_label='$D$')

# Data Processing

This section processes data generated through NMMA. Use this section as a guide for converting the .json files into dataframes. The dataframes are subsequently stored as .csv files. If using the data from Zenodo, skip ahead to the Tensor Processing section. 

### Varied Data

In [None]:
dir_path       = '/home/oppenheimer/summer2025/Kilo/data/lc_dir/varied/'
detection_limit = 22.0
bands           = ['ztfg', 'ztfr', 'ztfi']

In [None]:
# your existing helper to open a single .json
def open_json(file_name, dir_path):
    with open(os.path.join(dir_path, file_name)) as f:
        return json.load(f)

# your existing json → DataFrame function
def json_to_df(file_name, dir_path, detection_limit, bands):
    data = open_json(file_name, dir_path)
    df = pd.DataFrame.from_dict(data, orient="columns")
    df_unpacked = pd.DataFrame(columns=['t'] + bands)
    counter = 0
    for j, band in enumerate(bands):
        # unpack each band’s [t, value, x]
        df_unpacked[['t', band, 'x']] = pd.DataFrame(
            df[band].tolist(), index=df.index
        )
        counter += (df_unpacked[band] != detection_limit).sum()
    df_unpacked['num_detections'] = counter
    return df_unpacked.drop(columns=['x'])

# parameters
dir_path       = '/home/oppenheimer/summer2025/Kilo/data/varied/'
detection_limit = 22.0
bands           = ['ztfg', 'ztfr', 'ztfi']
# find all your test_*.json files
file_pattern = os.path.join(dir_path, 'test_varied_*.json')
all_files    = sorted(glob.glob(file_pattern))  # gives absolute paths

# process them all into a list of DataFrames
df_list = [
    json_to_df(os.path.basename(fp), dir_path, detection_limit, bands)
    for fp in all_files
]

print(len(df_list), "files found.")
# if you want one big DataFrame
df_all = pd.concat(df_list, ignore_index=True)

# now df_all contains the flattened photometry from all 25 000 files
print(df_all.shape)  


In [None]:
print(df_list[3])

#minimum tim across all df_list entries
min_time = min([df['t'].min() for df in df_list])
#maximum time across all df_list entries
max_time = max([df['t'].max() for df in df_list])

print(f"Minimum time: {min_time}, Maximum time: {max_time}")

In [None]:
dir_path       = '/home/oppenheimer/summer2025/Kilo/data/varied/'
detection_limit = 22.0
bands           = ['ztfg', 'ztfr', 'ztfi']

In [None]:
import os
import glob
import numpy as np
import pandas as pd
import gc

# — your helper functions must already be imported:
#    json_to_df(file_name, dir_path, detection_limit, bands)
#    pad_all_dfs(df_list, t_min, t_max, step, data_filler, bands)

# 0) USER PARAMETERS
# dir_path        = "/path/to/your/jsons"
detection_limit = 22.0
bands           = ['ztfg', 'ztfr', 'ztfi']
chunk_size      = 5000        # how many sims to process per batch
batch_size      = 50          # sims per batch for batch_id
output_csv      = "all_lightcurves.csv"

# 1) GATHER ALL JSON PATHS
file_list = sorted(glob.glob(os.path.join(dir_path, "test_*.json")))

# 2) DETERMINE GLOBAL TIME GRID
raw_min, raw_max = np.inf, -np.inf
for fp in file_list:
    df = json_to_df(os.path.basename(fp), dir_path, detection_limit, bands)
    raw_min = min(raw_min, df['t'].min())
    raw_max = max(raw_max, df['t'].max())

step  = 1.0
t_min = np.floor(raw_min)
t_max = np.ceil(raw_max) + step

# 3) REMOVE OLD CSV
if os.path.exists(output_csv):
    os.remove(output_csv)

# 4) PROCESS IN CHUNKS
for start in range(0, len(file_list), chunk_size):
    chunk_files = file_list[start : start + chunk_size]
    df_list     = []

    # 4a) LOAD & ANNOTATE
    for sim_idx, fp in enumerate(chunk_files, start=start):
        df = json_to_df(os.path.basename(fp), dir_path, detection_limit, bands)
        df['sim_id']         = sim_idx
        df['num_detections'] = (df[bands] < detection_limit).sum().sum()
        df_list.append(df)

    # 4b) PAD TO UNIFORM LENGTH
    padded = pad_all_dfs(df_list, t_min, t_max, step, detection_limit, bands)

    # 4c) ASSIGN batch_id
    for idx, df in enumerate(padded, start=start):
        df['batch_id'] = idx // batch_size

    # 4d) CONCAT & APPEND TO CSV
    chunk_df = pd.concat(padded, ignore_index=True)
    chunk_df.to_csv(
        output_csv,
        mode='a',
        index=False,
        header=not os.path.exists(output_csv)
    )

    # 4e) CLEAN UP
    del df_list, padded, chunk_df
    gc.collect()

# 5) (optional) READ BACK FULL DATAFRAME
df_all = pd.read_csv(output_csv)
print("Final table shape:", df_all.shape)
print("Time spans:", df_all['t'].min(), "→", df_all['t'].max())
print("Unique batch_ids:", sorted(df_all['batch_id'].unique()))


In [None]:
# print(df_list[3])
# print(padded_list[3])

#print df_all with a particular sim_id
sim_id = 10  # change this to the sim_id you want to inspect
df_sim = df_all[df_all['sim_id'] == sim_id]
print(f"Data for sim_id {sim_id}:\n", df_sim)

In [None]:
varied_simembed_dict = {}

In [None]:
# get the varied data

sim_path = '/nobackup/users/mmdesai/lowcsimdata'
num_sims = 25000

simembed_num_lc_list = [24750, 25000, 25000, 25000, 25000, 25000, 24850, 25000, 25000, 25000]

for i in range (0, 10):
    # get the names of each file
    file_names = get_names(sim_path, 'varied', i, simembed_num_lc_list[i])
    # open the files as dataframes
    varied_simembed_dict['varied_simembed_data_{}'.format(i)] = json_to_df(file_names, simembed_num_lc_list[i])
    # pad the data
    varied_simembed_dict['varied_simembed_data_{}'.format(i)] = pad_all_dfs(varied_simembed_dict['varied_simembed_data_{}'.format(i)])

In [None]:
varied_simembed_dict['varied_simembed_data_0'][-1]

In [None]:
# plot a small sample of the varied light curves

for i in range(0, 300, 50):
    plt.scatter(varied_simembed_dict['varied_simembed_data_0'][i]['t'], varied_simembed_dict['varied_simembed_data_0'][i]['ztfg'], color = 'g')
    plt.scatter(varied_simembed_dict['varied_simembed_data_0'][i]['t'], varied_simembed_dict['varied_simembed_data_0'][i]['ztfr'], color = 'r')
    plt.scatter(varied_simembed_dict['varied_simembed_data_0'][i]['t'], varied_simembed_dict['varied_simembed_data_0'][i]['ztfi'], color = 'c')
plt.gca().invert_yaxis()
plt.xlabel('Time (days)')
plt.ylabel('Magnitude')

In [None]:
inj_path_simembed = '/home/oppenheimer/summer2025/Kilo/data/varied/'

varied_inj_df = pd.DataFrame()
varied_params = open_json('injection_varied.json', inj_path_simembed)
varied_inj_df['mej'] = varied_params['injections']['content']['log10_mej']
varied_inj_df['vej'] = varied_params['injections']['content']['log10_vej']
varied_inj_df['xlan'] = varied_params['injections']['content']['log10_Xlan']
varied_inj_df['shift'] = varied_params['injections']['content']['timeshift']
varied_inj_df['distance'] = varied_params['injections']['content']['luminosity_distance']
varied_inj_df['sim_id'] = varied_params['injections']['content']['simulation_id']
varied_injections = varied_inj_df

print(varied_injections)

In [None]:
# if your varied_injections DataFrame calls its sim key something else, rename it:
# varied_injections = varied_injections.rename(columns={ 'simulation_id': 'sim_id' })

# select only the columns you care about from injections
inj_cols = ['sim_id', 'mej','vej','xlan','shift','distance']

# merge onto your full light‐curve table
df_final = df_all.merge(
    varied_injections[inj_cols],
    on='sim_id',
    how='left'
)

print("Final shape:", df_final.shape)  # should be (~3 300 000, original_cols+5)


In [None]:
# print a particular sim_id
sim_id = 10  # change this to the sim_id you want to inspect
df_sim = df_final[df_final['sim_id'] == sim_id]
print(f"Data for sim_id {sim_id}:\n", df_sim)

In [None]:
merged_list = []
for i, df in enumerate(padded_list):
    inj = varied_injections.iloc[i]  # a Series with index ['mej','vej',…,'sim_id']
    # drop the sim_id from inj if you don’t want to re–assign it
    for col in ['mej','vej','xlan','shift','distance']:
        df[col] = inj[col]
    merged_list.append(df)

# finally, glue them all together
df_final = pd.concat(merged_list, ignore_index=True)
print(df_final)

In [None]:
# print df_final with a particular sim_id
print(df_final[df_final['sim_id'] == 24999])

In [None]:
# injection files for the additional data

inj_path_simembed = '/nobackup/users/mmdesai/final_injections'
varied_injections = {}

for i in range(0, 10):
    varied_inj_df = pd.DataFrame()
    varied_params = open_json('/injection_simembed_varied_{}.json'.format(i), inj_path_simembed)
    varied_inj_df['mej'] = varied_params['injections']['content']['log10_mej']
    varied_inj_df['vej'] = varied_params['injections']['content']['log10_vej']
    varied_inj_df['xlan'] = varied_params['injections']['content']['log10_Xlan']
    varied_inj_df['shift'] = varied_params['injections']['content']['timeshift']
    varied_inj_df['distance'] = varied_params['injections']['content']['luminosity_distance']
    varied_inj_df['sim_id'] = varied_params['injections']['content']['simulation_id']
    varied_injections['varied_inj_df{}'.format(i)] = varied_inj_df

In [None]:
varied_injections['varied_inj_df0']

In [None]:
# concatenate dataframe lists

all_varied_data_list = [0] * 10

for i in range(0, 10):
    all_varied_data_list[i] = pd.concat(varied_simembed_dict['varied_simembed_data_{}'.format(i)])

In [None]:
# merge with injection parameters

all_varied_datawparams_list = [0] * 10

for i in range(0, 10):
    all_varied_datawparams_list[i] = all_varied_data_list[i].merge(varied_injections['varied_inj_df{}'.format(i)], on = 'sim_id')
    # save as csv file
    all_varied_datawparams_list[i].to_csv('/nobackup/users/mmdesai/final_csv/varied_lowc_{}.csv'.format(i), index = False)

In [None]:
all_varied_datawparams_list[0]

### Fixed Data

In [None]:
dir_path= '/home/oppenheimer/summer2025/Kilo/data/fixed/'
detection_limit = 22.0
bands= ['ztfg', 'ztfr', 'ztfi']

In [None]:
import numpy as np
import os, glob

# --- 1) load your dataframes as before ---
file_list = sorted(glob.glob(os.path.join(dir_path, 'test_*.json')))

df_fixed_list = []
for sim_idx, fp in enumerate(file_list):
    # assume you have a single-file json_to_df → DataFrame
    df = json_to_df(os.path.basename(fp), dir_path, detection_limit, bands)

    # add sim_id
    df['sim_id'] = sim_idx

    # count detections across all bands
    # (this matches your old logic: count all values != detection_limit)
    detections = (df[bands] < detection_limit).sum().sum()
    df['num_detections'] = detections

    df_fixed_list.append(df)

# raw_min = min(df['t'].min() for df in df_list)
# raw_max = max(df['t'].max() for df in df_list)

step  = 1.0
t_min = np.floor(raw_min)            # round down
t_max = np.ceil(raw_max) + step      # round up, then add one step

padded_fixed_list = pad_all_dfs(
    df_fixed_list,
    t_min,
    t_max,
    step,
    data_filler=np.nan,
    bands=bands
)
batch_size = 50  # adjust as needed

for sim_idx, df in enumerate(padded_fixed_list):
    # integer division gives you 0 for sims 0–49, 1 for sims 50–99, etc.
    df['batch_id'] = sim_idx // batch_size

df_fixed_all = pd.concat(padded_fixed_list, ignore_index=True)
print("time runs from", df_all['t'].min(), "to", df_all['t'].max())
print("should equal", t_min, "→", t_max-step)
print("total rows:", len(df_fixed_all))


In [None]:
import os
import glob
import numpy as np
import pandas as pd
import gc

# — your helper functions must already be imported:
#    json_to_df(file_name, dir_path, detection_limit, bands)
#    pad_all_dfs(df_list, t_min, t_max, step, data_filler, bands)

# 0) USER PARAMETERS
# dir_path        = "/path/to/your/jsons"
detection_limit = 22.0
bands           = ['ztfg', 'ztfr', 'ztfi']
chunk_size      = 5000        # how many sims to process per batch
batch_size      = 50          # sims per batch for batch_id
output_csv      = "all_lightcurves.csv"

# 1) GATHER ALL JSON PATHS
file_list = sorted(glob.glob(os.path.join(dir_path, "test_*.json")))

# 2) DETERMINE GLOBAL TIME GRID
raw_min, raw_max = np.inf, -np.inf
for fp in file_list:
    df = json_to_df(os.path.basename(fp), dir_path, detection_limit, bands)
    raw_min = min(raw_min, df['t'].min())
    raw_max = max(raw_max, df['t'].max())

step  = 1.0
# t_min = np.floor(raw_min)
# t_max = np.ceil(raw_max) + step

# 3) REMOVE OLD CSV
if os.path.exists(output_csv):
    os.remove(output_csv)

# 4) PROCESS IN CHUNKS
for start in range(0, len(file_list), chunk_size):
    chunk_files = file_list[start : start + chunk_size]
    df_list     = []

    # 4a) LOAD & ANNOTATE
    for sim_idx, fp in enumerate(chunk_files, start=start):
        df = json_to_df(os.path.basename(fp), dir_path, detection_limit, bands)
        df['sim_id']         = sim_idx
        df['num_detections'] = (df[bands] < detection_limit).sum().sum()
        df_list.append(df)

    # 4b) PAD TO UNIFORM LENGTH
    padded = pad_all_dfs(df_list, t_min, t_max, step, detection_limit, bands)

    # 4c) ASSIGN batch_id
    for idx, df in enumerate(padded, start=start):
        df['batch_id'] = idx // batch_size

    # 4d) CONCAT & APPEND TO CSV
    chunk_df = pd.concat(padded, ignore_index=True)
    chunk_df.to_csv(
        output_csv,
        mode='a',
        index=False,
        header=not os.path.exists(output_csv)
    )

    # 4e) CLEAN UP
    del df_list, padded, chunk_df
    gc.collect()

# 5) (optional) READ BACK FULL DATAFRAME
df_fixed_all = pd.read_csv(output_csv)
print("Final table shape:", df_fixed_all.shape)
print("Time spans:", df_fixed_all['t'].min(), "→", df_fixed_all['t'].max())
print("Unique batch_ids:", sorted(df_fixed_all['batch_id'].unique()))


In [None]:
# print(padded_fixed_list[5])
# print(df_fixed_list[5])

# print df_fixed_all with a particular sim_id
sim_id = 10  # change this to the sim_id you want to inspect
df_sim = df_fixed_all[df_fixed_all['sim_id'] == sim_id]
print(f"Data for sim_id {sim_id}:\n", df_sim)

In [None]:
inj_path_simembed = '/home/oppenheimer/summer2025/Kilo/data/fixed/'

fixed_inj_df = pd.DataFrame()
fixed_params = open_json('injection_fixed.json', inj_path_simembed)
fixed_inj_df['mej'] = fixed_params['injections']['content']['log10_mej']
fixed_inj_df['vej'] = fixed_params['injections']['content']['log10_vej']
fixed_inj_df['xlan'] = fixed_params['injections']['content']['log10_Xlan']
fixed_inj_df['shift'] = fixed_params['injections']['content']['timeshift']
fixed_inj_df['distance'] = fixed_params['injections']['content']['luminosity_distance']
fixed_inj_df['sim_id'] = fixed_params['injections']['content']['simulation_id']
fixed_injections = fixed_inj_df

print(fixed_injections)

In [None]:
# if your varied_injections DataFrame calls its sim key something else, rename it:
# varied_injections = varied_injections.rename(columns={ 'simulation_id': 'sim_id' })

# select only the columns you care about from injections
inj_cols = ['sim_id', 'mej','vej','xlan','shift','distance']

# merge onto your full light‐curve table
df_fixed_final = df_fixed_all.merge(
    fixed_injections[inj_cols],
    on='sim_id',
    how='left'
)

print("Final shape:", df_fixed_final.shape)  # should be (~3 300 000, original_cols+5)


In [None]:
merged_fixed_list = []
for i, df in enumerate(padded_fixed_list):
    inj = fixed_injections.iloc[i]  # a Series with index ['mej','vej',…,'sim_id']
    # drop the sim_id from inj if you don’t want to re–assign it
    for col in ['mej','vej','xlan','shift','distance']:
        df[col] = inj[col]
    merged_fixed_list.append(df)

# finally, glue them all together
df_fixed_final = pd.concat(merged_fixed_list, ignore_index=True)
print(df_fixed_final)

In [None]:
#print df_fixed_final with a particular sim_id
print(df_fixed_final[df_fixed_final['sim_id'] == 0])

In [None]:
fixed_simembed_dict = {}

In [None]:
# get the fixed data

sim_path = '/nobackup/users/mmdesai/lowcsimdata'
num_sims = 25000

simembed_num_lc_list = [24900, 25000, 25000, 25000, 25000, 25000, 25000, 24800, 25000, 25000]

for i in range(0, 10):
    # get the names of each file
    file_names = get_names(sim_path, 'fixed', i, simembed_num_lc_list[i])
    # open the files as dataframes
    fixed_simembed_dict['fixed_simembed_data_{}'.format(i)] = json_to_df(file_names, simembed_num_lc_list[i])
    # pad the data
    fixed_simembed_dict['fixed_simembed_data_{}'.format(i)] = pad_all_dfs(fixed_simembed_dict['fixed_simembed_data_{}'.format(i)])

In [None]:
fixed_simembed_dict['fixed_simembed_data_0'][-1]

In [None]:
for i in range(0, 300, 50):
    plt.scatter(fixed_simembed_dict['fixed_simembed_data_0'][i]['t'], fixed_simembed_dict['fixed_simembed_data_0'][i]['ztfg'], color = 'g')
    plt.scatter(fixed_simembed_dict['fixed_simembed_data_0'][i]['t'], fixed_simembed_dict['fixed_simembed_data_0'][i]['ztfr'], color = 'r')
    plt.scatter(fixed_simembed_dict['fixed_simembed_data_0'][i]['t'], fixed_simembed_dict['fixed_simembed_data_0'][i]['ztfi'], color = 'c')
plt.gca().invert_yaxis()
plt.xlabel('Time (days)')
plt.ylabel('Magnitude')

In [None]:
#plot lightcurves from padded_fixed_list
import matplotlib.pyplot as plt

first_n = 3
plt.figure(figsize=(10, 6))

for i in range(first_n):
    df = padded_fixed_list[i]
    bands = ['ztfg']
    for band in bands:
        plt.scatter(df['t'], df[band], label=f"{band} (sim {i})")

plt.gca().invert_yaxis()
plt.xlabel("Time")
plt.ylabel("Flux")
plt.title(f"First {first_n} Light Curves Overlapped in All Three Bands")
plt.legend(loc="upper right", ncol=first_n)  # adjust layout if it’s crowded
plt.tight_layout()
plt.show()

In [None]:
#plot lightcurves from padded_fixed_list
import matplotlib.pyplot as plt

first_n = 3
plt.figure(figsize=(10, 6))

for i in range(first_n):
    df = padded_fixed_list[i]
    bands = ['ztfr']
    for band in bands:
        plt.scatter(df['t'], df[band], label=f"{band} (sim {i})")

plt.gca().invert_yaxis()
plt.xlabel("Time")
plt.ylabel("Flux")
plt.title(f"First {first_n} Light Curves Overlapped in All Three Bands")
plt.legend(loc="upper right", ncol=first_n)  # adjust layout if it’s crowded
plt.tight_layout()
plt.show()

In [None]:
#plot lightcurves from padded_fixed_list
import matplotlib.pyplot as plt

first_n = 3
plt.figure(figsize=(10, 6))

for i in range(first_n):
    df = padded_fixed_list[i]
    bands = ['ztfi']
    for band in bands:
        plt.scatter(df['t'], df[band], label=f"{band} (sim {i})")

plt.gca().invert_yaxis()
plt.xlabel("Time")
plt.ylabel("Flux")
plt.title(f"First {first_n} Light Curves Overlapped in All Three Bands")
plt.legend(loc="upper right", ncol=first_n)  # adjust layout if it’s crowded
plt.tight_layout()
plt.show()

In [None]:
# injection files

inj_path_simembed = '/nobackup/users/mmdesai/final_injections'
fixed_injections = {}

for i in range(0, 10):
    fixed_inj_df = pd.DataFrame()
    fixed_params = open_json('/injection_simembed_fixed_{}.json'.format(i), inj_path_simembed)
    fixed_inj_df['mej'] = fixed_params['injections']['content']['log10_mej']
    fixed_inj_df['vej'] = fixed_params['injections']['content']['log10_vej']
    fixed_inj_df['xlan'] = fixed_params['injections']['content']['log10_Xlan']
    fixed_inj_df['shift'] = fixed_params['injections']['content']['timeshift']
    fixed_inj_df['distance'] = fixed_params['injections']['content']['luminosity_distance']
    fixed_inj_df['sim_id'] = fixed_params['injections']['content']['simulation_id']
    fixed_injections['fixed_inj_df{}'.format(i)] = fixed_inj_df

In [None]:
fixed_injections['fixed_inj_df0']

In [None]:
# concatenate dataframe lists

all_fixed_data_list = [0] * 10

for i in range(0, 10):
    all_fixed_data_list[i] = pd.concat(fixed_simembed_dict['fixed_simembed_data_{}'.format(i)])

In [None]:
# merge with injection parameters

all_fixed_datawparams_list = [0] * 10

for i in range(0, 10):
    all_fixed_datawparams_list[i] = all_fixed_data_list[i].merge(fixed_injections['fixed_inj_df{}'.format(i)], on = 'sim_id')
    # save as csv file
    all_fixed_datawparams_list[i].to_csv('/nobackup/users/mmdesai/final_csv/fixed_lowc_{}.csv'.format(i), index = False)

In [None]:
all_fixed_datawparams_list[0]

# Load in Data

If the data is stored as a .csv file, use this section to further process the data and import it to the notebook. This section ensures that the fixed (unshifted) and shifted light curves are properly paired and assigns each light curve a unique simulation id (sim_id). Batch numbers are also added, with 50 light curves in a batch. Each set of 50 light curves have the same mass, velocity, and lanthanide fraction of the ejecta. The fixed ones peak at the same time and are set to a luminosity distance of 50 Mpc, while the 50 repeated shifted light curves have a time and distance generated from a uniform prior. We also set the condition for number of detections in this section.

In [None]:
# directory where the csv files are stored

data_dir = '/nobackup/users/mmdesai/final_csv/'

In [None]:
# set the minimum number of detections needed

min_num_detections = 8

## First Batch

In [None]:
matched_df1 = matched(data_dir, 'varied', 'fixed', 0, 10) 

In [None]:
add_batch_sim_nums_all(matched_df1)
matched_df1

In [None]:
true_list1 = []
for i in range(int(len(matched_df1)/num_points/50)):
    batch_df = matched_df1.loc[matched_df1['batch_id'] == i]
    if batch_df['num_detections_x'].min() >= min_num_detections:
        true_list1.append(batch_df)
    else:
        pass

In [None]:
detected_df1 = pd.concat(true_list1)
detected_df1

In [None]:
# varied

var_df = detected_df1.iloc[:, :12]
var_df.columns = var_df.columns.str.rstrip('_x')
var_df = var_df.drop(columns=['key_1'])
add_batch_sim_nums_all(var_df)
var_df

In [None]:
# fixed

fix_df = detected_df1.iloc[:, 12:]
fix_df.columns = fix_df.columns.str.rstrip('_y')
add_batch_sim_nums_all(fix_df)
fix_df

## Second Batch

In [None]:
matched_df2 = matched(data_dir, 'varied', 'fixed', 10, 20) 

In [None]:
add_batch_sim_nums_all(matched_df2)
new_df2

In [None]:
true_list2 = []
for i in range(int(len(matched_df2)/121/50)):
    batch_df = matched_df2.loc[matched_df2['batch_id'] == i]
    if batch_df['num_detections_x'].min() >= min_num_detections:
        true_list2.append(batch_df)
    else:
        pass

In [None]:
detected_df2 = pd.concat(true_list2)
detected_df2

In [None]:
# varied

var_df = detected_df2.iloc[:, :12]
var_df.columns = var_df.columns.str.rstrip('_x')
var_df = var_df.drop(columns=['key_1'])
add_batch_sim_nums_all(var_df)
var_df

In [None]:
# fixed

fix_df = detected_df2.iloc[:, 12:]
fix_df.columns = fix_df.columns.str.rstrip('_y')
add_batch_sim_nums_all(fix_df)
fix_df

In [None]:
desired_order = ['t', 'ztfg', 'ztfr', 'ztfi', 'num_detections', 'mej', 'vej', 'xlan', 'shift', 'distance', 'batch_id', 'sim_id']
df_varied = df_final[desired_order]
df_fixed = df_fixed_final[desired_order]

In [None]:
# df_varied
# print df_varied with a particular sim_id
sim_id = 0  # change this to the sim_id you want to inspect
df_sim = df_varied[df_varied['sim_id'] == sim_id]
print(f"Data for sim_id {sim_id}:\n", df_sim)

In [None]:
# df_fixed
# print df_varied with a particular sim_id
sim_id = 0  # change this to the sim_id you want to inspect
df_sim = df_varied[df_varied['sim_id'] == sim_id]
print(f"Data for sim_id {sim_id}:\n", df_sim)

# Tensor Processing

In [None]:
def repeated_df_to_tensor(df_varied, df_fixed, batches):
    '''
    Converts dataframes into pytorch tensors
    Inputs:
        df_varied: dataframe containing the shifted light curve information
        df_fixed: dataframe containing the analagous fixed light curve information
        batches: number of unique mass, velocity, and lanthanide injections
    Outputs:
        data_shifted_list: list of tensors of shape [repeats, channels, num_points] containing the shifted light curve photometry
        data_unshifted_list: list of tensors of shape [repeats, channels, num_points] containing the fixed light curve photometry
        param_shifted_list: list of tensors of shape [repeats, 1, 5] containing the injection parameters of the shifted light curves
        param_unshifted_list: list of tensors of shape [repeats, 1, 5] containing the injection parameters of the fixed light curves
    '''
    data_shifted_list = []
    data_unshifted_list = []
    param_shifted_list = []
    param_unshifted_list = []
    for idx in tqdm(range(0, batches)):
        data_shifted = torch.tensor(df_varied.loc[df_varied['batch_id'] == idx].iloc[:, 1:4].values.reshape(num_repeats, num_points, num_channels), 
                                    dtype=torch.float32).transpose(1, 2)
        data_unshifted = torch.tensor(df_fixed.loc[df_fixed['batch_id'] == idx].iloc[:, 1:4].values.reshape(num_repeats, num_points, num_channels), 
                                    dtype=torch.float32).transpose(1, 2)
        param_shifted = torch.tensor(df_varied.loc[df_varied['batch_id'] == idx].iloc[::num_points, 5:10].values, 
                                    dtype=torch.float32).unsqueeze(2).transpose(1,2)
        param_unshifted = torch.tensor(df_fixed.loc[df_fixed['batch_id'] == idx].iloc[::num_points, 5:10].values, 
                                    dtype=torch.float32).unsqueeze(2).transpose(1,2)
        data_shifted_list.append(data_shifted)
        data_unshifted_list.append(data_unshifted)
        param_shifted_list.append(param_shifted)
        param_unshifted_list.append(param_unshifted)
    return data_shifted_list, data_unshifted_list, param_shifted_list, param_unshifted_list

In [None]:
# moving the data to tensors on gpu -- ONLY RUN IF YOU ARE USING DATA FROM A CSV FILE
num_repeats = 50
num_channels = 3
num_points = 33
num_batches_paper_sample = len(df_varied['batch_id'].unique())
print(f"Number of batches in the paper sample: {num_batches_paper_sample}")
data_shifted_paper, data_unshifted_paper, param_shifted_paper, param_unshifted_paper = repeated_df_to_tensor(
    df_varied, df_fixed, num_batches_paper_sample)



In [None]:
print(param_shifted_paper[1])

In [None]:
# call the path to the tensors from Zenodo and load the data in

data_shifted_paper1 = torch.load('/nobackup/users/mmdesai/updated_tensors/data_shifted_paper4.pt')
data_unshifted_paper1 = torch.load('/nobackup/users/mmdesai/updated_tensors/data_unshifted_paper4.pt')
param_shifted_paper1 = torch.load('/nobackup/users/mmdesai/updated_tensors/param_shifted_paper4.pt')
param_unshifted_paper1 = torch.load('/nobackup/users/mmdesai/updated_tensors/param_unshifted_paper4.pt')

In [None]:
data_shifted_paper2 = torch.load('/nobackup/users/mmdesai/updated_tensors/data_shifted_paper5.pt')
data_unshifted_paper2 = torch.load('/nobackup/users/mmdesai/updated_tensors/data_unshifted_paper5.pt')
param_shifted_paper2 = torch.load('/nobackup/users/mmdesai/updated_tensors/param_shifted_paper5.pt')
param_unshifted_paper2 = torch.load('/nobackup/users/mmdesai/updated_tensors/param_unshifted_paper5.pt')

In [None]:
data_shifted_paper = torch.stack(data_shifted_paper1 + data_shifted_paper2)
data_unshifted_paper = torch.stack(data_unshifted_paper1 + data_unshifted_paper2)
param_shifted_paper = torch.stack(param_shifted_paper1 + param_shifted_paper2)
param_unshifted_paper = torch.stack(param_unshifted_paper1 + param_unshifted_paper2)

In [None]:
# the data is organized into number of repeats (50) x number of channels (3 - ztfg, ztfr, ztfi) x number of points (121)

data_shifted_paper[0].shape

In [None]:
# the parameters stored are in the order: mass, velocity, lanthanide fraction, time, and distance

param_shifted_paper[0].shape

In [None]:
# total number of batches, each batch contains 50 light curves

num_batches_paper_sample = len(data_shifted_paper)
print(num_batches_paper_sample)

# Similarity Embedding Dataset

In [None]:
# define the loss
vicreg_loss = VICRegLoss()

In [None]:
# define the dataset

dataset_paper = Paper_data(data_shifted_paper, data_unshifted_paper, param_shifted_paper, param_unshifted_paper, num_batches_paper_sample)

# check the dataset shape
_, t, d, _ = dataset_paper[4]
_.shape, t.shape, d.shape

In [None]:
# split dataset into training, testing, and validation

num_batches_paper_sample = len(data_shifted_paper)

train_set_size_paper = int(0.8 * num_batches_paper_sample)    
val_set_size_paper = int(0.1 * num_batches_paper_sample)     
test_set_size_paper = num_batches_paper_sample - train_set_size_paper - val_set_size_paper

print(f"Train set size: {train_set_size_paper}, Validation set size: {val_set_size_paper}, Test set size: {test_set_size_paper}")

train_data_paper, val_data_paper, test_data_paper = torch.utils.data.random_split(
    dataset_paper, [train_set_size_paper, val_set_size_paper, test_set_size_paper])

In [None]:
# load and shuffle the data

train_data_loader_paper = DataLoader(train_data_paper, batch_size=50, shuffle=True)
val_data_loader_paper = DataLoader(val_data_paper, batch_size=50, shuffle=True)
test_data_loader_paper = DataLoader(test_data_paper, batch_size=1, shuffle=False)

# check lengths
len(train_data_loader_paper), len(test_data_loader_paper), len(val_data_loader_paper)

# Data Visualization

Taking a look at some of the data distributions

## Histograms

Checking if the selected data still is uniform after selecting data with greater than 8 detections

In [None]:
mej_list = []
vej_list = []
xlan_list = []

for i in range(len(param_shifted_paper)):
    mej = param_shifted_paper[i][0][0][0]
    vej = param_shifted_paper[i][0][0][1]
    xlan = param_shifted_paper[i][0][0][2]
    mej_list.append(mej)
    vej_list.append(vej)
    xlan_list.append(xlan)

In [None]:
hist = plt.hist(mej_list, bins=25)

In [None]:
mej_list = []
vej_list = []
xlan_list = []

for i in range(len(param_shifted_paper)):
    mej = param_shifted_paper[i][0][0][0]
    vej = param_shifted_paper[i][0][0][1]
    xlan = param_shifted_paper[i][0][0][2]
    mej_list.append(mej)
    vej_list.append(vej)
    xlan_list.append(xlan)

hist = plt.hist(mej_list)

In [None]:
hist = plt.hist(vej_list, bins=25)

In [None]:
hist = plt.hist(vej_list)

In [None]:
ist = plt.hist(xlan_list, bins=25)

In [None]:
hist = plt.hist(xlan_list)

In [None]:
dist_list = []
shift_list = []

for i in range(len(param_shifted_paper)):
    for j in range(0, 50):
        dist = param_shifted_paper[i][j][0][4]
        shift = param_shifted_paper[i][j][0][3]
        dist_list.append(dist)
        shift_list.append(shift)

In [None]:
print(param_shifted_paper[0][20][0][3])

In [None]:
hist = plt.hist(dist_list, bins = 25)

In [None]:
hist = plt.hist(dist_list)


In [None]:
param_shifted_paper[0]

In [None]:
hist = plt.hist(shift_list, bins=25)

In [None]:
hist = plt.hist(shift_list)

## Light Curve Graphs

Some plotting codes for visualizing the light curves from the .csv files -- SKIP FOR NOW

In [None]:

fixed_colors = ['seagreen', 'crimson', 'blue']
varied_colors = ['mediumaquamarine', 'salmon', 'skyblue']
label_list = ['g band', 'r band', 'i band']
bands = ['ztfg', 'ztfr', 'ztfi']

def varied_fixed_plot(varied_df, fixed_df, sim_id, xlim_min=None, xlim_max=None, title = False, bands=bands):
    varied_data = varied_df.loc[varied_df['sim_id'] == sim_id]
    fixed_data = fixed_df.loc[fixed_df['sim_id'] == sim_id]
    if 'mej' in varied_data.columns:
        mej = varied_data.iloc[0, varied_data.columns.get_loc('mej')]
    if 'vej' in varied_data.columns:
        vej = varied_data.iloc[0, varied_data.columns.get_loc('vej')]
    if 'xlan' in varied_data.columns:
        xlan = varied_data.iloc[0, varied_data.columns.get_loc('xlan')]
    fig, axs = plt.subplots(3, sharex=True, sharey=True, figsize=(7,7))
    for i in range(len(bands)):
        axs[i].scatter(fixed_data['t'], fixed_data[bands[i]], label = 'fixed, ' + label_list[i] , color = fixed_colors[i], s = 10)
        axs[i].scatter(varied_data['t'], varied_data[bands[i]], label = 'shifted, ' + label_list[i], color = varied_colors[i], s = 10)
        plt.gca().invert_yaxis()
        axs[i].legend()
        if (xlim_min != None) & (xlim_max != None):
            plt.xlim(xlim_min, xlim_max)
    fig.supxlabel('Time (Days)')
    fig.supylabel('Magnitude')
    if title == True:
        if ('vej' in varied_data.columns) and ('mej' in varied_data.columns):
            fig.suptitle('Light Curve for $\log_{{10}}(M_{{ej}})$: {:.2f}, $\log_{{10}}(V_{{ej}})$: {:.2f}, \n$\log_{{10}}(X_{{lan}})$: {:.2f}'.format(mej, vej, xlan), 
                         fontsize = 15)
    else:
        pass
    if 'shift' in varied_data.columns:
        shift = varied_data.iloc[0, varied_data.columns.get_loc('shift')]
    if 'distance' in varied_data.columns:
        distance = varied_data.iloc[0, varied_data.columns.get_loc('distance')]
    print(shift, distance)
    plt.show()

In [None]:
varied_fixed_plot(df_varied, df_fixed, 2,title = True)

# Similarity Embedding

In [None]:
# puts the neural network on the gpu
similarity_embedding = SimilarityEmbedding(num_dim=7, num_hidden_layers_f=1, num_hidden_layers_h=1, num_blocks=4, kernel_size=5, num_dim_final=5).to(device)
num_dim = 7

# optimizes
optimizer = optim.Adam(similarity_embedding.parameters(), lr=2.747064325271709e-05)

# sets learning rate steps
scheduler_1 = optim.lr_scheduler.ConstantLR(optimizer, total_iters=5) #constant lr
scheduler_2 = optim.lr_scheduler.OneCycleLR(optimizer, total_steps=20, max_lr=2e-3) #one cycle - increase and then decrease
scheduler_3 = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
scheduler = optim.lr_scheduler.SequentialLR(optimizer, schedulers=[scheduler_1, scheduler_2, scheduler_3], milestones=[5, 15])

In [None]:
# check if data is the right shape for similarity embedding

for var_inj_se, fix_inj_se, var_data_se, fix_data_se in train_data_loader_paper:
    var_inj_se = var_inj_se.reshape((-1,)+var_inj_se.shape[2:])
    fix_inj_se = fix_inj_se.reshape((-1,)+fix_inj_se.shape[2:])
    var_data_se = var_data_se.reshape((-1,)+var_data_se.shape[2:])
    fix_data_se = fix_data_se.reshape((-1,)+fix_data_se.shape[2:])
    break
var_inj_se.shape, var_data_se.shape, fix_inj_se.shape, fix_data_se.shape

In [None]:
# checking the shapes

embed, rep = similarity_embedding(var_data_se)
embed.shape, rep.shape

In [None]:
# embed the data and calculate the loss for one example to check for bugs

emb_aug, rep_aug = similarity_embedding(var_data_se)
emb_orig, rep_orig = similarity_embedding(fix_data_se)
vicreg_loss(emb_aug, emb_orig)

In [None]:
# print neural network parameters that require gradients and sum parameters

sum_param=0
for name, param in similarity_embedding.named_parameters():
    if param.requires_grad:
        print(name)
        print(param.numel())
        sum_param+=param.numel()
print(sum_param)

In [None]:
# write to tensorboard for data visualization

writer = SummaryWriter("torchlogs/")
model = similarity_embedding
writer.add_graph(model, var_data_se)
writer.close()

In [None]:
%%time
# training the neural network for many epochs

epoch_number = 0
EPOCHS = 50

sim_val_loss = []
sim_train_loss = []

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # set the vicreg weights
    wt_repr, wt_cov, wt_std = (1, 1, 1)
    print(f"VicReg wts: {wt_repr} {wt_cov} {wt_std}")
    
    # Gradient tracking
    similarity_embedding.train(True)
    avg_train_loss = train_one_epoch_se(epoch_number, writer, train_data_loader_paper,
                                        similarity_embedding, optimizer, vicreg_loss, verbose=True,
                                        wt_repr=wt_repr, wt_cov=wt_cov, wt_std=wt_std)
    sim_train_loss.append(avg_train_loss)
    
    # no gradient tracking, for validation
    similarity_embedding.train(False)
    similarity_embedding.eval()
    avg_val_loss = val_one_epoch_se(epoch_number, writer, val_data_loader_paper,
                                    similarity_embedding, vicreg_loss,
                                    wt_repr=wt_repr, wt_cov=wt_cov, wt_std=wt_std)
    sim_val_loss.append(avg_val_loss)
    
    print(f"Train/Val Sim Loss after epoch: {avg_train_loss:.4f}/{avg_val_loss:.4f}")

    epoch_number += 1
    scheduler.step()


In [None]:
# plot the train/val loss 

epoch_list = range(0,len(sim_train_loss))
plt.plot(epoch_list, sim_train_loss, label = 'Similarity Embedding (Train)', color = 'royalblue', alpha = 0.8, marker = 's')
plt.plot(epoch_list, sim_val_loss, label = 'Similarity Embedding (Val)', color = 'lightsteelblue', alpha=0.8, linestyle="dashed")
plt.legend()
plt.xlabel('Number of Epochs')
plt.ylabel('Loss')

## Saving Embedding Weights

In [None]:
# save the embedded weights (input your save path)

SAVEPATH = '/home/oppenheimer/summer2025/Kilo/weights/similarity-embedding-weights-tutorial2.pth'
torch.save(similarity_embedding.state_dict(), SAVEPATH)

# Loading Embedding Weights

If you do not want to retrain the embedding, just load the pretrained weights here: 

In [None]:
# load the weights 

similarity_embedding = SimilarityEmbedding(num_dim=7, num_hidden_layers_f=1, num_hidden_layers_h=1, num_blocks=4, kernel_size=5, num_dim_final=5).to(device)
num_dim = 7

SAVEPATH = '/home/oppenheimer/summer2025/Kilo/weights/similarity-embedding-weights-tutorial2.pth'
similarity_embedding.load_state_dict(torch.load(SAVEPATH, map_location=device))

# Visualizing the embedding for all test cases

In [None]:
similarity_embedding.train(False)
data_loader = test_data_loader_paper

similarity_outputs_1 = []

for idx, (_, shift_test, data_test, data_test_orig) in enumerate(data_loader):
    _ = _.reshape((-1,)+_.shape[2:])
    data_test = data_test.reshape((-1,)+data_test.shape[2:])
    data_test_orig = data_test_orig.reshape((-1,)+data_test_orig.shape[2:])
    shift_test = shift_test.reshape((-1,)+shift_test.shape[2:])
    if not ((shift_test[0][0][0] < -1) and (shift_test[0][0][0] > -1.25)):
        continue
    if not ((shift_test[0][0][1] < -0.5) and (shift_test[0][0][1] > -0.75)):
        continue
    if not ((shift_test[0][0][2] < -3) and (shift_test[0][0][2] > -4.5)):
        continue
    with torch.no_grad():
        _, similarity_output = similarity_embedding(data_test)
    similarity_outputs_1.append(similarity_output)

similarity_outputs_2 = []

for idx, (_, shift_test, data_test, data_test_orig) in enumerate(data_loader):
    _ = _.reshape((-1,)+_.shape[2:])
    data_test = data_test.reshape((-1,)+data_test.shape[2:])
    data_test_orig = data_test_orig.reshape((-1,)+data_test_orig.shape[2:])
    shift_test = shift_test.reshape((-1,)+shift_test.shape[2:])
    if not ((shift_test[0][0][0] < -1.25) and (shift_test[0][0][0] > -1.5)):
        continue
    if not ((shift_test[0][0][1] < -0.75) and (shift_test[0][0][1] > -1)):
        continue
    if not ((shift_test[0][0][2] < -4.5) and (shift_test[0][0][2] > -6)):
        continue
    with torch.no_grad():
        _, similarity_output = similarity_embedding(data_test)
    similarity_outputs_2.append(similarity_output)

similarity_outputs_3 = []

for idx, (_, shift_test, data_test, data_test_orig) in enumerate(data_loader):
    _ = _.reshape((-1,)+_.shape[2:])
    data_test = data_test.reshape((-1,)+data_test.shape[2:])
    data_test_orig = data_test_orig.reshape((-1,)+data_test_orig.shape[2:])
    shift_test = shift_test.reshape((-1,)+shift_test.shape[2:])
    if not ((shift_test[0][0][0] < -1.5) and (shift_test[0][0][0] > -1.75)):
        continue
    if not ((shift_test[0][0][1] < -1) and (shift_test[0][0][1] > -1.25)):
        continue
    if not ((shift_test[0][0][2] < -6) and (shift_test[0][0][2] > -7.5)):
        continue
    with torch.no_grad():
        _, similarity_output = similarity_embedding(data_test)
    similarity_outputs_3.append(similarity_output)

similarity_outputs_4 = []

for idx, (_, shift_test, data_test, data_test_orig) in enumerate(data_loader):
    _ = _.reshape((-1,)+_.shape[2:])
    data_test = data_test.reshape((-1,)+data_test.shape[2:])
    data_test_orig = data_test_orig.reshape((-1,)+data_test_orig.shape[2:])
    shift_test = shift_test.reshape((-1,)+shift_test.shape[2:])
    if not ((shift_test[0][0][0] < -1.75) and (shift_test[0][0][0] > -2.5)):
        continue
    if not ((shift_test[0][0][1] < -1.25) and (shift_test[0][0][1] > -1.55)):
        continue
    if not ((shift_test[0][0][2] < -7.5) and (shift_test[0][0][2] > -9)):
        continue
    with torch.no_grad():
        _, similarity_output = similarity_embedding(data_test)
    similarity_outputs_4.append(similarity_output)



In [None]:
similarity_outputs_1 = torch.stack(similarity_outputs_1)
similarity_outputs_2 = torch.stack(similarity_outputs_2)
similarity_outputs_3 = torch.stack(similarity_outputs_3)
similarity_outputs_4 = torch.stack(similarity_outputs_4)

In [None]:
similarity_outputs_1.shape, similarity_outputs_2.shape, similarity_outputs_3.shape, similarity_outputs_4.shape

In [None]:
figure = corner.corner(
    similarity_outputs_1.cpu().numpy().reshape((similarity_outputs_1.shape[0]*similarity_outputs_1.shape[1], num_dim)),
    quantiles=[0.16, 0.5, 0.84], color="C1"#, range = [[-1.2,-0.8], [-1.2,-0.8], [-1.2,-0.8]]
)
c1_line = mlines.Line2D([], [], color='C1', 
                            label='-1 > $\log_{{10}}(M_{{ej}})$ > -1.25, -0.5 > $\log_{{10}}(V_{{ej}})$ > -0.75, -3.0 > $\log_{{10}}(X_{{lan}})$ > -4.5')

figure = corner.corner(
    similarity_outputs_2.cpu().numpy().reshape((similarity_outputs_2.shape[0]*similarity_outputs_2.shape[1], num_dim)),
    quantiles=[0.16, 0.5, 0.84], fig=figure, 
    color="C2", # range = [[-1.2,-0.8], [-1.2,-0.8], [-1.2,-0.8]]
)
c2_line = mlines.Line2D([], [], color='C2', 
                            label='-1.25 > $\log_{{10}}(M_{{ej}})$ > -1.5, -0.75 > $\log_{{10}}(V_{{ej}})$ > -1.0, -4.5 > $\log_{{10}}(X_{{lan}})$ > -6.0')

figure = corner.corner(
    similarity_outputs_3.cpu().numpy().reshape((similarity_outputs_3.shape[0]*similarity_outputs_3.shape[1], num_dim)),
    quantiles=[0.16, 0.5, 0.84], fig=figure, color="C3"
)
c3_line = mlines.Line2D([], [], color='C3', 
                            label='-1.5 > $\log_{{10}}(M_{{ej}})$ > -1.75, -1.0 > $\log_{{10}}(V_{{ej}})$ > -1.25, -6.0 > $\log_{{10}}(X_{{lan}})$ > -7.5')

figure = corner.corner(
    similarity_outputs_4.cpu().numpy().reshape((similarity_outputs_4.shape[0]*similarity_outputs_4.shape[1], num_dim)),
    quantiles=[0.16, 0.5, 0.84], fig=figure, color="C4"
)
c4_line = mlines.Line2D([], [], color='C4', 
                            label='-1.75 > $\log_{{10}}(M_{{ej}})$ > -1.9, -1.25 > $\log_{{10}}(V_{{ej}})$ > -1.53, -7.5 > $\log_{{10}}(X_{{lan}})$ > -9.0')

plt.legend(handles=
           [c1_line, c2_line, c3_line, c4_line],
           bbox_to_anchor=(0.3, 7.3),
           fontsize = 18
          )

# Normalizing Flow Training Data

### Varied Data

In [None]:
varied_normflow_dict = {}

In [None]:
# get the varied data

norm_path = '/nobackup/users/mmdesai/lowcflowdata'
num_sims = 25000

normflow_num_lc_list = [25000, 25000, 25000, 24843, 25000, 24917, 25000, 25000, 24706, 25000]

for i in range(3, 10):
    # get the names of each file
    file_names = get_names(norm_path, 'varied', i, normflow_num_lc_list[i])
    # open the files as dataframes
    varied_normflow_dict['varied_normflow_data_{}'.format(i)] = json_to_df(file_names, normflow_num_lc_list[i])
    # pad the data
    varied_normflow_dict['varied_normflow_data_{}'.format(i)] = pad_all_dfs(varied_normflow_dict['varied_normflow_data_{}'.format(i)])

In [None]:
varied_normflow_dict['varied_normflow_data_0'][0]

In [None]:
for i in range(0, 50):
    plt.scatter(varied_normflow_dict['varied_normflow_data_0'][i]['t'], varied_normflow_dict['varied_normflow_data_0'][i]['ztfg'], color = 'g')
    plt.scatter(varied_normflow_dict['varied_normflow_data_0'][i]['t'], varied_normflow_dict['varied_normflow_data_0'][i]['ztfr'], color = 'r')
    plt.scatter(varied_normflow_dict['varied_normflow_data_0'][i]['t'], varied_normflow_dict['varied_normflow_data_0'][i]['ztfi'], color = 'c')
plt.gca().invert_yaxis()
plt.xlabel('Time (days)')
plt.ylabel('Magnitude')

In [None]:
# injection files

varied_normflow_inj_dict = {}
inj_path_normflow = '/nobackup/users/mmdesai/final_injections/'

for i in range(0, 10):
    varied_normflow_inj_dict['varied_inj_df{}'.format(i)] = pd.DataFrame()
    varied_params = open_json('injection_normflow_varied_{}.json'.format(i), inj_path_normflow)
    varied_normflow_inj_dict['varied_inj_df{}'.format(i)]['mej'] = varied_params['injections']['content']['log10_mej']
    varied_normflow_inj_dict['varied_inj_df{}'.format(i)]['vej'] = varied_params['injections']['content']['log10_vej']
    varied_normflow_inj_dict['varied_inj_df{}'.format(i)]['xlan'] = varied_params['injections']['content']['log10_Xlan']
    varied_normflow_inj_dict['varied_inj_df{}'.format(i)]['shift'] = varied_params['injections']['content']['timeshift']
    varied_normflow_inj_dict['varied_inj_df{}'.format(i)]['distance'] = varied_params['injections']['content']['luminosity_distance']
    varied_normflow_inj_dict['varied_inj_df{}'.format(i)]['sim_id'] = varied_params['injections']['content']['simulation_id']

In [None]:
varied_normflow_inj_dict['varied_inj_df0']

In [None]:
# concatenate dataframe lists

all_varied_data_list_flow = [0] * 10

for i in range(0, 10):
    all_varied_data_list_flow[i] = pd.concat(varied_normflow_dict['varied_normflow_data_{}'.format(i)])

In [None]:
# merge with injection parameters

all_varied_datawparams_list_flow = [0] * 10

for i in range(0, 10):
    all_varied_datawparams_list_flow[i] = all_varied_data_list_flow[i].merge(varied_normflow_inj_dict['varied_inj_df{}'.format(i)], on = 'sim_id')
    # save as csv file
    all_varied_datawparams_list_flow[i].to_csv('/nobackup/users/mmdesai/final_csv/flow_varied_lowc_{}.csv'.format(i), index = False)

In [None]:
all_varied_datawparams_list_flow[0]

# Load in the Data

ONLY CSV

In [None]:
data_dir_flow = '/nobackup/users/mmdesai/final_csv/'

In [None]:
df_flow1 = load_in_data(data_dir_flow, 'flow_varied', 10) 

In [None]:
df_flow1

In [None]:
detected_df1 = df_flow1.loc[df_flow1['num_detections'] >= 20]
detected_df1

In [None]:
# varied

detected_df1 = detected_df1.iloc[:29275950, :12]
add_batch_sim_nums_all(detected_df1)
detected_df1

In [None]:
matched_df_flow2 = matched(data_dir_flow, 'flow_varied', 'flow_fixed', 10, 20) 

In [None]:
new_df_flow2 = matched_df_flow2.loc[matched_df_flow2['mej_x'] >= -1.9].copy()
new_df_flow2

In [None]:
detected_df2 = new_df_flow2.loc[new_df_flow2['num_detections_x'] >= 8]
detected_df2

In [None]:
# varied

var_df = detected_df2.iloc[:27097950, :12]
var_df.columns = var_df.columns.str.rstrip('_x')
var_df = var_df.drop(columns=['key_1'])
add_batch_sim_nums_all(var_df)
var_df

In [None]:
# fixed

fix_df = detected_df2.iloc[:27097950, 12:]
fix_df.columns = fix_df.columns.str.rstrip('_y')
add_batch_sim_nums_all(fix_df)
fix_df

In [None]:
matched_df_flow3 = matched(data_dir_flow, 'flow_varied', 'flow_fixed', 20, 30) 

In [None]:
new_df_flow3 = matched_df_flow3.loc[matched_df_flow3['mej_x'] >= -1.9].copy()
new_df_flow3

In [None]:
detected_df3 = new_df_flow3.loc[new_df_flow3['num_detections_x'] >= 8]
detected_df3

In [None]:
# varied

var_df = detected_df3.iloc[:27073750, :12]
var_df.columns = var_df.columns.str.rstrip('_x')
var_df = var_df.drop(columns=['key_1'])
add_batch_sim_nums_all(var_df)
var_df

In [None]:
# fixed

fix_df = detected_df3.iloc[:27073750, 12:]
fix_df.columns = fix_df.columns.str.rstrip('_y')
add_batch_sim_nums_all(fix_df)
fix_df

In [None]:
plt.hist(var_df['mej'])

In [None]:
plt.hist(var_df['vej'])

In [None]:
plt.hist(var_df['xlan'])

In [None]:
plt.hist(var_df['distance'])

# Prep for Flow

In [None]:
# moving the data from csv to tensors on gpu -- Don't run if tensors are already stored and available

num_lc_flow = len(detected_df1['batch_id'].unique())
data_shifted_flow, param_shifted_flow = test_df_to_tensor(detected_df1, num_lc_flow, 50)

In [None]:
data_shifted_flow1 = torch.load('/nobackup/users/mmdesai/updated_tensors/data_shifted_flow4.pt')
data_unshifted_flow1 = torch.load('/nobackup/users/mmdesai/updated_tensors/data_unshifted_flow4.pt')
param_shifted_flow1 = torch.load('/nobackup/users/mmdesai/updated_tensors/param_shifted_flow4.pt')
param_unshifted_flow1 = torch.load('/nobackup/users/mmdesai/updated_tensors/param_unshifted_flow4.pt')

In [None]:
data_shifted_flow2 = torch.load('/nobackup/users/mmdesai/updated_tensors/data_shifted_flow7.pt')
data_unshifted_flow2 = torch.load('/nobackup/users/mmdesai/updated_tensors/data_unshifted_flow7.pt')
param_shifted_flow2 = torch.load('/nobackup/users/mmdesai/updated_tensors/param_shifted_flow7.pt')
param_unshifted_flow2 = torch.load('/nobackup/users/mmdesai/updated_tensors/param_unshifted_flow7.pt')

In [None]:
data_shifted_flow = torch.stack(data_shifted_flow1 + data_shifted_flow2)
param_shifted_flow = torch.stack(param_shifted_flow1 + param_shifted_flow2)

In [None]:
num_lc_flow = len(data_shifted_flow)
print(num_lc_flow)

In [None]:
data_shifted_flow[0].shape, param_shifted_flow[0].shape

In [None]:
dataset_normflow = Flow_data(data_shifted_flow, param_shifted_flow, num_lc_flow)

# check the dataset shape
t, d = dataset_normflow[4]
t.shape, d.shape

In [None]:
# split dataset into training, testing, and validation

train_set_size_flow = int(0.8 * num_lc_flow)    
val_set_size_flow = int(0.1 * num_lc_flow)     
test_set_size_flow = num_lc_flow - train_set_size_flow - val_set_size_flow

train_data_flow, val_data_flow, test_data_flow = torch.utils.data.random_split(
    dataset_normflow, [train_set_size_flow, val_set_size_flow, test_set_size_flow])

In [None]:
# load and shuffle the data

train_data_loader_flow = DataLoader(train_data_flow, batch_size=25, shuffle=True)
val_data_loader_flow = DataLoader(val_data_flow, batch_size=25, shuffle=True)
test_data_loader_flow = DataLoader(test_data_flow, batch_size=1, shuffle=False)

# check lengths
len(train_data_loader_flow), len(test_data_loader_flow), len(val_data_loader_flow)

In [None]:
# check first instance of data

for var_inj, var_data in train_data_loader_flow:
    var_inj = var_inj.reshape((-1,)+var_inj.shape[2:])
    var_data = var_data.reshape((-1,)+var_data.shape[2:])

    break
var_inj.shape, var_data.shape

# Histograms

In [None]:
mej_list = []
vej_list = []
xlan_list = []

for i in range(len(param_shifted_flow)):
    mej = param_shifted_flow[i][0][0][0]
    vej = param_shifted_flow[i][0][0][1]
    xlan = param_shifted_flow[i][0][0][2]
    mej_list.append(mej)
    vej_list.append(vej)
    xlan_list.append(xlan)

In [None]:
hist = plt.hist(mej_list, bins=25)

In [None]:
hist = plt.hist(vej_list, bins=25)

In [None]:
hist = plt.hist(xlan_list, bins=25)

In [None]:
dist_list = []
shift_list = []

for i in range(len(param_shifted_flow)):
    for j in range(0, 50):
        dist = param_shifted_flow[i][j][0][4]
        shift = param_shifted_flow[i][j][0][3]
        dist_list.append(dist)
        shift_list.append(shift)

In [None]:
hist = plt.hist(dist_list, bins = 25)

In [None]:
hist = plt.hist(shift_list, bins = 25)

# Partially Freeze the Similarity Embedding

In [None]:
for var_inj_se, var_data_se in train_data_loader_flow:
    var_inj_se = var_inj_se.reshape((-1,)+var_inj_se.shape[2:]).to(device)
    var_data_se = var_data_se.reshape((-1,)+var_data_se.shape[2:]).to(device)
    break

# check shapes
print(var_data_se.shape, var_inj_se.shape)
_, rep = similarity_embedding(var_data_se)  # _.shape = batch_size x 1 x 10, # rep.shape = batch_size x 1 x 2
print(_.shape, rep.shape)
context_features = rep.shape[-1]
print('number of context_features: ', context_features)
print('number of dimensions: ', num_dim)

In [None]:
# define parameters

transform, base_dist, embedding_net = normflow_params(similarity_embedding, 9, 5, 90, context_features=context_features, num_dim=num_dim) 

In [None]:
flow = Flow(transform, base_dist, embedding_net).to(device=device)

In [None]:
print('Total number of trainable parameters: ', sum(p.numel() for p in flow.parameters() if p.requires_grad))

In [None]:
for idx, val in enumerate(train_data_loader_flow, 1):
    augmented_shift, augmented_data = val
    augmented_shift = augmented_shift[...,0:3].to(device)
    augmented_shift = augmented_shift.flatten(0, 2).to(device)
    augmented_data = augmented_data.reshape(-1, 3, num_points).to(device)
    print(augmented_shift.shape, augmented_data.shape)
    break

In [None]:
similarity_embedding(augmented_data)[0].shape, similarity_embedding(augmented_data)[1].shape

In [None]:
flow_loss = -flow.log_prob(augmented_shift, context=augmented_data).mean()
flow_loss

# Train and Validate

In [None]:
# optimizer
optimizer = optim.SGD(flow.parameters(), lr=0.0000912, momentum=0.5)
# scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5, threshold=0.001)

In [None]:
writer = SummaryWriter()

In [None]:
%%time
# UNCOMMENT AND RUN TO TRAIN FROM SCRATCH

train_loss_list = []
val_loss_list = []

EPOCHS = 50
epoch_number = 0
for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))
    flow.train(True)
    for name, param in flow._embedding_net.named_parameters():
        param.requires_grad = True
    avg_train_loss = train_one_epoch(epoch_number, writer, train_data_loader_flow, flow, optimizer, 2)
    train_loss_list.append(avg_train_loss)
    flow.train(False)
    avg_val_loss = val_one_epoch(epoch_number, writer, val_data_loader_flow, flow, 2)
    val_loss_list.append(avg_val_loss)
    print(f"Train/Val flow Loss after epoch: {avg_train_loss:.4f}/{avg_val_loss:.4f}")
    epoch_number += 1
    scheduler.step(avg_val_loss)
    for param_group in optimizer.param_groups:
        print("Current LR = {:.3e}".format(param_group['lr']))

# Weights

In [None]:
# save normalizing flow weights

PATH_nflow = '/nobackup/users/mmdesai/flow_weights_tutorial.pth'
torch.save(flow.state_dict(), PATH_nflow)

In [None]:
# load the normalizing flow weights

context_features = 7
transform, base_dist, embedding_net = normflow_params(similarity_embedding, 9, 5, 90, context_features=context_features, num_dim=num_dim) 
flow = Flow(transform, base_dist, embedding_net).to(device=device)

PATH_nflow = '/nobackup/users/mmdesai/flow_weights_tutorial.pth'
flow.load_state_dict(torch.load(PATH_nflow, map_location=device))

In [None]:
# plot the comparision of the train/val loss for the three scenarios

epoch_list = range(0,200)

plt.plot(epoch_list[:len(train_loss_list)], train_loss_list[:len(train_loss_list)], label = 'Current Run', color = 'k')
plt.plot(epoch_list[:len(val_loss_list)], val_loss_list[:len(val_loss_list)], label = 'Validation', color = 'k', linestyle = 'dashed')
plt.ylabel('- Log. Prob.')
plt.xlabel('Epochs')
plt.legend()

In [None]:
for idx, (shift_test, data_test) in enumerate(test_data_loader_flow):
    data_test = data_test.reshape((-1,)+data_test.shape[2:])
    shift_test = shift_test.reshape((-1,)+shift_test.shape[2:])
    if idx % 100 !=0: continue 
    with torch.no_grad():
        samples = flow.sample(1000, context=data_test[0].reshape((1, 3, num_points)))
    live_plot_samples(samples.cpu().reshape(1000,3), shift_test[0][0].cpu()[...,0:3])
    plt.show()