In [1]:
import numpy as np
import math
import torch
import einops
import torch.nn as nn
import torch.nn.functional as F
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np
import numba as nb
import pandas as pd
import matplotlib.pyplot as plt
import bottleneck as bn
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

from src.mamba import Mamba, MambaConfig
from sbi import utils as utils
from sbi import analysis as analysis
from sbi.inference.base import infer
from sbi.inference import SNPE, prepare_for_sbi, simulate_for_sbi
from sbi import analysis, utils
from sbi.inference import SNPE, simulate_for_sbi
from sbi.utils.user_input_checks import (
    check_sbi_inputs,
    process_prior,
    process_simulator,
)
# import required modules
from sbi.utils.get_nn_models import posterior_nn
from sbi.neural_nets.embedding_nets import (
    FCEmbedding,
    CNNEmbedding,
    PermutationInvariantEmbedding
)
seed = 0 
torch.manual_seed(seed) 
from src.temporal_encoders import ResidualTemporalBlock, Residual, PreNorm, LinearAttention, Downsample1d, Conv1dBlock

In [2]:
@nb.jit(nopython=True, fastmath=True)
def force(x, k=3):
    Fx = -k * x
    return Fx

@nb.jit(nopython=True, fastmath=True)
def langevin_integrator(
    x0=-1.5,
    vx0=0.0,
    mx=1,
    nux=1,
    k=3,
    N=1000,
    dt=5e-5,
    fs=1
):
    
    beta = 1 # by default we set the temperature to 1 to simplify the calculation
    invmass = 1.0 / mx
    sigma = np.sqrt(2.0 * nux / (beta * mx))
    b1 = 1.0 - 0.5 * dt * nux + 0.125 * (dt ** 2) * (nux ** 2)
    b2 = 0.5 * dt - 0.125 * nux * dt ** 2
    s3 = np.sqrt(3.0)
    sdt3 = sigma * np.sqrt(dt ** 3)
    sdt = sigma * np.sqrt(dt)

    n = int(N / fs)
    x = np.zeros(n)
    xi = np.random.standard_normal(size=n)
    eta = np.random.standard_normal(size=n)
    x[0] = x0
    vx = vx0
    xold = x0
    Fx = force(x0, k=k) * invmass

    for i in range(1, N):

        # Drawing noise 
        if i%n == 0:
            xi = np.random.standard_normal(size=n)
            eta = np.random.standard_normal(size=n)

        _xi = xi[i%n]
        _eta = eta[i%n]

        n1 = 0.5 * sdt * _xi
        n3 = 0.5 * sdt3 * _eta / s3
        n4 = sdt3 * (0.125 * _xi + 0.25 * _eta / s3)
        n5 = n1 - nux * n4

        vx = vx * b1 + Fx * b2 + n5
        xnew = xold + dt * vx + n3
        
        Fx = force(xnew, k=k) * invmass
        vx = vx * b1 + Fx * b2 + n5
        
        if (i % fs) == 0:
            x[int(i / fs)] = xnew

        xold = xnew

    return x    

In [4]:
def Langevin_simulator_SSM(params):
    params = np.array(params.cpu(), dtype=np.float64)
    x = langevin_integrator(
    x0=0,
    vx0=0.0,
    mx= 10**params[0],
    nux=10**params[1],
    k=3,
    N=1000,
    dt=5e-3,
    fs=1
)
    return torch.tensor(x, dtype=torch.float32).unsqueeze(-1)

def Langevin_simulator_Transformer(params):
    params = np.array(params.cpu(), dtype=np.float64)
    x = langevin_integrator(
    x0=0,
    vx0=0.0,
    mx= 10**params[0],
    nux=10**params[1],
    k=3,
    N=1000,
    dt=5e-3,
    fs=1
)
    return torch.tensor(x, dtype=torch.float32)

def Langevin_simulator_Opt_CNN(params):
    params = np.array(params.cpu(), dtype=np.float64)
    x = langevin_integrator(
    x0=0,
    vx0=0.0,
    mx= 10**params[0],
    nux=10**params[1],
    k=3,
    N=1000,
    dt=5e-3,
    fs=1
)
    return torch.tensor(x, dtype=torch.float32).unsqueeze(1)

In [5]:
#Transformer
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class TemporalTransformer(nn.Module):
    def __init__(self, transition_dim, dim=32, kernel_sizes=(4, 4), stride=2, num_heads=4, depth=4, mlp_dim=32):
        super().__init__()

        self.conv_layers = nn.ModuleList([
            nn.Conv1d(in_channels=transition_dim if i == 0 else dim, out_channels=dim, 
                      kernel_size=ks, stride=stride, padding=ks//2)
            for i, ks in enumerate(kernel_sizes)
        ])

        self.pos_embedding = PositionalEncoding(dim, dropout=0.1, max_len=5000)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        
        transformer_layer = nn.TransformerEncoderLayer(
            d_model=dim, 
            nhead=num_heads, 
            dim_feedforward=mlp_dim,
            dropout=0.1,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(transformer_layer, num_layers=depth)

        self.to_out = nn.Linear(dim, 1)

    def forward(self, x):
        '''
            x : [ batch x horizon x transition ]
        '''

        x= x.unsqueeze(2)
        x = einops.rearrange(x, 'b h t -> b t h')

        for conv in self.conv_layers:
            x = F.relu(conv(x))

        x = einops.rearrange(x, 'b t h -> b h t')
        x = self.pos_embedding(x)

        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) 
        x = torch.cat((cls_tokens, x), dim=1)  # Shape: [batch, horizon + 1, dim]
        
        x = self.transformer(x)  

        x = x[:, 0]
        x = self.to_out(x)
        return x
    
#Optional Attention
class TemporalCNN(nn.Module):
    def __init__(
        self,
        transition_dim,
        dim=32,
        dim_mults=(1, 2, 4),
        attention=True,
        padding_mode='reflect',
        kernel_size=3,
    ):
        super().__init__()

        dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))
        print(f'[ models/temporal ] Channel dimensions: {dims}')

        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(nn.ModuleList([
                ResidualTemporalBlock(dim_in, dim_out, kernel_size = kernel_size, padding_mode=padding_mode),
                ResidualTemporalBlock(dim_out, dim_out, kernel_size = kernel_size, padding_mode=padding_mode),
                Residual(PreNorm(dim_out, LinearAttention(dim_out))) if attention else nn.Identity(),
                Downsample1d(dim_out) if not is_last else nn.Identity()
            ]))

        mid_dim = dims[-1]
        self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, kernel_size = kernel_size, padding_mode=padding_mode)
        self.mid_attn = Residual(PreNorm(mid_dim, LinearAttention(mid_dim))) if attention else nn.Identity()
        self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, kernel_size = kernel_size, padding_mode=padding_mode)

        self.proj_out = nn.Linear(mid_dim, 1)

    def forward(self, x):
        '''
            x : [ batch x horizon x transition ]
        '''

        #reshape to [batch x transition x horizon]

        x = einops.rearrange(x, 'b h t -> b t h')


        for resnet, resnet2, attn, downsample in self.downs:
            x = resnet(x)
            x = resnet2(x)
            x = attn(x)
            x = downsample(x)

        x = self.mid_block1(x)
        x = self.mid_attn(x)
        x = self.mid_block2(x)
        x = x.mean(dim=-1)
        x = self.proj_out(x)
        return x

#SSM
class TemporalMamba(nn.Module):
    def __init__(
        self,
        transition_dim,
        dim=128,
        kernel_size=3,
        expand=2,
        num_layers=1, 
    ):
        super().__init__()

        self.mamba_layers = nn.ModuleList()
        self.l_norm_layers = nn.ModuleList()

        for _ in range(num_layers):
            config = MambaConfig(n_layers=num_layers, d_model=dim, d_state=dim, d_conv=kernel_size, expand_factor=expand)
            self.mamba_layers.append(Mamba(config)
            )
        
            self.l_norm_layers.append(nn.LayerNorm(dim))
        
        self.x_emb = nn.Linear(transition_dim, dim)
        self.proj_out = nn.Linear(dim, 1)

    def forward(self, x):
        '''
            x : [ batch x horizon x transition ]
        '''
        x = self.x_emb(x)
        for mamba, l_norm in zip(self.mamba_layers, self.l_norm_layers):
            x_in = x
            x = mamba(x)
            x = l_norm(x + x_in)

        x = x[:, -1]
        x = self.proj_out(x)
        return x

In [None]:
model = TemporalMamba(transition_dim=1, dim=32, kernel_size=3, expand=2, num_layers=2) #SSM
#model = TemporalCNN(1, dim=32, dim_mults=(1, 2, 4), attention=True, padding_mode='reflect', kernel_size=1) #Opt CNN
#model = TemporalTransformer(transition_dim=1, dim=32, kernel_sizes=(4, 4), stride=2, num_heads=4, depth=4, mlp_dim=32) # Transformer 

In [6]:
nux_limits = [-1, 2]
mass_limits = [-1, 2]

In [None]:
prior = utils.BoxUniform(
    low = torch.tensor([mass_limits[0],nux_limits[0]], device='cuda'),
    high = torch.tensor([mass_limits[1],nux_limits[1]], device='cuda')
)

prior, num_parameters, prior_returns_numpy= process_prior(prior)

simulator_wrapper = process_simulator(Langevin_simulator_SSM, prior, prior_returns_numpy)

check_sbi_inputs(simulator_wrapper, prior)

In [8]:
Langevin_simulator_SSM, prior = prepare_for_sbi(Langevin_simulator_SSM, prior)

In [None]:
neural_posterior = posterior_nn(model='nsf', embedding_net = model)

inference = SNPE(prior, device='cuda', density_estimator=neural_posterior)

In [None]:
# run the inference procedure on one round and 10000 simulated data points
theta, x = simulate_for_sbi(simulator_wrapper, prior, num_simulations=50000, num_workers=1)

In [None]:
density_estimator = inference.append_simulations(theta, x, data_device='cuda').train(training_batch_size=128, show_train_summary=True)

posterior = inference.build_posterior(density_estimator)

In [None]:
with open('your_path.pkl', 'rb') as f:
    posterior = torch.save(f) #save the posterior for later use