# Train a simple Sichuan regression model

This notebook loads the prepared Sichuan sample data (GDP + night lights)
and fits a small regression model, just to check whether the night-light
signal explains any of the variation in local GDP.


In [None]:
# basic imports
import os
import pandas as pd
import numpy as np

from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_absolute_error

import matplotlib.pyplot as plt

In [None]:
# point this to where you saved sample_points.csv in the prepare notebook
DATA_DIR = '/content/drive/MyDrive/sichuan_data/processed'  # change if needed
CSV_PATH = os.path.join(DATA_DIR, 'sample_points.csv')

df = pd.read_csv(CSV_PATH)
df.head()

In [None]:
# basic cleaning (you can add more later if needed)
df = df.dropna(subset=['gdp', 'ntl'])

X = df[['ntl']].values  # using night lights only for now
y = df['gdp'].values

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

len(X_train), len(X_test)

In [None]:
# fit a simple RandomForestRegressor
rf = RandomForestRegressor(
    n_estimators=200,
    random_state=42,
    n_jobs=-1
)
rf.fit(X_train, y_train)

y_pred = rf.predict(X_test)

r2 = r2_score(y_test, y_pred)
mae = mean_absolute_error(y_test, y_pred)

print('R^2:', r2)
print('MAE:', mae)

In [None]:
# save metrics to a small JSON file
import json

results_dir = os.path.join(DATA_DIR, '..', 'results')
os.makedirs(results_dir, exist_ok=True)

metrics = {
    'r2': float(r2),
    'mae': float(mae),
    'n_train': int(len(X_train)),
    'n_test': int(len(X_test)),
}

metrics_path = os.path.join(results_dir, 'metrics.json')
with open(metrics_path, 'w') as f:
    json.dump(metrics, f, indent=2)

metrics_path

In [None]:
# scatter plot: true vs predicted
plt.figure(figsize=(5, 5))
plt.scatter(y_test, y_pred, alpha=0.6)
plt.xlabel('True GDP')
plt.ylabel('Predicted GDP')
plt.title('Random forest: true vs predicted GDP')
plt.grid(True)

plot_path = os.path.join(results_dir, 'scatter_true_vs_pred.png')
plt.savefig(plot_path, dpi=150, bbox_inches='tight')
plot_path