In [1]:
import numpy as np
import bilby 
#import pycbc 
import sys
import matplotlib.pyplot as plt
import glob 

#import zuko
from glasflow import RealNVP, CouplingNSF
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn

import river.data
from river.data.datagenerator import DataGeneratorBilbyFD
from river.data.dataset import DatasetSVDStrainFDFromSVDWFonGPU, DatasetSVDStrainFDFromSVDWFonGPUBatch
#import river.data.utils as datautils
from river.data.utils import *

from river.models import embedding
from river.models.utils import *
from river.models.embedding.conv import EmbeddingConv1D, EmbeddingConv2D
from river.models.embedding.mlp import EmbeddingMLP1D
from river.models.inference.cnf import GlasNSFConv1DRes, GlasNSFConv1D, GlasNSFTest, GlasflowEmbdding

import logging
import sys
import os
import json
from copy import deepcopy



SWIGLAL standard output/error redirection is enabled in IPython.
This may lead to performance penalties. To disable locally, use:

with lal.no_swig_redirect_standard_output_error():
    ...

To disable globally, use:

lal.swig_redirect_standard_output_error(True)

Note however that this will likely lead to error messages from
LAL functions being either misdirected or lost when called from
Jupyter notebooks.


import lal

  import lal


In [66]:
config_path = 'test_train_output'
with open(f"{config_path}/config.json", 'r') as f:
    config = json.load(f)

config_datagenerator = config['data_generator_parameters']
config_training = config['training_parameters']
config_model = config['model_parameters']
config_precaldata = config['precaldata_parameters']



# Set up logger
PID = os.getpid()
device='cuda:1'
logger = logging.getLogger()
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s')

stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setLevel(logging.DEBUG)
stdout_handler.setFormatter(formatter)

ckpt_dir = config['ckpt_dir']
if not os.path.exists(ckpt_dir):
    os.mkdir(ckpt_dir)
    logger.warning(f"{ckpt_dir} does not exist. Made dir {ckpt_dir}.")

logfilename = f"{ckpt_dir}/logs.log"
file_handler = logging.FileHandler(logfilename)
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
ckpt_path = f'{ckpt_dir}/checkpoint.pickle'

logger.info(f'PID={PID}.')
logger.info(f'Output path: {ckpt_dir}')

detector_names = config_datagenerator['detector_names']



logger.info(f'Loading precalculated data.')
train_filenames = glob.glob(f"{config_precaldata['train']['folder']}/batch*/*.h5")[:2]
valid_filenames = glob.glob(f"{config_precaldata['valid']['folder']}/batch*/*.h5")
#logger.info(f'{len(train_precaldata_filelist)}, {len(valid_precaldata_filelist)}')

data_generator = DataGeneratorBilbyFD(**config_datagenerator)

Vhfile = config_model['Vhfile']
Nbasis = config_model['Nbasis']
batch_size_train = config_training['batch_size_train']
minibatch_size_train = config_training['minibatch_size_train']
batch_size_valid = config_training['batch_size_valid']


dataset_train = DatasetSVDStrainFDFromSVDWFonGPUBatch(train_filenames, PARAMETER_NAMES_CONTEXT_PRECESSINGBNS_BILBY, data_generator,
                                 Nbasis=Nbasis, Vhfile=Vhfile, device=device, minibatch_size=minibatch_size_train, 
                                                      fix_extrinsic=True, shuffle=False, add_noise=False)
dataset_valid = DatasetSVDStrainFDFromSVDWFonGPU(valid_filenames, PARAMETER_NAMES_CONTEXT_PRECESSINGBNS_BILBY, data_generator,
                                 Nbasis=Nbasis, Vhfile=Vhfile, device=device, fix_extrinsic=True, shuffle=False, add_noise=False)



Nsample = len(dataset_train)*minibatch_size_train
Nvalid = len(dataset_valid)
logger.info(f'Nsample: {Nsample}, Nvalid: {Nvalid}.')
logger.info(f'batch_size_train: {batch_size_train}, batch_size_valid: {batch_size_valid}')

batch_size_train = 4096
train_loader = DataLoader(dataset_train, batch_size=batch_size_train // minibatch_size_train, shuffle=False)
valid_loader = DataLoader(dataset_valid, batch_size=batch_size_valid, shuffle=False)



02:10 bilby INFO    : Waveform generator initiated with
  frequency_domain_source_model: bilby.gw.source.lal_binary_neutron_star
  time_domain_source_model: None
  parameter_conversion: bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters


Using bilby_default PSDs to generate data.


In [67]:
config_dict = config.copy()
config_dict['model_parameters']['embedding']['model'] = 'EmbeddingResConv1DMLP'
NCOND = 64
config_dict['model_parameters']['embedding']['nout'] = NCOND
#config_dict['model_parameters']['embedding']['ndet'] = 3
config_dict['model_parameters']['embedding']['nbasis'] = config_dict['model_parameters']['Nbasis']
config_dict['model_parameters']['embedding']['conv_params'] = {
        'in_channel':  [6,  3, ],
        'out_channel': [3, 1],
        'kernel_size': [16, 16, 16, 8, 8, 8, 8, 8, 8, 8, 4, 4, 4, 4, 2, 2, 2, 1],
        'stride':      [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        'padding':     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        'dilation':    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        'dropout':     [0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    }
config_dict['model_parameters']['embedding']['mlp_params'] = {
        'in_features': [0,],
        'out_features': [NCOND,],
    }



In [68]:
config_dict['model_parameters']['flow'] = {}

config_dict['model_parameters']['flow']['model'] = 'CouplingNSF'
config_dict['model_parameters']['flow']['n_inputs'] = 17 
config_dict['model_parameters']['flow']['n_transforms'] = 3
config_dict['model_parameters']['flow']['n_conditional_inputs'] = NCOND
config_dict['model_parameters']['flow']['n_neurons'] = 6  # 32 by default
config_dict['model_parameters']['flow']['batch_norm_between_transforms'] = True
config_dict['model_parameters']['flow']['batch_norm_within_blocks'] = False
config_dict['model_parameters']['flow']['n_blocks_per_transform'] = 2  # 2 by default, 5
config_dict['model_parameters']['flow']['num_bins'] = 4  # 4 by default, 8
config_dict['model_parameters']['flow']['tail_bound'] = 1 # 5 by default, 1


In [69]:

#model = GlasNSFConv1DRes(config).to(device)
#model = GlasNSFConv1D(config).to(device)
model = GlasflowEmbdding(config_dict).to(device)


lr = config_training['lr']
gamma = config_training['gamma']
weight_decay = config_training['weight_decay']
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

logger.info(f'Initial learning rate: {lr}')
logger.info(f'Gamma: {gamma}')

max_epoch = config_training['max_epoch']
#epoches_pretrain = config_training['epoches_pretrain']
epoches_save_loss = config_training['epoches_save_loss']
epoches_adjust_lr = config_training['epoches_adjust_lr']
epoches_adjust_lr_again = config_training['epoches_adjust_lr_again']
#load_from_previous_train = 1

load_from_previous_train = False
if load_from_previous_train:
    checkpoint = torch.load(ckpt_path)

    best_epoch = checkpoint['epoch']
    start_epoch = best_epoch + 1
    lr_updated_epoch = start_epoch
    model.load_state_dict(checkpoint['model_state_dict']) 

    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])


    train_losses = checkpoint['train_losses']
    valid_losses = checkpoint['valid_losses']


    logger.info(f'Loaded states from {ckpt_path}, epoch={start_epoch}.')
else:
    best_epoch = 0
    train_losses = []
    valid_losses = []

    start_epoch = 0
    lr_updated_epoch = start_epoch

npara_flow = count_parameters(model.flow)
#npara_embd_proj = count_parameters(model.embedding)
#npara_embd_res = count_parameters(model.resnet)
npara_total = count_parameters(model)
#logger.info(f'Learnable parameters: flow: {npara_flow}, embedding_PCA: {npara_embd_proj}, ResNet: {npara_embd_res}. Total: {npara_total}. ')
#logger.info(f'Learnable parameters: flow: {npara_flow}, embedding_PCA: {npara_embd_proj}. Total: {npara_total}. ')
logger.info(f'Learnable parameters: flow: {npara_flow}, total: {npara_total}. ')

###
#for g in optimizer.param_groups:
#    g['lr'] = 1e-5
#    logger.info(f'Set lr to 1e-5.')



Initialized MLP in channel: 482


In [70]:
best_epoch

0

In [71]:
f'Learnable parameters: flow: {npara_flow}, total: {npara_total}. '

'Learnable parameters: flow: 6197, total: 37819. '

In [72]:
start_epoch

0

In [73]:
def mytrain_GlasNSFWarpper(model, optimizer, dataloader, detector_names=None, ipca_gen=None, device='cpu',downsample_rate=1, minibatch_size=0):
    model.train()
    loss_list = []
    for theta, x in dataloader:
        optimizer.zero_grad()
        theta = theta.to(device)
        x = x.to(device)
        
        if minibatch_size>0:
            # x: [bs, minibatch_size, nchannel, nbasis]
            # theta: [bs, minibatch_size, npara]
            bs = x.shape[0]
            nbasis = x.shape[-1]
            nchannel = x.shape[-2]
            npara = theta.shape[-1]
            theta = theta.view(bs*minibatch_size, npara)
            x = x.view(bs*minibatch_size, nchannel, nbasis)
        loss = -model.log_prob(theta, x).mean()
        print('train loss ', loss)
        loss.backward()
        optimizer.step()

        loss_list.append(loss.detach())

    mean_loss = torch.stack(loss_list).mean().item() # mean(list of mean losses of each batch)
    std_loss = torch.stack(loss_list).std().item()
    return mean_loss, std_loss

def myeval_GlasNSFWarpper(model, dataloader, detector_names=None, ipca_gen=None, device='cpu',downsample_rate=1):
    model.eval()
    loss_list = []
    with torch.no_grad():
        for theta, x in dataloader:
            theta = theta.to(device)
            x = x.to(device)
            loss = -model.log_prob(theta, x).mean()
            print('valid loss ', loss)
            loss_list.append(loss.detach())

    mean_loss = torch.stack(loss_list).mean().item()
    std_loss = torch.stack(loss_list).std().item()
    return mean_loss, std_loss

In [74]:
logger.info(f'Training started, device:{device}. ')

max_epoch = 5
#for epoch in range(start_epoch, max_epoch):    
for epoch in range(0, 2):   
    train_loss, train_loss_std = mytrain_GlasNSFWarpper(model, optimizer, train_loader, device=device, minibatch_size=minibatch_size_train)
    valid_loss, valid_loss_std = myeval_GlasNSFWarpper(model, valid_loader, device=device)
    
    #
    model.train()
    with torch.no_grad():
        for theta, x in train_loader:

            theta = theta.to(device)
            x = x.to(device)
            bs = x.shape[0]
            nbasis = x.shape[-1]
            nchannel = x.shape[-2]
            npara = theta.shape[-1]
            theta = theta.view(bs*minibatch_size_train, npara)
            x = x.view(bs*minibatch_size_train, nchannel, nbasis)

            loss = -model.log_prob(theta, x).mean()
            print('recal train loss ', loss)

            break
        
    
    with torch.no_grad():
        for theta, x in valid_loader:
            
            theta = theta[0:2].to(device)
            x = x[0:2].to(device)
            loss = -model.log_prob(theta, x).mean()
            print('recal valid loss ', loss)
            
            break
            
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)

    logger.info(f'epoch {epoch}, train loss = {train_loss}±{train_loss_std}, valid loss = {valid_loss}±{valid_loss_std}')

    #if valid_loss==min(valid_losses):
    if 0:
        best_epoch = epoch
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_losses': train_losses,
            'valid_losses': valid_losses,
            }, ckpt_path)

        logger.info(f'Current best epoch: {best_epoch}. Checkpoint saved.')

    if epoch%epoches_save_loss == 0 and epoch!=0:
        save_loss_data(train_losses, valid_losses, ckpt_dir)

    if epoch-best_epoch>=epoches_adjust_lr and epoch-lr_updated_epoch>=epoches_adjust_lr_again:
        adjust_lr(optimizer, gamma)
        logger.info(f'Validation loss has not dropped for {epoch-best_epoch} epoches. Learning rate is decreased by a factor of {gamma}.')
        lr_updated_epoch = epoch

    #dataset_train.shuffle_indexinfile()
    #dataset_train.shuffle_wflist()
    train_loader = DataLoader(dataset_train, batch_size=batch_size_train // minibatch_size_train, shuffle=False)

train loss  tensor(-1.8704, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-3.5586, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-5.5375, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-7.9637, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-10.2041, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-12.0127, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-12.9838, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-13.4175, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-13.8539, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-14.2418, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-14.7444, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-15.0322, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-15.1203, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-15.2579, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-15.7

train loss  tensor(-48.5215, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-48.0518, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-48.5588, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-48.4366, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-48.7357, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-48.3036, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-48.8190, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-49.3209, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-48.9690, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-49.5478, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-49.1854, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-49.4106, device='cuda:1', grad_fn=<NegBackward0>)
train loss  tensor(-49.4097, device='cuda:1', grad_fn=<NegBackward0>)
valid loss  tensor(10425445., device='cuda:1')
valid loss  tensor(10425443., device='cuda:

In [75]:
model.train()
with torch.no_grad():
    for theta, x in valid_loader:

        theta = theta[0:2].to(device)
        x = x[0:2].to(device)
        loss = -model.log_prob(theta, x).mean()
        print('recal valid loss ', loss)

        break

recal valid loss  tensor(-69.6999, device='cuda:1')


In [76]:

model.eval()

    
with torch.no_grad():
    for theta, x in valid_loader:

        theta = theta[0:2].to(device)
        x = x[0:2].to(device)
        loss = -model.log_prob(theta, x).mean()
        print('recal valid loss ', loss)

        break

recal valid loss  tensor(7133270., device='cuda:1')


In [23]:
theta.shape

torch.Size([1, 17])

In [16]:
valid_losses

[3298545631232.0, 16470815.0, 2748165.0, 3105646.0, 3529984.5]

In [17]:
def mysample_GlasNSFWarpper(model, dataset, detector_names=None, ipca_gen=None, device='cpu', Nsample=5000, max_event=1e3,
                           batch_size=1):
    model.eval()
    loss_list = []
    sample_list = []
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    with torch.no_grad():
        i = 0
        for theta, x in dataloader:
            theta = theta.to(device)
            x = x.to(device)
            
            if type(dataset) == DatasetSVDStrainFDFromSVDWFonGPUBatch:
                theta = theta.view(dataset.minibatch_size*batch_size, theta.shape[-1])
                x = x.view(dataset.minibatch_size*batch_size, x.shape[-2], x.shape[-1])
            
            #print(theta)
            lenx = x.shape[-1]
            lentheta = theta.shape[-1]
            loss = -model.log_prob(theta, x=x).mean()
            #samples = model.sample(Nsample, x=x)


            loss_list.append(loss.detach().cpu())
            #sample_list.append(samples.cpu().numpy())
            i+=1
            if i>=max_event:
                break
    #samples = np.array(sample_list)
    #samples = torch.from_numpy(samples)

    #return samples.movedim(1,2), loss_list
    return loss_list

In [26]:
dataset_test = DatasetSVDStrainFDFromSVDWFonGPUBatch(train_filenames[0:1], PARAMETER_NAMES_CONTEXT_PRECESSINGBNS_BILBY, data_generator,
                                     Nbasis=512, Vhfile=Vhfile, fix_extrinsic=True, shuffle=False, add_noise=False)
#sample_list, loss_list =  mysample_GlasNSFWarpper(model, dataset_test, device=device, Nsample=5000, max_event=10e3, batch_size = 16384)
loss_list =  mysample_GlasNSFWarpper(model, dataset_test, device=device, Nsample=5000,
                                     max_event=10, batch_size = 1)



In [24]:
dataset_test[0]

(tensor([[ 1.3860e+00,  5.0744e-01,  1.1953e-03,  4.1631e-02,  2.1004e+00,
           1.1277e+00,  5.6505e+00,  6.1753e+00,  1.9041e+03, -5.9970e+02,
           9.6461e-01,  1.0000e+02,  1.0000e+00,  1.0000e+00,  1.0000e+00,
           3.0525e+00,  0.0000e+00]], device='cuda:0'),
 tensor([[[ 0.0371,  0.0506, -0.1011,  ..., -0.1928, -0.1159, -0.0336],
          [ 0.0051, -0.0815, -0.0314,  ..., -0.3367,  0.0696, -0.1080],
          [-0.0649,  0.0327,  0.1192,  ..., -0.1099,  0.0522, -0.0448],
          [ 0.0399, -0.0800, -0.0875,  ...,  0.3880, -0.1349,  0.0475],
          [-0.0760,  0.0218,  0.1387,  ...,  0.1063, -0.1521, -0.0273],
          [-0.0412,  0.1007,  0.0600,  ...,  0.0143, -0.0500,  0.0155]]],
        device='cuda:0'))

In [21]:
dataset_train[0]

(tensor([[1.3860e+00, 5.0744e-01, 1.1953e-03,  ..., 1.0000e+00, 3.0525e+00,
          0.0000e+00],
         [1.2603e+00, 7.4590e-01, 4.4495e-02,  ..., 1.0000e+00, 6.2830e+00,
          0.0000e+00],
         [2.0932e+00, 8.6074e-01, 6.6544e-02,  ..., 1.0000e+00, 1.0084e+00,
          0.0000e+00],
         ...,
         [2.5070e+00, 9.6915e-01, 9.2409e-02,  ..., 1.0000e+00, 6.0114e+00,
          0.0000e+00],
         [1.3420e+00, 5.9567e-01, 6.0994e-02,  ..., 1.0000e+00, 3.4336e+00,
          0.0000e+00],
         [1.4278e+00, 5.1544e-01, 7.7941e-02,  ..., 1.0000e+00, 6.1844e-01,
          0.0000e+00]], device='cuda:1'),
 tensor([[[ 3.7062e-02,  5.0639e-02, -1.0107e-01,  ..., -1.9278e-01,
           -1.1593e-01, -3.3627e-02],
          [ 5.0555e-03, -8.1532e-02, -3.1366e-02,  ..., -3.3670e-01,
            6.9623e-02, -1.0800e-01],
          [-6.4943e-02,  3.2658e-02,  1.1915e-01,  ..., -1.0990e-01,
            5.2219e-02, -4.4834e-02],
          [ 3.9880e-02, -7.9975e-02, -8.7511e-02,  .

In [27]:
loss_list

[tensor(3530054.7500),
 tensor(3529944.2500),
 tensor(3530091.7500),
 tensor(3530026.7500),
 tensor(3530048.5000),
 tensor(3530029.7500),
 tensor(3529994.7500),
 tensor(3530016.5000),
 tensor(3529949.),
 tensor(3529895.2500)]