In [1]:
import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
import exoplanet
from transitFinder import TransitModel, plot_light_curves





In [2]:
print(torch.version.cuda)  # Check the installed CUDA version

12.4


In [3]:
# Define the function to generate the light curve
def generate_multi_planet_light_curve(planets, star_radius=1.0, observation_noise=0.001, total_time=365, snr_threshold=5, u1=0.3, u2=0.2, cadence=0.2):
      time = np.arange(0, total_time, cadence)
      planet_light_curves = np.zeros_like(time)
      individual_light_curves = []
      
      detected_count = 0
      star_radius_squared = star_radius ** 2

      for planet in planets:
            period = planet['period']
            rp = planet['rp'] * star_radius
            a = planet['a']
            incl = planet['incl']
            t0 = planet['transit_midpoint']
            
            orbit = exoplanet.orbits.KeplerianOrbit(period=period, t0=t0, a=a, incl=incl)
            light_curve_model = exoplanet.LimbDarkLightCurve([u1, u2]).get_light_curve(
                  orbit=orbit, r=rp, t=time
            ).eval().flatten()

            planet_light_curves += light_curve_model
            
            transit_depth = (rp ** 2) / star_radius_squared
            snr = transit_depth / observation_noise

            if snr > snr_threshold:
                  detected_count += 1
                  individual_light_curves.append(light_curve_model)
      
      flux_with_noise = planet_light_curves + np.random.normal(0, observation_noise, len(time))
      
      return time, flux_with_noise, planet_light_curves, detected_count, individual_light_curves



In [4]:
# Define the function to load the model and make predictions
def load_model_and_predict(model_path, planets):
      # Generate the light curve
      time, flux_with_noise, planet_light_curves, detected_count, individual_light_curves = generate_multi_planet_light_curve(planets)
      
      # Load the trained model
      device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
      if torch.cuda.is_available():
            print("CUDA is available. Using GPU.")
      else:
            print("CUDA is not available. Using CPU.")
      
      max_flux_len = len(flux_with_noise)
      max_time_len = len(time)
      
      model = TransitModel(max_flux_len=max_flux_len, max_time_len=max_time_len).to(device)
      model.load_state_dict(torch.load(model_path, map_location=device))
      model.eval()
      
      # Preprocess the generated light curve
      flux_with_noise_tensor = torch.tensor(flux_with_noise, dtype=torch.float32).unsqueeze(0).to(device)
      time_tensor = torch.tensor(time, dtype=torch.float32).unsqueeze(0).to(device)
      
      # Make predictions
      with torch.no_grad():
            detected_count_pred, individual_light_curves_flux_pred, individual_light_curves_time_pred = model(flux_with_noise_tensor, time_tensor)
      
      # Plot the results
      plot_light_curves(detected_count_pred.cpu(), individual_light_curves_flux_pred.cpu(), individual_light_curves_time_pred.cpu())



In [5]:
# Example usage
planets = [
    {'period': 10, 'rp': 0.1, 'a': 0.1, 'incl': np.pi/2, 'transit_midpoint': 5},
    {'period': 20, 'rp': 0.2, 'a': 0.2, 'incl': np.pi/2, 'transit_midpoint': 10},
]

load_model_and_predict("transit_model.pth", planets)

CUDA is available. Using GPU.


  model.load_state_dict(torch.load(model_path, map_location=device))


RuntimeError: Error(s) in loading state_dict for TransitModel:
	Missing key(s) in state_dict: "flux_input.weight", "flux_input.bias", "time_input.weight", "time_input.bias", "concat.0.weight", "concat.0.bias", "bayesian1.weight_mean", "bayesian1.weight_std", "bayesian1.bias_mean", "bayesian1.bias_std", "bayesian2.weight_mean", "bayesian2.weight_std", "bayesian2.bias_mean", "bayesian2.bias_std", "detected_count_output.weight", "detected_count_output.bias", "individual_light_curves_flux_output.weight", "individual_light_curves_flux_output.bias", "individual_light_curves_time_output.weight", "individual_light_curves_time_output.bias". 
	Unexpected key(s) in state_dict: "model_state_dict". 