In [1]:
import tensorflow as tf

In [14]:
import numpy as np

In [34]:
def retrieve_data(tfrecord_file):
    feature_description = {
        "global_view": tf.io.FixedLenFeature([2001], tf.float32),
        "local_view": tf.io.FixedLenFeature([201], tf.float32),
        "av_training_set": tf.io.FixedLenFeature([], tf.string),
        "kepid": tf.io.FixedLenFeature([], tf.int64),
    }

    def parse_tfrecord(example_proto):
        example = tf.io.parse_single_example(example_proto, feature_description)
        
        # Convert label to integer
        label = tf.cond(tf.equal(example["av_training_set"], tf.constant("PC")),
                        lambda: tf.constant(1, dtype=tf.int64), # Represents a planet candidate
                        lambda: tf.constant(0, dtype=tf.int64)) # Represents a false positive
        
        # Return (global, local) as features
        return (example["global_view"], example["local_view"]), label

    dataset = tf.data.TFRecordDataset(tfrecord_file)
    dataset = dataset.map(parse_tfrecord)
    return dataset


In [35]:
import matplotlib.pyplot as plt
dataset = retrieve_data("/Users/swebber/Documents/Personal_Projects/QuantumComet---NASA-Space-Apps-Challenge-2025/Astronet_Preprocessed_Data/train-00000-of-00008")

for (global_view, local_view), label in dataset.take(3):
    print("Label:", label.numpy())
    print(f"{global_view.shape=}")
    print(f"{local_view.shape=}")

    
    plt.figure(figsize=(14,5))
    
    plt.subplot(1,2,1)
    plt.plot(global_view.numpy(), ".")
    plt.title("Global view (full phase)")
    
    plt.subplot(1,2,2)
    plt.plot(local_view.numpy(), ".")
    plt.title("Local view (transit zoom)")
    
    plt.show()


In [13]:
from tensorflow.keras import layers, models


In [None]:


input_global = layers.Input(shape=(2001,), name='global_view')
input_local = layers.Input(shape=(201,), name='local_view')

# Process each view separately
xg = layers.Dense(128, activation='relu')(input_global)
xg = layers.Dense(64, activation='relu')(xg)

xl = layers.Dense(32, activation='relu')(input_local)
xl = layers.Dense(16, activation='relu')(xl)

# Merge global + local
x = layers.concatenate([xg, xl])
x = layers.Dense(64, activation='relu')(x)
output = layers.Dense(1, activation='sigmoid')(x)

model = models.Model(inputs=[input_global, input_local], outputs=output)
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.summary()


In [32]:
X_global, X_local, y = [], [], []

for (global_view, local_view), label in dataset:
    X_global.append(global_view.numpy())
    X_local.append(local_view.numpy())
    y.append(label.numpy())

X_global = np.array(X_global)  # Shape: (num_samples, 2001)
X_local = np.array(X_local)    # Shape: (num_samples, 201)
y = np.array(y)


In [33]:
model.fit(
    {'global_view': X_global, 'local_view': X_local},
    y,
    epochs=20,
    batch_size=32,
    validation_split=0.1
)


In [38]:
dataTest = retrieve_data("Astronet_Preprocessed_Data/test-00000-of-00001")
X_global_test, X_local_test, y_test = [], [], []

for (global_view, local_view), label in dataset:
    X_global_test.append(global_view.numpy())
    X_local_test.append(local_view.numpy())
    y_test.append(label.numpy())

X_global_test = np.array(X_global_test)  # Shape: (num_samples, 2001)
X_local_test = np.array(X_local_test)    # Shape: (num_samples, 201)
y_test = np.array(y_test)


In [39]:
loss, acc = model.evaluate({'global_view': X_global_test, 'local_view': X_local_test}, y_test)
print("Test Accuracy:", acc)
