In [1]:
#The code in this file is modified code from

#The Official Graphcast Repo
#https://github.com/google-deepmind/graphcast

#Graphcast: How to Get Things Done by Abhinav Kumar
#https://towardsdatascience.com/graphcast-how-to-get-things-done-f2fd5630c5fb

import os
import cdsapi
import datetime
import isodate
import math
import numpy as np
import pandas as pd
from pysolar.radiation import get_radiation_direct
from pysolar.solar import get_altitude
import pytz
import xarray
from scipy import __name__ as scipy_name

client = cdsapi.Client() # Making a connection to CDS, to fetch data.


# The fields to be fetched from the single-level source.
singlelevelfields = [
                        '10m_u_component_of_wind',
                        '10m_v_component_of_wind',
                        '2m_temperature',
                        'geopotential',
                        'land_sea_mask',
                        'mean_sea_level_pressure',
                        'toa_incident_solar_radiation',
                        'total_precipitation'
                    ]

# The fields to be fetched from the pressure-level source.
pressurelevelfields = [
                        'u_component_of_wind',
                        'v_component_of_wind',
                        'geopotential',
                        'specific_humidity',
                        'temperature',
                        'vertical_velocity'
                    ]

# The 13 pressure levels.
pressure_levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]

# Initializing other required constants.
pi = math.pi
gap = 6 # There is a gap of 6 hours between each graphcast prediction.
first_prediction = datetime.datetime(2024, 1, 1, 18, 0) # Timestamp of the first prediction.
watts_to_joules = 3600
predictions_steps = 12 # Predicting for 4 timestamps.
lat_range = range(-180, 181, 1) # Latitude range.
lon_range = range(0, 360, 1) # Longitude range.

smallModel = False

if smallModel == False:
    spatial_resolution = '0.25/0.25'
else:
    spatial_resolution = '1.0/1.0'


# A utility function used for ease of coding.
# Converting the variable to a datetime object.
def toDatetime(dt) -> datetime.datetime:
    if isinstance(dt, datetime.date) and isinstance(dt, datetime.datetime):
        return dt
    
    elif isinstance(dt, datetime.date) and not isinstance(dt, datetime.datetime):
        return datetime.datetime.combine(dt, datetime.datetime.min.time())
    
    elif isinstance(dt, str):
        if 'T' in dt:
            return isodate.parse_datetime(dt)
        else:
            return datetime.datetime.combine(isodate.parse_date(dt), datetime.datetime.min.time())

 # Functions taken from the official Graphcast REPO  - Function from "data_utils"

In [1]:
#Graphcast Function from "data_utils"
_SEC_PER_HOUR = 3600
_HOUR_PER_DAY = 24
SEC_PER_DAY = _SEC_PER_HOUR * _HOUR_PER_DAY
_AVG_DAY_PER_YEAR = 365.24219




def get_year_progress(seconds_since_epoch: np.ndarray) -> np.ndarray:
  """Computes year progress for times in seconds.

  Args:
    seconds_since_epoch: Times in seconds since the "epoch" (the point at which
      UNIX time starts).

  Returns:
    Year progress normalized to be in the [0, 1) interval for each time point.
  """

  # Start with the pure integer division, and then float at the very end.
  # We will try to keep as much precision as possible.
  years_since_epoch = (
      seconds_since_epoch / SEC_PER_DAY / np.float64(_AVG_DAY_PER_YEAR)
  )
  # Note depending on how these ops are down, we may end up with a "weak_type"
  # which can cause issues in subtle ways, and hard to track here.
  # In any case, casting to float32 should get rid of the weak type.
  # [0, 1.) Interval.
  return np.mod(years_since_epoch, 1.0).astype(np.float32)


def get_day_progress(
    seconds_since_epoch: np.ndarray,
    longitude: np.ndarray,
) -> np.ndarray:
  """Computes day progress for times in seconds at each longitude.

  Args:
    seconds_since_epoch: 1D array of times in seconds since the 'epoch' (the
      point at which UNIX time starts).
    longitude: 1D array of longitudes at which day progress is computed.

  Returns:
    2D array of day progress values normalized to be in the [0, 1) inverval
      for each time point at each longitude.
  """

  # [0.0, 1.0) Interval.
  day_progress_greenwich = (
      np.mod(seconds_since_epoch, SEC_PER_DAY) / SEC_PER_DAY
  )

  # Offset the day progress to the longitude of each point on Earth.
  longitude_offsets = np.deg2rad(longitude) / (2 * np.pi)
  day_progress = np.mod(
      day_progress_greenwich[..., np.newaxis] + longitude_offsets, 1.0
  )
  return day_progress.astype(np.float32)

NameError: name 'np' is not defined

 # Download data from the CDSAPI
 ~WARNING ENSURE 'inputs_data.pkl'  is deleted if you have made changes to the data you wish to have downloaded

In [2]:
# File paths for saved data     
data_file = 'inputs_data.pkl' 

# Function to check for file existence and load or retrieve data
def load_or_retrieve_data():
    if os.path.exists(data_file):
        print("Loading data from file...")
        values = pd.read_pickle(data_file)
    else:
        print("File not found, retrieving data...")
        single, pressure = getSingleAndPressureValues()
        inputs = pd.merge(pressure, single, left_index=True, right_index=True, how='inner')
        inputs = integrateProgress(inputs)
        inputs = formatData(inputs)
        values = {'inputs': inputs}  # Store the DataFrame in a dictionary
        pd.to_pickle(values, data_file)  # Save to file for future runs
        print("Data saved to file.")
    return values

# Getting the single and pressure level values.
def getSingleAndPressureValues():
    client.retrieve(
        'reanalysis-era5-single-levels',
        {
            'product_type': 'reanalysis',
            'variable': singlelevelfields,
            'grid': spatial_resolution,
            'year': [2024],
            'month': [1],
            'day': [1],
            'time': ['00:00', '01:00', '02:00', '03:00', '04:00', '05:00', '06:00', '07:00', '08:00', '09:00', '10:00', '11:00', '12:00'],
            'format': 'netcdf'
        },
        'single-level.nc'
    )
    singlelevel = xarray.open_dataset('single-level.nc', engine=scipy_name).to_dataframe()
    singlelevel = singlelevel.rename(columns={col: singlelevelfields[ind] for ind, col in enumerate(singlelevel.columns.values.tolist())})
    singlelevel = singlelevel.rename(columns={'geopotential': 'geopotential_at_surface'})

    # Calculating the sum of the last 6 hours of rainfall.
    singlelevel = singlelevel.sort_index()
    singlelevel['total_precipitation_6hr'] = singlelevel.groupby(level=[0, 1])['total_precipitation'].rolling(window=6, min_periods=1).sum().reset_index(level=[0, 1], drop=True)
    singlelevel.pop('total_precipitation')

    client.retrieve(
        'reanalysis-era5-pressure-levels',
        {
            'product_type': 'reanalysis',
            'variable': pressurelevelfields,
            'grid': spatial_resolution,
            'year': [2024],
            'month': [1],
            'day': [1],
            'time': ['06:00', '12:00'],
            'pressure_level': pressure_levels,
            'format': 'netcdf'
        },
        'pressure-level.nc'
    )
    pressurelevel = xarray.open_dataset('pressure-level.nc', engine=scipy_name).to_dataframe()
    pressurelevel = pressurelevel.rename(columns={col: pressurelevelfields[ind] for ind, col in enumerate(pressurelevel.columns.values.tolist())})

    return singlelevel, pressurelevel

# Adding sin and cos of the year progress.
def addYearProgress(secs, data):
    progress = get_year_progress(secs)
    data['year_progress_sin'] = math.sin(2 * math.pi * progress)
    data['year_progress_cos'] = math.cos(2 * math.pi * progress)
    return data

# Adding sin and cos of the day progress.
def addDayProgress(secs, lon: str, data: pd.DataFrame):
    lons = data.index.get_level_values(lon).unique()
    progress: np.ndarray = get_day_progress(secs, np.array(lons))
    prxlon = {lon: prog for lon, prog in list(zip(list(lons), progress.tolist()))}
    data['day_progress_sin'] = data.index.get_level_values(lon).map(lambda x: math.sin(2 * math.pi * prxlon[x]))
    data['day_progress_cos'] = data.index.get_level_values(lon).map(lambda x: math.cos(2 * math.pi * prxlon[x]))
    return data

# Adding day and year progress.
def integrateProgress(data: pd.DataFrame):
    for dt in data.index.get_level_values('time').unique():
        seconds_since_epoch = toDatetime(dt).timestamp()
        data = addYearProgress(seconds_since_epoch, data)
        data = addDayProgress(seconds_since_epoch, 'longitude' if 'longitude' in data.index.names else 'lon', data)
    return data

# Adding batch field and renaming some others.
def formatData(data: pd.DataFrame) -> pd.DataFrame:
    data = data.rename_axis(index={'latitude': 'lat', 'longitude': 'lon'})
    if 'batch' not in data.index.names:
        data['batch'] = 0
        data = data.set_index('batch', append=True)
    return data

# Main function to run the process
if __name__ == '__main__':
    values = load_or_retrieve_data()
    print("Data processing complete and saved.")

Loading data from file...
Data processing complete and saved.


# Creating Targets

In [4]:
# Includes the packages imported and constants assigned.
# The functions created for the inputs also go here.

predictionFields = [
                        'u_component_of_wind',
                        'v_component_of_wind',
                        'geopotential',
                        'specific_humidity',
                        'temperature',
                        'vertical_velocity',
                        '10m_u_component_of_wind',
                        '10m_v_component_of_wind',
                        '2m_temperature',
                        'mean_sea_level_pressure',
                        'total_precipitation_6hr'
                    ]

# Creating an array full of nan values.
def nans(*args) -> list:
    return np.full((args), np.nan)

# Adding or subtracting time.
def deltaTime(dt, **delta) -> datetime.datetime:
    return dt + datetime.timedelta(**delta)

def getTargets(dt, data:pd.DataFrame):
    
    print("1")
    # Creating an array consisting of unique values of each index.
    lat, lon, levels, batch = sorted(data.index.get_level_values('lat').unique().tolist()), sorted(data.index.get_level_values('lon').unique().tolist()), sorted(data.index.get_level_values('level').unique().tolist()), data.index.get_level_values('batch').unique().tolist()
    time = [deltaTime(dt, hours = days * gap) for days in range(4)] #harcoded as 4
    #time = [deltaTime(dt, hours = days * gap) for days in range(predictions_steps)] #Fixed

    # Creating an empty dataset using latitude, longitude, the pressure levels and each prediction timestamp.
    target = xarray.Dataset({field: (['lat', 'lon', 'level', 'time'], nans(len(lat), len(lon), len(levels), len(time))) for field in predictionFields}, coords = {'lat': lat, 'lon': lon, 'level': levels, 'time': time, 'batch': batch})

    return target.to_dataframe()

if __name__ == '__main__':

    # The code for creating inputs will be here.

    values['targets'] = getTargets(first_prediction, values['inputs'])

1


# Creating Forcings

In [5]:
# Includes the packages imported and constants assigned.
# The functions created for the inputs and targets also go here.

# Adding a timezone to datetime.datetime variables.
def addTimezone(dt, tz = pytz.UTC) -> datetime.datetime:
    dt = toDatetime(dt)
    if dt.tzinfo == None:
        return pytz.UTC.localize(dt).astimezone(tz)
    else:
        return dt.astimezone(tz)

# Getting the solar radiation value wrt longitude, latitude and timestamp.
def getSolarRadiation(longitude, latitude, dt):
        
    altitude_degrees = get_altitude(latitude, longitude, addTimezone(dt))
    solar_radiation = get_radiation_direct(dt, altitude_degrees) if altitude_degrees > 0 else 0

    return solar_radiation * watts_to_joules

# Calculating the solar radiation values for timestamps to be predicted.
def integrateSolarRadiation(data:pd.DataFrame):
    
    dates = list(data.index.get_level_values('time').unique())
    coords = [[lat, lon] for lat in lat_range for lon in lon_range]
    values = []
    
    # For each data, getting the solar radiation value at a particular coordinate.
    for dt in dates:
        values.extend(list(map(lambda coord:{'time': dt, 'lon': coord[1], 'lat': coord[0], 'toa_incident_solar_radiation': getSolarRadiation(coord[1], coord[0], dt)}, coords)))
  
    # Setting indices.
    values = pd.DataFrame(values).set_index(keys = ['lat', 'lon', 'time'])
      
    # The forcings dataset will now contain the solar radiation values.
    return pd.merge(data, values, left_index = True, right_index = True, how = 'inner')

def getForcings(data:pd.DataFrame):
  
    # Since forcings data does not contain batch as an index, it is dropped.
    # So are all the columns, since forcings data only has 5, which will be created.
    forcingdf = data.reset_index(level = 'level', drop = True).drop(labels = predictionFields, axis = 1)
    
    # Keeping only the unique indices.
    forcingdf = pd.DataFrame(index = forcingdf.index.drop_duplicates(keep = 'first'))

    # Adding the sin and cos of day and year progress.
    # Functions are included in the creation of inputs data section.
    forcingdf = integrateProgress(forcingdf)

    # Integrating the solar radiation values.
    forcingdf = integrateSolarRadiation(forcingdf)

    return forcingdf

if __name__ == '__main__':

    # The code for creating inputs and targets will be here.

    values['forcings'] = getForcings(values['targets'])

# Processing and formatting the data

In [6]:
# Includes the packages imported and constants assigned.
# The functions created for the inputs, targets and forcings also go here.

# A dictionary created, containing each coordinate a data variable requires.
class AssignCoordinates:
    
    coordinates = {
                    '2m_temperature': ['batch', 'lon', 'lat', 'time'],
                    'mean_sea_level_pressure': ['batch', 'lon', 'lat', 'time'],
                    '10m_v_component_of_wind': ['batch', 'lon', 'lat', 'time'],
                    '10m_u_component_of_wind': ['batch', 'lon', 'lat', 'time'],
                    'total_precipitation_6hr': ['batch', 'lon', 'lat', 'time'],
                    'temperature': ['batch', 'lon', 'lat', 'level', 'time'],
                    'geopotential': ['batch', 'lon', 'lat', 'level', 'time'],
                    'u_component_of_wind': ['batch', 'lon', 'lat', 'level', 'time'],
                    'v_component_of_wind': ['batch', 'lon', 'lat', 'level', 'time'],
                    'vertical_velocity': ['batch', 'lon', 'lat', 'level', 'time'],
                    'specific_humidity': ['batch', 'lon', 'lat', 'level', 'time'],
                    'toa_incident_solar_radiation': ['batch', 'lon', 'lat', 'time'],
                    'year_progress_cos': ['batch', 'time'],
                    'year_progress_sin': ['batch', 'time'],
                    'day_progress_cos': ['batch', 'lon', 'time'],
                    'day_progress_sin': ['batch', 'lon', 'time'],
                    'geopotential_at_surface': ['lon', 'lat'],
                    'land_sea_mask': ['lon', 'lat'],
                }

def modifyCoordinates(data:xarray.Dataset):
    
    # Parsing through each data variable and removing unneeded indices.
    for var in list(data.data_vars):
        varArray:xarray.DataArray = data[var]
        nonIndices = list(set(list(varArray.coords)).difference(set(AssignCoordinates.coordinates[var])))
        data[var] = varArray.isel(**{coord: 0 for coord in nonIndices})
    data = data.drop_vars('batch')

    return data

def makeXarray(data:pd.DataFrame) -> xarray.Dataset:
    
    # Converting to xarray.
    data = data.to_xarray()
    data = modifyCoordinates(data)

    return data

if __name__ == '__main__':

    # The code for creating inputs, targets and forcings will be here.
    values = {value:makeXarray(values[value]) for value in values}

In [8]:
for key, value in values.items():
    print(f"\nSample of data in '{key}':")
    print(value.head())  # If it's a DataFrame or something with a similar method



Sample of data in 'inputs':
<xarray.Dataset> Size: 15kB
Dimensions:                       (lon: 5, lat: 5, level: 5, time: 2, batch: 1)
Coordinates:
  * lon                           (lon) float32 20B 0.0 0.25 0.5 0.75 1.0
  * lat                           (lat) float32 20B -90.0 -89.75 ... -89.0
  * level                         (level) int32 20B 50 100 150 200 250
  * time                          (time) datetime64[ns] 16B 2024-01-01T06:00:...
Dimensions without coordinates: batch
Data variables: (12/18)
    u_component_of_wind           (lon, lat, level, time, batch) float64 2kB ...
    v_component_of_wind           (lon, lat, level, time, batch) float64 2kB ...
    geopotential                  (lon, lat, level, time, batch) float64 2kB ...
    specific_humidity             (lon, lat, level, time, batch) float64 2kB ...
    temperature                   (lon, lat, level, time, batch) float64 2kB ...
    vertical_velocity             (lon, lat, level, time, batch) float64 2kB ...
 

In [9]:
# Assuming `values` is a dictionary of xarray.Datasets
for key, dataset in values.items():
    if isinstance(dataset, xarray.Dataset):
        # Create a unique filename for each dataset
        output_file = f'{key}_output_data.nc'
        
        # Save the xarray Dataset to a NetCDF file
        dataset.to_netcdf(output_file)
        
        print(f"NetCDF file saved successfully for '{key}' as {output_file}.")
    else:
        print(f"Error: '{key}' is not an xarray.Dataset. It is {type(dataset)}.")

NetCDF file saved successfully for 'inputs' as inputs_output_data.nc.
NetCDF file saved successfully for 'targets' as targets_output_data.nc.
NetCDF file saved successfully for 'forcings' as forcings_output_data.nc.


In [39]:
#Combine the Files
# Define file paths for the NetCDF files
input_file = 'inputs_output_data.nc'
target_file = 'targets_output_data.nc'
forcing_file = 'forcings_output_data.nc'

# Load the NetCDF files into xarray Datasets
inputs = xarray.open_dataset(input_file)
targets = xarray.open_dataset(target_file)
forcings = xarray.open_dataset(forcing_file)

# Print the structure of each dataset to verify they are loaded correctly
print("Inputs dataset:")
print(inputs)

print("\nTargets dataset:")
print(targets)

print("\nForcings dataset:")
print(forcings)


Inputs dataset:
<xarray.Dataset> Size: 1GB
Dimensions:                       (lon: 1440, lat: 721, level: 13, time: 2,
                                   batch: 1)
Coordinates:
  * lon                           (lon) float32 6kB 0.0 0.25 0.5 ... 359.5 359.8
  * lat                           (lat) float32 3kB -90.0 -89.75 ... 89.75 90.0
  * level                         (level) int32 52B 50 100 150 ... 850 925 1000
  * time                          (time) datetime64[ns] 16B 2024-01-01T06:00:...
Dimensions without coordinates: batch
Data variables: (12/18)
    u_component_of_wind           (lon, lat, level, time, batch) float64 216MB ...
    v_component_of_wind           (lon, lat, level, time, batch) float64 216MB ...
    geopotential                  (lon, lat, level, time, batch) float64 216MB ...
    specific_humidity             (lon, lat, level, time, batch) float64 216MB ...
    temperature                   (lon, lat, level, time, batch) float64 216MB ...
    vertical_velocity   

# Format and export the data

In [40]:
# Step 1: Merge the datasets
combined_dataset = xarray.merge([inputs, targets, forcings])

batch_dim = combined_dataset.sizes['batch']  # Number of batches
time_dim = combined_dataset.sizes['time']    # Number of time steps

# Step 2: Fix 'time' coordinate (convert datetime64[ns] to timedelta64[ns])
# Get the original 'time' values (which are in datetime64[ns])
datetime_values = combined_dataset['time'].values

# Calculate the difference between each time value and the first time value (to get timedelta)
timedelta_values = datetime_values - datetime_values[0]

# Update the 'time' coordinate in the dataset with the timedelta values (timedelta64[ns])
combined_dataset['time'] = xarray.DataArray(timedelta_values, dims="time")

# Step 3: Fix 'datetime' by aligning it only to 'batch' and 'time'
# Create a 1D array of datetime values with a frequency of 6 hours (assuming 6 time steps)
datetime_1d = pd.date_range(start="2024-01-01", periods=time_dim, freq='6H').values

# Expand to 2D with the correct batch and time dimensions
datetime_2d = xarray.DataArray(datetime_1d, dims=["time"]).expand_dims({"batch": batch_dim}, axis=0)

# Extract the underlying NumPy data from the DataArray
datetime_2d_data = datetime_2d.data

# Remove the unnecessary dimensions from 'datetime' and assign it properly
combined_dataset = combined_dataset.assign_coords(datetime=(["batch", "time"], datetime_2d_data))

# Print the structure after fixing 'datetime'
print("\nCombined dataset after fixing 'datetime' dimensions (batch, time):")
print(combined_dataset)

# Step 4: Save the final combined dataset to a NetCDF file
combined_output_file = 'source-era5_date-2024-01-01_res-0.25_levels-13_steps-12.nc' # this needs to be dynamic for 
combined_dataset.to_netcdf(combined_output_file)
print(f"\nCombined NetCDF file saved as {combined_output_file}.")


  datetime_1d = pd.date_range(start="2024-01-01", periods=time_dim, freq='6H').values



Combined dataset after fixing 'datetime' dimensions (batch, time):
<xarray.Dataset> Size: 4GB
Dimensions:                       (lon: 1440, lat: 721, level: 13, time: 6,
                                   batch: 1)
Coordinates:
  * lon                           (lon) float64 12kB 0.0 0.25 ... 359.5 359.8
  * lat                           (lat) float64 6kB -90.0 -89.75 ... 89.75 90.0
  * level                         (level) int32 52B 50 100 150 ... 850 925 1000
  * time                          (time) timedelta64[ns] 48B 00:00:00 ... 1 d...
    datetime                      (batch, time) datetime64[ns] 48B 2024-01-01...
Dimensions without coordinates: batch
Data variables: (12/18)
    u_component_of_wind           (lon, lat, level, time, batch) float64 648MB ...
    v_component_of_wind           (lon, lat, level, time, batch) float64 648MB ...
    geopotential                  (lon, lat, level, time, batch) float64 648MB ...
    specific_humidity             (lon, lat, level, time, ba