In [19]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

def above_and_below(n):
    return [int(is_prime(n-1)),int(is_prime(n+1))]

# --- Helper function to check primality ---
def is_prime(n):
    if n < 2:
        return False
    for i in range(2, int(n**0.5)+1):
        if n % i == 0:
            return False
    return True

# --- Prepare the dataset ---
def create_dataset(start, end):
    X, y = [], []
    for n in range(start, end+1):
        left = 1 if is_prime(n-1) else 0
        right = 1 if is_prime(n+1) else 0
        X.append([left, right])
        y.append(1 if is_prime(n) else 0)
    return np.array(X, dtype=np.float32), np.array(y, dtype=np.float32)

# Training: numbers 50-99
X_train, y_train = create_dataset(50, 99)

# Testing: numbers 100-149
X_test, y_test = create_dataset(100, 149)

# --- Build the model ---
model = Sequential([
    Dense(8, activation='relu', input_shape=(2,)),
    Dense(4, activation='relu'),
    Dense(1, activation='sigmoid')  # output between 0 and 1
])

model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy'])

# --- Train the model ---
model.fit(X_train, y_train, epochs=100, verbose=0)

# --- Evaluate the model ---
loss, acc = model.evaluate(X_test, y_test, verbose=0)
print(f"Test accuracy: {acc:.3f}")

# --- Test with your examples ---
examples = np.array([above_and_below(i) for i in range(50, 150)], dtype=np.float32)

predictions = model.predict(examples)
for inp, pred in zip(examples, predictions):
    label = 1 if pred >= 0.5 else 0
    print(f"Input {inp} -> Predicted: {pred[0]:.3f} -> Label: {label}")

Test accuracy: 0.740
Input [0. 0.] -> Predicted: 0.455 -> Label: 0
Input [0. 0.] -> Predicted: 0.455 -> Label: 0
Input [0. 1.] -> Predicted: 0.455 -> Label: 0
Input [0. 0.] -> Predicted: 0.455 -> Label: 0
Input [1. 0.] -> Predicted: 0.455 -> Label: 0
Input [0. 0.] -> Predicted: 0.455 -> Label: 0
Input [0. 0.] -> Predicted: 0.455 -> Label: 0
Input [0. 0.] -> Predicted: 0.455 -> Label: 0
Input [0. 1.] -> Predicted: 0.455 -> Label: 0
Input [0. 0.] -> Predicted: 0.455 -> Label: 0
Input [1. 1.] -> Predicted: 0.502 -> Label: 1
Input [0. 0.] -> Predicted: 0.455 -> Label: 0
Input [1. 0.] -> Predicted: 0.455 -> Label: 0
Input [0. 0.] -> Predicted: 0.455 -> Label: 0
Input [0. 0.] -> Predicted: 0.455 -> Label: 0
Input [0. 0.] -> Predicted: 0.455 -> Label: 0
Input [0. 1.] -> Predicted: 0.455 -> Label: 0
Input [0. 0.] -> Predicted: 0.455 -> Label: 0
Input [1. 0.] -> Predicted: 0.455 -> Label: 0
Input [0. 0.] -> Predicted: 0.455 -> Label: 0
Input [0. 1.] -> Predicted: 0.455 -> Label: 0
Input [0. 0.]