In [None]:
import numpy as np
import xarray as xr
import pandas as pd
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.activations import sigmoid
from sklearn.model_selection import train_test_split


# Load data
def load_data(folder_path):
    # Assuming data is stored in a .nc file
    data = xr.open_dataset(folder_path)
    sic_data = data['sic']
    sst_data = data['sst']
    return sic_data, sst_data

# Preprocess data
def preprocess_data(sic_data, sst_data):
    # Convert data to pandas DataFrame
    sic_data = sic_data.to_dataframe().reset_index()
    sst_data = sst_data.to_dataframe().reset_index()

    # Merge data on year and location
    data = pd.merge(sic_data, sst_data, on=['year', 'location'])
    
    # Split data into features and labels
    X = data[['sst']].values
    y = data['sic'].values
    
    return X, y

# Split data into training, testing, and validation sets
def split_data(X, y):
    X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.4, random_state=42)
    X_test, X_val, y_test, y_val = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)
    return X_train, X_test, X_val, y_train, y_test, y_val

# Define the PINN model
def create_pinn_model():
    model = Sequential([
        Input(shape=(1,)),
        Dense(64, activation='relu'),
        Dense(64, activation='relu'),
        Dense(1, activation=sigmoid)  # Constrain SIC values between 0 and 1
    ])
    model.compile(optimizer=Adam(learning_rate=0.001), loss=MeanSquaredError())
    return model

# Train the model
def train_model(model, X_train, y_train, X_val, y_val):
    history = model.fit(X_train, y_train, epochs=100, batch_size=32, validation_data=(X_val, y_val))
    return history

# Evaluate the model
def evaluate_model(model, X_test, y_test):
    loss = model.evaluate(X_test, y_test)
    print(f'Test Loss: {loss}')

# Main function
def main(folder_path):
    sic_data, sst_data = load_data(folder_path)
    X, y = preprocess_data(sic_data, sst_data)
    X_train, X_test, X_val, y_train, y_test, y_val = split_data(X, y)
    model = create_pinn_model()
    history = train_model(model, X_train, y_train, X_val, y_val)
    evaluate_model(model, X_test, y_test)

# Specify the folder path where the data is stored
folder_path = '/Users/skylargale/Documents/VSCode/MLGEO2024_SeaIcePrediction/data/ready_sic_sst_data.nc'
main(folder_path)