In [2]:
!pip3 install pyro-ppl 

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pyro-ppl
  Downloading pyro_ppl-1.8.1-py3-none-any.whl (718 kB)
[K     |████████████████████████████████| 718 kB 18.2 MB/s 
Collecting pyro-api>=0.1.1
  Downloading pyro_api-0.1.2-py3-none-any.whl (11 kB)
Installing collected packages: pyro-api, pyro-ppl
Successfully installed pyro-api-0.1.2 pyro-ppl-1.8.1


In [3]:
!pip install extinction

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting extinction
  Downloading extinction-0.4.6-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (418 kB)
[K     |████████████████████████████████| 418 kB 27.5 MB/s 
Installing collected packages: extinction
Successfully installed extinction-0.4.6


In [4]:
!pip install corner

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting corner
  Downloading corner-2.2.1-py3-none-any.whl (15 kB)
Installing collected packages: corner
Successfully installed corner-2.2.1


In [5]:
import logging
import os

import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import matplotlib.colors as colors

import corner
import time as measure_time
import os

import pyro

import pyro.distributions as dist
import pyro.distributions.constraints as constraints
from pyro.infer import SVI, Trace_ELBO,Predictive
from pyro.optim import Adam

from scipy.interpolate import interp1d
import h5py
import extinction

from astropy.cosmology import FlatLambdaCDM

from spline_hsiao_fns import *

##Get constants from files

In [6]:
W0 = np.loadtxt("W0.txt")
W1 = np.loadtxt("W1.txt")
time_knots = np.loadtxt("tau_knots.txt")
wavelength_knots = np.loadtxt("l_knots.txt")
L_sigma_epsilon = np.loadtxt("L_Sigma_epsilon.txt")
M0, sigma0, rv, tauA = np.loadtxt("M0_sigma0_RV_tauA.txt")

In [7]:
g_wavelengths, _, g_norm_throughput = np.loadtxt("g_PS1.txt", skiprows = 2, unpack = True)
r_wavelengths, _, r_norm_throughput = np.loadtxt("r_PS1.txt", skiprows = 2, unpack = True)
i_wavelengths, _, i_norm_throughput = np.loadtxt("i_PS1.txt", skiprows = 2, unpack = True)
z_wavelengths, _, z_norm_throughput = np.loadtxt("z_PS1.txt", skiprows = 2, unpack = True)

Read in metadata for Foundation dataset

In [8]:
meta = pd.read_csv("meta.txt", sep='\s+', index_col = False, names = ['SNID','PEAKMJD','MWEBV','REDSHIFT_HELIO','REDSHIFT_CMB','REDSHIFT_CMB_ERR'], skiprows = 2)

In [9]:
tmax_dict = {}
mwebv_dict = {}
z_helio_dict = {}
z_cmb_dict = {}
for index, row in meta.iterrows():
  name = row['SNID']
  tmax = row['PEAKMJD']
  mwebv = row['MWEBV']
  z_helio = row['REDSHIFT_HELIO']
  z_cmb = row['REDSHIFT_CMB']
  tmax_dict[name] = tmax
  mwebv_dict[name] = mwebv
  z_helio_dict[name] = z_helio
  z_cmb_dict[name] = z_cmb

In [10]:
def get_fluxes_from_file(filename, tmax, z_helio):
  df = pd.read_csv(filename, sep=" ", header = 0, index_col = False, names = ['MJD', "FLT", 'FLUXCAL', 'FLUXCALERR', 'MAG', 'MAGERR'])
  df = df.dropna()

  # Remove time points outside of [-10., 40.]
  df['adj_time'] = (df.MJD - tmax) / (1 + z_helio)
  mask = (df['adj_time'] < 40.) & (df['adj_time'] > -10.0)
  df = df[mask]

  # Separate based on bands
  g_time = df.adj_time[df.FLT=='g'].values
  r_time = df.adj_time[df.FLT=='r'].values
  i_time = df.adj_time[df.FLT=='i'].values
  z_time = df.adj_time[df.FLT=='z'].values

  g_flux = df.FLUXCAL[df.FLT=='g'].values
  r_flux = df.FLUXCAL[df.FLT=='r'].values
  i_flux = df.FLUXCAL[df.FLT=='i'].values
  z_flux = df.FLUXCAL[df.FLT=='z'].values

  g_fluxerr = df.FLUXCALERR[df.FLT=='g'].values
  r_fluxerr = df.FLUXCALERR[df.FLT=='r'].values
  i_fluxerr = df.FLUXCALERR[df.FLT=='i'].values
  z_fluxerr = df.FLUXCALERR[df.FLT=='z'].values

  times_dict = {'g':g_time, 'r':r_time, 'i':i_time, 'z':z_time}

  observed_fluxes = [(torch.as_tensor(g_flux), torch.as_tensor(r_flux), torch.as_tensor(i_flux), torch.as_tensor(z_flux))]

  flux_errors = [torch.as_tensor(g_fluxerr), torch.as_tensor(r_fluxerr), torch.as_tensor(i_fluxerr), torch.as_tensor(z_fluxerr)]

  return observed_fluxes, flux_errors, times_dict

In [11]:
ZPT = 27.5
M0 = -19.5
gamma = np.log(10) / 2.5

hsiao_phase, hsiao_wave, hsiao_flux = read_model_grid()

bands = ['g', 'r', 'i', 'z']
wavelengths_dict = {'g':g_wavelengths, 'r':r_wavelengths, 'i':i_wavelengths, 'z':z_wavelengths}
norm_throughput_dict = {'g':g_norm_throughput, 'r':r_norm_throughput, 'i':i_norm_throughput, 'z':z_norm_throughput}
eps_cov = torch.as_tensor(np.matmul(L_sigma_epsilon, L_sigma_epsilon.T), dtype = torch.float)

Things we can pre-calculate / helper functions
---



In [12]:
def get_lambda_int_for_band(band_wavelengths, z):
  source_wavelengths = band_wavelengths / (1 + z)
  return np.linspace(np.min(source_wavelengths), np.max(source_wavelengths), 150)

In [13]:
def calculate_band_dependent_stuff(times_dict, z_helio):
  band_Jl = {}
  band_h = {}
  band_Jt = {}
  band_xis_matrix = {}
  band_S0 = {}  

  for band in bands:
    # Calculate Jt
    times_to_interpolate = times_dict[band]
    Jt = spline_coeffs_irr(times_to_interpolate, time_knots, invKD_irr(time_knots))
    band_Jt[band] = Jt

    # Calculate wavelengths to interpolate
    band_wavelengths = wavelengths_dict[band]
    wavelengths_to_interpolate = get_lambda_int_for_band(band_wavelengths, z = z_helio)

    # Calculate Jl (J matrix for wavelengths)
    Jl = spline_coeffs_irr(wavelengths_to_interpolate, wavelength_knots, invKD_irr(wavelength_knots))
    band_Jl[band] = Jl

    # Calculate xis matrix
    xis = extinction.fitzpatrick99(wavelengths_to_interpolate, 1, rv)
    xis_matrix = np.tile(xis, (len(times_to_interpolate),1)).T
    band_xis_matrix[band] = xis_matrix

    # Interpolate throughput wavelengths
    band_norm_throughput = norm_throughput_dict[band]
    throughput_interpolator = interp1d(band_wavelengths, band_norm_throughput)
    b = throughput_interpolator([band_wavelengths[0]] + list(wavelengths_to_interpolate[1:-1] * (1 + z_helio)) + [band_wavelengths[-1]])

    # Calculate S0 matrix
    S0 = np.zeros((len(wavelengths_to_interpolate), len(times_to_interpolate)))
    for i, wavelength in enumerate(wavelengths_to_interpolate):
      for j, time_point in enumerate(times_to_interpolate):
        S0[i][j] = interpolate_hsiao(time_point, wavelength, hsiao_phase, hsiao_wave, hsiao_flux)
    band_S0[band] = S0

    # Calculate h
    xis_obs = extinction.fitzpatrick99(wavelengths_to_interpolate*(1 + z_helio), Av_obs, 3.1)
    dLambda = wavelengths_to_interpolate[1] - wavelengths_to_interpolate[0]
    h = (1 + z_helio) * dLambda * b * wavelengths_to_interpolate * np.exp(-gamma * xis_obs)
    band_h[band] = h

  return band_Jt, band_Jl, band_xis_matrix, band_S0, band_h

## Define model

In [14]:
def model_vi_with_params(obs, z_cmb, band_Jl, band_Jt, band_xis_matrix, band_S0, band_h, flux_errors):
  # epsilon_interior = pyro.sample("eps_int", dist.MultivariateNormal(torch.zeros(len(eps_cov)), covariance_matrix = eps_cov))

  nu = pyro.sample("nu", dist.MultivariateNormal(torch.zeros(len(eps_cov)), covariance_matrix = torch.eye(len(eps_cov))))
  # nu = torch.zeros(54)
  epsilon_interior = torch.matmul(torch.as_tensor(L_sigma_epsilon, dtype = torch.float), nu)
  # print(epsilon_interior)

  # theta = pyro.sample("theta", dist.Uniform(low=-1.49, high=2.77))
  theta = pyro.sample("theta", dist.Normal(0., 1.0))
  # theta = torch.as_tensor(0.03430605)

  # theta = torch.as_tensor(-0.9835187)
  cosmo = FlatLambdaCDM(H0 = 73.24, Om0 = 0.28)
  # mu_s = pyro.sample("mu_s", dist.Normal(cosmo.distmod(z_helio).value, 10.)) ## fix this later
  mu_s = pyro.sample("mu_s", dist.Normal(cosmo.distmod(z_cmb).value, 10.)) ## fix this later
  # mu_s = torch.as_tensor(37.248927127193724)

  # dMs = pyro.sample("Ms", dist.Normal(torch.tensor(0.), torch.tensor(sigma0)))
  dMs = torch.tensor(0.)
  Av = pyro.sample("Av", dist.Exponential(1 / 0.252))
  # Av = torch.tensor(7.2387767)
  # Av = pyro.sample("Av", dist.Normal(0.27, 1.))

  epsilon = torch.zeros(W0.shape)   ## populate epsilon matrix
  epsilon[1:-1] = torch.transpose(torch.reshape(epsilon_interior, (6,9)), 0, 1)

  W = torch.as_tensor(W0) + theta*torch.as_tensor(W1) + epsilon

  generated_fluxes = []

  for band in bands:
    Jl = band_Jl[band]
    S0 = band_S0[band]
    xis_matrix = band_xis_matrix[band]
    h = band_h[band]
    Jt = band_Jt[band]

    JlWJt = torch.matmul(torch.as_tensor(Jl), torch.matmul(W, torch.as_tensor(Jt.T)))

    Stilde = torch.as_tensor(S0) * torch.exp(-gamma * (JlWJt + Av * torch.as_tensor(xis_matrix)))

    f = 10**(0.4 * (torch.as_tensor(ZPT) - mu_s - torch.as_tensor(M0) - dMs)) * torch.matmul(torch.as_tensor(h), torch.as_tensor(Stilde))
    generated_fluxes.append(f)

  # generated_fluxes = torch.stack(generated_fluxes)
  # print(generated_fluxes)
  
  if obs is not None:
    # for i in range(len(bands)):
    for i in pyro.plate("bands", len(bands)):
      # print(i)
      with pyro.plate("observations" + str(bands[i])):
        # print(len(generated_fluxes[i]))
        # for j in range(len(generated_fluxes[i])):
          pyro.sample("flux" + str(bands[i]), dist.Normal(generated_fluxes[i], flux_errors[0][i]), obs = obs[0][i])
          # print(obs[0][i][j])
        # print(generated_fluxes[i])

In [15]:
def train_with_params(model, guide, obs, z_cmb, band_Jl, band_Jt, band_xis_matrix, band_S0, band_h, flux_errors, lr=0.01, n_steps=5000, verbose = True):
  pyro.get_param_store().clear()
  adam_params = {"lr": lr, "betas": (0.95, 0.999)}
  # adam = pyro.optim.Adam({"lr": 0.05}, {"clip_norm": 10.0})
  adam = pyro.optim.ClippedAdam(adam_params)
  svi = SVI(model, guide, adam, loss=Trace_ELBO())
  losses = []

  for step in range(n_steps):
      loss = svi.step(obs, z_cmb, band_Jl, band_Jt, band_xis_matrix, band_S0, band_h, flux_errors)
      losses.append(loss)
      # if verbose and step % 500 == 0:
          # print('[iter {}]  loss: {:.4f}'.format(step, loss))
          ## if it hasn't decreased in 3 steps, end fitting
          # if step > 2000:
          #  if losses_every_500[-1] >= losses_every_500[-2] and losses_every_500[-2] >= losses_every_500[-3]:
          #   break
  return losses

## Iterate over all supernovae

In [None]:
vi_loc = {}
vi_scale_tril = {}
laplace_vi_loc = {}
laplace_vi_scale_tril = {}

# Iterate through all supernovas in dataset
start_time = measure_time.time()

for sn in meta.SNID.values:
  print(sn)

  tmax = tmax_dict[sn]
  z_helio = z_helio_dict[sn]
  EBV_MW = mwebv_dict[sn]
  z_cmb = z_cmb_dict[sn]

  Av_obs = EBV_MW * 3.1
  observed_fluxes, flux_errors, times_dict = get_fluxes_from_file(sn + ".dat", tmax, z_helio)

  # Calculate band-dependent quantities (Jl, Jt, xis, S0, etc.)
  band_Jt, band_Jl, band_xis_matrix, band_S0, band_h = calculate_band_dependent_stuff(times_dict, z_helio)

  try:
    autoguide_vi = pyro.infer.autoguide.AutoMultivariateNormal(model_vi_with_params, 
                                                              init_loc_fn = pyro.infer.autoguide.initialization.init_to_sample(), 
                                                              init_scale = 1.)
    
    # Fit VI Model
    losses = train_with_params(model_vi_with_params, autoguide_vi, observed_fluxes, z_cmb, 
                  band_Jl, band_Jt, band_xis_matrix, band_S0, band_h, flux_errors, lr=0.01, n_steps=5000)
    vi_loc[sn] = autoguide_vi.get_posterior().loc.detach().numpy()
    vi_scale_tril[sn] = autoguide_vi.get_posterior().scale_tril.detach().numpy()

    np.save("vi_loc.npy", vi_loc)
    np.save("vi_scale_tril.npy", vi_scale_tril)


    # Fit Laplace Approximation
    autoguide_laplace = pyro.infer.autoguide.AutoLaplaceApproximation(model_vi_with_params)
    start_time = measure_time.time()
    losses = train_with_params(model_vi_with_params, autoguide_laplace, observed_fluxes, z_cmb, 
                  band_Jl, band_Jt, band_xis_matrix, band_S0, band_h, flux_errors, n_steps=3000)

    # Fit VI based on Laplace Approximation
    new_laplace_approx_guide = autoguide_laplace.laplace_approximation(observed_fluxes, z_cmb, 
                  band_Jl, band_Jt, band_xis_matrix, band_S0, band_h, flux_errors,)
    losses = train_with_params(model_vi_with_params, new_laplace_approx_guide, observed_fluxes, z_cmb, 
                  band_Jl, band_Jt, band_xis_matrix, band_S0, band_h, flux_errors, n_steps=5000)
    laplace_vi_loc[sn] = new_laplace_approx_guide.get_posterior().loc.detach().numpy()
    laplace_vi_scale_tril[sn] = new_laplace_approx_guide.get_posterior().scale_tril.detach().numpy()


    np.save("laplace_vi_loc.npy", laplace_vi_loc)
    np.save("laplace_vi_scale_tril.npy", laplace_vi_scale_tril)

  except:
    print("Fitting did not work for", sn)

end_time = measure_time.time()
print("Time:", end_time - start_time, "seconds")



2016W
2016afk
ASASSN-15bc
ASASSN-15fa
ASASSN-15fs
ASASSN-15go
ASASSN-15hg
ASASSN-15il
ASASSN-15jl
ASASSN-15jt
ASASSN-15kx
ASASSN-15la
ASASSN-15lg
ASASSN-15lu
ASASSN-15mf
ASASSN-15mg
ASASSN-15mi
ASASSN-15np
ASASSN-15nq
ASASSN-15nr
ASASSN-15od
ASASSN-15oh
Fitting did not work for ASASSN-15oh
ASASSN-15pm
ASASSN-15pn
ASASSN-15pr
ASASSN-15sf
ASASSN-15ss
ASASSN-15tg
ASASSN-15ti
ASASSN-15tz
ASASSN-15uu
ASASSN-15uv
Fitting did not work for ASASSN-15uv
ASASSN-15uw
ASASSN-16ad
ASASSN-16aj
ASASSN-16av
ASASSN-16ay
ASASSN-16bc
ASASSN-16bq
ASASSN-16br
ASASSN-16ch
ASASSN-16cs
ASASSN-16ct
ASASSN-16db
ASASSN-16dw
ASASSN-16em
ASASSN-16fo
ASASSN-16fs
ASASSN-16hc
ASASSN-16hr
ASASSN-16hz
ASASSN-16ip
ASASSN-16lg
ASASSN-16oz
ASASSN-17aj
ASASSN-17at
ASASSN-17bs
ASASSN-17co
ASASSN-17eb
AT2016aj
AT2016bln
AT2016cor
AT2016cvv
Fitting did not work for AT2016cvv
AT2016cvw
AT2016cyt
Fitting did not work for AT2016cyt
AT2016eoa
AT2016ews
AT2016fbk
AT2016gmg
AT2016gsu
AT2016hns
AT2016htm
AT2016htn
AT2017cfb
AT2017cfc

In [21]:
np.load("laplace_vi_scale_tril(4).npy", allow_pickle = True).item()

{'2016W': array([[ 0.787692  ,  0.        ,  0.        , ...,  0.        ,
          0.        ,  0.        ],
        [-0.28118193,  0.4807721 ,  0.        , ...,  0.        ,
          0.        ,  0.        ],
        [-0.14346032, -0.24913411,  0.79556507, ...,  0.        ,
          0.        ,  0.        ],
        ...,
        [-0.0318319 , -0.04952134, -0.03050281, ...,  0.08766367,
          0.        ,  0.        ],
        [ 0.02528861,  0.04705482,  0.01349545, ..., -0.00850895,
          0.01302131,  0.        ],
        [-0.0828293 , -0.16410697, -0.04642937, ..., -0.01866046,
         -0.04784194,  0.01679309]], dtype=float32),
 '2016afk': array([[ 0.9026067 ,  0.        ,  0.        , ...,  0.        ,
          0.        ,  0.        ],
        [-0.10011043,  0.85752505,  0.        , ...,  0.        ,
          0.        ,  0.        ],
        [-0.08394552, -0.08056036,  0.9173637 , ...,  0.        ,
          0.        ,  0.        ],
        ...,
        [-0.0026837

In [None]:
list(vi_loc.values())[0]

In [None]:
np.savetxt("vi_distributions.txt", [list(vi_loc.values()), list(vi_scale_tril.values())])

In [None]:
np.savetxt("vi_distributions.txt", [list(vi_loc.keys()), list(vi_loc.values()), list(vi_scale_tril.values())])

Now let's actually make the model
---

In [None]:
def get_flux_from_params(epsilon_interior, theta, mu_s, dMs, Av, W0 = W0, W1 = W1, xis_matrix = xis_matrix, S0 = S0, h = h, Jl = Jl, Jt = Jt):

  epsilon = np.zeros_like(W0)   ## populate epsilon matrix
  epsilon[1:-1] = epsilon_interior.numpy().reshape((9,6), order = 'F')

  W = W0 + theta.numpy()*W1 + epsilon

  JlWJt = np.matmul(Jl, np.matmul(W, Jt.T))

  Stilde = S0 * np.exp(-gamma * (JlWJt + Av.numpy() * xis_matrix))

  f = 10**(0.4 * (ZPT - mu_s.numpy() - M0 - dMs.numpy())) * np.matmul(h,Stilde)

  return f

In [None]:
tauA

In [None]:
pyro.sample("Av", dist.Exponential(tauA))

In [None]:
# for sn in meta.SNID.values[:1]:
for sn in ['AT2016aj']:
  tmax = tmax_dict[sn]
  z_helio = z_helio_dict[sn]
  EBV_MW = mwebv_dict[sn]
  z_cmb = z_cmb_dict[sn]

  Av_obs = EBV_MW * 3.1
  observed_fluxes, flux_errors, times_dict = get_fluxes_from_file(sn + ".dat", tmax, z_helio)

  # Calculate band-dependent quantities (Jl, Jt, xis, S0, etc.)
  # band_Jt, band_Jl, band_xis_matrix, band_S0, band_h = calculate_band_dependent_stuff(times_dict, z_helio)


In [None]:
EBV_MW

In [None]:
# def model_vi():
#   epsilon_interior = pyro.sample("eps_int", dist.MultivariateNormal(torch.zeros(len(eps_cov)), covariance_matrix = eps_cov))
#   theta = pyro.sample("theta", dist.Normal(torch.tensor(0.), torch.tensor(1.)))
#   mu_s = pyro.sample("mu_s", dist.Normal(34.5, 1.)) ## fix this later


#   # dMs = pyro.sample("Ms", dist.Normal(torch.tensor(0.), torch.tensor(sigma0)))
#   dMs = torch.tensor(0.)
#   # Av = pyro.sample("Av", dist.Exponential(tauA))
#   Av = pyro.sample("Av", dist.Normal(0.27, 1.))

#   epsilon = torch.zeros(W0.shape)   ## populate epsilon matrix
#   epsilon[1:-1] = torch.transpose(torch.reshape(epsilon_interior, (6,9)), 0, 1)

#   W = torch.as_tensor(W0) + theta*torch.as_tensor(W1) + epsilon

#   JlWJt = torch.matmul(torch.as_tensor(Jl), torch.matmul(W, torch.as_tensor(Jt.T)))

#   Stilde = torch.as_tensor(S0) * torch.exp(-gamma * (JlWJt + Av * torch.as_tensor(xis_matrix)))

#   f = 10**(0.4 * ZPT - mu_s - M0 - dMs) * torch.matmul(torch.as_tensor(h), torch.as_tensor(Stilde))

#   with pyro.plate("observations", len(i_flux.values)):
#     pyro.sample("flux", dist.Normal(f, 200. * torch.ones(6)).independent(1), obs = torch.as_tensor(i_flux.values))

In [None]:
Av_obs

In [None]:
# flux_errors = [(torch.as_tensor(g_fluxerr.values), torch.as_tensor(r_fluxerr.values), torch.as_tensor(i_fluxerr.values), torch.as_tensor(z_fluxerr.values))]

In [None]:
wavelengths_dict = {'g':g_wavelengths, 'r':r_wavelengths, 'i':i_wavelengths, 'z':z_wavelengths}
norm_throughput_dict = {'g':g_norm_throughput, 'r':r_norm_throughput, 'i':i_norm_throughput, 'z':z_norm_throughput}
# times_dict = {'g':g_time.values, 'r':r_time.values, 'i':i_time.values, 'z':z_time.values}

bands = ['g', 'r', 'i', 'z']

band_Jl = {}
band_xis_matrix = {}
band_S0 = {}
band_h = {}
band_Jt = {}

for band in bands:
  band_wavelengths = wavelengths_dict[band]
  band_norm_throughput = norm_throughput_dict[band]

  times_to_interpolate = times_dict[band]

  Jt = spline_coeffs_irr(times_to_interpolate, time_knots, invKD_irr(time_knots))
  band_Jt[band] = Jt

  wavelengths_to_interpolate = get_lambda_int_for_band(band_wavelengths, z = z_helio)
  
  Jl = spline_coeffs_irr(wavelengths_to_interpolate, wavelength_knots, invKD_irr(wavelength_knots))
  band_Jl[band] = Jl
  xis = extinction.fitzpatrick99(wavelengths_to_interpolate, 1, rv)
  xis_matrix = np.tile(xis, (len(times_to_interpolate),1)).T
  band_xis_matrix[band] = xis_matrix

  S0 = np.zeros((len(wavelengths_to_interpolate), len(times_to_interpolate)))
  for i, wavelength in enumerate(wavelengths_to_interpolate):
    for j, time_point in enumerate(times_to_interpolate):
      S0[i][j] = interpolate_hsiao(time_point, wavelength, hsiao_phase, hsiao_wave, hsiao_flux)

  band_S0[band] = S0

  throughput_interpolator = interp1d(band_wavelengths, band_norm_throughput)
  b = throughput_interpolator([band_wavelengths[0]] + list(wavelengths_to_interpolate[1:-1] * (1 + z_helio)) + [band_wavelengths[-1]])

  xis_obs = extinction.fitzpatrick99(wavelengths_to_interpolate*(1 + z_helio), Av_obs, 3.1)
  dLambda = wavelengths_to_interpolate[1] - wavelengths_to_interpolate[0]
  h = (1 + z_helio) * dLambda * b * wavelengths_to_interpolate * np.exp(-gamma * xis_obs)

  band_h[band] = h

In [None]:
def model_vi(obs):
  # epsilon_interior = pyro.sample("eps_int", dist.MultivariateNormal(torch.zeros(len(eps_cov)), covariance_matrix = eps_cov))

  nu = pyro.sample("nu", dist.MultivariateNormal(torch.zeros(len(eps_cov)), covariance_matrix = torch.eye(len(eps_cov))))
  # nu = torch.zeros(54)
  epsilon_interior = torch.matmul(torch.as_tensor(L_sigma_epsilon, dtype = torch.float), nu)
  # print(epsilon_interior)

  # theta = pyro.sample("theta", dist.Uniform(low=-1.49, high=2.77))
  theta = pyro.sample("theta", dist.Normal(0., 1.0))
  # theta = torch.as_tensor(0.03430605)

  # theta = torch.as_tensor(-0.9835187)
  cosmo = FlatLambdaCDM(H0 = 73.24, Om0 = 0.28)
  # mu_s = pyro.sample("mu_s", dist.Normal(cosmo.distmod(z_helio).value, 10.)) ## fix this later
  mu_s = pyro.sample("mu_s", dist.Normal(cosmo.distmod(z_cmb).value, 10.)) ## fix this later
  # mu_s = torch.as_tensor(37.248927127193724)

  # dMs = pyro.sample("Ms", dist.Normal(torch.tensor(0.), torch.tensor(sigma0)))
  dMs = torch.tensor(0.)
  Av = pyro.sample("Av", dist.Exponential(1 / 0.252))
  # Av = torch.tensor(7.2387767)
  # Av = pyro.sample("Av", dist.Normal(0.27, 1.))

  epsilon = torch.zeros(W0.shape)   ## populate epsilon matrix
  epsilon[1:-1] = torch.transpose(torch.reshape(epsilon_interior, (6,9)), 0, 1)

  W = torch.as_tensor(W0) + theta*torch.as_tensor(W1) + epsilon

  generated_fluxes = []

  for band in bands:
    Jl = band_Jl[band]
    S0 = band_S0[band]
    xis_matrix = band_xis_matrix[band]
    h = band_h[band]
    Jt = band_Jt[band]

    JlWJt = torch.matmul(torch.as_tensor(Jl), torch.matmul(W, torch.as_tensor(Jt.T)))

    Stilde = torch.as_tensor(S0) * torch.exp(-gamma * (JlWJt + Av * torch.as_tensor(xis_matrix)))

    f = 10**(0.4 * (torch.as_tensor(ZPT) - mu_s - torch.as_tensor(M0) - dMs)) * torch.matmul(torch.as_tensor(h), torch.as_tensor(Stilde))
    generated_fluxes.append(f)

  # generated_fluxes = torch.stack(generated_fluxes)
  # print(generated_fluxes)
  
  if obs is not None:
    # for i in range(len(bands)):
    for i in pyro.plate("bands", len(bands)):
      # print(i)
      with pyro.plate("observations" + str(bands[i])):
        # print(len(generated_fluxes[i]))
        # for j in range(len(generated_fluxes[i])):
          pyro.sample("flux" + str(bands[i]), dist.Normal(generated_fluxes[i], flux_errors[0][i]), obs = obs[0][i])
          # print(obs[0][i][j])
        # print(generated_fluxes[i])

In [None]:
# observed_fluxes = [(torch.as_tensor(g_flux.values), torch.as_tensor(r_flux.values), torch.as_tensor(i_flux.values), torch.as_tensor(z_flux.values))]

In [None]:
observed_fluxes

In [None]:
flux_errors

In [None]:
model_vi(obs = observed_fluxes)

In [None]:
EBV_MW

In [None]:
pyro.render_model(model_vi, model_args = (observed_fluxes,), render_distributions=True, render_params = True)

In [None]:
autoguide_vi = pyro.infer.autoguide.AutoMultivariateNormal(model_vi, init_loc_fn = pyro.infer.autoguide.initialization.init_to_sample(), init_scale = 1.)

In [None]:
pyro.get_param_store().clear()
adam = pyro.optim.Adam({"lr": 0.001, "betas": (0.90, 0.999)}, {"clip_norm": 10.0})
# adam = pyro.optim.SGD({"lr": 0.1})


svi = SVI(model_vi, autoguide_vi, adam, loss=Trace_ELBO().differentiable_loss)
# svi = SVI(model_vi, custom_guide, adam, loss=Trace_ELBO().differentiable_loss)

for i in range(10):
  svi.step(observed_fluxes)

In [None]:
def train(model, guide, lr=0.01, n_steps=5000, verbose = True):
    pyro.get_param_store().clear()
    adam_params = {"lr": lr, "betas": (0.95, 0.999)}
    # adam = pyro.optim.Adam({"lr": 0.05}, {"clip_norm": 10.0})
    adam = pyro.optim.ClippedAdam(adam_params)
    svi = SVI(model, guide, adam, loss=Trace_ELBO())
    losses = []
    losses_every_500 = []

    for step in range(n_steps):
        loss = svi.step(observed_fluxes)
        losses.append(loss)
        if verbose and step % 500 == 0:
            print('[iter {}]  loss: {:.4f}'.format(step, loss))
            losses_every_500.append(loss)
            ## if it hasn't decreased in 3 steps, end fitting
            # if step > 2000:
            #  if losses_every_500[-1] >= losses_every_500[-2] and losses_every_500[-2] >= losses_every_500[-3]:
            #   break
    return losses

In [None]:
start_time = measure_time.time()
losses = train(model_vi, autoguide_vi, lr = 0.01, n_steps = 20000)
end_time = measure_time.time()
print("Time:", end_time - start_time, "seconds")

In [None]:
autoguide_vi.get_posterior().scale_tril

In [None]:
plt.plot(losses)

In [None]:
autoguide_vi.median()['theta']

In [None]:
autoguide_vi.median()['Av']

In [None]:
autoguide_vi.median()['mu_s']

In [None]:
autoguide_vi.median()['nu']

In [None]:
autoguide_laplace = pyro.infer.autoguide.AutoLaplaceApproximation(model_vi)
start_time = measure_time.time()
losses = train(model_vi, autoguide_laplace, n_steps = 3000)
end_time = measure_time.time()
print("Time:", end_time - start_time, "seconds")

In [None]:
new_laplace_approx_guide = autoguide_laplace.laplace_approximation(observed_fluxes)
losses = train(model_vi, new_laplace_approx_guide, n_steps = 10000)

In [None]:
posterior_samples = Predictive(model_vi, guide = autoguide_vi, num_samples = 100)(None,)
posterior_samples['Av']

In [None]:
def calculate_all_fluxes_from_params(nu, theta, mu_s, dMs, Av):
  times_to_interpolate = np.linspace(-10, 40)
  Jt = spline_coeffs_irr(times_to_interpolate, time_knots, invKD_irr(time_knots))

  epsilon_interior = np.matmul(L_sigma_epsilon, nu.numpy().T)

  epsilon = np.zeros_like(W0)   ## populate epsilon matrix
  epsilon[1:-1] = epsilon_interior.reshape((9,6), order = 'F')

  W = W0 + theta.numpy()*W1 + epsilon

  generated_fluxes_from_params = []

  for band in bands:
    band_wavelengths = wavelengths_dict[band]
    band_norm_throughput = norm_throughput_dict[band]

    Jl = band_Jl[band]
    # xis_matrix = band_xis_matrix[band]
    # h = band_h[band]

    wavelengths_to_interpolate = get_lambda_int_for_band(band_wavelengths)
    
    # Jl = spline_coeffs_irr(wavelengths_to_interpolate, wavelength_knots, invKD_irr(wavelength_knots))
    xis = extinction.fitzpatrick99(wavelengths_to_interpolate, 1, rv)
    xis_matrix = np.tile(xis, (len(times_to_interpolate),1)).T

    S0 = np.zeros((len(wavelengths_to_interpolate), len(times_to_interpolate)))
    for i, wavelength in enumerate(wavelengths_to_interpolate):
      for j, time in enumerate(times_to_interpolate):
        S0[i][j] = interpolate_hsiao(time, wavelength, hsiao_phase, hsiao_wave, hsiao_flux)

    throughput_interpolator = interp1d(band_wavelengths, band_norm_throughput)
    b = throughput_interpolator([band_wavelengths[0]] + list(wavelengths_to_interpolate[1:-1] * (1 + z)) + [band_wavelengths[-1]])

    xis_obs = extinction.fitzpatrick99(wavelengths_to_interpolate*(1 + z), Av_obs, 3.1)
    dLambda = wavelengths_to_interpolate[1] - wavelengths_to_interpolate[0]
    h = (1 + z) * dLambda * b * wavelengths_to_interpolate * np.exp(-gamma * xis_obs)

    band_flux = get_flux_from_params(torch.as_tensor(epsilon_interior), torch.as_tensor(theta), 
                                torch.as_tensor(mu_s), torch.as_tensor(dMs), torch.as_tensor(Av),
                                xis_matrix = xis_matrix, S0 = S0, h = h, Jt = Jt, Jl = Jl)
    
    generated_fluxes_from_params.append(band_flux)
  
  return np.array(generated_fluxes_from_params)

In [None]:
posterior_smooth_curves = []
for nu, theta, mu_s, Av in zip(posterior_samples['nu'][:,0,0], posterior_samples['theta'][:,0,0], posterior_samples['mu_s'][:,0,0], posterior_samples['Av'][:,0,0]):
  # print(nu, theta, mu_s, Av)
  posterior_smooth_curves.append(calculate_all_fluxes_from_params(nu, theta, mu_s, torch.as_tensor(0.), Av))

In [None]:
posterior_smooth_curves = np.array(posterior_smooth_curves)
median_posterior_fit = np.median(posterior_smooth_curves, axis = 0)
upper_posterior_fit = np.percentile(posterior_smooth_curves, 50 + 34.1, axis = 0)
lower_posterior_fit = np.percentile(posterior_smooth_curves, 50 - 34.1, axis = 0)

In [None]:
colors  = ['g', 'r', 'c', 'k']
# offset = np.array([0, 1000,2000,3000]) * 30
offset = np.zeros(4)
for i, band_flux in enumerate(median_posterior_fit):
    plt.plot(np.linspace(-10, 40), band_flux + offset[i], color = colors[i])
    plt.fill_between(np.linspace(-10, 40), upper_posterior_fit[i]+ offset[i], lower_posterior_fit[i]+ offset[i], color = colors[i], alpha = 0.2)
for i, band_flux in enumerate(observed_fluxes.numpy()):
  plt.plot(times_to_interpolate, band_flux + offset[i], 'o',color = colors[i], label = bands[i])
plt.legend()
plt.title("Posterior samples from VI")

In [None]:
num_corner_samples = 1000
vi_corner_samples = Predictive(model_vi, guide = autoguide_vi, num_samples = num_corner_samples)(None,)

In [None]:
vi_trained_laplace_samples = Predictive(model_vi, guide = new_laplace_approx_guide, num_samples = num_corner_samples)(None,)

In [None]:
laplace_samples = Predictive(model_vi, guide = autoguide_laplace.laplace_approximation(observed_fluxes), num_samples = num_corner_samples)(None,)

In [None]:
other_fit = np.load("AT2016aj_chains_210610_135216.npy", allow_pickle = True).item()

In [None]:
other_fit.keys()

In [None]:
len(other_fit['mu'])

In [None]:
figure = corner.corner(np.vstack((vi_corner_samples['mu_s'].numpy(),vi_corner_samples['theta'].numpy(), 
                                  vi_corner_samples['Av'].numpy())).T, 
                       labels = [r"$\mu_s$", r"$\theta$", r"$A_v$"])
# corner.corner(np.vstack((vi_trained_laplace_samples['mu_s'][:,0,0].numpy(),vi_trained_laplace_samples['theta'][:,0,0].numpy(), 
#                                   vi_trained_laplace_samples['Av'][:,0,0].numpy())).T, 
#                        labels = [r"$\mu_s$", r"$\theta$", r"$A_v$"], fig = figure, color = 'c')
corner.corner(np.vstack((laplace_samples['mu_s'].numpy(),laplace_samples['theta'].numpy(), 
                                  laplace_samples['Av'].numpy())).T, 
                       labels = [r"$\mu_s$", r"$\theta$", r"$A_v$"], fig = figure, color = 'c')
corner.corner(np.vstack((other_fit['mu'] + other_fit['delM'],other_fit['theta'], other_fit['AV'])).T, 
                       labels = [r"$\mu_s$", r"$\theta$", r"$A_v$"], fig = figure, color = 'r')
colors = ['k','c', 'r']

labels = ['VI', 'VI init from Laplace approx', 'MCMC']

plt.legend(
    handles=[
        mlines.Line2D([], [], color=colors[i], label=labels[i])
        for i in range(len(labels))
    ],
    fontsize=16, frameon=False,
    bbox_to_anchor=(1, 3), loc="upper right"
)
figure.suptitle("Parameter distributions for ASASSN-16CS dataset", fontsize = 20)
plt.show()

In [None]:
observed_fluxes

In [None]:
flux_errors

In [None]:
def model_mcmc(obs):
  # epsilon_interior = pyro.sample("eps_int", dist.MultivariateNormal(torch.zeros(len(eps_cov)), covariance_matrix = eps_cov))

  nu = pyro.sample("nu", dist.MultivariateNormal(torch.zeros(len(eps_cov)), covariance_matrix = torch.eye(len(eps_cov))))
  # nu = torch.zeros(54)
  epsilon_interior = torch.matmul(torch.as_tensor(L_sigma_epsilon, dtype = torch.double), nu.double())
  # print(epsilon_interior)

  # theta = pyro.sample("theta", dist.Uniform(low=-1.49, high=2.77))
  theta = pyro.sample("theta", dist.Normal(0., 1.0))
  # theta = torch.as_tensor(0.03430605)

  # theta = torch.as_tensor(-0.9835187)
  cosmo = FlatLambdaCDM(H0 = 73.24, Om0 = 0.28)
  mu_s = pyro.sample("mu_s", dist.Normal(cosmo.distmod(z_helio).value, 10.)) ## fix this later
  # mu_s = torch.as_tensor(37.248927127193724)

  # dMs = pyro.sample("Ms", dist.Normal(torch.tensor(0.), torch.tensor(sigma0)))
  dMs = torch.tensor(0.)
  Av = pyro.sample("Av", dist.Exponential(1 / 0.252))
  # Av = torch.tensor(7.2387767)
  # Av = pyro.sample("Av", dist.Normal(0.27, 1.))

  epsilon = torch.zeros(W0.shape)   ## populate epsilon matrix
  epsilon[1:-1] = torch.transpose(torch.reshape(epsilon_interior, (6,9)), 0, 1)

  W = torch.as_tensor(W0) + theta*torch.as_tensor(W1) + epsilon

  generated_fluxes = []

  for band in bands:
    Jl = band_Jl[band]
    S0 = band_S0[band]
    xis_matrix = band_xis_matrix[band]
    h = band_h[band]
    Jt = band_Jt[band]

    JlWJt = torch.matmul(torch.as_tensor(Jl), torch.matmul(W, torch.as_tensor(Jt.T)))

    Stilde = torch.as_tensor(S0) * torch.exp(-gamma * (JlWJt + Av * torch.as_tensor(xis_matrix)))

    f = 10**(0.4 * (torch.as_tensor(ZPT) - mu_s - torch.as_tensor(M0) - dMs)) * torch.matmul(torch.as_tensor(h), torch.as_tensor(Stilde))
    generated_fluxes.append(f)

  # generated_fluxes = torch.stack(generated_fluxes)
  # print(generated_fluxes)

  if obs is not None:
    # for i in range(len(bands)):
    for i in pyro.plate("bands", len(bands)):
      # print(i)
      with pyro.plate("observations"):
        # print(len(generated_fluxes[i]))
        # for j in range(len(generated_fluxes[i])):
            pyro.sample("flux" + str(bands[i]), dist.Normal(generated_fluxes[i], flux_errors[0][i]), obs = obs[0][i])

In [None]:
num_chains = 1
nuts_kernel = pyro.infer.NUTS(model_mcmc, adapt_step_size=True)

mcmc = pyro.infer.MCMC(nuts_kernel, num_samples=600, warmup_steps=250, num_chains = num_chains)

start_time = measure_time.time()
mcmc.run(observed_fluxes)
end_time = measure_time.time()
print("Total:", end_time - start_time, "seconds")

In [None]:
mcmc_corner_samples = mcmc.get_samples(num_samples = num_corner_samples)

In [None]:
figure = corner.corner(np.vstack((vi_corner_samples['mu_s'].numpy(),vi_corner_samples['theta'].numpy(), 
                                  vi_corner_samples['Av'].numpy())).T, 
                       labels = [r"$\mu_s$", r"$\theta$", r"$A_v$"], color = 'k')
corner.corner(np.vstack((vi_trained_laplace_samples['mu_s'].numpy(),vi_trained_laplace_samples['theta'].numpy(), 
                                  vi_trained_laplace_samples['Av'].numpy())).T, 
                       labels = [r"$\mu_s$", r"$\theta$", r"$A_v$"], fig = figure, color = 'c')
corner.corner(np.vstack((mcmc_corner_samples['mu_s'],mcmc_corner_samples['theta'],mcmc_corner_samples['Av'])).T, 
                       labels = [r"$\mu_s$", r"$\theta$", r"$A_v$"], color ='r',fig = figure,)
# corner.corner(np.vstack((laplace_samples['mu_s'][:,0,0].numpy(),laplace_samples['theta'][:,0,0].numpy(), 
#                                   laplace_samples['Av'][:,0,0].numpy())).T, 
#                        labels = [r"$\mu_s$", r"$\theta$", r"$A_v$"], fig = figure, color = 'b')
# corner.corner(np.vstack((other_fit['mu'] + other_fit['delM'],other_fit['theta'], other_fit['AV'])).T, 
#                        labels = [r"$\mu_s$", r"$\theta$", r"$A_v$"], fig = figure, color = 'r')
# colors = ['k','c', 'r']

# labels = ['VI', 'VI init from Laplace approx', 'MCMC']
colors = [ 'k', 'c', 'r']
labels = ['VI', 'VI init from Laplace approx','MCMC']

plt.legend(
    handles=[
        mlines.Line2D([], [], color=colors[i], label=labels[i])
        for i in range(len(labels))
    ],
    fontsize=16, frameon=False,
    bbox_to_anchor=(1, 3), loc="upper right"
)

plt.text(0, 270, "  Fitting $\\mu_s, \\theta$, $A_v$, \n     and $\\epsilon$", fontsize = 20)
figure.suptitle("Parameter distributions for AT2016AJ dataset", fontsize = 20)
plt.show()

In [None]:
mcmc.diagnostics()

In [None]:
many_samples = mcmc.get_samples(num_samples = 10000)
np.savetxt("mcmc_at2016aj_samples.txt",[many_samples['mu_s'].numpy(),many_samples['theta'].numpy(),many_samples['Av'].numpy()])