In [1]:
import dataclasses
import datetime
import functools
import math
import re
from typing import Optional

import cartopy.crs as ccrs
from google.cloud import storage
from graphcast import autoregressive
from graphcast import casting
from graphcast import checkpoint
from graphcast import data_utils
from graphcast import graphcast
from graphcast import normalization
from graphcast import rollout
from graphcast import xarray_jax
from graphcast import xarray_tree
from IPython.display import HTML
import ipywidgets as widgets
import haiku as hk
import jax
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
import numpy as np
import xarray
import his_utils

def parse_file_parts(file_name):
  return dict(part.split("-", 1) for part in file_name.split("_"))

gcs_client = storage.Client.create_anonymous_client()
gcs_bucket = gcs_client.get_bucket("dm_graphcast")
gcs_bucket

<Bucket: dm_graphcast>

### Set the model and properties

In [2]:
# initialization 
# TODO: 여기 값 바꾸면 됨 근데 모델 불러서 쓸거면 상관 없ㅅ음
resolution = 0 # 0.25 or 1.0
mesh_size = 1000 # 4~6
latent_size = 100 # 2^4 ~ 2^9
gnn_msg_steps = 10 # 1~32
pressure_levels = 37 # 13, 25, 37
hidden_layers = 1 
radius_query_fraction_edge_length = 0.6 # 1로도 가능
params = None
state = {}

In [3]:
# 모델 생성
model_config = graphcast.ModelConfig(
        resolution=resolution, 
        mesh_size=mesh_size,
        latent_size=latent_size,
        gnn_msg_steps=gnn_msg_steps,
        hidden_layers=hidden_layers,
        radius_query_fraction_edge_length=radius_query_fraction_edge_length    
    )

task_config = graphcast.TaskConfig(
        input_variables=(graphcast.TARGET_SURFACE_VARS + graphcast.TARGET_ATMOSPHERIC_VARS + graphcast.FORCING_VARS +
        graphcast.STATIC_VARS),
        target_variables=graphcast.TARGET_SURFACE_VARS + graphcast.TARGET_ATMOSPHERIC_VARS,
        forcing_variables=graphcast.FORCING_VARS,
        pressure_levels=graphcast.PRESSURE_LEVELS[pressure_levels],
        input_duration="12h"
    )

model_config

ModelConfig(resolution=0, mesh_size=1000, latent_size=100, gnn_msg_steps=10, hidden_layers=1, radius_query_fraction_edge_length=0.6, mesh2grid_edge_normalization_factor=None)

In [4]:
# 그냥 가져다 쓰고 싶으면 여기서
GC_original = 'GraphCast - ERA5 1979-2017 - resolution 0.25 - pressure levels 37 - mesh 2to6 - precipitation input and output.npz'
GC_operational = 'GraphCast_operational - ERA5-HRES 1979-2021 - resolution 0.25 - pressure levels 13 - mesh 2to6 - precipitation output only.npz'
GC_small = 'GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz'

# TODO: 모델 하나 골라 로드하기
pretrained_model = GC_original

with gcs_bucket.blob(f"params/{pretrained_model}").open("rb") as f:
    ckpt = checkpoint.load(f, graphcast.CheckPoint)

In [5]:

params = ckpt.params  # 로드된 체크포인트에서 파라미터 가져오기
state = {}  # 초기 상태는 빈 딕셔너리로 설정

# 체크포인트에서 모델 구성 가져오기
model_config = ckpt.model_config
# 체크포인트에서 작업 구성 가져오기
task_config = ckpt.task_config

# 모델 설명 출력
print("Model description:\n", ckpt.description, "\n")
# 모델 라이선스 출력
print("Model license:\n", ckpt.license, "\n")

print("Model config:\n", model_config, "\n")

print("Task config:\n", task_config)

Model description:
 
GraphCast model at 0.25deg resolution, with 37 pressure levels. This model is
trained on ERA5 data from 1979 to 2017, and can be causally evaluated on 2018
and later years. This model takes as inputs `total_precipitation_6hr`. This was
described in the paper
`GraphCast: Learning skillful medium-range global weather forecasting`
(https://arxiv.org/abs/2212.12794).
 

Model license:
 
The model weights are licensed under the Creative Commons
Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0). You
may obtain a copy of the License at:
https://creativecommons.org/licenses/by-nc-sa/4.0/.
The weights were trained on ERA5 data, see README for attribution statement.
 

Model config:
 ModelConfig(resolution=0.25, mesh_size=6, latent_size=512, gnn_msg_steps=16, hidden_layers=1, radius_query_fraction_edge_length=0.5999912857713345, mesh2grid_edge_normalization_factor=0.6180338738074472) 

Task config:
 TaskConfig(input_variables=('2m_temperature', 'mean_s

### Load Data

In [6]:
# 모델과 데이터셋의 유효성을 검사하는 함수
def data_valid_for_model(
    file_name: str, 
    model_config: graphcast.ModelConfig, 
    task_config: graphcast.TaskConfig):
    # 파일 이름에서 접미사 ".nc"를 제거하고 파일의 각 부분을 파싱
    file_parts = parse_file_parts(file_name.removesuffix(".nc"))
    
    # 모델의 설정과 데이터셋 파일의 설정을 비교하여 유효성 검사
    return (
        # 모델의 해상도가 0이거나 파일의 해상도와 일치해야 함
        model_config.resolution in (0, float(file_parts["res"])) and
        # 모델의 압력 레벨 수가 파일의 압력 레벨 수와 일치해야 함
        len(task_config.pressure_levels) == int(file_parts["levels"]) and
        (
            # 모델이 강수량을 입력 변수로 가지고 있는 경우, 파일 소스는 "era5" 또는 "fake"여야 함
            ("total_precipitation_6hr" in task_config.input_variables and
             file_parts["source"] in ("era5", "fake")) or
            # 모델이 강수량을 입력 변수로 가지고 있지 않은 경우, 파일 소스는 "hres" 또는 "fake"여야 함
            ("total_precipitation_6hr" not in task_config.input_variables and
             file_parts["source"] in ("hres", "fake"))
        )
    )

In [7]:
# 데이터 로드는 여기서

# TODO: 다운받은 데이터를 잘 가공해서 여기에 넣어야 한다...
# 요구조건은 무엇인지? 필수 param.은 뭐가 있는지?
# 사이즈는? 
# file_name = "testdata/new_sample.nc"

# e.g.
file_name = "testdata/source-era5_date-2022-01-01_res-0.25_levels-37_steps-12.nc"
dataset = xarray.open_dataset(file_name)
print(dataset.dims)
print(dataset.coords)
dataset

Coordinates:
  * lon       (lon) float32 6kB 0.0 0.25 0.5 0.75 ... 359.0 359.2 359.5 359.8
  * lat       (lat) float32 3kB -90.0 -89.75 -89.5 -89.25 ... 89.5 89.75 90.0
  * level     (level) int32 148B 1 2 3 5 7 10 20 ... 875 900 925 950 975 1000
  * time      (time) timedelta64[ns] 112B 0 days 00:00:00 ... 3 days 06:00:00
    datetime  (batch, time) datetime64[ns] 112B ...


In [8]:
# 점검

# 데이터셋 파일이 모델 설정에 유효한지 확인
# if not data_valid_for_model(file_name, model_config, task_config):
    # raise ValueError(
        # "Invalid dataset file, rerun the cell above and choose a valid dataset file.")
# 
# 데이터셋의 시간 차원이 최소 3인지 확인 (입력용 2, 목표용 1 이상)
# assert dataset.dims["time"] >= 3  # 2 for input, >=1 for targets


# 선택한 데이터셋 파일의 정보를 출력
# print(", ".join([f"{k}: {v}" for k, v in parse_file_parts(file_name.removesuffix(".nc")).items()]))

# 로드된 데이터셋을 출력
# dataset

In [9]:
# 데이터셋의 시간 크기 확인
total_time_steps = dataset.sizes["time"] - 2

total_time_steps

12

In [10]:
# split dataset
train_steps = 2
eval_steps = 2

# train_steps와 eval_steps가 유효한 범위 내에 있는지 확인
assert 1 <= train_steps <= total_time_steps, f"train_steps must be between 1 and {total_time_steps}."

assert 1 <= eval_steps <= total_time_steps, f"eval_steps must be between 1 and {total_time_steps}."

dataset

In [11]:
# train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
#     dataset, 
#     target_lead_times=slice("6h", f"{train_steps * 6}h"), 
#     **dataclasses.asdict(task_config)
# )

eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(
    dataset, 
    target_lead_times=slice("6h", f"{eval_steps * 6}h"), 
    **dataclasses.asdict(task_config)
)

|| data_utils.py -> add_derived_vars() init. ||


In [12]:
# Print the dimensions of the example batch and extracted training and evaluation data
print("All Examples:  ", dataset.dims.mapping)
print("----------------------------------------------")
# print("Train Inputs:  ", train_inputs.dims.mapping)
# print("Train Targets: ", train_targets.dims.mapping)
# print("Train Forcings:", train_forcings.dims.mapping)
print("----------------------------------------------")
print("Eval Inputs:   ", eval_inputs.dims.mapping)
print("Eval Targets:  ", eval_targets.dims.mapping)
print("Eval Forcings: ", eval_forcings.dims.mapping)
print("----------------------------------------------")
# print("Eval Inputs:   \n", eval_inputs)
# print("Eval Targets:  \n", eval_targets)
# print("Eval Forcings: \n", eval_forcings)

All Examples:   {'lon': 1440, 'lat': 721, 'level': 37, 'time': 14, 'batch': 1}
----------------------------------------------
----------------------------------------------
Eval Inputs:    {'batch': 1, 'time': 2, 'lat': 721, 'lon': 1440, 'level': 37}
Eval Targets:   {'batch': 1, 'time': 2, 'lat': 721, 'lon': 1440, 'level': 37}
Eval Forcings:  {'batch': 1, 'time': 2, 'lat': 721, 'lon': 1440}
----------------------------------------------


In [13]:
# set(eval_inputs.data_vars) - set(eval_targets.data_vars) - set(eval_forcings.data_vars)

In [14]:
# @title Load normalization data

# Load the dataset containing the standard deviations of differences by level
# from the specified Google Cloud Storage (GCS) bucket
with gcs_bucket.blob("stats/diffs_stddev_by_level.nc").open("rb") as f:
    diffs_stddev_by_level = xarray.load_dataset(f).compute()

# Load the dataset containing the mean values by level
# from the specified GCS bucket
with gcs_bucket.blob("stats/mean_by_level.nc").open("rb") as f:
    mean_by_level = xarray.load_dataset(f).compute()

# Load the dataset containing the standard deviations by level
# from the specified GCS bucket
with gcs_bucket.blob("stats/stddev_by_level.nc").open("rb") as f:
    stddev_by_level = xarray.load_dataset(f).compute()


In [15]:
# JIT 컴파일된 함수를 생성하고, 필요한 경우 랜덤 가중치를 초기화합니다

def construct_wrapped_graphcast(
    model_config: graphcast.ModelConfig,
    task_config: graphcast.TaskConfig):
  """GraphCast 예측기를 구성하고 래핑합니다."""
  predictor = graphcast.GraphCast(model_config, task_config)
  predictor = casting.Bfloat16Cast(predictor)
  predictor = normalization.InputsAndResiduals(
      predictor,
      diffs_stddev_by_level=diffs_stddev_by_level,
      mean_by_level=mean_by_level,
      stddev_by_level=stddev_by_level)
  predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)
  return predictor

@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
    """구성된 예측기를 사용하여 순전파를 실행합니다."""
    predictor = construct_wrapped_graphcast(model_config, task_config)
    return predictor(inputs, targets_template=targets_template, forcings=forcings)

@hk.transform_with_state
def loss_fn(model_config, task_config, inputs, targets, forcings):
    """구성된 예측기를 사용하여 손실과 진단 정보를 계산합니다."""
    predictor = construct_wrapped_graphcast(model_config, task_config)
    loss, diagnostics = predictor.loss(inputs, targets, forcings)
    return xarray_tree.map_structure(
        lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
        (loss, diagnostics)
    )

def grads_fn(params, state, model_config, task_config, inputs, targets, forcings):
  """매개변수에 대한 손실 함수의 그래디언트를 계산합니다."""
  def _aux(params, state, i, t, f):
    (loss, diagnostics), next_state = loss_fn.apply(
        params, state, jax.random.PRNGKey(0), model_config, task_config,
        i, t, f)
    return loss, (diagnostics, next_state)
  (loss, (diagnostics, next_state)), grads = jax.value_and_grad(
      _aux, has_aux=True)(params, state, inputs, targets, forcings)
  return loss, diagnostics, next_state, grads

def with_configs(fn):
  """모델 및 작업 구성을 적용하는 유틸리티 함수입니다."""
  return functools.partial(
      fn, model_config=model_config, task_config=task_config)

def with_params(fn):
  """함수가 항상 매개변수와 상태를 받도록 보장합니다."""
  return functools.partial(fn, params=params, state=state)

def drop_state(fn):
  """예측만 반환합니다. 롤아웃 코드에 필요합니다."""
  return lambda **kw: fn(**kw)[0]

init_jitted = jax.jit(with_configs(run_forward.init))

if params is None:
    # 매개변수와 상태가 아직 설정되지 않은 경우 초기화합니다.
    params, state = init_jitted(
        rng=jax.random.PRNGKey(0),
        inputs=train_inputs,
        targets_template=train_targets,
        forcings=train_forcings
    )

# 필요한 구성과 매개변수로 함수들을 JIT 컴파일합니다.
loss_fn_jitted = drop_state(with_params(jax.jit(with_configs(loss_fn.apply))))
grads_fn_jitted = with_params(jax.jit(with_configs(grads_fn)))
run_forward_jitted = drop_state(with_params(jax.jit(with_configs(run_forward.apply))))

In [16]:
# Ensure that the model resolution matches the data resolution
assert model_config.resolution in (0, 360. / eval_inputs.sizes["lon"]), (
    "Model resolution doesn't match the data resolution. You likely want to "
    "re-filter the dataset list, and download the correct data.")

In [17]:
# eval_targets * np.nan
eval_forcings.data_vars

Data variables:
    toa_incident_solar_radiation  (batch, time, lat, lon) float32 8MB ...
    year_progress_sin             (batch, time) float32 8B 0.05856 0.06285
    year_progress_cos             (batch, time) float32 8B 0.9983 0.998
    day_progress_sin              (batch, time, lon) float32 12kB 0.0 ... 1.0
    day_progress_cos              (batch, time, lon) float32 12kB 1.0 ... 0.0...

In [18]:
print(jax.__version__)
print(jax.devices())

import his_utils

# Perform autoregressive rollout to generate predictions
predictions = rollout.chunked_prediction(
    run_forward_jitted,
    rng=jax.random.PRNGKey(0),
    inputs=eval_inputs,
    targets_template=his_utils.create_target_dataset(time_steps=2, 
                   resolution=0.25, 
                   pressure_levels=37),
    forcings=his_utils.create_forcing_dataset(time_steps=2,
                                              resolution=0.25,
                                              start_time="2022-01-01")
)

# Display the predictions
predictions

0.4.23
[cuda(id=0), cuda(id=1)]
|| data_utils.py -> add_derived_vars() init. ||
|| rollout.py -> chunked_prediction() init. ||
|| rollout.py -> chunked_prediction_generator() init. ||


  num_target_steps = targets_template.dims["time"]
  scan_length = targets_template.dims['time']


|| graphcast.py -> __call__() init ||
latent_mesh_nodes shape: (40962, 1, 512)
updated_latent_mesh_nodes shape: (40962, 1, 512)
output_grid_nodes shape: (1038240, 1, 227)
|| graphcast.py -> __call__() done ||


  num_inputs = inputs.dims['time']
2024-07-18 16:25:08.229067: E external/xla/xla/service/slow_operation_alarm.cc:65] Constant folding an instruction is taking > 1s:

  %pad.1 = bf16[40962,1,474]{2,1,0} pad(bf16[40962,1,3]{2,1,0} %constant.348, bf16[] %constant.349), padding=0_0x0_0x471_0, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/while/body/concatenate[dimension=2]" source_file="/home/hiskim1/graphcast/graphcast/graphcast.py" source_line=643}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2024-07-18 16:25:08.379593: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 1.150870224s
Constant folding

|| rollout.py -> chunked_prediction_generator() done ||
|| rollout.py -> chunked_prediction_generator() done ||
|| rollout.py -> chunked_prediction() done ||


In [None]:
# predictions.to_netcdf("testdata/googles_predictions.nc")

In [None]:
his_prediction = xarray.open_dataset("testdata/his_predictions.nc")
google_prediction = xarray.open_dataset("testdata/googles_predictions.nc")

In [None]:
for var in his_prediction.data_vars:
    for t in range(2):
        diff = his_prediction[var].isel(time=t) - google_prediction[var].isel(time=t)
        print(f"{var}: {diff.min().values}, {diff.max().values}, {diff.mean().values}")
    print("----------------")

In [None]:
diff = his_prediction["mean_sea_level_pressure"] - google_prediction["mean_sea_level_pressure"]


diff[0,:].plot(cmap="bwr", vmin=-35, vmax=35)

In [None]:

diff[1,:].plot(cmap="bwr",vmin=-35, vmax=35)

In [53]:
eval_forcings

In [55]:
eval_targets.dims

