In [16]:
import torch
from torch import nn

from einops import rearrange
from einops.layers.torch import Rearrange

# helpers

def posemb_sincos_1d(patches, temperature = 10000, dtype = torch.float32):
    _, n, dim, device, dtype = *patches.shape, patches.device, patches.dtype

    n = torch.arange(n, device = device)
    assert (dim % 2) == 0, 'feature dimension must be multiple of 2 for sincos emb'
    omega = torch.arange(dim // 2, device = device) / (dim // 2 - 1)
    omega = 1. / (temperature ** omega)

    n = n.flatten()[:, None] * omega[None, :]
    pe = torch.cat((n.sin(), n.cos()), dim = 1)
    return pe.type(dtype)

# classes

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head),
                FeedForward(dim, mlp_dim)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

class SimpleViT(nn.Module):
    def __init__(self, *, seq_len, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
        super().__init__()

        assert seq_len % patch_size == 0

        num_patches = seq_len // patch_size
        patch_dim = channels * patch_size

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (n p) -> b n (p c)', p = patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

        self.to_latent = nn.Identity()
        self.linear_head = nn.Linear(dim, num_classes)

    def forward(self, series):
        *_, n, dtype = *series.shape, series.dtype

        x = self.to_patch_embedding(series)
        pe = posemb_sincos_1d(x)
        x = rearrange(x, 'b ... d -> b (...) d') + pe
        print(x.shape)
        x = self.transformer(x)
        print(x.shape)
        x = x.mean(dim = 1)
        print(x.shape)

        x = self.to_latent(x)
        print(x.shape)
        return self.linear_head(x)

In [42]:
class SimpleViT2(nn.Module):
    def __init__(self, *, seq_len, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, **kwargs):
        super().__init__()

        assert seq_len % patch_size == 0

        num_patches = seq_len // patch_size
        patch_dim = channels * patch_size

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (n p) -> b n (p c)', p = patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

        self.to_latent = nn.Identity()
        self.linear_head = nn.Linear(dim*num_patches, num_classes)

    def forward(self, series):
        *_, n, dtype = *series.shape, series.dtype
        
        x = self.to_patch_embedding(series)
        pe = posemb_sincos_1d(x)
        x = rearrange(x, 'b ... d -> b (...) d') + pe

        x = self.transformer(x)
        return x
        #x = x.mean(dim = 1)
        #x = x.flatten(1,2)
        ##x = self.to_latent(x)
        return self.linear_head(x)

In [43]:
v = SimpleViT2(
        seq_len = 3328,
        patch_size = 256,
        num_classes = 256,
        dim = 128,
        depth = 6,
        heads = 8,
        mlp_dim = 2048,
        channels = 6, 
        dim_head = 64,
    ).to('cuda')



In [73]:
x = torch.randn(10) + torch.randn(10)*1j

In [74]:
torch.cat((x.real, x.imag)).float()

tensor([-0.0688, -1.8898,  0.4153, -0.4304,  0.3556,  0.0640, -1.0887, -1.0803,
         0.4693,  1.3235, -0.4128, -0.5735,  1.2037,  0.7315, -0.8733, -0.3219,
        -0.2529,  1.4177,  2.1978, -0.0811])

In [63]:
ff.repeat((6,1)).shape

torch.Size([6, 10])

In [64]:
ff

tensor([-0.2239, -2.1005, -0.2320, -1.4094,  0.4021,  0.7713,  0.8216,  1.7990,
        -0.3338, -3.4558])

In [67]:
torch.cat((ff, ))

TypeError: cat() received an invalid combination of arguments - got (Tensor, Tensor), but expected one of:
 * (tuple of Tensors tensors, int dim, *, Tensor out)
 * (tuple of Tensors tensors, name dim, *, Tensor out)


In [71]:
torch.cat((ff,ff))

tensor([-0.2239, -2.1005, -0.2320, -1.4094,  0.4021,  0.7713,  0.8216,  1.7990,
        -0.3338, -3.4558, -0.2239, -2.1005, -0.2320, -1.4094,  0.4021,  0.7713,
         0.8216,  1.7990, -0.3338, -3.4558])

In [66]:
ff.repeat((6,1))

tensor([[-0.2239, -2.1005, -0.2320, -1.4094,  0.4021,  0.7713,  0.8216,  1.7990,
         -0.3338, -3.4558],
        [-0.2239, -2.1005, -0.2320, -1.4094,  0.4021,  0.7713,  0.8216,  1.7990,
         -0.3338, -3.4558],
        [-0.2239, -2.1005, -0.2320, -1.4094,  0.4021,  0.7713,  0.8216,  1.7990,
         -0.3338, -3.4558],
        [-0.2239, -2.1005, -0.2320, -1.4094,  0.4021,  0.7713,  0.8216,  1.7990,
         -0.3338, -3.4558],
        [-0.2239, -2.1005, -0.2320, -1.4094,  0.4021,  0.7713,  0.8216,  1.7990,
         -0.3338, -3.4558],
        [-0.2239, -2.1005, -0.2320, -1.4094,  0.4021,  0.7713,  0.8216,  1.7990,
         -0.3338, -3.4558]])

In [22]:
3328/256

13.0

In [44]:
time_series = torch.randn(1024, 6, 3328).to('cuda')


In [45]:
logits = v(time_series)

In [46]:
logits.shape

torch.Size([1024, 13, 128])

In [47]:
logit = logits[0]

In [49]:
logit.mean()

tensor(0., device='cuda:0', grad_fn=<MeanBackward0>)

In [None]:
coun

In [25]:
13*128

1664

In [11]:
logits.shape

torch.Size([1024, 256])

In [None]:
nvidia

In [41]:
!nvidia-smi

Tue Apr  9 07:05:39 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce GTX 1650        Off |   00000000:02:00.0 Off |                  N/A |
| 30%   23C    P8              4W /   75W |       2MiB /   4096MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  Tesla P40                     

In [26]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [27]:
count_parameters(v)


4968064

In [40]:
count_parameters(v)


5361280

In [12]:
1000 * 1000

1000000

# train

In [2]:

import numpy as np
import bilby 
#import pycbc 
import sys
import matplotlib.pyplot as plt
import glob 

#import zuko
from glasflow import RealNVP, CouplingNSF
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn

import river.data
from river.data.datagenerator import DataGeneratorBilbyFD
from river.data.dataset_multiband import DatasetMBStrainFDFromMBWFonGPU, DatasetMBStrainFDFromMBWFonGPUBatch
#import river.data.utils as datautils
from river.data.utils import *

from river.models import embedding
from river.models.utils import *
from river.models.embedding.conv import EmbeddingConv1D, EmbeddingConv2D
from river.models.embedding.mlp import EmbeddingMLP1D
from river.models.inference.cnf import GlasNSFConv1DRes, GlasNSFConv1D, GlasNSFTest, GlasflowEmbdding

import logging
import sys
import os
import json
from copy import deepcopy





SWIGLAL standard output/error redirection is enabled in IPython.
This may lead to performance penalties. To disable locally, use:

with lal.no_swig_redirect_standard_output_error():
    ...

To disable globally, use:

lal.swig_redirect_standard_output_error(True)

Note however that this will likely lead to error messages from
LAL functions being either misdirected or lost when called from
Jupyter notebooks.


import lal

  import lal


In [3]:
import time

config_path = '/home/qian.hu/mlpe/river/scripts/trained_models/BNS20MB_8M'
with open(f"{config_path}/config.json", 'r') as f:
    config = json.load(f)

config_datagenerator = config['data_generator_parameters']
config_training = config['training_parameters']
config_model = config['model_parameters']
config_precaldata = config['precaldata_parameters']


dmin = config_datagenerator['d_min']
dmax = config_datagenerator['d_max']
dpower = config_datagenerator['d_power']
tc_min = config_datagenerator['tc_min']
tc_max = config_datagenerator['tc_max']
timing_std = config_datagenerator['timing_std']
full_duration = config_datagenerator['full_duration']

detector_names = config_datagenerator['detector_names']


wf_folder_train = config_precaldata['train']['folder']
wf_folder_valid = config_precaldata['valid']['folder']
asd_folder = config_precaldata['asd_path']

batch_size_train = config_training['batch_size_train']
minibatch_size_train = config_training['minibatch_size_train']
batch_size_valid = config_training['batch_size_valid']
minibatch_size_valid = config_training['minibatch_size_valid']

device='cuda'

In [4]:
t1 =time.time()
minibatch_size_train = 1024
dataset_train = DatasetMBStrainFDFromMBWFonGPUBatch(wf_folder = config_precaldata['train']['folder'],
                                                    asd_folder = asd_folder,
                                                    parameter_names = PARAMETER_NAMES_CONTEXT_PRECESSINGBNS_BILBY, 
                                                    full_duration = full_duration, 
                                                    detector_names = detector_names,
                                                    dmin = dmin,
                                                    dmax = dmax,
                                                    dpower = dpower, 
                                                    tc_min = tc_min,
                                                    tc_max = tc_max,
                                                    timing_std = timing_std,
                                                    device = device,
                                                    minibatch_size = minibatch_size_train,
                                                    add_noise = True,
                                                    fix_extrinsic = False,
                                                    reparameterize = True,
                                                    random_asd = False)
t2 =time.time()
print(t2-t1)

25.129700183868408


In [5]:
minibatch_size_train

1024

In [6]:
batch_size_train = 8192
train_loader = DataLoader(dataset_train, batch_size=batch_size_train // minibatch_size_train, shuffle=False)


In [7]:
len(dataset_train)

8192

In [8]:
tt = []

t1 =time.time()
for t,x in train_loader:
    t2 =time.time()
    tt.append(t2-t1)
    print(t2-t1, '!')
    t1 =time.time()

2.1345620155334473 !
0.2530670166015625 !
1.597344160079956 !
0.23201775550842285 !
1.4167084693908691 !
0.2352135181427002 !
4.226032018661499 !
3.6697959899902344 !
1.3655133247375488 !
0.26023006439208984 !
0.2669806480407715 !
3.169447183609009 !
0.22391986846923828 !
0.2427988052368164 !
0.2726259231567383 !
huh
23.914116144180298 !
0.2500483989715576 !
0.277482271194458 !
0.26161742210388184 !
0.26476573944091797 !
0.2657797336578369 !
0.3048679828643799 !
0.26399731636047363 !
0.2325451374053955 !
0.28082895278930664 !
0.2776827812194824 !
0.28624439239501953 !
0.30045628547668457 !
0.2827482223510742 !
0.31584954261779785 !
0.28054308891296387 !
huh
22.877390146255493 !
0.25937867164611816 !
0.26985979080200195 !
0.27018117904663086 !
0.2723073959350586 !
0.28502464294433594 !
8.500895500183105 !
11.06877589225769 !
0.2531454563140869 !
0.27324461936950684 !
0.3037087917327881 !
0.34827566146850586 !
0.29572415351867676 !
0.24181008338928223 !
0.24010896682739258 !
0.2539408206

KeyboardInterrupt: 

In [47]:
del dataset_train, train_loader

In [4]:

config_path = '/home/qian.hu/mlpe/river/scripts/trained_models/BNS20MB_8M'
with open(f"{config_path}/config.json", 'r') as f:
    config = json.load(f)

config_datagenerator = config['data_generator_parameters']
config_training = config['training_parameters']
config_model = config['model_parameters']
config_precaldata = config['precaldata_parameters']


dmin = config_datagenerator['d_min']
dmax = config_datagenerator['d_max']
dpower = config_datagenerator['d_power']
tc_min = config_datagenerator['tc_min']
tc_max = config_datagenerator['tc_max']
timing_std = config_datagenerator['timing_std']
full_duration = config_datagenerator['full_duration']



# Set up logger
PID = os.getpid()
device='cuda'
logger = logging.getLogger()
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s')

stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setLevel(logging.DEBUG)
stdout_handler.setFormatter(formatter)

ckpt_dir = 'test_train_output'
if not os.path.exists(ckpt_dir):
    os.mkdir(ckpt_dir)
    logger.warning(f"{ckpt_dir} does not exist. Made dir {ckpt_dir}.")

logfilename = f"{ckpt_dir}/logs.log"
file_handler = logging.FileHandler(logfilename)
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
ckpt_path = f'{ckpt_dir}/checkpoint.pickle'

logger.info(f'PID={PID}.')
logger.info(f'Output path: {ckpt_dir}')

detector_names = config_datagenerator['detector_names']


logger.info(f'Loading precalculated data.')
wf_folder_train = config_precaldata['valid']['folder']
wf_folder_valid = config_precaldata['valid']['folder']
asd_folder = config_precaldata['asd_path']

batch_size_train = config_training['batch_size_valid']
minibatch_size_train = config_training['minibatch_size_valid']
batch_size_valid = config_training['batch_size_valid']
minibatch_size_valid = config_training['minibatch_size_valid']


dataset_train = DatasetMBStrainFDFromMBWFonGPUBatch(wf_folder = wf_folder_train,
                                                    asd_folder = asd_folder,
                                                    parameter_names = PARAMETER_NAMES_CONTEXT_PRECESSINGBNS_BILBY, 
                                                    full_duration = full_duration, 
                                                    detector_names = detector_names,
                                                    dmin = dmin,
                                                    dmax = dmax,
                                                    dpower = dpower, 
                                                    tc_min = tc_min,
                                                    tc_max = tc_max,
                                                    timing_std = timing_std,
                                                    device = device,
                                                    minibatch_size = minibatch_size_train,
                                                    add_noise = True,
                                                    fix_extrinsic = False,
                                                    reparameterize = True,
                                                    random_asd = False)

dataset_valid = DatasetMBStrainFDFromMBWFonGPUBatch(wf_folder = wf_folder_valid,
                                                    asd_folder = asd_folder,
                                                    parameter_names = PARAMETER_NAMES_CONTEXT_PRECESSINGBNS_BILBY, 
                                                    full_duration = full_duration, 
                                                    detector_names = detector_names,
                                                    dmin = dmin,
                                                    dmax = dmax,
                                                    dpower = dpower, 
                                                    tc_min = tc_min,
                                                    tc_max = tc_max,
                                                    timing_std = timing_std,
                                                    device = device,
                                                    minibatch_size = minibatch_size_train,
                                                    add_noise = True,
                                                    fix_extrinsic = False,
                                                    reparameterize = True,
                                                    random_asd = False)


Nsample = len(dataset_train)*minibatch_size_train
Nvalid = len(dataset_valid)*minibatch_size_valid
logger.info(f'Nsample: {Nsample}, Nvalid: {Nvalid}.')
logger.info(f'batch_size_train: {batch_size_train}, batch_size_valid: {batch_size_valid}')

train_loader = DataLoader(dataset_train, batch_size=batch_size_train // minibatch_size_train, shuffle=False)
valid_loader = DataLoader(dataset_valid, batch_size=batch_size_valid // minibatch_size_valid, shuffle=False)

model = GlasflowEmbdding(config).to(device)


lr = config_training['lr']
gamma = config_training['gamma']
weight_decay = config_training['weight_decay']
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

logger.info(f'Initial learning rate: {lr}')
logger.info(f'Gamma: {gamma}')

max_epoch = config_training['max_epoch']
#epoches_pretrain = config_training['epoches_pretrain']
epoches_save_loss = config_training['epoches_save_loss']
epoches_adjust_lr = config_training['epoches_adjust_lr']
epoches_adjust_lr_again = config_training['epoches_adjust_lr_again']
#load_from_previous_train = 1
load_from_previous_train = config_training['load_from_previous_train']
if load_from_previous_train:
    checkpoint = torch.load(ckpt_path)

    best_epoch = checkpoint['epoch']
    start_epoch = best_epoch + 1
    lr_updated_epoch = start_epoch
    model.load_state_dict(checkpoint['model_state_dict']) 

    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])


    train_losses = checkpoint['train_losses']
    valid_losses = checkpoint['valid_losses']


    logger.info(f'Loaded states from {ckpt_path}, epoch={start_epoch}.')
else:
    best_epoch = 0
    train_losses = []
    valid_losses = []

    start_epoch = 0
    lr_updated_epoch = start_epoch

npara_embd = count_parameters(model.embedding)
npara_flow = count_parameters(model.flow)
npara_total = count_parameters(model)
logger.info(f'Learnable parameters: embedding: {npara_embd}, flow: {npara_flow}, total: {npara_total}. ')

###
#for g in optimizer.param_groups:
#    g['lr'] = 1e-5
#    logger.info(f'Set lr to 1e-5.')

logger.info(f'Training started, device:{device}. ')


In [5]:
start_epoch

0

In [6]:
max_epoch

1000

In [17]:

for epoch in range(start_epoch, 1):    

    train_loss, train_loss_std = train_GlasNSFWarpper(model, optimizer, train_loader, device=device, minibatch_size=minibatch_size_train)
    valid_loss, valid_loss_std = eval_GlasNSFWarpper(model, valid_loader, device=device, minibatch_size=minibatch_size_valid)


    train_losses.append(train_loss)
    valid_losses.append(valid_loss)

    logger.info(f'epoch {epoch}, train loss = {train_loss}±{train_loss_std}, valid loss = {valid_loss}±{valid_loss_std}')

    if valid_loss==min(valid_losses):
        best_epoch = epoch
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_losses': train_losses,
            'valid_losses': valid_losses,
            }, ckpt_path)

        logger.info(f'Current best epoch: {best_epoch}. Checkpoint saved.')

    if epoch%epoches_save_loss == 0 and epoch!=0:
        save_loss_data(train_losses, valid_losses, ckpt_dir)

    if epoch-best_epoch>=epoches_adjust_lr and epoch-lr_updated_epoch>=epoches_adjust_lr_again:
        adjust_lr(optimizer, gamma)
        logger.info(f'Validation loss has not dropped for {epoch-best_epoch} epoches. Learning rate is decreased by a factor of {gamma}.')
        lr_updated_epoch = epoch

    #dataset_train.shuffle_indexinfile()
    dataset_train.shuffle_wflist()
    train_loader = DataLoader(dataset_train, batch_size=batch_size_train // minibatch_size_train, shuffle=False)

In [36]:
aa = []

In [38]:
aa.append(None)

In [39]:
aa

[None]

In [18]:
train_loss

23.024776458740234

In [19]:
valid_loss

7709.6376953125

In [10]:
valid_loss_std

134071248.0

In [20]:
for t, x in train_loader:
    pass

In [22]:
t.shape

torch.Size([50, 10, 17])

In [23]:
x.shape

torch.Size([50, 10, 6, 3328])

In [27]:
loss_list = []
minibatch_size = 10
for theta, x in train_loader:
    theta = theta.to(device)
    x = x.to(device)

    if minibatch_size>0:
        # x: [bs, minibatch_size, nchannel, nbasis]
        # theta: [bs, minibatch_size, npara]
        bs = x.shape[0]
        nbasis = x.shape[-1]
        nchannel = x.shape[-2]
        npara = theta.shape[-1]
        theta = theta.view(bs*minibatch_size, npara)
        x = x.view(bs*minibatch_size, nchannel, nbasis)
    loss = -model.log_prob(theta, x).mean()

    loss_list.append(loss.detach())

In [28]:
loss_list

[tensor(7733.1514, device='cuda:0'), tensor(7739.5537, device='cuda:0')]

In [34]:
x.shape

torch.Size([500, 6, 3328])

In [35]:
loss

tensor(7739.5537, device='cuda:0', grad_fn=<NegBackward0>)

In [46]:
!nvidia-smi


Thu Mar 28 22:14:38 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3080        Off |   00000000:01:00.0 Off |                  N/A |
|  0%   16C    P8             15W /  320W |       3MiB /  10240MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100-PCIE-40GB         