In [3]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import batman
from scipy.optimize import minimize
from scipy.signal import find_peaks
import emcee
import corner
import os
from tqdm import tqdm

# Load data
data_file = "tess2021.csv"
if not os.path.exists(data_file):
    raise FileNotFoundError(f"The file {data_file} does not exist. Please provide the correct path.")

data_planet = pd.read_csv(data_file)
time_data = np.array(data_planet['time'])
flux_data = np.array(data_planet['flux'])
flux_err = np.ones_like(flux_data) * 0.001  # Assuming a constant error for simplicity

# Defines initial parameters
initial_params = batman.TransitParams()
initial_params.t0 = 2430.15  # time of inferior conjunction; mid-transit time in days
initial_params.per = 2.770860  # orbital period in days (estimated)
initial_params.rp = 0.2  # planet radius (in units of stellar radii)
initial_params.a = 15  # semi-major axis (in units of stellar radii)
initial_params.inc = 87.47  # orbital inclination (in degrees)
initial_params.ecc = 0  # eccentricity
initial_params.w = 90  # longitude of periastron (in degrees)
initial_params.limb_dark = "quadratic"  # limb darkening model
initial_params.u = [0.1, 0.3]  # limb darkening coefficients

# Function to detect the first dip in the data
def find_first_dip(flux, prominence=0.01):
    peaks, _ = find_peaks(-flux, prominence=prominence)
    if peaks.any():
        return peaks[0]
    return None

# Define a function to generate the model light curve
def generate_model(params, time):
    m = batman.TransitModel(params, time)
    return m.light_curve(params)

# Define a function for a preliminary fit to estimate rp
def preliminary_fit(params, time, flux):
    model_flux = generate_model(params, time)
    return np.sum((flux - model_flux) ** 2)

# Use scipy.optimize.minimize to estimate rp
def estimate_initial_rp(time, flux, initial_params):
    def objective(rp):
        params = initial_params
        params.rp = rp
        return preliminary_fit(params, time, flux)
    
    result = minimize(objective, x0=[initial_params.rp], bounds=[(0.01, 1.0)])
    return result.x[0]

# Estimate rp
initial_params.rp = estimate_initial_rp(time_data, flux_data, initial_params)

# Define a log-likelihood function for MCMC
def log_likelihood(theta, time, flux, flux_err):
    t0, rp, per, a = theta
    params = batman.TransitParams()
    params.t0 = t0
    params.rp = rp
    params.per = per
    params.a = a
    params.inc = initial_params.inc
    params.ecc = initial_params.ecc
    params.w = initial_params.w
    params.limb_dark = initial_params.limb_dark
    params.u = initial_params.u

    model = generate_model(params, time)
    sigma2 = flux_err**2

    if np.any(np.isnan(model)):
        return -np.inf

    return -0.5 * np.sum((flux - model)**2 / sigma2 + np.log(sigma2))

# Define a log-prior function for MCMC
def log_prior(theta):
    t0, rp, per, a = theta
    if (initial_params.t0 - 5 < t0 < initial_params.t0 + 5 and 
        0.1 < rp < 1.0 and 
        initial_params.per * 0.8 < per < initial_params.per * 1.2 and 
        initial_params.a * 0.1 < a < initial_params.a * 10):
        return 0.0
    return -np.inf

# Define a log-probability function for MCMC
def log_probability(theta, time, flux, flux_err):
    lp = log_prior(theta)
    if not np.isfinite(lp):
        return -np.inf
    ll = log_likelihood(theta, time, flux, flux_err)
    if not np.isfinite(ll):
        return -np.inf
    return lp + ll

# Enhanced flux data filtering to remove NaN values
filtered_idx = pd.notnull(flux_data) & pd.notnull(time_data) & pd.notnull(flux_err)
filtered_flux = flux_data[filtered_idx]
filtered_time = time_data[filtered_idx]
filtered_err = flux_err[filtered_idx]
# Set up the MCMC sampler
nwalkers = 32
ndim = 4
initial_guess = [initial_params.t0, initial_params.rp, initial_params.per, initial_params.a]
pos = initial_guess + 1e-4 * np.random.randn(nwalkers, ndim)

sampler = emcee.EnsembleSampler(nwalkers, ndim, log_probability, args=(filtered_time, filtered_flux, filtered_err))

# File to save the MCMC chain
chain_file = "mcmc_chain.npy"

if os.path.exists(chain_file):
    # Load the chain if it exists
    sampler.run_mcmc(pos, 0, progress=False)  # Initialize the sampler
    sampler.chain = np.load(chain_file)
else:
    # Run MCMC and save the chain to a file
    sampler.run_mcmc(pos, 50000, progress=True)
    np.save(chain_file, sampler.get_chain())

# Get the MCMC results
samples = sampler.get_chain(discard=200, thin=1, flat=True)
t0_mcmc, rp_mcmc, per_mcmc, a_mcmc = np.percentile(samples, 50, axis=0)

# Calculate median and 16th, 84th percentiles for each parameter
medians = np.percentile(samples, 50, axis=0)
errors_lower = np.percentile(samples, 16, axis=0)
errors_upper = np.percentile(samples, 84, axis=0)

parameters = ["t0", "rp", "per", "a"]

print("Median values and errors:")
for i, param in enumerate(parameters):
    print(f"{param}: {medians[i]:.5f} (+{errors_upper[i] - medians[i]:.5f}, -{medians[i] - errors_lower[i]:.5f})")


# Trace plots
fig, axes = plt.subplots(4, figsize=(10, 7), sharex=True)
labels = ["t0", "rp", "per", "a"]
all_samples = sampler.get_chain()
for i in range(ndim):
    ax = axes[i]
    ax.plot(all_samples[:, :, i], "k", alpha=0.3)
    ax.set_ylabel(labels[i])
axes[-1].set_xlabel("Step number")
plt.savefig("trace_plot.png")
plt.close()

# Corner plot
fig = corner.corner(samples, labels=["t0", "rp", "per", "a"], show_titles=True, title_fmt=".2e", quantiles=[0.16, 0.5, 0.84])
plt.savefig("corner_plot.png")
plt.close()

# Print acceptance fraction
print("Acceptance fraction:", np.mean(sampler.acceptance_fraction))

# Define a function to fit transit parameters and generate the model light curve
def fit_transit_parameters(time, flux, initial_params):
    # Detect the first dip
    first_dip_index = find_first_dip(flux)

    if first_dip_index is not None:
        # Use the first dip as the initial transit center time
        initial_params.t0 = time[first_dip_index]

    # Define the transit model parameters
    params = batman.TransitParams()
    params.t0 = t0_mcmc
    params.per = per_mcmc
    params.rp = rp_mcmc
    params.a = a_mcmc
    params.inc = initial_params.inc
    params.ecc = initial_params.ecc
    params.w = initial_params.w
    params.limb_dark = initial_params.limb_dark
    params.u = initial_params.u

    # Generate a finer time grid during transit for a smoother model
    time_fine = np.linspace(time[0], time[-1], len(time) * 10)
    m = batman.TransitModel(params, time_fine)

    # Calculate the model light curve
    flux_model = m.light_curve(params)

    return time_fine, flux_model, [params.t0, params.rp, params.per, params.a]

# Generate the model light curve with best-fit parameters
time_fine, flux_model, temp_best_fit_params = fit_transit_parameters(time_data, flux_data, initial_params)

# Phase-fold the model
model_phase = ((time_fine - temp_best_fit_params[0] + 0.5 * temp_best_fit_params[2]) % temp_best_fit_params[2]) / temp_best_fit_params[2] - 0.5

# Phase-fold the data
phase_folded_time = ((time_data - temp_best_fit_params[0] + 0.5 * temp_best_fit_params[2]) % temp_best_fit_params[2]) / temp_best_fit_params[2] - 0.5

# Sorting the phase and corresponding flux data for plotting
sort_order_data = np.argsort(phase_folded_time)
sorted_flux = flux_data[sort_order_data]
sorted_phase = phase_folded_time[sort_order_data]

sort_order_model = np.argsort(model_phase)
sorted_model_flux = flux_model[sort_order_model]
sorted_model_phase = model_phase[sort_order_model]

# Plot the observed data without connecting the blue points
plt.figure(figsize=(10, 6))
plt.plot(sorted_phase, sorted_flux, 'bo', label='Observed Data', markersize=2)
plt.plot(sorted_model_phase, sorted_model_flux, 'r-', label='Model Light Curve')
plt.xlabel('Phase')
plt.ylabel('Flux')
plt.legend()
plt.savefig("transit_fit.png")
plt.close()

# Print the median values and errors
print("Median values and errors:")
for i, param in enumerate(["t0", "rp", "per", "a"]):
    print(f"{param}: {medians[i]:.5f} (+{errors_upper[i] - medians[i]:.5f}, -{medians[i] - errors_lower[i]:.5f})")

# Calculate and print chi-squared value
chi_squared = np.sum(((sorted_flux - np.interp(sorted_phase, sorted_model_phase, sorted_model_flux)) / flux_err[:len(sorted_flux)]) ** 2)
print("Chi-squared value:", chi_squared)


100%|██████████████████████████████████████████████████████████████████████████| 50000/50000 [1:24:49<00:00,  9.82it/s]


Median values and errors:
t0: 2430.15372 (+0.00007, -0.00007)
rp: 0.13672 (+0.00018, -0.00018)
per: 2.77094 (+0.00002, -0.00002)
a: 9.69858 (+0.01338, -0.01322)




Acceptance fraction: 0.5918806249999999
Median values and errors:
t0: 2430.15372 (+0.00007, -0.00007)
rp: 0.13672 (+0.00018, -0.00018)
per: 2.77094 (+0.00002, -0.00002)
a: 9.69858 (+0.01338, -0.01322)
Chi-squared value: nan
