In [28]:
# Import libraries

import os
os.chdir('/home/scur2012/Thesis/master-thesis/experiments/tmnre')

import torch
torch.set_float32_matmul_precision('high')
from torch import nn
from torch.functional import F
import swyft.lightning as sl
from sklearn.metrics import roc_curve, auc

import matplotlib.pyplot as plt
from pytorch_lightning import loggers as pl_loggers
import logging
logging.getLogger("pytorch_lightning.utilities.rank_zero").setLevel(logging.WARNING)
logging.getLogger("pytorch_lightning.accelerators.cuda").setLevel(logging.WARNING)

import importlib
import gw_parameters
importlib.reload(gw_parameters)

<module 'gw_parameters' from '/gpfs/home3/scur2012/Thesis/master-thesis/experiments/tmnre/gw_parameters.py'>

In [29]:
# 1D Unet implementation below
class DoubleConv(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size=3,
        mid_channels=None,
        padding=1,
        bias=False,
    ):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv1d(
                in_channels,
                mid_channels,
                kernel_size=kernel_size,
                padding=padding,
                bias=bias,
            ),
            nn.BatchNorm1d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(
                mid_channels,
                out_channels,
                kernel_size=kernel_size,
                padding=padding,
                bias=bias,
            ),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    def __init__(self, in_channels, out_channels, down_sampling=2):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool1d(down_sampling), DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2):
        super().__init__()
        self.up = nn.ConvTranspose1d(
            in_channels, in_channels // 2, kernel_size=kernel_size, stride=stride
        )
        # self.up = nn.Sequential(
        #     nn.Upsample(scale_factor=scale_factor, mode='linear', align_corners=True),
        #     nn.ConvTranspose1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
        #     nn.BatchNorm1d(out_channels),
        #     nn.ReLU(inplace=True)
        # )
        self.att = AttentionGate(out_channels, out_channels // 2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        diff_signal_length = x2.size()[2] - x1.size()[2]
        x1 = F.pad(
            x1, [diff_signal_length // 2, diff_signal_length - diff_signal_length // 2]
        )
        
        s = self.att(x1, x2)
        x = torch.cat([s, x1], dim=1)
        return self.conv(x)

class AttentionGate(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        bias=False):
        super().__init__()
        
        self.Wg = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
            nn.BatchNorm1d(out_channels)
        )

        self.Wx = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
            nn.BatchNorm1d(out_channels)
        )

        self.psi = nn.Sequential(
            nn.Conv1d(out_channels, 1, kernel_size=1, stride=1, padding=0, bias=bias),
            nn.BatchNorm1d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, g, x):
        
        # g = gate
        # x = skip-connection

        Wg = self.Wg(g)
        Wx = self.Wx(x)
        out = self.relu(Wg + Wx)
        out = self.psi(out)
        out = out * x
        return out
    
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1):
        super(OutConv, self).__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size)

    def forward(self, x):
        return self.conv(x)


class Unet(nn.Module):
    def __init__(
        self,
        n_in_channels,
        n_out_channels,
        sizes=(16, 32, 64, 128, 256),
        down_sampling=(2, 2, 2, 2),
    ):
        super(Unet, self).__init__()
        self.inc = DoubleConv(n_in_channels, sizes[0])
        self.down1 = Down(sizes[0], sizes[1], down_sampling[0])
        self.down2 = Down(sizes[1], sizes[2], down_sampling[1])
        self.down3 = Down(sizes[2], sizes[3], down_sampling[2])
        self.down4 = Down(sizes[3], sizes[4], down_sampling[3])
        self.up1 = Up(sizes[4], sizes[3], down_sampling[3])
        self.up2 = Up(sizes[3], sizes[2], down_sampling[2])
        self.up3 = Up(sizes[2], sizes[1], down_sampling[1])
        self.up4 = Up(sizes[1], sizes[0], down_sampling[0])
        self.outc = OutConv(sizes[0], n_out_channels)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        f = self.outc(x)
        return f


class LinearCompression(nn.Module):
    def __init__(self):
        super(LinearCompression, self).__init__()
        self.sequential = nn.Sequential(
            nn.LazyLinear(1024),
            nn.ReLU(),
            nn.LazyLinear(256),
            nn.ReLU(),
            nn.LazyLinear(64),
            nn.ReLU(),
            nn.LazyLinear(16),
        )

    def forward(self, x):
        return self.sequential(x)


In [30]:
# The network architecture

class InferenceNetwork(sl.SwyftModule):
    def __init__(self, **conf):
        super().__init__()
        self.one_d_only = conf["one_d_only"]
        self.batch_size = conf["training_batch_size"]
        self.noise_shuffling = conf["shuffling"]
        self.num_params = len(conf["priors"]["int_priors"].keys()) + len(
            conf["priors"]["ext_priors"].keys()
        )
        self.marginals = conf["marginals"]
        
        self.netw_t = Unet(
            n_in_channels=len(conf["ifo_list"]),
            n_out_channels=1,
            sizes=(16, 32, 64, 128, 256),
            down_sampling=(8, 8, 8, 8),
        )
        
        self.netw_f = Unet(
            n_in_channels=2 * len(conf["ifo_list"]),
            n_out_channels=1,
            sizes=(16, 32, 64, 128, 256),
            down_sampling=(2, 2, 2, 2),
        )

        self.flatten = nn.Flatten(1)
        self.linear_t = LinearCompression()
        self.linear_f = LinearCompression()

        self.logratios_1d = sl.LogRatioEstimator_1dim(
            num_features=32, num_params=int(self.num_params), varnames="z_total"
        )
            
        self.optimizer_init = sl.AdamOptimizerInit(lr=conf["learning_rate"])
        
    def forward(self, A, B):        
                   
        if self.noise_shuffling and A["d_t"].size(0) != 1:
            noise_shuffling = torch.randperm(self.batch_size)
            d_t = A["d_t"] + A["n_t"][noise_shuffling]
            d_f_w = A["d_f_w"] + A["n_f_w"][noise_shuffling]
        else:
            d_t = A["d_t"] + A["n_t"]
            d_f_w = A["d_f_w"] + A["n_f_w"]
       
        z_total = B["z_total"]

        d_t = self.netw_t(d_t)
        d_f_w = self.netw_f(d_f_w[:,:,:-1])
        flatten_t = self.flatten(d_t)
        flatten_f = self.flatten(d_f_w)
        features_t = self.linear_t(flatten_t)
        features_f = self.linear_f(flatten_f)
        
        features = torch.cat([features_t, features_f], dim=1)
        logratios_1d = self.logratios_1d(features, z_total)
        
        return logratios_1d
 

In [31]:
# Settings for trainer and network

conf = gw_parameters.default_conf
bounds = gw_parameters.limits

intrinsic_variables = gw_parameters.intrinsic_variables
extrinsic_variables = gw_parameters.extrinsic_variables

trainer_settings = dict(
    min_epochs = 1,
    max_epochs = 10,
    early_stopping = 7,
    num_workers = 8,
    training_batch_size = 256,
    validation_batch_size = 256,
    train_split = 0.9,
    val_split = 0.1
)

network_settings = dict(
    # Peregrine
    shuffling = True,
    include_noise = False,
    priors = dict(
        int_priors = conf['priors']['int_priors'],
        ext_priors = conf['priors']['ext_priors'],
    ),
    marginals = ((0, 1),),
    one_d_only = True,
    ifo_list = conf["waveform_params"]["ifo_list"],
    learning_rate = 5e-4,
    training_batch_size = trainer_settings['training_batch_size'],
)


In [32]:
# Load simulation data
simulation_store_path = '/scratch-shared/scur2012/training_data/default_limits_2e6/training_data'
zarr_store = sl.ZarrStore(f"{simulation_store_path}")


In [33]:
# Initialise dataloaders and setup trainer

train_data = zarr_store.get_dataloader(
    num_workers=trainer_settings['num_workers'],
    batch_size=trainer_settings['training_batch_size'],
    idx_range=[0, 30000],
    on_after_load_sample=False
)

val_data = zarr_store.get_dataloader(
    num_workers=trainer_settings['num_workers'],
    batch_size=trainer_settings['validation_batch_size'],
    idx_range=[30000, 35000],
    on_after_load_sample=None
)

In [34]:
# Initialise network

network = InferenceNetwork(**network_settings)



In [35]:
name_of_run = 'unet_attention_layer_1'

# Make directory for logger
logger_tbl = pl_loggers.TensorBoardLogger(
    save_dir=f"/home/scur2012/Thesis/master-thesis/experiments/unet_mods",
    name=f"tb_logs",
    version=name_of_run,
    default_hp_metric=False,
)

swyft_trainer = sl.SwyftTrainer(
    accelerator='gpu',
    devices=1,
    min_epochs=trainer_settings["min_epochs"],
    max_epochs=trainer_settings["max_epochs"],
    logger=logger_tbl,
    enable_progress_bar = True,
    val_check_interval = 50,
    max_time = '00:00:10:00'
)


  rank_zero_warn(


In [36]:
# Start training
swyft_trainer.fit(network, train_data, val_data)


  | Name         | Type                   | Params
--------------------------------------------------------
0 | netw_t       | Unet                   | 744 K 
1 | netw_f       | Unet                   | 744 K 
2 | flatten      | Flatten                | 0     
3 | linear_t     | LinearCompression      | 0     
4 | linear_f     | LinearCompression      | 0     
5 | logratios_1d | LogRatioEstimator_1dim | 290 K 
--------------------------------------------------------
1.8 M     Trainable params
0         Non-trainable params
1.8 M     Total params
7.121     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]