# Notebook 4: Experimental V3 (Deep ResNet)

This notebook implements a **Deep Residual Network (ResNet)**.
Hypothesis: Deeper networks with skip connections can capture finer gravitational perturbations than shallow MLPs.

In [None]:
import sys
import os
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, callbacks, optimizers

# Add src/ to path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..', 'src')))
import stars_utils

print(f"TensorFlow Version: {tf.__version__}")

## 1. Data Loading

In [None]:
data_path = '../data/mars_processed_data.csv'
df = pd.read_csv(data_path)
scaler = stars_utils.load_scaler('../models/scaler_features.pkl')
y_scaler = stars_utils.load_scaler('../models/scaler_targets.pkl')

# V3 uses the same features as V2 (Physics + Lags)
FEATURES = [
    'Time_Index', 'Time_Index_2', 
    'Sin_Mars', 'Cos_Mars',
    'Sin_Jupiter', 'Cos_Jupiter', 'Inv_Dist_Jupiter',
    'Sin_Saturn', 'Cos_Saturn', 'Inv_Dist_Saturn',
    'Sin_Venus', 'Cos_Venus', 'Inv_Dist_Venus',
    'X_au_Lag1', 'Y_au_Lag1', 'Z_au_Lag1', 
    'X_au_Lag2', 'Y_au_Lag2', 'Z_au_Lag2',
    'Kepler_X', 'Kepler_Y', 'Kepler_Z'     
]
TARGETS = ['Res_X', 'Res_Y', 'Res_Z']

X = df[FEATURES].values
y = df[TARGETS].values

TEST_SIZE = 0.2
split_index = int(len(df) * (1 - TEST_SIZE))
X_train, X_test = X[:split_index], X[split_index:]
y_train, y_test = y[:split_index], y[split_index:]

# Scale using existing scalers (or fit new ones for experimental rigor)
# Here we strictly reuse to ensure apples-to-apples comparison with V2
X_train_scaled = scaler.transform(X_train)
X_test_scaled = scaler.transform(X_test)
y_train_scaled = y_scaler.transform(y_train)
y_test_scaled = y_scaler.transform(y_test)

## 2. Define ResNet Architecture
We use the Functional API to create Skip Connections: `x = x + block(x)`

In [None]:
def residual_block(x, units, dropout=0.1):
    shortcut = x
    # Layer 1
    x = layers.Dense(units, activation='relu')(x)
    x = layers.Dropout(dropout)(x)
    # Layer 2
    x = layers.Dense(units)(x) # Linear activation before add
    
    # If dimensions match, add directly. Else, project shortcut.
    if shortcut.shape[-1] != units:
        shortcut = layers.Dense(units)(shortcut)
        
    x = layers.Add()([x, shortcut])
    x = layers.Activation('relu')(x)
    return x

def build_resnet(input_dim, output_dim):
    inputs = layers.Input(shape=(input_dim,))
    
    # Entry
    x = layers.Dense(128, activation='relu')(inputs)
    
    # Deep Residual Stack (4 Blocks = 8 Layers)
    x = residual_block(x, 128)
    x = residual_block(x, 128)
    x = residual_block(x, 128)
    x = residual_block(x, 128)
    
    # Exit
    x = layers.Dense(64, activation='relu')(x)
    outputs = layers.Dense(output_dim, activation='linear')(x)
    
    model = models.Model(inputs=inputs, outputs=outputs, name="ResNet_V3")
    return model

model = build_resnet(X_train_scaled.shape[1], 3)
model.compile(optimizer=optimizers.Adam(learning_rate=0.0005), loss='mse')
model.summary()

## 3. Train

In [None]:
early_stop = callbacks.EarlyStopping(monitor='val_loss', patience=30, restore_best_weights=True)

history = model.fit(
    X_train_scaled, y_train_scaled,
    validation_data=(X_test_scaled, y_test_scaled),
    epochs=300,
    batch_size=64,
    callbacks=[early_stop],
    verbose=1
)

## 4. Evaluation & Visualization

In [None]:
import plotly.graph_objects as go

# Predict
pred_scaled = model.predict(X_test_scaled)
pred_res = y_scaler.inverse_transform(pred_scaled)

# Construct Orbits
test_df = df.iloc[split_index:].copy()
test_df['Pred_X'] = test_df['Kepler_X'] + pred_res[:, 0]
test_df['Pred_Y'] = test_df['Kepler_Y'] + pred_res[:, 1]
test_df['Pred_Z'] = test_df['Kepler_Z'] + pred_res[:, 2]

# Calc Error
diff = test_df[['Pred_X','Pred_Y','Pred_Z']].values - test_df[['X_au','Y_au','Z_au']].values
mae = np.mean(np.sqrt(np.sum(diff**2, axis=1)))
print(f"V3 ResNet MAE: {mae:.6f} AU")

# Save Model
if not os.path.exists('../models'): os.mkdir('../models')
model.save('../models/mars_geocentric_v3_resnet.keras')

# Visualization
fig = go.Figure()
viz_df = test_df.sample(2000).sort_index()
fig.add_trace(go.Scatter3d(x=viz_df['X_au'], y=viz_df['Y_au'], z=viz_df['Z_au'], mode='lines', name='True'))
fig.add_trace(go.Scatter3d(x=viz_df['Pred_X'], y=viz_df['Pred_Y'], z=viz_df['Pred_Z'], mode='markers', marker=dict(size=3, color='cyan'), name='V3 ResNet'))
fig.update_layout(title=f'V3 ResNet Orbit (MAE: {mae:.6f} AU)', template='plotly_dark')
fig.write_html('../v3_resnet_viz.html')
print("Visualization saved to ../v3_resnet_viz.html")