In [6]:
import dataclasses
import functools
import shutil
import tensorflow_datasets as tfds
import pandas as pd
import tensorflow as tf
import pickle
import numpy as np
from typing import Any
from climsim_utils.data_utils import *
import fnmatch
from etils import etree
import time
import gc
import rich

from google.cloud import storage

2024-07-18 11:28:58.363228: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-07-18 11:28:58.456206: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
# import concurrent.futures
# from huggingface_hub import HfFileSystem
# import time

# # Initialize HfFileSystem
# fs = HfFileSystem()

# # Function to list files in a single directory
# def list_directory(path):
#     retry_count = 0
#     max_retries = 5
#     while retry_count < max_retries:
#         try:
#             items = fs.ls(path, detail=True)
#             directories = [item['name'] for item in items if item['type'] == 'directory']
#             print(directories)
#             files = [item['name'] for item in items if item['type'] != 'directory']
#             return directories, files
#         except Exception as e:
#             print(f"Error listing {path}: {e}")
#             retry_count += 1
#             time.sleep(2 ** retry_count)  # Exponential backoff
#     return [], []

# # Function to list all files in a dataset repository using multithreading
# def list_all_files_multithreaded(repo_id):
#     stack = [f"datasets/{repo_id}"]
#     all_files = []
#     count = 0
    
#     with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
#         while stack:
#             futures = {executor.submit(list_directory, path): path for path in stack}
#             stack = []
            
#             for future in concurrent.futures.as_completed(futures):
#                 directories, files = future.result()
#                 stack.extend(directories)
#                 all_files.extend(files)
#                 count += len(files)
                
#                 if count % 100 == 0:
#                     print(f"Retrieved {count} files")
    
#     return all_files

# # List files in the dataset repository
# repo_id = "LEAP/ClimSim_low-res"
# file_paths = list_all_files_multithreaded(repo_id)

# # Print the file paths
# # for path in file_paths:
# #     print(path)

# # Save file paths to a pickle file
# import pickle

# with open('file_paths.pkl', 'wb') as f:
#     pickle.dump(file_paths, f)

# # Shorten file paths for easier directory structure inspection
# shorten_fp = ['/'.join(f.split('/')[3:]) for f in file_paths]
# shorten_fp = [f for f in shorten_fp if 'train' in f]
# shorten_fp = sorted(shorten_fp)

# # Print directory structure 
# def print_directory_structure(file_paths):
#     directories = set()
#     for f in file_paths:
#         directory = '/'.join(f.split('/')[:-1])
#         directories.add(directory)
#     directories = sorted(directories)
# print_directory_structure(shorten_fp)

# np.save("all_files_in_low_res.npy", shorten_fp, allow_pickle=True)
# !mv all_files_in_low_res.npy climsim/all_files_in_low_res.npy

In [4]:
import fnmatch
import time
import numpy as np
from huggingface_hub import hf_hub_download

all_file_paths = np.load("climsim/all_files_in_low_res.npy", allow_pickle=True)

input_files = [f for f in all_file_paths if "mli" in f]
output_files = set([f for f in all_file_paths if "mlo" in f])

print(f"Total input files: {len(input_files)}")
print(f"Total output files: {len(output_files)}")

start = time.time()

regexps = [
        "*/*/E3SM-MMF.mli.000[1234567]-*-*-*.nc",  # years 1 through 7
        "*/*/E3SM-MMF.mli.0008-01-*-*.nc",  # first month of year 8
    ]
matched_files = []
for pattern in regexps:
    matched_files.extend(fnmatch.filter(input_files, pattern))

print(f"Total matched files: {len(matched_files)}")
print(f"Matching files took {time.time() - start} seconds")

start = time.time()

filtered_files = [
    ip_f for ip_f in matched_files if ip_f.replace("mli", "mlo") in output_files
]
print(f"Filtering files took {time.time() - start} seconds")

filtered_files = filtered_files[::7]  # Time Stride
filtered_files = np.array(filtered_files)

print(f"Total filtered files: {len(filtered_files)}")

n_splits = 8

total_files = len(filtered_files)
files_per_split = total_files // n_splits
print(f"Files per split: {files_per_split}")

np.random.seed(42)
np.random.shuffle(filtered_files)

splits = np.array_split(filtered_files, n_splits)
splits = [s.tolist() for s in splits]

print(len(splits[0]))

Total input files: 210240
Total output files: 210240
Total matched files: 183960
Matching files took 0.14578962326049805 seconds
Filtering files took 0.031223773956298828 seconds
Total filtered files: 26280
Files per split: 3285
3285


In [7]:
# from build_high_res_ds_local import get_dataset_from_file_names
from typing import Any

def get_dataset_from_file_names(
    train_files,
    repo_id="LEAP/ClimSim_high-res",
) -> dict[str, dict[str, Any]]:

    download_file_fn = functools.partial(
        hf_hub_download,
        repo_id=repo_id,
        repo_type="dataset",
    )

    def download_file(filename: str) -> str:
        try:
            return download_file_fn(filename=filename)
        except Exception as e:
            print(f"Failed to download {filename} with error {e}")
            return None

    download_paths = etree.parallel_map(download_file, train_files)
    download_paths = [x for x in download_paths if x is not None]
    file_paths = dict(zip(train_files, download_paths))
    data_path = "/".join(file_paths[train_files[0]].split("/")[:-3])
    return file_paths, data_path


train_files = splits[0][:3]

train_files.extend([f.replace("mli", "mlo") for f in train_files])
train_files.append("ClimSim_low-res_grid-info.nc")

file_paths, data_path = get_dataset_from_file_names(
    train_files, repo_id="LEAP/ClimSim_low-res"
)

file_paths, data_path

({'train/0006-06/E3SM-MMF.mli.0006-06-04-37200.nc': '/home/joylunkad/.cache/huggingface/hub/datasets--LEAP--ClimSim_low-res/snapshots/bab82a2ebdc750a0134ddcd0d5813867b92eed2a/train/0006-06/E3SM-MMF.mli.0006-06-04-37200.nc',
  'train/0002-10/E3SM-MMF.mli.0002-10-05-37200.nc': '/home/joylunkad/.cache/huggingface/hub/datasets--LEAP--ClimSim_low-res/snapshots/bab82a2ebdc750a0134ddcd0d5813867b92eed2a/train/0002-10/E3SM-MMF.mli.0002-10-05-37200.nc',
  'train/0003-01/E3SM-MMF.mli.0003-01-24-14400.nc': '/home/joylunkad/.cache/huggingface/hub/datasets--LEAP--ClimSim_low-res/snapshots/bab82a2ebdc750a0134ddcd0d5813867b92eed2a/train/0003-01/E3SM-MMF.mli.0003-01-24-14400.nc',
  'train/0006-06/E3SM-MMF.mlo.0006-06-04-37200.nc': '/home/joylunkad/.cache/huggingface/hub/datasets--LEAP--ClimSim_low-res/snapshots/bab82a2ebdc750a0134ddcd0d5813867b92eed2a/train/0006-06/E3SM-MMF.mlo.0006-06-04-37200.nc',
  'train/0002-10/E3SM-MMF.mlo.0002-10-05-37200.nc': '/home/joylunkad/.cache/huggingface/hub/datasets--LE

In [40]:
import xarray as xr

norm_path = "climsim/preprocessing/normalizations/"
input_mean = xr.open_dataset(norm_path + "inputs/input_mean.nc")
input_max = xr.open_dataset(norm_path + "inputs/input_max.nc")
input_min = xr.open_dataset(norm_path + "inputs/input_min.nc")
output_scale = xr.open_dataset(norm_path + "outputs/output_scale.nc")

grid_path = os.path.join(data_path, "ClimSim_low-res_grid-info.nc")
grid_info = xr.open_dataset(grid_path)
train_col_names = np.load("train_col_names.npy").tolist()

data = data_utils(
    grid_info=grid_info,
    input_mean=input_mean,
    input_max=input_max,
    input_min=input_min,
    output_scale=output_scale,
)

data.set_to_v2_vars()

data.normalize = False
data.data_path = f"{data_path}/*/"

data.set_regexps(
    data_split="train",
    regexps=[
        "E3SM-MMF.mli.000[1234567]-*-*-*.nc",  # years 1 through 7
        "E3SM-MMF.mli.0008-01-*-*.nc",  # first month of year 8
    ],
)

data.set_stride_sample(data_split="train", stride_sample=1)
data.set_filelist(data_split="train")
data_loader = data.load_ncdata_with_generator(data_split="train")
npy_iterator = list(data_loader.as_numpy_iterator())

# filelist = np.array(
#     [npy_iterator[x][2] for x in range(len(npy_iterator))]
# ).flatten()
# filelist = [x.decode() for x in filelist]
# file_ids = [f.split("/")[-1].split(".")[-2] for f in filelist]

# train_index = [
#     [
#         f"train_{file_ids[f_idx]}_{str(x)}"
#         for x in range(len(npy_iterator[f_idx][0]))
#     ]
#     for f_idx in range(len(file_ids))
# ]
# train_index = np.concatenate(train_index)

# npy_input = np.concatenate([npy_iterator[x][0] for x in range(len(npy_iterator))])
# print("dropping cam_in_SNOWHICE because of strange values")
# drop_idx = train_col_names.index("cam_in_SNOWHICE")
# npy_input = np.delete(npy_input, drop_idx, axis=1)
# npy_output = np.concatenate([npy_iterator[x][1] for x in range(len(npy_iterator))])

2024-07-18 12:07:10.442964: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:984] could not open file to read NUMA node: /sys/bus/pci/devices/0000:03:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-07-18 12:07:13.183600: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2251] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
2024-07-18 12:07:14.553647: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [53]:
npy_input = np.concatenate([npy_iterator[x][0] for x in range(len(npy_iterator))])
npy_output = np.concatenate([npy_iterator[x][1] for x in range(len(npy_iterator))])

grid_ids = np.arange(0, 384)
grid_ids = np.broadcast_to(grid_ids, (len(npy_iterator), 384)).flatten()

npy_input.shape, npy_output.shape, grid_ids.shape

((1152, 557), (1152, 368), (1152,))

In [51]:
npy_input = np.concatenate([npy_iterator[x][0] for x in range(len(npy_iterator))])
npy_output = np.concatenate([npy_iterator[x][1] for x in range(len(npy_iterator))])

npy_input.shape, npy_output.shape

((1152, 557), (1152, 368))

In [9]:
from dataclasses import dataclass, field, asdict

@dataclass
class Test:
    a: int = 1
    b: int = 2
    c: int = field(init=False)

    def __post_init__(self):
        self.c = self.a + self.b

test = Test()
test

Test(a=1, b=2, c=3)

In [10]:
asdict(test)

{'a': 1, 'b': 2, 'c': 3}

In [14]:
import jax
import jax.numpy as jnp
import numpy as np

a = np.arange(8)
np.random.shuffle(a)

emb = np.random.normal(0, 1, size=(384, 12))

a, emb, emb[a], a.shape, emb.shape

(array([1, 6, 0, 7, 5, 4, 2, 3]),
 array([[ 0.75371362,  1.28111299, -1.37435467, ...,  1.53462682,
          1.03634234, -1.33140981],
        [ 2.01270494,  0.76469265, -2.29979263, ...,  0.27500362,
          0.6303607 , -0.11230725],
        [-0.23218852,  1.22245143,  0.04383246, ...,  0.54443714,
         -1.00727361,  0.42453657],
        ...,
        [ 0.39611081, -0.22550056,  0.63230307, ..., -1.00806094,
         -1.03804627,  0.48448222],
        [-0.65697366, -0.38875687,  0.9545601 , ...,  1.34281384,
         -1.09468071, -1.00750662],
        [-0.01321841,  0.37577535, -0.62182652, ..., -1.03458696,
          1.06802559, -0.30713247]]),
 array([[ 2.01270494e+00,  7.64692655e-01, -2.29979263e+00,
          3.35268690e-01, -3.11094517e-01, -1.28600009e+00,
          1.29145402e+00,  1.04409942e+00,  5.74793761e-01,
          2.75003619e-01,  6.30360704e-01, -1.12307253e-01],
        [ 3.72432391e-01, -3.91675625e-01, -2.85045704e+00,
         -1.77669907e+00,  4.64291424e