In [9]:
import torch
import torch.nn as nn
from torch import Tensor
from lampe.nn import ResMLP
from zuko.flows import MAF
from pathlib import Path
import os
from lampe.inference import NPE

In [10]:
class SoftClip(nn.Module):
    def __init__(self, bound: float = 1.0):
        super().__init__()

        self.bound = bound

    def forward(self, x: Tensor) -> Tensor:
        return x / (1 + abs(x / self.bound))

In [31]:
class NPEWithEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.embedding = nn.Sequential(
                SoftClip(100.0),
                ResMLP(
                    6144*2, 128,
                    hidden_features=[512] * 2 + [256] * 3 + [128] * 5,
                    activation=nn.ELU,
                ),
            )

        self.npe = NPE(
                19, 128, #self.embedding.out_features,
                # moments=((l + u) / 2, (l - u) / 2),
                transforms=3,
                build=MAF,
#                 bins=32,
                hidden_features=[512] * 5,
                activation= nn.ELU,
            )
    
    def forward(self, theta: Tensor, x: Tensor) -> Tensor:
        y = self.embedding(x)
        if torch.isnan(y).sum()>0:
             print('NaNs in embedding')
        return self.npe(theta, y)

    def flow(self, x: Tensor):  # -> Distribution
        out = self.npe.flow(self.embedding(x)) #.to(torch.double)) #
        # if np.any(np.isnan(out.detach().cpu().numpy())):
        #      print('NaNs in flow')
        return out

In [32]:
scratch = os.environ.get('SCRATCH', '')
datapath = Path(scratch) / 'highres-sbi/data_fulltheta' 
savepath = Path(scratch) / 'highres-sbi/runs/imp/MAF/'
runpath = savepath / 'ancient-destroyer-24'
epoch = 1350

In [33]:
estimator = NPEWithEmbedding().double() 
states = torch.load(runpath / ('states_' + str(epoch) + '.pth'), map_location='cpu')
estimator.load_state_dict(states['estimator'])
estimator.cuda().eval()

NPEWithEmbedding(
  (embedding): Sequential(
    (0): SoftClip()
    (1): ResMLP(
      (0): Linear(in_features=12288, out_features=512, bias=True)
      (1): Residual(MLP(
        (0): Linear(in_features=512, out_features=512, bias=True)
        (1): ELU(alpha=1.0)
        (2): Linear(in_features=512, out_features=512, bias=True)
      ))
      (2): Residual(MLP(
        (0): Linear(in_features=512, out_features=512, bias=True)
        (1): ELU(alpha=1.0)
        (2): Linear(in_features=512, out_features=512, bias=True)
      ))
      (3): Linear(in_features=512, out_features=256, bias=True)
      (4): Residual(MLP(
        (0): Linear(in_features=256, out_features=256, bias=True)
        (1): ELU(alpha=1.0)
        (2): Linear(in_features=256, out_features=256, bias=True)
      ))
      (5): Residual(MLP(
        (0): Linear(in_features=256, out_features=256, bias=True)
        (1): ELU(alpha=1.0)
        (2): Linear(in_features=256, out_features=256, bias=True)
      ))
      (6): R

In [None]:
#factor = 0.5
#lr = 1e-4
#min_lr = 1e-7
#opt = Adam
#sch = ReduceonPlateau
#patience =32
#noise = /250
#batch_size = 2048
#weight_decay = 1e-2