In [None]:
#Installing packages needed

!pip install netcdf4
!pip install torch_geometric
import torch
print("torch version is", torch.__version__) #torch-2.6.0+cu124

!pip install torch-sparse -f https://data.pyg.org/whl/torch-2.6.0+cu124.html
!pip install torch-scatter -f https://data.pyg.org/whl/torch-2.6.0+cu124.html

In [None]:
#importing and setting file paths
import xarray as xr
import numpy as np
import torch as t
import torch.nn as nn
import netCDF4 as nc
import matplotlib.pyplot as plt
import time
import os
import gc
import torch_sparse
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv


#Making sure can use data in Drive. Change this and make it suitable for Github

from google.colab import drive
drive.mount('/content/drive')
os.chdir('/content/drive/MyDrive/Weather Forecasting/Aash_work')
#!ls /content/drive/MyDrive

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
device = t.device('cuda' if t.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

#Add api request for data on

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout, BatchNormalization
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.regularizers import l2
import matplotlib.pyplot as plt
import os
import psutil  # For memory usage monitoring

# Define file paths
file_path_train_p = "Data/FinalData/p.nc"
file_path_train_a = "Data/FinalData/a.nc"
file_path_train_i = "Data/FinalData/i.nc"
output_path = "Data/FinalData/combined_3.nc"
history_path = "training_history.csv"
plot_path_prefix = "prediction_plot"

# Function to print memory usage
def print_memory_usage():
    process = psutil.Process(os.getpid())
    mem_info = process.memory_info()
    print(f"Memory usage: {mem_info.rss / (1024 * 1024):.2f} MB")

# Step 1: Merge .nc files with checks
def merge_nc_files(file_paths, output_path, time_subset=None):
    # Print memory usage before merging
    print("Memory usage before merging:")
    print_memory_usage()

    datasets = []
    for fp in file_paths:
        # Check if file exists
        if not os.path.exists(fp):
            raise FileNotFoundError(f"File not found: {fp}")
        # Load dataset with chunking (use valid_time instead of time)
        ds = xr.open_dataset(fp, chunks={'valid_time': 100})
        # Print dataset info for debugging
        print(f"\nDataset from {fp}:")
        print(f"Dimensions: {ds.dims}")
        print(f"Coordinates: {list(ds.coords.keys())}")
        print(f"Variables: {list(ds.variables.keys())}")
        # Print expver values if present
        if 'expver' in ds.variables:
            print(f"{fp}: expver values = {ds['expver'].values}")
        # Drop expver to avoid merge conflicts
        if 'expver' in ds.variables:
            ds = ds.drop_vars('expver')
        # Subset time dimension if specified
        if time_subset is not None:
            ds = ds.isel(valid_time=slice(0, time_subset))  # Updated to valid_time
        datasets.append(ds)

    # Check for consistent dimensions and coordinates
    dims = [ds.dims for ds in datasets]
    coords = [ds.coords.keys() for ds in datasets]
    dims_set = set(tuple(sorted(d.items())) for d in dims)
    if len(dims_set) > 1:
        print("Warning: Datasets have different dimensions:", [dict(d) for d in dims])
    coords_set = set(tuple(sorted(c)) for c in coords)
    if len(coords_set) > 1:
        print("Warning: Datasets have different coordinates:", coords)

    # Merge datasets
    try:
        combined = xr.merge(datasets, compat='no_conflicts')
    except Exception as e:
        raise ValueError(f"Failed to merge datasets: {str(e)}")

    # Check for 't2m' in combined dataset
    if 't2m' not in combined.variables:
        raise ValueError("'t2m' variable not found in combined dataset")

    # Check for missing values in 't2m'
    if combined['t2m'].isnull().any():
        print("Warning: 't2m' in combined dataset contains missing values")

    # Check for missing values in other variables
    for var in combined.variables:
        if combined[var].isnull().any():
            print(f"Warning: Variable '{var}' contains missing values")

    # Save combined dataset
    combined.to_netcdf(output_path)
    print(f"Merged file saved to {output_path}")

    # Check merged file size
    if os.path.exists(output_path):
        file_size_mb = os.path.getsize(output_path) / (1024 * 1024)  # Convert bytes to MB
        print(f"Merged file size: {file_size_mb:.2f} MB")
    else:
        raise FileNotFoundError(f"Merged file not found at {output_path}")

    # Print memory usage after merging
    print("Memory usage after merging:")
    print_memory_usage()

    return combined

# Merge files with checks
file_paths = [file_path_train_p, file_path_train_a, file_path_train_i]
combined_ds = merge_nc_files(file_paths, output_path, time_subset=None)

# Step 2: Prepare data for CNN with checks
def prepare_data(dataset):
    # Print memory usage before loading data
    print("Memory usage before loading data:")
    print_memory_usage()

    # Print dataset info for debugging
    print("\nCombined dataset info:")
    print(f"Dimensions: {dataset.dims}")
    print(f"Coordinates: {list(dataset.coords.keys())}")
    print(f"Variables: {list(dataset.variables.keys())}")

    # Check if 't2m' exists
    if 't2m' not in dataset.variables:
        raise ValueError("'t2m' variable not found in combined dataset")

    # Handle pressure_level dimension if present
    if 'pressure_level' in dataset.dims:
        print(f"Pressure levels found: {dataset['pressure_level'].values}")
        # Since t2m is a surface variable, it shouldn't have pressure levels
        # Check t2m dimensions
        t2m_dims = dataset['t2m'].dims
        print(f"t2m dimensions: {t2m_dims}")
        if 'pressure_level' in t2m_dims:
            raise ValueError("t2m should not have a pressure_level dimension; check dataset structure")

    # Use chunks to reduce memory usage (use valid_time instead of time)
    chunk_size = {'valid_time': 100}  # Adjust based on your dataset size
    dataset = dataset.chunk(chunk_size)

    # Extract temperature and additional variables
    temp = dataset['t2m'].values  # Shape: (valid_time, latitude, longitude)

    # Check for valid data shape
    if len(temp.shape) != 3:
        raise ValueError(f"Expected 3D temperature array (valid_time, latitude, longitude), got shape {temp.shape}")

    # Check for NaN or infinite values
    if np.any(np.isnan(temp)) or np.any(np.isinf(temp)):
        raise ValueError("Temperature data contains NaN or infinite values")

    # Check for sufficient data
    if temp.shape[0] < 2:
        raise ValueError("Not enough time steps for training (need at least 2)")

    # Print dataset size
    print(f"Dataset size: {temp.shape[0]} time steps, {temp.shape[1]} lat, {temp.shape[2]} lon")

    # Normalize temperature (scale to 0-1)
    temp_min, temp_max = np.min(temp), np.max(temp)
    if temp_max == temp_min:
        raise ValueError("Temperature data has no variation (min equals max)")
    temp_normalized = (temp - temp_min) / (temp_max - temp_min)

    # Verify normalization
    if not (np.min(temp_normalized) >= 0 and np.max(temp_normalized) <= 1):
        raise ValueError("Normalization failed: values outside [0, 1]")

    # Prepare additional variables (e.g., 'tp', 'msl', 'u10', 'v10')
    input_channels = [temp_normalized[..., np.newaxis]]  # Start with t2m
    additional_vars = ['tp', 'msl', 'u10', 'v10']
    for var in additional_vars:
        if var in dataset.variables:
            var_data = dataset[var].values
            if 'pressure_level' in dataset[var].dims:
                print(f"Skipping variable '{var}' due to pressure_level dimension")
                continue
            if var_data.shape != temp.shape:
                print(f"Warning: Variable '{var}' shape {var_data.shape} does not match t2m shape {temp.shape}")
                continue
            if np.any(np.isnan(var_data)) or np.any(np.isinf(var_data)):
                print(f"Warning: Variable '{var}' contains NaN or infinite values")
                continue
            var_min, var_max = np.min(var_data), np.max(var_data)
            if var_max == var_min:
                print(f"Warning: Variable '{var}' has no variation")
                continue
            var_normalized = (var_data - var_min) / (var_max - var_min)
            input_channels.append(var_normalized[..., np.newaxis])

    # Stack input channels
    X = np.concatenate(input_channels, axis=-1)  # Shape: (valid_time, latitude, longitude, channels)

    # Check input shape
    print(f"Input shape (with {X.shape[-1]} channels): {X.shape}")

    # Create input-output pairs (predict next time step)
    X = X[:-1]  # All but last time step
    y = temp_normalized[1:]  # t2m only, all but first time step

    # Check X and y consistency
    if X.shape[0] != y.shape[0]:
        raise ValueError(f"Mismatch in X and y samples: {X.shape[0]} vs {y.shape[0]}")

    # Split into train and test (80-20 split)
    train_size = int(0.8 * len(X))
    if train_size == 0:
        raise ValueError("Training set is empty after split")

    X_train, X_test = X[:train_size], X[train_size:]
    y_train, y_test = y[:train_size], y[train_size:]

    # Verify train/test shapes
    print(f"X_train shape: {X_train.shape}")
    print(f"X_test shape: {X_test.shape}")
    print(f"y_train shape: {y_train.shape}")
    print(f"y_test shape: {y_test.shape}")

    # Print memory usage after loading data
    print("Memory usage after loading data:")
    print_memory_usage()

    return X_train, X_test, y_train, y_test, temp_min, temp_max

# Prepare data with checks
X_train, X_test, y_train, y_test, temp_min, temp_max = prepare_data(combined_ds)

# Step 3: Build CNN model with two additional hidden layers
def build_cnn(input_shape):
    model = Sequential([
        Conv2D(32, (3, 3), activation='relu', padding='same', input_shape=input_shape),
        BatchNormalization(),
        MaxPooling2D((2, 2)),
        Conv2D(64, (3, 3), activation='relu', padding='same'),
        BatchNormalization(),
        MaxPooling2D((2, 2)),
        Conv2D(128, (3, 3), activation='relu', padding='same'),
        BatchNormalization(),
        Flatten(),
        Dense(256, activation='relu', kernel_regularizer=l2(0.005)),
        Dropout(0.4),
        Dense(128, activation='relu', kernel_regularizer=l2(0.005)),
        Dropout(0.4),
        Dense(64, activation='relu', kernel_regularizer=l2(0.005)),
        Dense(np.prod(y_train.shape[1:]), activation='linear')
    ])

    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
    model.compile(optimizer=optimizer, loss='mse', metrics=['mae'])
    return model

# Build and train model
input_shape = X_train.shape[1:]  # (lat, lon, channels)
model = build_cnn(input_shape)

# Define callbacks
early_stopping = EarlyStopping(monitor='val_loss', patience=15, restore_best_weights=True)
lr_scheduler = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5)

# Train model
history = model.fit(
    X_train, y_train.reshape(-1, np.prod(y_train.shape[1:])),
    epochs=50,
    batch_size=32,
    validation_data=(X_test, y_test.reshape(-1, np.prod(y_test.shape[1:]))),
    callbacks=[early_stopping, lr_scheduler],
    verbose=1
)

# Step 4: Save model and history
model.save('temperature_cnn_model.keras')
print("Model saved to temperature_cnn_model.keras")

# Save training history
history_df = pd.DataFrame(history.history)
history_df.to_csv(history_path, index=False)
print(f"Training history saved to {history_path}")

# Step 5: Predict and denormalize
def predict_temperature(model, X, temp_min, temp_max):
    pred_normalized = model.predict(X)
    # Clip predictions to [0, 1] to match normalization range
    pred_normalized = np.clip(pred_normalized, 0, 1)
    pred_reshaped = pred_normalized.reshape(X.shape[:-1])  # Remove channel dimension
    pred_denorm = pred_reshaped * (temp_max - temp_min) + temp_min
    return pred_denorm

# Get predictions
y_pred = predict_temperature(model, X_test, temp_min, temp_max)
y_true = y_test * (temp_max - temp_min) + temp_min  # Denormalize true values

# Step 6: Plot predictions
# Plot 1: Spatial map for a sample test time step
sample_idx = 0
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.title("Predicted t2m")
plt.imshow(y_pred[sample_idx], cmap='coolwarm')
plt.colorbar(label='Temperature (K)')
plt.subplot(1, 2, 2)
plt.title("Actual t2m")
plt.imshow(y_true[sample_idx], cmap='coolwarm')
plt.colorbar(label='Temperature (K)')
plt.tight_layout()
plt.savefig(f"{plot_path_prefix}_spatial.png")
plt.close()
print(f"Spatial plot saved to {plot_path_prefix}_spatial.png")

# Plot 2: Time series of mean t2m
mean_pred = np.mean(y_pred, axis=(1, 2))
mean_true = np.mean(y_true, axis=(1, 2))
plt.figure(figsize=(10, 5))
plt.plot(mean_pred, label='Predicted Mean t2m', color='blue')
plt.plot(mean_true, label='Actual Mean t2m', color='red')
plt.title("Mean t2m Over Time (Test Set)")
plt.xlabel("Time Step")
plt.ylabel("Temperature (K)")
plt.legend()
plt.grid(True)
plt.savefig(f"{plot_path_prefix}_timeseries.png")
plt.close()
print(f"Time series plot saved to {plot_path_prefix}_timeseries.png")

# Plot 3: Difference map for a sample test time step
plt.figure(figsize=(8, 5))
plt.title("Prediction Error (Predicted - Actual t2m)")
plt.imshow(y_pred[sample_idx] - y_true[sample_idx], cmap='coolwarm', vmin=-5, vmax=5)
plt.colorbar(label='Error (K)')
plt.tight_layout()
plt.savefig(f"{plot_path_prefix}_difference.png")
plt.close()
print(f"Difference plot saved to {plot_path_prefix}_difference.png")

# Plot 4: Scatter plot of predicted vs actual mean t2m
plt.figure(figsize=(8, 8))
plt.scatter(mean_true, mean_pred, alpha=0.5)
plt.plot([mean_true.min(), mean_true.max()], [mean_true.min(), mean_true.max()], 'k--')
plt.title("Predicted vs Actual Mean t2m")
plt.xlabel("Actual Mean t2m (K)")
plt.ylabel("Predicted Mean t2m (K)")
plt.grid(True)
plt.savefig(f"{plot_path_prefix}_scatter.png")
plt.close()
print(f"Scatter plot saved to {plot_path_prefix}_scatter.png")

# Step 7: Comparison of Predictions and Actual Values
# Compute metrics
mae = np.mean(np.abs(y_pred - y_true))
rmse = np.sqrt(np.mean((y_pred - y_true) ** 2))
mbe = np.mean(y_pred - y_true)

print("\nPrediction Metrics:")
print(f"Mean Absolute Error (MAE): {mae:.4f} K")
print(f"Root Mean Squared Error (RMSE): {rmse:.4f} K")
print(f"Mean Bias Error (MBE): {mbe:.4f} K")

# Plot 5: Loss curve from training history
plt.figure(figsize=(10, 5))
plt.plot(history_df['loss'], label='Training Loss')
plt.plot(history_df['val_loss'], label='Validation Loss')
plt.title("Training and Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss (MSE)")
plt.legend()
plt.grid(True)
plt.savefig(f"{plot_path_prefix}_loss_curve.png")
plt.close()
print(f"Loss curve plot saved to {plot_path_prefix}_loss_curve.png")

# Plot 6: Histogram of prediction errors
errors = (y_pred - y_true).flatten()
plt.figure(figsize=(10, 5))
plt.hist(errors, bins=50, edgecolor='black')
plt.title("Histogram of Prediction Errors")
plt.xlabel("Error (K)")
plt.ylabel("Frequency")
plt.grid(True)
plt.savefig(f"{plot_path_prefix}_error_histogram.png")
plt.close()
print(f"Error histogram plot saved to {plot_path_prefix}_error_histogram.png")

# Optional: Evaluate model
test_loss, test_mae = model.evaluate(X_test, y_test.reshape(-1, np.prod(y_test.shape[1:])))
print(f"Test MAE (normalized scale): {test_mae:.4f}")
print("Sample prediction shape:", y_pred.shape)

Memory usage before merging:
Memory usage: 1296.32 MB


  ds = xr.open_dataset(fp, chunks={'valid_time': 100})



Dataset from Data/FinalData/p.nc:
Coordinates: ['number', 'valid_time', 'pressure_level', 'latitude', 'longitude', 'expver']
Variables: ['number', 'valid_time', 'pressure_level', 'latitude', 'longitude', 'expver', 'q', 't']
Data/FinalData/p.nc: expver values = ['0001' '0001' '0001' ... '0005' '0005' '0005']

Dataset from Data/FinalData/a.nc:
Coordinates: ['number', 'valid_time', 'latitude', 'longitude', 'expver']
Variables: ['number', 'valid_time', 'latitude', 'longitude', 'expver', 'tp', 'slhf', 'sshf', 'ssrd', 'strd']
Data/FinalData/a.nc: expver values = ['0001' '0001' '0001' ... '0005' '0005' '0005']

Dataset from Data/FinalData/i.nc:
Coordinates: ['number', 'valid_time', 'latitude', 'longitude', 'expver']
Variables: ['number', 'valid_time', 'latitude', 'longitude', 'expver', 'u10', 'v10', 'd2m', 't2m', 'sp', 'tcc', 'stl1', 'blh']
Data/FinalData/i.nc: expver values = ['0001' '0001' '0001' ... '0005' '0005' '0005']
    number          int64 8B ...
  * valid_time      (valid_time) da

  ds = xr.open_dataset(fp, chunks={'valid_time': 100})
  ds = xr.open_dataset(fp, chunks={'valid_time': 100})
  dims_set = set(tuple(sorted(d.items())) for d in dims)


Merged file saved to Data/FinalData/combined_3.nc
Merged file size: 1331.78 MB
Memory usage after merging:
Memory usage: 2082.54 MB
Memory usage before loading data:
Memory usage: 2082.54 MB

Combined dataset info:
Coordinates: ['number', 'valid_time', 'pressure_level', 'latitude', 'longitude']
Variables: ['number', 'valid_time', 'pressure_level', 'latitude', 'longitude', 'q', 't', 'tp', 'slhf', 'sshf', 'ssrd', 'strd', 'u10', 'v10', 'd2m', 't2m', 'sp', 'tcc', 'stl1', 'blh']
Pressure levels found: [850.]
t2m dimensions: ('valid_time', 'latitude', 'longitude')
Dataset size: 2160 time steps, 101 lat, 237 lon
Input shape (with 4 channels): (2160, 101, 237, 4)
X_train shape: (1727, 101, 237, 4)
X_test shape: (432, 101, 237, 4)
y_train shape: (1727, 101, 237)
y_test shape: (432, 101, 237)
Memory usage after loading data:
Memory usage: 4064.77 MB


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/50
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 205ms/step - loss: 2.2588 - mae: 0.3468 - val_loss: 1.0223 - val_mae: 0.5188 - learning_rate: 0.0010
Epoch 2/50
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 51ms/step - loss: 0.6660 - mae: 0.1453 - val_loss: 0.6569 - val_mae: 0.5202 - learning_rate: 0.0010
Epoch 3/50
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 51ms/step - loss: 0.3483 - mae: 0.1094 - val_loss: 0.5147 - val_mae: 0.5287 - learning_rate: 0.0010
Epoch 4/50
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 51ms/step - loss: 0.2171 - mae: 0.1086 - val_loss: 0.4349 - val_mae: 0.5337 - learning_rate: 0.0010
Epoch 5/50
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 51ms/step - loss: 0.1387 - mae: 0.0964 - val_loss: 0.3607 - val_mae: 0.5134 - learning_rate: 0.0010
Epoch 6/50
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 51ms/step - loss: 0.0929 - mae: 0.0879 - val_los