In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

# ---------------------------------
# 1. Generate Synthetic HR Climate Data
# ---------------------------------
np.random.seed(42)
n_samples = 1000
hr_size = 20  # High resolution: 20x20 grid

# Generate synthetic HR temperature fields (HR)
HR_data = np.random.rand(n_samples, hr_size, hr_size, 1)

# Generate LR data by downsampling (e.g., 2x)
LR_data = tf.image.resize(HR_data, size=(hr_size//2, hr_size//2), method='bicubic')
LR_data = tf.image.resize(LR_data, size=(hr_size, hr_size), method='bicubic')  # Upsample to HR size for SRCNN input

# ---------------------------------
# 2. Train/Test Split
# ---------------------------------
X_train, X_test, y_train, y_test = train_test_split(LR_data.numpy(), HR_data, test_size=0.2, random_state=42)

# ---------------------------------
# 3. Define SRCNN Model
# ---------------------------------
model = models.Sequential([
    layers.Conv2D(64, (9, 9), activation='relu', padding='same', input_shape=(hr_size, hr_size, 1)),
    layers.Conv2D(32, (1, 1), activation='relu', padding='same'),
    layers.Conv2D(1, (5, 5), activation='linear', padding='same')
])

model.compile(optimizer='adam', loss='mse', metrics=['mae'])
model.summary()

# ---------------------------------
# 4. Train the Model
# ---------------------------------
history = model.fit(X_train, y_train, epochs=10, batch_size=32, validation_split=0.1)

# ---------------------------------
# 5. Evaluate & Visualize Results
# ---------------------------------
loss, mae = model.evaluate(X_test, y_test)
print(f"\nTest Loss (MSE): {loss:.4f}, MAE: {mae:.4f}")

# Show example prediction
idx = 0
predicted = model.predict(np.expand_dims(X_test[idx], axis=0))

plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.title("Low-Res Input")
plt.imshow(X_test[idx].squeeze(), cmap='hot')
plt.axis('off')

plt.subplot(1, 3, 2)
plt.title("Super-Res Output")
plt.imshow(predicted.squeeze(), cmap='hot')
plt.axis('off')

plt.subplot(1, 3, 3)
plt.title("High-Res Ground Truth")
plt.imshow(y_test[idx].numpy().squeeze(), cmap='hot')
plt.axis('off')

plt.tight_layout()
plt.show()
