In [None]:
import dataclasses
import datetime
import functools
import math
import re
from typing import Optional

import cartopy.crs as ccrs
from google.cloud import storage
from graphcast import autoregressive
from graphcast import casting
from graphcast import checkpoint
from graphcast import data_utils
from graphcast import graphcast
from graphcast import normalization
from graphcast import rollout
from graphcast import xarray_jax
from graphcast import xarray_tree
from IPython.display import HTML
import ipywidgets as widgets
import haiku as hk
import jax
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
import numpy as np
import xarray
import his_utils

def parse_file_parts(file_name):
  return dict(part.split("-", 1) for part in file_name.split("_"))

gcs_client = storage.Client.create_anonymous_client()
gcs_bucket = gcs_client.get_bucket("dm_graphcast")
gcs_bucket

### Set the model and properties

In [None]:
# initialization 
# TODO: 여기 값 바꾸면 됨 근데 모델 불러서 쓸거면 상관 없ㅅ음
resolution = 0 # 0.25 or 1.0
mesh_size = 1000 # 4~6
latent_size = 100 # 2^4 ~ 2^9
gnn_msg_steps = 10 # 1~32
pressure_levels = 37 # 13, 25, 37
hidden_layers = 1 
radius_query_fraction_edge_length = 0.6 # 1로도 가능
params = None
state = {}

In [None]:
# 모델 생성
model_config = graphcast.ModelConfig(
        resolution=resolution, 
        mesh_size=mesh_size,
        latent_size=latent_size,
        gnn_msg_steps=gnn_msg_steps,
        hidden_layers=hidden_layers,
        radius_query_fraction_edge_length=radius_query_fraction_edge_length    
    )

task_config = graphcast.TaskConfig(
        input_variables=(graphcast.TARGET_SURFACE_VARS + graphcast.TARGET_ATMOSPHERIC_VARS + graphcast.FORCING_VARS +
        graphcast.STATIC_VARS),
        target_variables=graphcast.TARGET_SURFACE_VARS + graphcast.TARGET_ATMOSPHERIC_VARS,
        forcing_variables=graphcast.FORCING_VARS,
        pressure_levels=graphcast.PRESSURE_LEVELS[pressure_levels],
        input_duration="12h"
    )

model_config

In [None]:
# prerequisits

GC_original = 'GraphCast - ERA5 1979-2017 - resolution 0.25 - pressure levels 37 - mesh 2to6 - precipitation input and output.npz'
GC_operational = 'GraphCast_operational - ERA5-HRES 1979-2021 - resolution 0.25 - pressure levels 13 - mesh 2to6 - precipitation output only.npz'
GC_small = 'GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz'

# TODO: 모델 하나 골라 로드하기
pretrained_model = GC_original

with gcs_bucket.blob(f"params/{pretrained_model}").open("rb") as f:
    ckpt = checkpoint.load(f, graphcast.CheckPoint)

    
params = ckpt.params  # 로드된 체크포인트에서 파라미터 가져오기
state = {}  # 초기 상태는 빈 딕셔너리로 설정

# 체크포인트에서 모델 구성 가져오기
model_config = ckpt.model_config
# 체크포인트에서 작업 구성 가져오기
task_config = ckpt.task_config

# 모델 설명 출력
print("Model description:\n", ckpt.description, "\n")
# 모델 라이선스 출력
print("Model license:\n", ckpt.license, "\n")

print("Model config:\n", model_config, "\n")

print("Task config:\n", task_config)

# 모델과 데이터셋의 유효성을 검사하는 함수
def data_valid_for_model(
    file_name: str, 
    model_config: graphcast.ModelConfig, 
    task_config: graphcast.TaskConfig):
    # 파일 이름에서 접미사 ".nc"를 제거하고 파일의 각 부분을 파싱
    file_parts = parse_file_parts(file_name.removesuffix(".nc"))
    
    # 모델의 설정과 데이터셋 파일의 설정을 비교하여 유효성 검사
    return (
        # 모델의 해상도가 0이거나 파일의 해상도와 일치해야 함
        model_config.resolution in (0, float(file_parts["res"])) and
        # 모델의 압력 레벨 수가 파일의 압력 레벨 수와 일치해야 함
        len(task_config.pressure_levels) == int(file_parts["levels"]) and
        (
            # 모델이 강수량을 입력 변수로 가지고 있는 경우, 파일 소스는 "era5" 또는 "fake"여야 함
            ("total_precipitation_6hr" in task_config.input_variables and
             file_parts["source"] in ("era5", "fake")) or
            # 모델이 강수량을 입력 변수로 가지고 있지 않은 경우, 파일 소스는 "hres" 또는 "fake"여야 함
            ("total_precipitation_6hr" not in task_config.input_variables and
             file_parts["source"] in ("hres", "fake"))
        )
    )

### Load Data

In [None]:
# load data

file_name = 'testdata/2022-01-01/source-era5_date-2022-01-01_res-0.25_levels-37_steps-12.nc'

dataset = xarray.open_dataset(file_name)

print("TOA_solar_incident_radiation" in dataset.data_vars)

# dataset.drop_vars("TOA_solar_incident_radiation") #       <---- TOA 사용 여부 변경!


In [None]:
# 점검

# 데이터셋 파일이 모델 설정에 유효한지 확인
# if not data_valid_for_model(file_name, model_config, task_config):
    # raise ValueError(
        # "Invalid dataset file, rerun the cell above and choose a valid dataset file.")
# 
# 데이터셋의 시간 차원이 최소 3인지 확인 (입력용 2, 목표용 1 이상)
# assert dataset.dims["time"] >= 3  # 2 for input, >=1 for targets


# 선택한 데이터셋 파일의 정보를 출력
# print(", ".join([f"{k}: {v}" for k, v in parse_file_parts(file_name.removesuffix(".nc")).items()]))

# 로드된 데이터셋을 출력
# dataset

In [None]:
# 데이터셋의 시간 크기 확인
total_time_steps = dataset.sizes["time"] - 2

total_time_steps

In [None]:
# split dataset
train_steps = 0
eval_steps = 12

# train_steps와 eval_steps가 유효한 범위 내에 있는지 확인
# assert 1 <= train_steps <= total_time_steps, f"train_steps must be between 1 and {total_time_steps}."

# assert 1 <= eval_steps <= total_time_steps, f"eval_steps must be between 1 and {total_time_steps}."

In [None]:
# train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
#     dataset, 
#     target_lead_times=slice("6h", f"{train_steps * 6}h"), 
#     **dataclasses.asdict(task_config)
# )

eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(
    dataset, 
    target_lead_times=slice("6h", f"{eval_steps * 6}h"), 
    **dataclasses.asdict(task_config)
)

# 실행

데이터 뽑기는 여기 앞까지

In [None]:
# Print the dimensions of the example batch and extracted training and evaluation data
print("All Examples:  ", dataset.dims.mapping)
# print("----------------------------------------------")
# print("Train Inputs:  ", train_inputs.dims.mapping)
# print("Train Targets: ", train_targets.dims.mapping)
# print("Train Forcings:", train_forcings.dims.mapping)
print("----------------------------------------------")
print("Eval Inputs:   ", eval_inputs.dims.mapping)
print("Eval Targets:  ", eval_targets.dims.mapping)
print("Eval Forcings: ", eval_forcings.dims.mapping)
print("----------------------------------------------")
print("Eval Inputs:   \n", eval_inputs)
print("Eval Targets:  \n", eval_targets)
print("Eval Forcings: \n", eval_forcings)

In [None]:
# @title Load normalization data

# Load the dataset containing the standard deviations of differences by level
# from the specified Google Cloud Storage (GCS) bucket
with gcs_bucket.blob("stats/diffs_stddev_by_level.nc").open("rb") as f:
    diffs_stddev_by_level = xarray.load_dataset(f).compute()

# Load the dataset containing the mean values by level
# from the specified GCS bucket
with gcs_bucket.blob("stats/mean_by_level.nc").open("rb") as f:
    mean_by_level = xarray.load_dataset(f).compute()

# Load the dataset containing the standard deviations by level
# from the specified GCS bucket
with gcs_bucket.blob("stats/stddev_by_level.nc").open("rb") as f:
    stddev_by_level = xarray.load_dataset(f).compute()


In [None]:
# JIT 컴파일된 함수를 생성하고, 필요한 경우 랜덤 가중치를 초기화합니다

def construct_wrapped_graphcast(
    model_config: graphcast.ModelConfig,
    task_config: graphcast.TaskConfig):
  """GraphCast 예측기를 구성하고 래핑합니다."""
  predictor = graphcast.GraphCast(model_config, task_config)
  predictor = casting.Bfloat16Cast(predictor)
  predictor = normalization.InputsAndResiduals(
      predictor,
      diffs_stddev_by_level=diffs_stddev_by_level,
      mean_by_level=mean_by_level,
      stddev_by_level=stddev_by_level)
  predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)
  return predictor

@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
    """구성된 예측기를 사용하여 순전파를 실행합니다."""
    predictor = construct_wrapped_graphcast(model_config, task_config)
    return predictor(inputs, targets_template=targets_template, forcings=forcings)

@hk.transform_with_state
def loss_fn(model_config, task_config, inputs, targets, forcings):
    """구성된 예측기를 사용하여 손실과 진단 정보를 계산합니다."""
    predictor = construct_wrapped_graphcast(model_config, task_config)
    loss, diagnostics = predictor.loss(inputs, targets, forcings)
    return xarray_tree.map_structure(
        lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
        (loss, diagnostics)
    )

def grads_fn(params, state, model_config, task_config, inputs, targets, forcings):
  """매개변수에 대한 손실 함수의 그래디언트를 계산합니다."""
  def _aux(params, state, i, t, f):
    (loss, diagnostics), next_state = loss_fn.apply(
        params, state, jax.random.PRNGKey(0), model_config, task_config,
        i, t, f)
    return loss, (diagnostics, next_state)
  (loss, (diagnostics, next_state)), grads = jax.value_and_grad(
      _aux, has_aux=True)(params, state, inputs, targets, forcings)
  return loss, diagnostics, next_state, grads

def with_configs(fn):
  """모델 및 작업 구성을 적용하는 유틸리티 함수입니다."""
  return functools.partial(
      fn, model_config=model_config, task_config=task_config)

def with_params(fn):
  """함수가 항상 매개변수와 상태를 받도록 보장합니다."""
  return functools.partial(fn, params=params, state=state)

def drop_state(fn):
  """예측만 반환합니다. 롤아웃 코드에 필요합니다."""
  return lambda **kw: fn(**kw)[0]

init_jitted = jax.jit(with_configs(run_forward.init))

if params is None:
    # 매개변수와 상태가 아직 설정되지 않은 경우 초기화합니다.
    params, state = init_jitted(
        rng=jax.random.PRNGKey(0),
        inputs=train_inputs,
        targets_template=train_targets,
        forcings=train_forcings
    )

# 필요한 구성과 매개변수로 함수들을 JIT 컴파일합니다.
loss_fn_jitted = drop_state(with_params(jax.jit(with_configs(loss_fn.apply))))
grads_fn_jitted = with_params(jax.jit(with_configs(grads_fn)))
run_forward_jitted = drop_state(with_params(jax.jit(with_configs(run_forward.apply))))

In [None]:
# Ensure that the model resolution matches the data resolution
assert model_config.resolution in (0, 360. / eval_inputs.sizes["lon"]), (
    "Model resolution doesn't match the data resolution. You likely want to "
    "re-filter the dataset list, and download the correct data.")

In [None]:
target_template = his_utils.create_target_dataset(time_steps=eval_steps, 
                   resolution=model_config.resolution, 
                   pressure_levels=len(task_config.pressure_levels))

In [None]:
forcings = his_utils.create_forcing_dataset(time_steps=eval_steps,
                                              resolution=model_config.resolution,
                                              start_time="2021-01-01")

In [None]:
print(jax.__version__)
print(jax.devices())

import his_utils

# Perform autoregressive rollout to generate predictions
predictions = rollout.chunked_prediction(
    run_forward_jitted,
    rng=jax.random.PRNGKey(0),
    inputs=eval_inputs,
    targets_template=target_template,
    forcings=forcings
)

# Display the predictions
predictions

In [None]:
predictions.to_netcdf("predictions_google_2022-01-01T00h_12step.nc")

# Plot

이쁘게 포장하는 공정은 여기부터

## 움짤 만들기

In [None]:
from PIL import Image
from pathlib import Path
import re
import os
import his_utils

# 'figure' 디렉토리의 경로를 지정합니다.
figure_dir = 'figure'

# # 파일 이름에서 날짜와 시간을 추출하는 정규표현식
# pattern = r'polar_GC_temperature_(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}).png'

# # 파일 이름에서 날짜와 시간을 추출하여 리스트를 만듭니다.
# date_list = sorted([re.search(pattern, f).group(1) for f in os.listdir(figure_dir) if f.startswith('polar_GC_temperature_') and f.endswith('.png')])

# 정렬된 date_list를 사용하여 image_frames를 생성합니다.
image_frames = [Image.open(os.path.join(figure_dir, f'total_precipitation_6hrRR_{date}.png')) for date in range(0,14)]

# GIF를 저장합니다.
his_utils.save_gif(image_frames, 'tp_6hr diff.gif', duration=700)

In [None]:
from multiprocessing import Pool
from functools import partial

list = [
 'total_precipitation_6hr']

def process_var(var, date_list, duration):
    image_frames = []
    for date in date_list:
        with Image.open(os.path.join("figure/", f'{var}_{date}.png')) as img:
            image_frames.append(img.copy())
    
    his_utils.save_gif(image_frames, f'2021-01-01_{var}.gif', duration=duration)

process_var_partial = partial(process_var, date_list=date_list, duration=700)

# 멀티프로세싱 풀 생성 및 작업 실행
with Pool() as pool:
    pool.map(process_var_partial, list)

# 연습장


In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import xarray as xr
import numpy as np
from matplotlib.colors import TwoSlopeNorm
import matplotlib.pyplot as plt
from multiprocessing import Pool

ERA5 = xr.open_dataset(f'testdata/2022-01-01RR.nc')
google = xr.open_dataset('testdata/source-era5_date-2022-01-01_res-0.25_levels-37_steps-12.nc').drop_vars('toa_incident_solar_radiation')

def plot(args):
    dataset, target_var, time_index = args

    fig = plt.figure(figsize=(20, 10))
    ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=180))

    data = dataset * 1000

    im = ax.pcolormesh(data.lon, data.lat, data.isel(time=time_index).squeeze(), 
                   transform=ccrs.PlateCarree(), 
                   cmap='Blues',
                   norm=TwoSlopeNorm(vmin=0, vcenter=0.5, vmax=20),
                   shading='auto')
    
    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS)
    ax.gridlines(draw_labels=True)

    flag = 'batch' in dataset.coords

    plt.title(f'{flag} {target_var}\nTime: {time_index}')
    ax.set_global()

    plt.savefig(f'figure/{flag} {target_var}_{time_index}.png', dpi=300, bbox_inches='tight')
    plt.close()


with Pool() as pool:
    args_list = [(ERA5['total_precipitation_6hr'], 'total_precipitation_6hr',time_index) for time_index in range(0, 14)]
    pool.map(plot, args_list)

with Pool() as pool:
    args_list = [(google['total_precipitation_6hr'], 'total_precipitation_6hr', time_index) for time_index in range(0, 14)]
    pool.map(plot, args_list)