In [1]:
import os
import re
import numpy as np
import pandas as pd
import geopandas as gpd
import rasterio
from rasterio.mask import mask
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score
from torch.utils.data import TensorDataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import logging

In [2]:
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set global font to Times New Roman
plt.rcParams["font.family"] = "Times New Roman"

# Constant definitions
PROJECTED_CRS = "EPSG:32651"  # UTM Zone 51N

# File paths
lst_dir = "shanghai_LST/LST"
boundary_path = "OSMB-a05454319b28f099ff3da3a49c5dd21e484d8b2d.geojson"
output_dir = "Results"
processed_dir = "processed_data"

In [3]:
# Ensure output directory and processed data directory exist
os.makedirs(output_dir, exist_ok=True)
os.makedirs(processed_dir, exist_ok=True)

# Check file existence
def check_file_exists(file_path):
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"File {file_path} does not exist")
    return file_path

In [4]:
# Load Shanghai boundary
boundary_save_path = os.path.join(processed_dir, "shanghai_boundary.geojson")
if os.path.exists(boundary_save_path):
    gdf_shanghai = gpd.read_file(boundary_save_path)
else:
    gdf_shanghai = gpd.read_file(check_file_exists(boundary_path))
    gdf_shanghai.to_file(boundary_save_path, driver="GeoJSON")

In [5]:
# Read and clip raster data
def read_raster(file_path, boundary):
    """Read and clip raster data, return data and metadata"""
    with rasterio.open(file_path) as src:
        out_image, out_transform = mask(src, [boundary], crop=True)
        out_meta = src.meta.copy()
        out_meta.update({'height': out_image.shape[1], 'width': out_image.shape[2], 'transform': out_transform})
        out_image = out_image.astype(np.float32)
        out_image[out_image == src.nodata] = np.nan
    return out_image[0], out_meta

In [6]:
# Data loading and preprocessing with flexible filename parsing
lst_data = []
skipped_files = 0
for year in range(2000, 2025):
    year_dir = os.path.join(lst_dir, str(year))
    if os.path.exists(year_dir):
        for filename in os.listdir(year_dir):
            if filename.endswith(".tif"):
                # Use regex to match both "LST_DayYYYY_MM.tif", "LST_DayYYYY-MM.tif", "LST_DayYYYY-M.tif", and case variations
                match = re.match(r"LST_Day(\d{4})[-_](\d{1,2})\.tif$", filename, re.IGNORECASE)
                if match:
                    year = int(match.group(1))
                    month = int(match.group(2))
                    if month < 1 or month > 12:
                        logger.warning(f"Invalid month value {month} in file {filename}, skipping")
                        continue
                    file_path = os.path.join(year_dir, filename)
                    lst, _ = read_raster(file_path, gdf_shanghai.unary_union)
                    lst_mean = np.nanmean(lst)
                    if np.isnan(lst_mean):
                        logger.warning(f"LST data in file {filename} is entirely NaN, skipping")
                        skipped_files += 1
                        continue
                    lst_data.append((year, month, lst_mean))
                    logger.info(f"Loaded file: {filename}, Average LST: {lst_mean:.2f}")
                else:
                    logger.warning(f"Skipping invalid filename: {filename}")
logger.info(f"Total files skipped due to NaN: {skipped_files}")


  lst, _ = read_raster(file_path, gdf_shanghai.unary_union)
INFO:__main__:Loaded file: LST_Day2000_10.tif, Average LST: 26.24
  lst, _ = read_raster(file_path, gdf_shanghai.unary_union)
INFO:__main__:Loaded file: LST_Day2000_11.tif, Average LST: 15.98
  lst, _ = read_raster(file_path, gdf_shanghai.unary_union)
INFO:__main__:Loaded file: LST_Day2000_12.tif, Average LST: 11.81
  lst, _ = read_raster(file_path, gdf_shanghai.unary_union)
INFO:__main__:Loaded file: LST_Day2000_2.tif, Average LST: 12.52
  lst, _ = read_raster(file_path, gdf_shanghai.unary_union)
INFO:__main__:Loaded file: LST_Day2000_3.tif, Average LST: 16.33
  lst, _ = read_raster(file_path, gdf_shanghai.unary_union)
INFO:__main__:Loaded file: LST_Day2000_4.tif, Average LST: 20.53
  lst, _ = read_raster(file_path, gdf_shanghai.unary_union)
INFO:__main__:Loaded file: LST_Day2000_5.tif, Average LST: 28.75
  lst, _ = read_raster(file_path, gdf_shanghai.unary_union)
INFO:__main__:Loaded file: LST_Day2000_6.tif, Average LST: 29.

In [7]:
# Create time series data
ts_data = pd.DataFrame(lst_data, columns=['year', 'month', 'lst'])
ts_data['datetime'] = pd.to_datetime(ts_data[['year', 'month']].assign(day=1))
initial_count = len(ts_data)
ts_data = ts_data.sort_values('datetime').drop_duplicates(subset='datetime', keep='first')
logger.info(f"Removed {initial_count - len(ts_data)} duplicate timestamps")

INFO:__main__:Removed 0 duplicate timestamps


In [8]:
# Fill missing months
all_dates = pd.date_range(start='2000-01-01', end='2024-12-01', freq='MS')
ts_data_full = pd.DataFrame(all_dates, columns=['datetime'])
ts_data_full = ts_data_full.merge(ts_data, on='datetime', how='left')
missing_before = ts_data_full['lst'].isna().sum()
ts_data_full['lst'] = ts_data_full['lst'].interpolate(method='linear')
if ts_data_full['lst'].isna().any():
    logger.warning("NaN values remain after interpolation, filling with nearest values")
    ts_data_full['lst'] = ts_data_full['lst'].fillna(method='ffill').fillna(method='bfill')
logger.info(f"Interpolated {missing_before} missing values")
ts_data = ts_data_full.dropna(subset=['lst']).drop(columns=['year', 'month'])
logger.info(f"Number of data points after interpolation: {len(ts_data)}")


  ts_data_full['lst'] = ts_data_full['lst'].fillna(method='ffill').fillna(method='bfill')
INFO:__main__:Interpolated 1 missing values
INFO:__main__:Number of data points after interpolation: 300


In [9]:
# Standardize data
scaler = StandardScaler()
lst_scaled = scaler.fit_transform(ts_data['lst'].values.reshape(-1, 1)).flatten()

In [10]:
# Create sequences
def create_sequences(data, seq_length):
    if len(data) < seq_length:
        raise ValueError(f"Data length ({len(data)}) is less than sequence length ({seq_length})")
    X, y = [], []
    for i in range(len(data) - seq_length):
        X.append(data[i:i + seq_length])
        y.append(data[i + seq_length])
    return np.array(X), np.array(y)

In [11]:
seq_length = 12
X_seq, y_seq = create_sequences(lst_scaled, seq_length)
X_seq = X_seq.reshape((X_seq.shape[0], X_seq.shape[1], 1))
X_seq, y_seq = torch.FloatTensor(X_seq), torch.FloatTensor(y_seq)
logger.info(f"Number of training sequences generated: {len(X_seq)}")

INFO:__main__:Number of training sequences generated: 288


In [12]:
# Split training, validation, and test sets
train_size = int(0.7 * len(X_seq))
val_size = int(0.15 * len(X_seq))
X_train_ts, X_val_ts, X_test_ts = X_seq[:train_size], X_seq[train_size:train_size+val_size], X_seq[train_size+val_size:]
y_train_ts, y_val_ts, y_test_ts = y_seq[:train_size], y_seq[train_size:train_size+val_size], y_seq[train_size+val_size:]
logger.info(f"Training set size: {len(X_train_ts)}, Validation set size: {len(X_val_ts)}, Test set size: {len(X_test_ts)}")


INFO:__main__:Training set size: 201, Validation set size: 43, Test set size: 44


In [13]:
# Define LSTM model
class LSTMModel(nn.Module):
    def __init__(self, input_size=1, hidden_size=50, num_layers=1, dropout=0.2):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        out, _ = self.lstm(x)
        out = self.dropout(out[:, -1, :])
        out = self.fc(out)
        return out

In [15]:
# Train model
model = LSTMModel()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

train_dataset = TensorDataset(X_train_ts, y_train_ts)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)

num_epochs = 500
model.train()
for epoch in range(num_epochs):
    for batch_X, batch_y in train_loader:
        outputs = model(batch_X)
        loss = criterion(outputs.squeeze(), batch_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # Validation step
    model.eval()
    with torch.no_grad():
        val_outputs = model(X_val_ts)
        val_loss = criterion(val_outputs.squeeze(), y_val_ts)
    model.train()
    if (epoch + 1) % 20 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Train Loss: {loss.item():.4f}, Val Loss: {val_loss.item():.4f}')


Epoch [20/500], Train Loss: 0.0897, Val Loss: 0.0448
Epoch [40/500], Train Loss: 0.0584, Val Loss: 0.0439
Epoch [60/500], Train Loss: 0.0520, Val Loss: 0.0376
Epoch [80/500], Train Loss: 0.1088, Val Loss: 0.0446
Epoch [100/500], Train Loss: 0.0700, Val Loss: 0.0450
Epoch [120/500], Train Loss: 0.0355, Val Loss: 0.0413
Epoch [140/500], Train Loss: 0.1075, Val Loss: 0.0292
Epoch [160/500], Train Loss: 0.0247, Val Loss: 0.0364
Epoch [180/500], Train Loss: 0.0342, Val Loss: 0.0364
Epoch [200/500], Train Loss: 0.0510, Val Loss: 0.0326
Epoch [220/500], Train Loss: 0.0269, Val Loss: 0.0371
Epoch [240/500], Train Loss: 0.0530, Val Loss: 0.0341
Epoch [260/500], Train Loss: 0.0515, Val Loss: 0.0323
Epoch [280/500], Train Loss: 0.0243, Val Loss: 0.0282
Epoch [300/500], Train Loss: 0.0123, Val Loss: 0.0261
Epoch [320/500], Train Loss: 0.0276, Val Loss: 0.0293
Epoch [340/500], Train Loss: 0.0356, Val Loss: 0.0293
Epoch [360/500], Train Loss: 0.0332, Val Loss: 0.0253
Epoch [380/500], Train Loss: 0.0

In [16]:
# Evaluate model
model.eval()
with torch.no_grad():
    test_pred = model(X_test_ts).squeeze().numpy()
    test_pred_rescaled = scaler.inverse_transform(test_pred.reshape(-1, 1)).flatten()
    y_test_rescaled = scaler.inverse_transform(y_test_ts.numpy().reshape(-1, 1)).flatten()
    test_rmse = np.sqrt(mean_squared_error(y_test_rescaled, test_pred_rescaled))
    test_r2 = r2_score(y_test_rescaled, test_pred_rescaled)
    print(f'LSTM Test RMSE (after inverse scaling): {test_rmse:.2f}, R²: {test_r2:.2f}')


LSTM Test RMSE (after inverse scaling): 1.62, R²: 0.96


In [18]:
# Future prediction
future_months = 72
last_sequence = X_seq[-1].reshape(1, seq_length, 1)
future_pred = []
model.eval()
with torch.no_grad():
    current_seq = last_sequence.clone()  # Use clone() for Tensor
    for _ in range(future_months):
        pred = model(current_seq).item()
        future_pred.append(pred)
        current_seq = torch.roll(current_seq, -1, dims=1)
        current_seq[0, -1, 0] = pred

future_pred_rescaled = scaler.inverse_transform(np.array(future_pred).reshape(-1, 1)).flatten()

# Visualization with validation comparison
# Split historical data into train, validation, and test periods for plotting
train_val_dates = ts_data['datetime'].iloc[:train_size + val_size + seq_length]
train_val_lst = ts_data['lst'].iloc[:train_size + val_size + seq_length]
test_dates = ts_data['datetime'].iloc[train_size + val_size + seq_length:]
future_dates = pd.date_range(start=test_dates.iloc[-1] + pd.offsets.MonthBegin(1), periods=future_months, freq='MS')

# Plotting
plt.figure(figsize=(12, 6))
# Plot historical data (train + validation)
plt.plot(train_val_dates, train_val_lst, label='Historical LST (Train + Val)', marker='o', color='blue')
# Plot test set (actual vs predicted)
plt.plot(test_dates, y_test_rescaled, label='Actual LST (Test)', marker='o', color='green')
plt.plot(test_dates, test_pred_rescaled, label='Predicted LST (Test)', marker='x', linestyle='--', color='orange')
# Plot future predictions
plt.plot(future_dates, future_pred_rescaled, label='Predicted LST (2025-2030)', marker='x', linestyle='--', color='red')

plt.xlabel('Time', fontsize=14)
plt.ylabel('Average LST (°C)', fontsize=14)
plt.title('Shanghai LST Time Series Forecast with Validation (2000-2030)', fontsize=16)
plt.legend()
plt.xticks(rotation=45)
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'lstm_forecast_with_validation_2000_2030.png'), dpi=300)
plt.close()