In [210]:
import pathlib

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
import torch
import xarray as xr

from sklearn.cluster import KMeans
from torch.utils.data import Dataset

# Symmetric Log Transformation
def symmetric_log(x, C):
    return np.sign(x) * np.log1p(np.abs(x) + C)

# Inverse of the symmetric log
def inverse_symmetric_log(y, C):
    return np.sign(y) * (np.exp(np.sign(y) * y) - C - 1)

class MappableDataset(Dataset):
    def __init__(self, features, target):
        self.features = features
        self.target = target

    def __len__(self):
        return len(self.target)

    def __getitem__(self, idx):
        return torch.tensor(self.features[idx], dtype=torch.float32), torch.tensor(self.target[idx], dtype=torch.float32)

class EKE_Dataset(MappableDataset):
    def __init__(self, file_path, do_normalization=True):
        self.do_normalization = do_normalization

        ds = xr.open_dataset(file_path)
        # TODO: Check to see if we should be dynamically changing this
        self.C = self._compute_C(ds.RV_vert_avg.values.flatten())

        vars = ["KE_vert_sum", "RV_vert_avg", "slope_vert_avg", "Rd_dx_scaled"]
        # Extract features & target, flattening them to 1D vectors
        features = np.stack(
            [ds[var].values.flatten() for var in vars],
            axis=1
        )
        target = np.log1p(ds['EKE'].values.flatten())  # Log transform target

        # **Filter out samples where ln(EKE) < 0**
        valid_indices = target > 0
        super().__init__(features[valid_indices,:], target[valid_indices])

        # Compute mean & std for standardization (across dataset)
        self.mean = self.features.mean(axis=0)
        self.std = self.features.std(axis=0)
        self.features = self.transform(self.features)

    def __len__(self):
        return len(self.target)

    def normalize(self, X):
        return (X - self.mean)/self.std

    def inverse_normalize(self, X):
        return X*self.std + self.mean

    def transform(self, X):
        Y = np.zeros_like(X)
        Y[:,0] = np.log1p(X[:,0])
        Y[:,1] = symmetric_log(X[:,1], self.C)  # Symmetric Log
        Y[:,2] = np.log1p(X[:,2])
        Y[:,3] = X[:,3]

        if self.do_normalization:
            Y = self.normalize(Y)
        return Y

    # Undo the transform
    def inverse_transform(self, X):
        if self.do_normalization:
            Y = self.inverse_normalize(X)
        else:
            Y = X.copy()

        Y[:,0] = np.expm1(Y[:,0])
        Y[:,1] = inverse_symmetric_log(Y[:,1], self.C)
        Y[:,2] = np.expm1(Y[:,2])
        return Y

    # Function to compute C dynamically based on the smallest nonzero absolute value in RV_vert_avg
    def _compute_C(self, RV):
        nonzero_values = np.abs(RV[RV != 0])
        C = np.min(nonzero_values) if len(nonzero_values) > 0 else 1.0  # Avoid zero
        return np.log(C + 1)

    # Return a truncated dataset by deliberating excluding a cluster of data
    # Default is the most positive relative vorticity
    def truncate(self, feature_idx=1):
        clusters = KMeans(n_clusters=6, random_state=0).fit(self.features)
        centers_dimensional = self.inverse_transform(clusters.cluster_centers_)
        print(centers_dimensional)
        excluded_cluster = np.argmax(centers_dimensional[:,feature_idx])
        retained_idx = clusters.labels_ != excluded_cluster
        truncated = MappableDataset(self.features[retained_idx], self.target[retained_idx])
        truncated.clusters = clusters
        truncated.excluded_cluster = excluded_cluster
        return truncated


In [211]:
DATAPATH = pathlib.Path("/lustre/data/shao/cug_2024/")
SIMULATION_DATA = DATAPATH / "featurized.nc"
TRAINING_DATA = DATAPATH / "training_data.pkl"
OUTPUT_CLUSTER = DATAPATH / "clusters.pkl"
ds = EKE_Dataset(SIMULATION_DATA)



In [213]:
truncated_ds = ds.truncate()
with open(TRAINING_DATA, "wb") as f:
  pickle.dump(truncated_ds, f)

[[ 2.44920196e+02  1.15463336e-02  3.57502202e-03  5.74257489e-01]
 [ 1.53860461e+02 -1.03021710e-03  2.16386084e-03  6.80731294e-01]
 [ 9.14740434e+00 -6.29118175e-04  4.45399527e-04  7.76396870e-01]
 [ 1.12734897e+01 -3.85966267e-04  1.16298642e-03  3.68724724e-01]
 [ 2.48926633e+02 -1.68839914e-03  3.68238037e-03  5.26326932e-01]
 [ 2.03785435e+02 -1.28160361e-02  2.97113680e-03  5.63655733e-01]]


In [197]:
print(excluded_cluster)

0


In [198]:
labels = clusters.predict(truncated_ds.features)
print(f"Samples in excluded cluster: {np.sum(labels == excluded_cluster)}")

Samples in excluded cluster: 0


In [200]:
labels = clusters.predict(ds.features)
print(f"Samples in excluded cluster: {np.sum(labels == excluded_cluster)}")
print(f"Total number of samples: {ds.features.shape[0]}")

Samples in excluded cluster: 846338
Total number of samples: 8821872


In [201]:
with open(OUTPUT_CLUSTER, "wb") as f:
  pickle.dump(clusters, f)

In [202]:
with open(OUTPUT_CLUSTER, "rb") as f:
  clusters_test = pickle.load(f)
labels = clusters_test.predict(truncated_ds.features)
print(f"Samples in excluded cluster: {np.sum(labels == excluded_cluster)}")
labels = clusters.predict(ds.features)
print(f"Samples in excluded cluster: {np.sum(labels == excluded_cluster)}")

Samples in excluded cluster: 0
Samples in excluded cluster: 846338
