In [53]:
import os
os.chdir('/home/scur2012/Thesis/master-thesis/experiments/peregrine')
import zarr
import pickle
import torch

import swyft.lightning as sl

run_name = 'lowSNR'
rnd_id = 1

data_directory = '/scratch-shared/scur2012/peregrine_data/bhardwaj2023'

for run_name in ['lowSNR', 'highSNR']:
    for rid in range(8):
        rnd_id = rid+1

        # Test the network
        match run_name:
            case 'highSNR':
                match rnd_id:
                    case 1: ckpt = 'epoch=91_val_loss=-4.33_train_loss=-4.65_R1.ckpt'
                    case 2: ckpt = 'epoch=87_val_loss=-6.76_train_loss=-7.49_R2.ckpt'
                    case 3: ckpt = 'epoch=49_val_loss=-7.17_train_loss=-7.48_R3.ckpt'
                    case 4: ckpt = 'epoch=70_val_loss=-5.81_train_loss=-5.89_R4.ckpt'
                    case 5: ckpt = 'epoch=34_val_loss=-5.53_train_loss=-5.32_R5.ckpt'
                    case 6: ckpt = 'epoch=24_val_loss=-5.06_train_loss=-5.13_R6.ckpt'
                    case 7: ckpt = 'epoch=72_val_loss=-5.55_train_loss=-5.79_R7.ckpt'
                    case 8: ckpt = 'epoch=38_val_loss=-5.33_train_loss=-5.22_R8.ckpt'
            case 'lowSNR':
                match rnd_id:
                    case 1: ckpt = 'epoch=71_val_loss=-4.34_train_loss=-4.30_R1.ckpt'
                    case 2: ckpt = 'epoch=71_val_loss=-4.77_train_loss=-4.81_R2.ckpt'
                    case 3: ckpt = 'epoch=59_val_loss=-4.36_train_loss=-4.02_R3.ckpt'
                    case 4: ckpt = 'epoch=48_val_loss=-4.41_train_loss=-4.19_R4.ckpt'
                    case 5: ckpt = 'epoch=50_val_loss=-4.57_train_loss=-4.41_R5.ckpt'
                    case 6: ckpt = 'epoch=36_val_loss=-4.47_train_loss=-4.28_R6.ckpt'
                    case 7: ckpt = 'epoch=79_val_loss=-4.47_train_loss=-4.27_R7.ckpt'
                    case 8: ckpt = 'epoch=65_val_loss=-4.36_train_loss=-4.53_R8.ckpt' 

        checkpoint_path = f'/scratch-shared/scur2012/peregrine_data/bhardwaj2023/trainer_{run_name}_R{rnd_id}/{ckpt}'

        # Load the data from the store

        simulation_path = f'/scratch-shared/scur2012/peregrine_data/bhardwaj2023/simulations_{run_name}_R{rnd_id}'
        simulation_results = zarr.convenience.open(simulation_path)

        zarr_store = sl.ZarrStore(f"{simulation_path}")

        logratio_path = f'/scratch-shared/scur2012/peregrine_data/bhardwaj2023_v2/logratios_{run_name}/logratios_R{rnd_id}'
        if os.path.exists(logratio_path):
            with open(logratio_path, 'rb') as f:
                logratio_data = pickle.load(f)

        # Add function to network to get predictions

        import torch
        from torch import nn
        from torch.functional import F
        from toolz.dicttoolz import valmap
        from peregrine_network import InferenceNetwork

        class InferenceNetworkX(InferenceNetwork):

            def get_A_B_samples(self, batch):
                
                if isinstance(
                    batch, list
                ):  # multiple dataloaders provided, using second one for contrastive samples
                    A = batch[0]
                    B = batch[1]
                else:  # only one dataloader provided, using same samples for constrative samples
                    A = batch
                    B = valmap(lambda z: torch.roll(z, 1, dims=0), A)

                # Concatenate positive samples and negative (contrastive) examples
                x = A
                z = {}
                for key in B:
                    z[key] = torch.cat([A[key], B[key]])
                    
                return x,z
            
            def get_logratios_probabilities(self, batch):
                
                A, B = self.get_A_B_samples(batch)

                num_pos = len(list(A.values())[0])  # Number of positive examples
                num_neg = len(list(B.values())[0]) - num_pos  # Number of negative examples

                out = self(A,B)  # Evaluate network

                logratios = self._get_logratios(
                    out
                )  # Generates concatenated flattened list of all estimated log ratios
                
                if logratios is not None:
                    y = torch.zeros_like(logratios)
                    y[:num_pos, ...] = 1
                    
                    pos_weight = torch.ones_like(logratios[0]) * num_neg / num_pos
                    loss_xe = F.binary_cross_entropy_with_logits(
                        logratios, y, reduction="none", pos_weight=pos_weight
                    )
                    
                # Use soft-max to convert logratios to probabilities
                probabilities = nn.functional.softmax(logratios, dim=0)
                
                return logratios, probabilities, y, loss_xe
        import gw_parameters

        conf = gw_parameters.default_conf
        bounds = gw_parameters.limits

        network_settings = dict(
            # Peregrine
            shuffling = True,
            priors = dict(
                int_priors = conf['priors']['int_priors'],
                ext_priors = conf['priors']['ext_priors'],
            ),
            marginals = ((0, 1),),
            one_d_only = True,
            ifo_list = conf["waveform_params"]["ifo_list"],
            learning_rate = 5e-4,
            training_batch_size = 256,
        )

        # Load network model
        network = InferenceNetworkX(**network_settings)
        checkpoint = torch.load(checkpoint_path)
        network.load_state_dict(checkpoint['state_dict'])
        # Initialise data loader

        train_data = zarr_store.get_dataloader(
            num_workers=8,
            batch_size=256,
            idx_range=[0, int(0.9 * len(zarr_store.data.z_int))],
            on_after_load_sample=False,
        )

        val_data = zarr_store.get_dataloader(
            num_workers=8,
            batch_size=256,
            idx_range=[
                int(0.9 * len(zarr_store.data.z_int)),
                len(zarr_store.data.z_int) - 1,
            ],
            on_after_load_sample=None,
        )
        from tqdm import tqdm

        network.eval()
        torch.set_grad_enabled(False)

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        network.to(device)

        val_losses = []
        val_losses_xe = []
        logratios = []
        probabilities = []
        labels = []

        for batch_idx, batch in enumerate(tqdm(val_data, total=val_data.dataset.n_samples//val_data.batch_size)):

            batch = {key:batch[key].to(device) for key in batch}
            
            loss = network.validation_step(batch, batch_idx)
            val_losses.append(loss.item())
            
            logratios_probs = network.get_logratios_probabilities(batch)
            logratios.append(logratios_probs[0])
            probabilities.append(logratios_probs[1])
            labels.append(logratios_probs[2])
            val_losses_xe.append(logratios_probs[3])

        avg_epoch_val_loss = sum(val_losses)/len(val_losses)

        logratios_np = torch.concat(logratios).detach().cpu().numpy()
        probabilities_np = torch.concat(probabilities).detach().cpu().numpy()
        labels_np = torch.concat(labels).cpu().numpy()

        import matplotlib.pyplot as plt
        from sklearn.metrics import roc_curve, auc

        intrinsic_variables = gw_parameters.intrinsic_variables
        extrinsic_variables = gw_parameters.extrinsic_variables

        plt.figure(1)

        # Intrinsic variables plotting
        for i, name in enumerate(intrinsic_variables):
            fpr, tpr, thresholds = roc_curve(labels_np[:,i], probabilities_np[:,i])
            roc_auc = auc(fpr, tpr)

            plt.plot(fpr, tpr, lw=1, label=f'{name} (area = {roc_auc:.2f})')
        plt.plot([0, 1], [0, 1], color='grey', lw=0.5, linestyle='--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title(f'ROC instrinsic. {run_name} round {rnd_id}')
        plt.legend(bbox_to_anchor=(1.7, 1), loc="upper right")
        plt.savefig(f'ROC_curve_intrinsic_{run_name}_round_{rnd_id}.png', dpi=600, bbox_inches='tight')
        plt.close()

        plt.figure(2)

        # Extrinsic variables plotting
        for i, name in enumerate(extrinsic_variables):
            fpr, tpr, thresholds = roc_curve(labels_np[:,i+10], probabilities_np[:,i+10])
            roc_auc = auc(fpr, tpr)

            plt.plot(fpr, tpr, lw=1, label=f'{name} (area = {roc_auc:.2f})')
        plt.plot([0, 1], [0, 1], color='grey', lw=0.5, linestyle='--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title(f'ROC extrinsic. {run_name} round {rnd_id}')
        plt.legend(bbox_to_anchor=(1.7, 1), loc="upper right")
        plt.savefig(f'ROC_curve_extrinsic_{run_name}_round_{rnd_id}.png', dpi=600, bbox_inches='tight')
        plt.close()
        

  rank_zero_warn(
 82%|████████▏ | 9/11 [00:12<00:02,  1.44s/it]
  rank_zero_warn(
 78%|███████▊  | 18/23 [00:18<00:05,  1.04s/it]
  rank_zero_warn(
 89%|████████▊ | 31/35 [00:26<00:03,  1.19it/s]
  rank_zero_warn(
 91%|█████████▏| 42/46 [00:32<00:03,  1.28it/s]
  rank_zero_warn(
 91%|█████████▏| 42/46 [00:29<00:02,  1.42it/s]
  rank_zero_warn(
 90%|████████▉ | 52/58 [00:37<00:04,  1.40it/s]
  rank_zero_warn(
 90%|████████▉ | 52/58 [00:37<00:04,  1.39it/s]
  rank_zero_warn(
 90%|████████▉ | 52/58 [00:38<00:04,  1.35it/s]
  rank_zero_warn(
 82%|████████▏ | 9/11 [00:13<00:03,  1.51s/it]
  rank_zero_warn(
 78%|███████▊  | 18/23 [00:18<00:05,  1.04s/it]
  rank_zero_warn(
 89%|████████▊ | 31/35 [00:25<00:03,  1.24it/s]
  rank_zero_warn(
 91%|█████████▏| 42/46 [00:30<00:02,  1.37it/s]
  rank_zero_warn(
 91%|█████████▏| 42/46 [00:31<00:02,  1.35it/s]
  rank_zero_warn(
 90%|████████▉ | 52/58 [00:37<00:04,  1.40it/s]
  rank_zero_warn(
 90%|████████▉ | 52/58 [00:36<00:04,  1.42it/s]
  rank_zero_

1

11