# run a model from file
This script is used to run a 10 day forecast for the whole of 2020 using saved model parameters from a file.

# Installation and initialisation.

In [1]:
# @title Imports
import sys
sys.path.append("../")

import dataclasses
import functools

from google.cloud import storage
from graphcast import autoregressive
from graphcast import casting
from graphcast import data_utils
from graphcast import graphcast
from graphcast import rollout
from graphcast import normalization
from graphcast import xarray_jax
from graphcast import xarray_tree
import haiku as hk
import jax
import numpy as np
import jax.numpy as jnp
import pandas as pd
import xarray

import optax

import os
import time
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

# Access data from GCS

In [2]:
# @title Authenticate with Google Cloud Storage
gcs_client = storage.Client.create_anonymous_client()
gcs_bucket = gcs_client.get_bucket("dm_graphcast")

## load model from file

In [3]:
# load model params:

def unflatten_dict(d, sep='//'):
    result_dict = {}
    for flat_key, value in d.items():
        keys = flat_key.split(sep)
        d = result_dict
        for key in keys[:-1]:
            if key not in d:
                d[key] = {}
            d = d[key]
        d[keys[-1]] = value
    return result_dict

def load_model_params(file_path):
    with np.load(file_path, allow_pickle=True) as npz_file:
        # Convert NumPy arrays back to JAX arrays
        jax_dict = {k: jnp.array(v) for k, v in npz_file.items()}
    return unflatten_dict(jax_dict)

"""" change path below for correct params"""
rel_path = 'models/params_64x32_mse.npz'
params_path = os.path.join(os.path.dirname(os.getcwd()), rel_path)

## load model parameters

In [4]:
# @title choose model parameters
random_mesh_size = 5 # mesh size: 4 - 6
random_gnn_msg_steps = 8 # message passing steps: 1 - 32
random_latent_size = 128 # latent size: 16,32,64,128,256,512
random_levels = 13 # levels: 13 or 37


In [5]:
# @title load the model parameters
params = load_model_params(params_path)
state = {}
model_config = graphcast.ModelConfig(
    resolution=0,
    mesh_size=random_mesh_size,
    latent_size=random_latent_size,
    gnn_msg_steps=random_gnn_msg_steps,
    hidden_layers=1,
    radius_query_fraction_edge_length=0.6)
task_config = graphcast.TaskConfig(
    input_variables=graphcast.TASK.input_variables,
    target_variables=graphcast.TASK.target_variables,
    forcing_variables=graphcast.TASK.forcing_variables,
    pressure_levels=graphcast.PRESSURE_LEVELS[random_levels],
    input_duration=graphcast.TASK.input_duration,
)
model_config

ModelConfig(resolution=0, mesh_size=5, latent_size=128, gnn_msg_steps=8, hidden_layers=1, radius_query_fraction_edge_length=0.6, mesh2grid_edge_normalization_factor=None)

## load obs_data:

In [6]:
# load obs data:
# Define the relative path to the file
relative_path = "ERA5_data/obs_data.zarr"

# Get the absolute path by joining the current directory with the relative path
absolute_path = os.path.join(os.path.dirname(os.getcwd()), relative_path)
print(absolute_path)

# Open the Zarr file using xarray
obs_data = xarray.open_zarr(absolute_path)

/home/jupyter-lukas/Masters-Thesis/ERA5_data/obs_data.zarr


In [7]:
# get eval data from obs data:
#time slice: 2020-01-01T00:00:00.000000000 to 2022-01-01T00:00:00.000000000 - EVALUATION
eval_data = obs_data.sel(time=slice('2020-01-01T00:00:00.000000000','2021-03-31T00:00:00.000000000'))
eval_steps = 40 # {1 - obs_data.sizes["time"]-2} | 40 = 10days
eval_data = eval_data.compute()
eval_data

In [8]:
eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(
    eval_data, target_lead_times=slice("6h", f"{eval_steps*6}h"),
    **dataclasses.asdict(task_config))

## load normalisation data

In [9]:
with gcs_bucket.blob("stats/diffs_stddev_by_level.nc").open("rb") as f:
    diffs_stddev_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob("stats/mean_by_level.nc").open("rb") as f:
    mean_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob("stats/stddev_by_level.nc").open("rb") as f:
    stddev_by_level = xarray.load_dataset(f).compute()

## build functions

In [10]:
# build functions + jit:
def construct_wrapped_graphcast(
    model_config: graphcast.ModelConfig,
    task_config: graphcast.TaskConfig):
  """Constructs and wraps the GraphCast Predictor."""
  # Deeper one-step predictor.
  predictor = graphcast.GraphCast(model_config, task_config)

  # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
  # from/to float32 to/from BFloat16.
  predictor = casting.Bfloat16Cast(predictor)

  # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
  # BFloat16 happens after applying normalization to the inputs/targets.
  predictor = normalization.InputsAndResiduals(
      predictor,
      diffs_stddev_by_level=diffs_stddev_by_level,
      mean_by_level=mean_by_level,
      stddev_by_level=stddev_by_level)

  # Wraps everything so the one-step model can produce trajectories.
  predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)
  return predictor


@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
  predictor = construct_wrapped_graphcast(model_config, task_config)
  return predictor(inputs, targets_template=targets_template, forcings=forcings)


@hk.transform_with_state
def loss_fn(model_config, task_config, inputs, targets, forcings):
  predictor = construct_wrapped_graphcast(model_config, task_config)
  loss, diagnostics = predictor.loss(inputs, targets, forcings)
  return xarray_tree.map_structure(
      lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
      (loss, diagnostics))

def grads_fn(params, state, inputs, targets, forcings, model_config, task_config):
  def _aux(params, state, i, t, f):
    (loss, diagnostics), next_state = loss_fn.apply(
        params, state, jax.random.PRNGKey(0), model_config, task_config,
        i, t, f)
    return loss, (diagnostics, next_state)
  (loss, (diagnostics, next_state)), grads = jax.value_and_grad(
      _aux, has_aux=True)(params, state, inputs, targets, forcings)
  return loss, diagnostics, next_state, grads

# Jax doesn't seem to like passing configs as args through the jit. Passing it
# in via partial (instead of capture by closure) forces jax to invalidate the
# jit cache if you change configs.
def with_configs(fn):
  return functools.partial(
      fn, model_config=model_config, task_config=task_config)

# Always pass params and state, so the usage below are simpler
def with_params(fn):
  return functools.partial(fn, params=params, state=state)

# Our models aren't stateful, so the state is always empty, so just return the
# predictions. This is required by our rollout code, and generally simpler.
def drop_state(fn):
  return lambda **kw: fn(**kw)[0]

init_jitted = jax.jit(with_configs(run_forward.init))

if params is None:
  params, state = init_jitted(
      rng=jax.random.PRNGKey(0),
      inputs=train_inputs.compute(),
      targets_template=train_targets.compute(),
      forcings=train_forcings.compute())

loss_fn_jitted = drop_state(with_params(jax.jit(with_configs(loss_fn.apply))))
grads_fn_jitted = jax.jit(with_configs(grads_fn))
run_forward_jitted = drop_state(with_params(jax.jit(with_configs(
    run_forward.apply))))

## model running loop: 

In [11]:
# loop:
N = 1465 # = 366*4 + 1
"""" change path below for correct destination file"""
store_path = 'predictions/pred_64x32_2020_mse1.zarr'
store_path = os.path.join(os.path.dirname(os.getcwd()), store_path)
time_values = []

for i in range(N):
    eval_batch = eval_data.isel(time=slice(i, i + eval_steps + 2))
    eval_batch = eval_batch.compute()

    eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(
    eval_batch, target_lead_times=slice("6h", f"{eval_steps*6}h"),
    **dataclasses.asdict(task_config))

    time_value = pd.Timestamp("2020-01-01 00:00:00") + pd.Timedelta(hours=i*6)
    time_values.append(time_value)

    prediction = rollout.chunked_prediction(
        run_forward_jitted,
        rng=jax.random.PRNGKey(0),
        inputs=eval_inputs,
        targets_template=eval_targets * np.nan,
        forcings=eval_forcings)
    
    # modify dataset:
    prediction = xarray.concat([eval_inputs.isel(time=1),prediction], dim='time') #!!! possibly remove this to remove 00:00:00.
    prediction = prediction.rename({'time': 'prediction_timedelta'})
    prediction = prediction.expand_dims(time=[time_value])
    
    # Write the prediction dataset to the zarr store
    if i == 0:
        prediction.to_zarr(store_path, mode='w',encoding={'time': {'dtype': 'float64', 'units': 'hours since 2020-01-01'}})
    else:
        prediction.to_zarr(store_path, append_dim='time')

    if i% (4*7) == 0:
        print("time:" , time_value)

print("prediction run completed!")




time: 2020-01-01 00:00:00
time: 2020-01-08 00:00:00
time: 2020-01-15 00:00:00
time: 2020-01-22 00:00:00
time: 2020-01-29 00:00:00
time: 2020-02-05 00:00:00
time: 2020-02-12 00:00:00
time: 2020-02-19 00:00:00
time: 2020-02-26 00:00:00
time: 2020-03-04 00:00:00
time: 2020-03-11 00:00:00
time: 2020-03-18 00:00:00
time: 2020-03-25 00:00:00
time: 2020-04-01 00:00:00
time: 2020-04-08 00:00:00
time: 2020-04-15 00:00:00
time: 2020-04-22 00:00:00
time: 2020-04-29 00:00:00
time: 2020-05-06 00:00:00
time: 2020-05-13 00:00:00
time: 2020-05-20 00:00:00
time: 2020-05-27 00:00:00
time: 2020-06-03 00:00:00
time: 2020-06-10 00:00:00
time: 2020-06-17 00:00:00
time: 2020-06-24 00:00:00
time: 2020-07-01 00:00:00
time: 2020-07-08 00:00:00
time: 2020-07-15 00:00:00
time: 2020-07-22 00:00:00
time: 2020-07-29 00:00:00
time: 2020-08-05 00:00:00
time: 2020-08-12 00:00:00
time: 2020-08-19 00:00:00
time: 2020-08-26 00:00:00
time: 2020-09-02 00:00:00
time: 2020-09-09 00:00:00
time: 2020-09-16 00:00:00
time: 2020-0

In [12]:
# load pred data:

# Get the absolute path by joining the current directory with the relative path
absolute_path = os.path.join(os.path.dirname(os.getcwd()), store_path)
# Open the Zarr file using xarray
pred_data = xarray.open_zarr(absolute_path)
pred_data

Unnamed: 0,Array,Chunk
Bytes,469.26 MiB,164.00 kiB
Shape,"(1465, 1, 64, 32, 41)","(1, 1, 32, 32, 41)"
Dask graph,2930 chunks in 2 graph layers,2930 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 469.26 MiB 164.00 kiB Shape (1465, 1, 64, 32, 41) (1, 1, 32, 32, 41) Dask graph 2930 chunks in 2 graph layers Data type float32 numpy.ndarray",1  1465  41  32  64,

Unnamed: 0,Array,Chunk
Bytes,469.26 MiB,164.00 kiB
Shape,"(1465, 1, 64, 32, 41)","(1, 1, 32, 32, 41)"
Dask graph,2930 chunks in 2 graph layers,2930 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,469.26 MiB,164.00 kiB
Shape,"(1465, 1, 64, 32, 41)","(1, 1, 32, 32, 41)"
Dask graph,2930 chunks in 2 graph layers,2930 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 469.26 MiB 164.00 kiB Shape (1465, 1, 64, 32, 41) (1, 1, 32, 32, 41) Dask graph 2930 chunks in 2 graph layers Data type float32 numpy.ndarray",1  1465  41  32  64,

Unnamed: 0,Array,Chunk
Bytes,469.26 MiB,164.00 kiB
Shape,"(1465, 1, 64, 32, 41)","(1, 1, 32, 32, 41)"
Dask graph,2930 chunks in 2 graph layers,2930 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,469.26 MiB,164.00 kiB
Shape,"(1465, 1, 64, 32, 41)","(1, 1, 32, 32, 41)"
Dask graph,2930 chunks in 2 graph layers,2930 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 469.26 MiB 164.00 kiB Shape (1465, 1, 64, 32, 41) (1, 1, 32, 32, 41) Dask graph 2930 chunks in 2 graph layers Data type float32 numpy.ndarray",1  1465  41  32  64,

Unnamed: 0,Array,Chunk
Bytes,469.26 MiB,164.00 kiB
Shape,"(1465, 1, 64, 32, 41)","(1, 1, 32, 32, 41)"
Dask graph,2930 chunks in 2 graph layers,2930 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,14.66 MiB,10.25 kiB
Shape,"(1465, 41, 1, 64)","(1, 41, 1, 64)"
Dask graph,1465 chunks in 2 graph layers,1465 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 14.66 MiB 10.25 kiB Shape (1465, 41, 1, 64) (1, 41, 1, 64) Dask graph 1465 chunks in 2 graph layers Data type float32 numpy.ndarray",1465  1  64  1  41,

Unnamed: 0,Array,Chunk
Bytes,14.66 MiB,10.25 kiB
Shape,"(1465, 41, 1, 64)","(1, 41, 1, 64)"
Dask graph,1465 chunks in 2 graph layers,1465 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,14.66 MiB,10.25 kiB
Shape,"(1465, 41, 1, 64)","(1, 41, 1, 64)"
Dask graph,1465 chunks in 2 graph layers,1465 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 14.66 MiB 10.25 kiB Shape (1465, 41, 1, 64) (1, 41, 1, 64) Dask graph 1465 chunks in 2 graph layers Data type float32 numpy.ndarray",1465  1  64  1  41,

Unnamed: 0,Array,Chunk
Bytes,14.66 MiB,10.25 kiB
Shape,"(1465, 41, 1, 64)","(1, 41, 1, 64)"
Dask graph,1465 chunks in 2 graph layers,1465 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,5.96 GiB,574.00 kiB
Shape,"(1465, 1, 13, 64, 32, 41)","(1, 1, 7, 32, 16, 41)"
Dask graph,11720 chunks in 2 graph layers,11720 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 5.96 GiB 574.00 kiB Shape (1465, 1, 13, 64, 32, 41) (1, 1, 7, 32, 16, 41) Dask graph 11720 chunks in 2 graph layers Data type float32 numpy.ndarray",13  1  1465  41  32  64,

Unnamed: 0,Array,Chunk
Bytes,5.96 GiB,574.00 kiB
Shape,"(1465, 1, 13, 64, 32, 41)","(1, 1, 7, 32, 16, 41)"
Dask graph,11720 chunks in 2 graph layers,11720 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,469.26 MiB,168.00 kiB
Shape,"(1465, 41, 1, 64, 32)","(1, 21, 1, 64, 32)"
Dask graph,2930 chunks in 2 graph layers,2930 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 469.26 MiB 168.00 kiB Shape (1465, 41, 1, 64, 32) (1, 21, 1, 64, 32) Dask graph 2930 chunks in 2 graph layers Data type float32 numpy.ndarray",41  1465  32  64  1,

Unnamed: 0,Array,Chunk
Bytes,469.26 MiB,168.00 kiB
Shape,"(1465, 41, 1, 64, 32)","(1, 21, 1, 64, 32)"
Dask graph,2930 chunks in 2 graph layers,2930 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,469.26 MiB,168.00 kiB
Shape,"(1465, 41, 1, 64, 32)","(1, 21, 1, 64, 32)"
Dask graph,2930 chunks in 2 graph layers,2930 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 469.26 MiB 168.00 kiB Shape (1465, 41, 1, 64, 32) (1, 21, 1, 64, 32) Dask graph 2930 chunks in 2 graph layers Data type float32 numpy.ndarray",41  1465  32  64  1,

Unnamed: 0,Array,Chunk
Bytes,469.26 MiB,168.00 kiB
Shape,"(1465, 41, 1, 64, 32)","(1, 21, 1, 64, 32)"
Dask graph,2930 chunks in 2 graph layers,2930 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,469.26 MiB,164.00 kiB
Shape,"(1465, 1, 64, 32, 41)","(1, 1, 32, 32, 41)"
Dask graph,2930 chunks in 2 graph layers,2930 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 469.26 MiB 164.00 kiB Shape (1465, 1, 64, 32, 41) (1, 1, 32, 32, 41) Dask graph 2930 chunks in 2 graph layers Data type float32 numpy.ndarray",1  1465  41  32  64,

Unnamed: 0,Array,Chunk
Bytes,469.26 MiB,164.00 kiB
Shape,"(1465, 1, 64, 32, 41)","(1, 1, 32, 32, 41)"
Dask graph,2930 chunks in 2 graph layers,2930 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,5.96 GiB,574.00 kiB
Shape,"(1465, 1, 13, 64, 32, 41)","(1, 1, 7, 32, 16, 41)"
Dask graph,11720 chunks in 2 graph layers,11720 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 5.96 GiB 574.00 kiB Shape (1465, 1, 13, 64, 32, 41) (1, 1, 7, 32, 16, 41) Dask graph 11720 chunks in 2 graph layers Data type float32 numpy.ndarray",13  1  1465  41  32  64,

Unnamed: 0,Array,Chunk
Bytes,5.96 GiB,574.00 kiB
Shape,"(1465, 1, 13, 64, 32, 41)","(1, 1, 7, 32, 16, 41)"
Dask graph,11720 chunks in 2 graph layers,11720 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,5.96 GiB,574.00 kiB
Shape,"(1465, 1, 13, 64, 32, 41)","(1, 1, 7, 32, 16, 41)"
Dask graph,11720 chunks in 2 graph layers,11720 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 5.96 GiB 574.00 kiB Shape (1465, 1, 13, 64, 32, 41) (1, 1, 7, 32, 16, 41) Dask graph 11720 chunks in 2 graph layers Data type float32 numpy.ndarray",13  1  1465  41  32  64,

Unnamed: 0,Array,Chunk
Bytes,5.96 GiB,574.00 kiB
Shape,"(1465, 1, 13, 64, 32, 41)","(1, 1, 7, 32, 16, 41)"
Dask graph,11720 chunks in 2 graph layers,11720 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,469.26 MiB,168.00 kiB
Shape,"(1465, 41, 1, 64, 32)","(1, 21, 1, 64, 32)"
Dask graph,2930 chunks in 2 graph layers,2930 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 469.26 MiB 168.00 kiB Shape (1465, 41, 1, 64, 32) (1, 21, 1, 64, 32) Dask graph 2930 chunks in 2 graph layers Data type float32 numpy.ndarray",41  1465  32  64  1,

Unnamed: 0,Array,Chunk
Bytes,469.26 MiB,168.00 kiB
Shape,"(1465, 41, 1, 64, 32)","(1, 21, 1, 64, 32)"
Dask graph,2930 chunks in 2 graph layers,2930 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,469.26 MiB,164.00 kiB
Shape,"(1465, 1, 64, 32, 41)","(1, 1, 32, 32, 41)"
Dask graph,2930 chunks in 2 graph layers,2930 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 469.26 MiB 164.00 kiB Shape (1465, 1, 64, 32, 41) (1, 1, 32, 32, 41) Dask graph 2930 chunks in 2 graph layers Data type float32 numpy.ndarray",1  1465  41  32  64,

Unnamed: 0,Array,Chunk
Bytes,469.26 MiB,164.00 kiB
Shape,"(1465, 1, 64, 32, 41)","(1, 1, 32, 32, 41)"
Dask graph,2930 chunks in 2 graph layers,2930 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,5.96 GiB,574.00 kiB
Shape,"(1465, 1, 13, 64, 32, 41)","(1, 1, 7, 32, 16, 41)"
Dask graph,11720 chunks in 2 graph layers,11720 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 5.96 GiB 574.00 kiB Shape (1465, 1, 13, 64, 32, 41) (1, 1, 7, 32, 16, 41) Dask graph 11720 chunks in 2 graph layers Data type float32 numpy.ndarray",13  1  1465  41  32  64,

Unnamed: 0,Array,Chunk
Bytes,5.96 GiB,574.00 kiB
Shape,"(1465, 1, 13, 64, 32, 41)","(1, 1, 7, 32, 16, 41)"
Dask graph,11720 chunks in 2 graph layers,11720 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,5.96 GiB,574.00 kiB
Shape,"(1465, 1, 13, 64, 32, 41)","(1, 1, 7, 32, 16, 41)"
Dask graph,11720 chunks in 2 graph layers,11720 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 5.96 GiB 574.00 kiB Shape (1465, 1, 13, 64, 32, 41) (1, 1, 7, 32, 16, 41) Dask graph 11720 chunks in 2 graph layers Data type float32 numpy.ndarray",13  1  1465  41  32  64,

Unnamed: 0,Array,Chunk
Bytes,5.96 GiB,574.00 kiB
Shape,"(1465, 1, 13, 64, 32, 41)","(1, 1, 7, 32, 16, 41)"
Dask graph,11720 chunks in 2 graph layers,11720 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,5.96 GiB,574.00 kiB
Shape,"(1465, 1, 13, 64, 32, 41)","(1, 1, 7, 32, 16, 41)"
Dask graph,11720 chunks in 2 graph layers,11720 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 5.96 GiB 574.00 kiB Shape (1465, 1, 13, 64, 32, 41) (1, 1, 7, 32, 16, 41) Dask graph 11720 chunks in 2 graph layers Data type float32 numpy.ndarray",13  1  1465  41  32  64,

Unnamed: 0,Array,Chunk
Bytes,5.96 GiB,574.00 kiB
Shape,"(1465, 1, 13, 64, 32, 41)","(1, 1, 7, 32, 16, 41)"
Dask graph,11720 chunks in 2 graph layers,11720 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,234.63 kiB,164 B
Shape,"(1465, 41, 1)","(1, 41, 1)"
Dask graph,1465 chunks in 2 graph layers,1465 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 234.63 kiB 164 B Shape (1465, 41, 1) (1, 41, 1) Dask graph 1465 chunks in 2 graph layers Data type float32 numpy.ndarray",1  41  1465,

Unnamed: 0,Array,Chunk
Bytes,234.63 kiB,164 B
Shape,"(1465, 41, 1)","(1, 41, 1)"
Dask graph,1465 chunks in 2 graph layers,1465 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,234.63 kiB,164 B
Shape,"(1465, 41, 1)","(1, 41, 1)"
Dask graph,1465 chunks in 2 graph layers,1465 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 234.63 kiB 164 B Shape (1465, 41, 1) (1, 41, 1) Dask graph 1465 chunks in 2 graph layers Data type float32 numpy.ndarray",1  41  1465,

Unnamed: 0,Array,Chunk
Bytes,234.63 kiB,164 B
Shape,"(1465, 41, 1)","(1, 41, 1)"
Dask graph,1465 chunks in 2 graph layers,1465 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [13]:
# Example usage
size_in_gb = pred_data.nbytes / (1024*1024*1024)
print("Size of pred_data dataset:", size_in_gb, "GB")

Size of pred_data dataset: 39.43942487239838 GB


Now the predictions run for 2020 is completed.