In [7]:
import os
os.chdir('/home/scur2012/Thesis/master-thesis/experiments/network_training')

import zarr
import swyft.lightning as sl

zarr_store_dirs = '/scratch-shared/scur2012/peregrine_data/tmnre_experiments'
name_of_run = 'peregrine_copy_highSNR_v3'

rnd_id = 2
simulation_store_path = f"{zarr_store_dirs}/{name_of_run}/simulations/round_{rnd_id+1}"
zarr_store = sl.ZarrStore(f"{simulation_store_path}")

simulation_results = zarr.convenience.open(simulation_store_path)


In [8]:
import importlib
import gw_parameters
importlib.reload(gw_parameters)

conf = gw_parameters.default_conf
bounds = gw_parameters.limits

# Settings for trainer and network

trainer_settings = dict(
    min_epochs = 30,
    max_epochs = 200,
    early_stopping = 7,
    num_workers = 8,
    training_batch_size = 256,
    validation_batch_size = 256,
    train_split = 0.9,
    val_split = 0.1
)

network_settings = dict(
    # Peregrine
    shuffling = True,
    priors = dict(
        int_priors = conf['priors']['int_priors'],
        ext_priors = conf['priors']['ext_priors'],
    ),
    marginals = ((0, 1),),
    one_d_only = True,
    ifo_list = conf["waveform_params"]["ifo_list"],
    learning_rate = 5e-4,
    training_batch_size = trainer_settings['training_batch_size'],
    save_path = '/home/scur2012/Thesis/master-thesis/experiments/network_training/roc_curve'
)

In [9]:
import torch
torch.set_float32_matmul_precision('high')
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

# Initialise dataloaders

train_data = zarr_store.get_dataloader(
    num_workers=trainer_settings['num_workers'],
    batch_size=trainer_settings['training_batch_size'],
    idx_range=[0, int(trainer_settings['train_split'] * len(zarr_store.data.z_int))],
    on_after_load_sample=False,
)

val_data = zarr_store.get_dataloader(
    num_workers=trainer_settings['num_workers'],
    batch_size=trainer_settings['validation_batch_size'],
    idx_range=[
        int(trainer_settings['train_split'] * len(zarr_store.data.z_int)),
        len(zarr_store.data.z_int) - 1,
    ],
    on_after_load_sample=None,
)

# Set up the pytorch trainer settings

tmp_dir = f"tmp_dir"

lr_monitor = LearningRateMonitor(logging_interval="step")
early_stopping_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.0,
    patience=trainer_settings["early_stopping"],
    verbose=False,
    mode="min",
)
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath=f"{tmp_dir}",
    filename="{epoch}_{val_loss:.2f}_{train_loss:.2f}" + f"_round_{rnd_id+1}",
    mode="min",
)

# Make directory for logger
os.makedirs(f'{tmp_dir}/logs', exist_ok=True)
logger_tbl = pl_loggers.TensorBoardLogger(
    save_dir=f"{tmp_dir}",
    name=f"logs",
    version=None,
    default_hp_metric=False,
)

swyft_trainer = sl.SwyftTrainer(
    accelerator='gpu',
    devices=1,
    min_epochs=trainer_settings["min_epochs"],
    max_epochs=trainer_settings["max_epochs"],
    logger=logger_tbl,
    callbacks=[lr_monitor, early_stopping_callback, checkpoint_callback],
    enable_progress_bar = True
)


  rank_zero_warn(


In [10]:
import peregrine_network
from peregrine_network import InferenceNetwork
importlib.reload(peregrine_network)

<module 'peregrine_network' from '/gpfs/home3/scur2012/Thesis/master-thesis/experiments/network_training/peregrine_network.py'>

In [11]:
# Load network model
ckpt = f"/scratch-shared/scur2012/peregrine_data/tmnre_experiments/peregrine_copy_highSNR_v3/training/round_1/epoch=59_val_loss=-4.26_train_loss=-4.24_round_1.ckpt"
# ckpt = f"/scratch-shared/scur2012/peregrine_data/tmnre_experiments/peregrine_copy_highSNR_v3/training/round_7/epoch=28_val_loss=-5.16_train_loss=-5.23_round_7.ckpt"
# checkpoint = torch.load(ckpt)

network = InferenceNetwork.load_from_checkpoint(ckpt, **network_settings)
# network = InferenceNetwork(**network_settings)




In [12]:
# Fit data to model

swyft_trainer.fit(network, train_data, val_data)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name         | Type                   | Params
--------------------------------------------------------
0 | unet_t       | Unet                   | 722 K 
1 | unet_f       | Unet                   | 722 K 
2 | flatten      | Flatten                | 0     
3 | linear_t     | LinearCompression      | 0     
4 | linear_f     | LinearCompression      | 0     
5 | logratios_1d | LogRatioEstimator_1dim | 290 K 
--------------------------------------------------------
1.7 M     Trainable params
0         Non-trainable params
1.7 M     Total params
6.942     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [None]:
import torch.optim as optim

# model = InferenceNetwork.load_from_checkpoint(ckpt, **network_settings)
# optimizer = optim.Adam(model.parameters(), lr=network_settings['learning_rate'])
# 
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 
# # Start training the model
# 
# # enable gradient calculation
# torch.set_grad_enabled(True)

for batch_idx, batch in enumerate(train_data):

    break

    loss = model.training_step(batch, batch_idx)

    # zero the parameter gradients
    optimizer.zero_grad()   
    
    loss.backward()
    optimizer.step()

    print (batch_idx)



In [None]:
from toolz.dicttoolz import valmap

A = batch
B = valmap(lambda z: torch.roll(z, 1, dims=0), A)

x = A
z = {}
for key in B:
    z[key] = torch.cat([A[key], B[key]])

num_pos = len(list(x.values())[0])  # Number of positive examples
num_neg = len(list(z.values())[0]) - num_pos  # Number of negative examples

In [None]:
A['z_total'].shape[1]

In [None]:
logratios = model._get_logratios( model(x,z) )

In [None]:
y = torch.zeros_like(logratios)
y[:num_pos, ...] = 1
pos_weight = torch.ones_like(logratios[0]) * num_neg / num_pos
probabilities = torch.nn.functional.softmax(logratios, dim=1)

In [None]:
probabilities.detach().numpy()[:,1]

In [None]:
import numpy as np
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

# Example data

plt.figure()
for i, param_name in enumerate(bounds.keys()):
    
    # Compute ROC curve
    fpr, tpr, thresholds = roc_curve(y.numpy()[:,i], probabilities.detach().numpy()[:,i])
    roc_auc = auc(fpr, tpr)  # Calculate area under the curve

    # Plot ROC curve
    plt.plot(fpr, tpr, lw=1, label=f'{param_name} (area = {roc_auc :0.2f})')


plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(bbox_to_anchor=(1.7, 1), loc="upper right")
plt.show()

In [None]:
bounds