In [None]:
!pip install pyspark
!pip install xarray

In [2]:
import xarray as xr
import pandas as pd
import numpy as np
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, last
from pyspark.sql.window import Window
from tqdm import tqdm
from pyspark.ml.feature import VectorAssembler, RobustScaler
from pyspark.ml import Pipeline
import sys
from pyspark.ml.linalg import Vectors, DenseVector, VectorUDT
import pyspark.sql.functions as F
from pyspark.sql.functions import udf, struct
from pyspark.sql.functions import col as Fcol

from pyspark.ml.feature import SQLTransformer


In [3]:
def get_chunk_as_dataframe(dataset, chunk_index):
    """
    Selects a specified chunk along the 'time' dimension and converts it to a DataFrame.
    """
    time_chunks = dataset.chunks['time']
    num_chunks = len(time_chunks)
    
    if chunk_index >= num_chunks:
        raise IndexError(f"Chunk index {chunk_index} out of bounds for axis 'time' with {num_chunks} chunks.")
    
    start_idx = sum(time_chunks[:chunk_index])
    end_idx = start_idx + time_chunks[chunk_index]
    dataset_chunk = dataset.isel(time=slice(start_idx, end_idx))
    print(dataset_chunk)
    
    return dataset_chunk.to_dataframe().reset_index()

In [None]:
file_path = "/kaggle/input/era5-82-23-three-hours/adaptor.mars.internal-1714934325.815979-26713-14-d4ecb07d-ed0d-4eca-940f-48bdda8774ae.nc"
dataset = xr.open_dataset(file_path, chunks={'time': 100})
chunk_df = get_chunk_as_dataframe(dataset, 0)

In [None]:
chunk_df.to_csv("spark_chunk.csv", index=False)

In [None]:
# Start a Spark session
spark = SparkSession.builder.master("local").appName("NetCDF").getOrCreate()

# Convert pandas DataFrame to Spark DataFrame
# df = spark.createDataFrame(chunk_df)
chunk_sparkdf=spark.read.csv('spark_chunk.csv', inferSchema=True, header = True)


In [None]:
# Drop features

chunk_sparkdf = chunk_sparkdf.dropna(subset=["sst"])


features_to_remove = [
    'cdir', 'msdrswrf', 'msdrswrfcs', 'msdwswrf', 'msdwswrfcs', 'msdwuvrf',
    'msnswrf', 'msnswrfcs', 'mtdwswrf', 'mtnswrf', 'mtnswrfcs', 'ssr', 'ssrc',
    'ssrdc', 'ssrd', 'tsr', 'tsrc', 'fdir'
]
hand_selected_features = ['mwp', 'pp1d', 'mwd', 'swh', 'expver', 'siconc']

chunk_sparkdf = chunk_sparkdf.drop(*features_to_remove, *hand_selected_features)

NaN_cols = [col for col in chunk_sparkdf.columns if chunk_sparkdf.filter(chunk_sparkdf[col].isNull()).count() > 0]

for col in tqdm(NaN_cols):
    chunk_sparkdf = chunk_sparkdf.withColumn(
        col,
        last(col, ignorenulls=True).over(Window.partitionBy("latitude", "longitude").orderBy("time").rowsBetween(0, sys.maxsize))
    )

