In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
# plt.xkcd()
# rcParams['font.family'] = ['xkcd', 'Comic Neue', 'Comic Mono']

import xarray as xr
import random 
import os 

from work import handler
from work import casestudy
from work import storm_tracker

from work.plots.hist import simple_hist
from work.transect import add_transects_with_aligned_boxes,make_mask_box

settings_path = 'settings/sam3d.yaml'

# import matplotlib.cm as cm
# from scipy.interpolate import CloughTocher2DInterpolator, LinearNDInterpolator, NearestNDInterpolator
# import glob
# import intake
# import dask
# import functools
# import pandas as pd
# dask.config.set({"array.slicing.split_large_chunks": True}) 
import cartopy.crs as ccrs
import cartopy.feature as cf
# import cmocean
# # !pip install easygems
# import tqdm
# import scipy
# import datetime as dt 
import matplotlib.colors as mcolors
from matplotlib.colors import LinearSegmentedColormap
# from funcs import *


hdlr = handler.Handler(settings_path)
cs = casestudy.CaseStudy(hdlr, overwrite = False ,verbose = False)
st = storm_tracker.StormTracker(cs, overwrite_storms = False, overwrite = False, verbose = True) #overwrite = True is super long, computes growth rate (triangle fit)

Data loaded from /homedata/mcarenso/shear/SAM3d_Tropics/var_id_days_i_t.json
Loading storms...
loading storms from netcdf
Time elapsed for loading storms: 1.76 seconds


In [None]:
duration_min = 6  # or 10
surfmaxkm2_min = 10000  # or other value
region_latmin, region_latmax, region_lonmin, region_lonmax = -15, 30, -180, 180
filename_save = f"profile_dataset_storms_dmin{duration_min}_smin{surfmaxkm2_min}_lat{region_latmin}_{region_latmax}_lon{region_lonmin}_{region_lonmax}.nc"
storms_path = os.path.join(st.settings["DIR_DATA_OUT"], cs.name, filename_save)
ds = xr.open_dataset(storms_path)

In [None]:
TABS_init_profile = ds['TABS_init_profile'].values  # Shape: (num_samples, num_levels)
QV_init_profile = ds['QV_init_profile'].values
U_init_profile = ds['U_init_profile'].values
V_init_profile = ds['V_init_profile'].values

TABS_max_instant_profile = ds['TABS_max_instant_profile'].values
QV_max_instant_profile = ds['QV_max_instant_profile'].values
U_max_instant_profile = ds['U_max_instant_profile'].values
V_max_instant_profile = ds['V_max_instant_profile'].values

# Extract scalar variables
lon_init = ds['lon_init'].values  # Shape: (num_samples,)
lat_init = ds['lat_init'].values
time_init = pd.to_datetime(ds['time_init'].values)

lon_max_instant = ds['lon_max_instant'].values
lat_max_instant = ds['lat_max_instant'].values
time_max_instant = pd.to_datetime(ds['time_max_instant'].values)

In [None]:
# Convert time variables to Unix timestamp (seconds since epoch)
time_init_numeric = time_init.astype(np.int64) // 10**9
time_max_instant_numeric = time_max_instant.astype(np.int64) // 10**9


In [None]:
# Number of samples and levels
num_samples = TABS_init_profile.shape[0]
num_levels = TABS_init_profile.shape[1]

# Concatenate profile variables
profile_vars_init = [TABS_init_profile, QV_init_profile, U_init_profile, V_init_profile]
profile_vars_max = [TABS_max_instant_profile, QV_max_instant_profile, U_max_instant_profile, V_max_instant_profile]

# Concatenate initial and max instant profiles along the feature axis
profiles_init_concat = np.concatenate(profile_vars_init, axis=1)  # Shape: (num_samples, 4 * num_levels)
profiles_max_concat = np.concatenate(profile_vars_max, axis=1)    # Shape: (num_samples, 4 * num_levels)

# Reshape scalar variables to 2D arrays
lon_init = lon_init.reshape(-1, 1)
lat_init = lat_init.reshape(-1, 1)
time_init_numeric = time_init_numeric.reshape(-1, 1)

lon_max_instant = lon_max_instant.reshape(-1, 1)
lat_max_instant = lat_max_instant.reshape(-1, 1)
time_max_instant_numeric = time_max_instant_numeric.reshape(-1, 1)

# Concatenate scalar variables
scalar_vars = np.hstack([
    lon_init, lat_init, time_init_numeric,
    lon_max_instant, lat_max_instant, time_max_instant_numeric
])  # Shape: (num_samples, 6)

# Combine all inputs
X = np.hstack([profiles_init_concat, profiles_max_concat, scalar_vars])  # Shape: (num_samples, total_features)


In [None]:
from sklearn.preprocessing import StandardScaler, MinMaxScaler

# Number of profile features
num_profile_features = profiles_init_concat.shape[1] + profiles_max_concat.shape[1]  # 8 * num_levels

# Split profiles and scalars
X_profiles = X[:, :num_profile_features]
X_scalars = X[:, num_profile_features:]

# Normalize profiles using StandardScaler
scaler_profiles = StandardScaler()
X_profiles_scaled = scaler_profiles.fit_transform(X_profiles)

# Normalize scalars using MinMaxScaler
scaler_scalars = MinMaxScaler()
X_scalars_scaled = scaler_scalars.fit_transform(X_scalars)

# Combine scaled profiles and scalars
X_scaled = np.hstack([X_profiles_scaled, X_scalars_scaled])


In [None]:
from sklearn.model_selection import train_test_split

X_train, X_test = train_test_split(X_scaled, test_size=0.2, random_state=42)


In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

input_dim = X_train.shape[1]  # Total number of features

# Input layer
input_layer = keras.Input(shape=(input_dim,))

# Encoder
encoded = layers.Dense(512, activation='relu')(input_layer)
encoded = layers.Dense(256, activation='relu')(encoded)
encoded = layers.Dense(128, activation='relu')(encoded)
bottleneck = layers.Dense(64, activation='relu')(encoded)  # Bottleneck layer

# Decoder
decoded = layers.Dense(128, activation='relu')(bottleneck)
decoded = layers.Dense(256, activation='relu')(decoded)
decoded = layers.Dense(512, activation='relu')(decoded)
output_layer = layers.Dense(input_dim, activation='linear')(decoded)

# Autoencoder model
autoencoder = keras.Model(inputs=input_layer, outputs=output_layer)

# Compile the model
autoencoder.compile(optimizer='adam', loss='mean_squared_error')

autoencoder.summary()

In [None]:
history = autoencoder.fit(
    X_train, X_train,
    epochs=100,
    batch_size=32,
    shuffle=True,
    validation_data=(X_test, X_test)
)

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Autoencoder Loss Over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# Encoder model
encoder = keras.Model(inputs=input_layer, outputs=bottleneck)

# Encode the data
encoded_data = encoder.predict(X_scaled)

In [None]:
autoencoder.save('autoencoder_model.h5')
encoder.save('encoder_model.h5')

In [None]:
# Reconstruct inputs
reconstructed_data = autoencoder.predict(X_test)

# Compare original and reconstructed data for a sample
sample_index = 0  # Change as needed
plt.figure(figsize=(15, 5))
plt.plot(X_test[sample_index], label='Original')
plt.plot(reconstructed_data[sample_index], label='Reconstructed')
plt.legend()
plt.show()

In [None]:
### -- Chat separate encoding

# # Profile input
# input_profiles = keras.Input(shape=(num_profile_features,))
# x_profiles = layers.Dense(256, activation='relu')(input_profiles)
# x_profiles = layers.Dense(128, activation='relu')(x_profiles)

# # Scalar input
# input_scalars = keras.Input(shape=(X_scalars_scaled.shape[1],))
# x_scalars = layers.Dense(32, activation='relu')(input_scalars)

# # Combine profiles and scalars
# combined = layers.concatenate([x_profiles, x_scalars])

# # Bottleneck layer
# bottleneck = layers.Dense(64, activation='relu')(combined)

# # Decoder for combined data
# x = layers.Dense(128, activation='relu')(bottleneck)
# x = layers.Dense(256, activation='relu')(x)
# output_layer = layers.Dense(input_dim, activation='linear')(x)

# # Autoencoder model with two inputs
# autoencoder = keras.Model(inputs=[input_profiles, input_scalars], outputs=output_layer)

# # Compile the model
# autoencoder.compile(optimizer='adam', loss='mean_squared_error')

# autoencoder.summary()

# # Train the model
# history = autoencoder.fit(
#     [X_train[:, :num_profile_features], X_train[:, num_profile_features:]], X_train,
#     epochs=100,
#     batch_size=32,
#     shuffle=True,
#     validation_data=([X_test[:, :num_profile_features], X_test[:, num_profile_features:]], X_test)
# )
