In [None]:
import h5py
import numpy as np
import matplotlib.pyplot as plt
import pyro
import pyro.infer
import pyro.optim
from pyro.infer import SVI, Trace_ELBO
import pyro.distributions as dist
import torch.nn as nn
import torch.nn.functional as F
from pyro.nn import PyroModule, PyroSample
from torch.utils.data import Dataset, DataLoader
import torch


In [None]:
class PlanetDataset(Dataset):
    def __init__(self, hdf5_file, iteration):
        self.hdf5_file = hdf5_file
        self.iteration = iteration
        self.file = h5py.File(hdf5_file, 'r')
        self.systems = list(self.file[iteration].keys())

    def __len__(self):
        return len(self.systems)

    def __getitem__(self, idx):
        system = self.systems[idx]
        group = self.file[f'{self.iteration}/{system}']
        combined_light_curve = torch.tensor(group['combined_light_curve'][:], dtype=torch.float32)
        detected_count = torch.tensor(group['detected_count'][()], dtype=torch.float32)
        return combined_light_curve, detected_count

    def close(self):
        self.file.close()

In [None]:
class BNN(PyroModule):
    def __init__(self):
        super().__init__()
        self.fc1 = PyroModule[nn.Linear](1000, 128)  # Adjust input size as needed
        self.fc1.weight = PyroSample(dist.Normal(0., 1.).expand([128, 1000]).to_event(2))
        self.fc1.bias = PyroSample(dist.Normal(0., 1.).expand([128]).to_event(1))
        
        self.fc2 = PyroModule[nn.Linear](128, 64)
        self.fc2.weight = PyroSample(dist.Normal(0., 1.).expand([64, 128]).to_event(2))
        self.fc2.bias = PyroSample(dist.Normal(0., 1.).expand([64]).to_event(1))
        
        self.fc3 = PyroModule[nn.Linear](64, 1)
        self.fc3.weight = PyroSample(dist.Normal(0., 1.).expand([1, 64]).to_event(2))
        self.fc3.bias = PyroSample(dist.Normal(0., 1.).expand([1]).to_event(1))

    def forward(self, x, y=None):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        mean = self.fc3(x).squeeze(-1)
        sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
        return mean

In [None]:
# Prepare the dataset and dataloader
dataset = PlanetDataset('planet_systems.h5', 'iteration_0')
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Define the model and guide
bnn = BNN().cuda()  # Move the model to GPU
guide = pyro.infer.autoguide.AutoDiagonalNormal(bnn)

# Define the optimizer and SVI object
optimizer = pyro.optim.Adam({"lr": 0.01})
svi = SVI(bnn, guide, optimizer, loss=Trace_ELBO())

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    epoch_loss = 0
    for batch in dataloader:
        light_curves, counts = batch
        light_curves = light_curves.cuda()  # Move data to GPU
        counts = counts.cuda()  # Move data to GPU
        epoch_loss += svi.step(light_curves, counts)
    print(f"Epoch {epoch + 1}, Loss: {epoch_loss / len(dataloader)}")

# Close the dataset
dataset.close()

In [None]:
# Sanity check code 
with h5py.File('planet_systems.hdf5', 'r') as hdf5_file:
      # Function to recursively print the structure of the HDF5 file
      def print_structure(name, obj):
            print(name)
            if isinstance(obj, h5py.Group):
                  for key, value in obj.items():
                        print_structure(f"{name}/{key}", value)
            elif isinstance(obj, h5py.Dataset):
                  print(f"  Dataset: {name}, shape: {obj.shape}, dtype: {obj.dtype}")

      # Extract and print some sample values
      iteration = 'iteration_0'
      system = 'system_0'
      planet = 'planet_0'


      print("\nExtracting sample values for iteration_0/system_0:")
      print("Combined light curve:", hdf5_file[f'{iteration}/{system}/combined_light_curve'][:5])  # Print first 5 values
      print("Detected count:", hdf5_file[f'{iteration}/{system}/detected_count'][()])
      print("Flux with noise:", hdf5_file[f'{iteration}/{system}/flux_with_noise'][:5])  # Print first 5 values
      print("Observation noise:", hdf5_file[f'{iteration}/{system}/observation_noise'][()])
      print("Star radius:", hdf5_file[f'{iteration}/{system}/star_radius'][()])
      print("Total time:", hdf5_file[f'{iteration}/{system}/total_time'][()])
      print("u1:", hdf5_file[f'{iteration}/{system}/u1'][()])
      print("u2:", hdf5_file[f'{iteration}/{system}/u2'][()])
      print("Time:", hdf5_file[f'{iteration}/{system}/time'][:5])  # Print first 5 values
      
      print("\nExtracting sample values for iteration_0/system_0/planet_0:")
      print("a:", hdf5_file[f'{iteration}/{system}/planets/{planet}/a'][()])
      print("incl:", hdf5_file[f'{iteration}/{system}/planets/{planet}/incl'][()])
      print("period:", hdf5_file[f'{iteration}/{system}/planets/{planet}/period'][()])
      print("rp:", hdf5_file[f'{iteration}/{system}/planets/{planet}/rp'][()])
      print("transit_midpoint:", hdf5_file[f'{iteration}/{system}/planets/{planet}/transit_midpoint'][()])