In [2]:
from   dst_predict.imports import etl
from   dst_predict.imports import window

import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from   jaxkan.models.KAN import KAN
from   flax import nnx
from   sklearn.model_selection import train_test_split
from   sklearn.metrics import mean_squared_error
import optax
import pandas as pd

In [3]:
import importlib
importlib.reload(window)

<module 'dst_predict.imports.window' from '/home/a100/Projects/dst_predict/src/dst_predict/imports/window.py'>

In [4]:
file = open("../datasets/WWW_dstae01508718.dat", 'r')
records = etl.read_records(file, read_all=True)
rec = None
with file as f:
    rec = next(records)

rec["data"].size

88416

In [None]:
# Training from data in 64 datapoints:
# 64 / 24 = around 2.6 days of geomagnetic data

training_sets = window.build_training_dataset(rec)
print("Total samples:", len(training_sets["inputs"]))

first_sample = training_sets["inputs"][0]

print("\nKeys in each sample:")
print(first_sample.keys())

print("\nShape of dst_nT window:")
print(first_sample["dst_nT"].shape)

print("\nShape of time_enc window:")
print(first_sample["time_enc"].shape)

In [None]:
df_truths = pd.DataFrame(training_sets["truths"])
df_truths

Unnamed: 0,0
0,-10.0
1,-9.0
2,-10.0
3,-13.0
4,-14.0
...,...
530050,-7.0
530051,-4.0
530052,-7.0
530053,-4.0


In [None]:
X = []

for sample in training_sets["inputs"]:
    dst = sample["dst_nT"]              
    time = sample["time_enc"]           

    combined = np.concatenate(
        [dst.reshape(-1, 1), time], axis=1
    )                                  

    X.append(combined.flatten())         

X = np.array(X)
y = np.array(training_sets["truths"])

min_len = min(len(X), len(y))
X = X[:min_len]
y = y[:min_len]

print("X shape:", X.shape)
print("y shape:", y.shape)

X shape: (530055, 448)
y shape: (530055,)


In [None]:
split_index = int(0.8 * len(X))

X_train = X[:split_index]
X_test  = X[split_index:]

y_train = y[:split_index]
y_test  = y[split_index:]

In [None]:
print("Train shape:", X_train.shape, y_train.shape)
print("Test shape:", X_test.shape, y_test.shape)

Train shape: (424044, 448) (424044,)
Test shape: (106011, 448) (106011,)


In [None]:
X_mean = X_train.mean(axis=0)
X_std  = X_train.std(axis=0) + 1e-8

X_train = (X_train - X_mean) / X_std
X_test  = (X_test - X_mean) / X_std

y_mean = y_train.mean()
y_std  = y_train.std() + 1e-8

y_train = (y_train - y_mean) / y_std
y_test  = (y_test - y_mean) / y_std

In [None]:
X_train = jnp.array(X_train)
X_test  = jnp.array(X_test)
y_train = jnp.array(y_train).reshape(-1, 1)
y_test  = jnp.array(y_test).reshape(-1, 1)

print("JAX Train shape:", X_train.shape, y_train.shape)
print("JAX Test shape:", X_test.shape, y_test.shape)

JAX Train shape: (424044, 448) (424044, 1)
JAX Test shape: (106011, 448) (106011, 1)


In [None]:
input_dim = X.shape[1]

req_params = {
    "k": 3,  
    "G": 10   
}

model = KAN(
    layer_dims=[input_dim, 128, 64, 1],
    layer_type='base',
    required_parameters=req_params,
    seed=42
)

pred = model(X_train[:64])
print(pred.shape)

(64, 1)


In [None]:
model_and_optimizer = nnx.ModelAndOptimizer(
    model,
    optax.adam(1e-3)
)

In [None]:
@nnx.jit
def train_step(model_and_optimizer, X, y):

    def loss_fn(model):
        preds = model(X)
        return jnp.mean((preds - y) ** 2)

    loss, grads = nnx.value_and_grad(loss_fn)(model_and_optimizer.model)
    model_and_optimizer.update(grads)

    return loss

In [None]:
batch_size = 64
epochs = 1000

def get_batches(X, y, batch_size):
    for i in range(0, len(X), batch_size):
        yield X[i:i+batch_size], y[i:i+batch_size]

for epoch in range(epochs):

    perm = np.random.permutation(len(X_train))
    X_train = X_train[perm]
    y_train = y_train[perm]

    epoch_loss = 0.0
    num_batches = 0

    for X_batch, y_batch in get_batches(X_train, y_train, batch_size):
        loss = train_step(model_and_optimizer, X_batch, y_batch)
        epoch_loss += loss
        num_batches += 1

    epoch_loss /= num_batches

    if epoch % 50 == 0:
        print(f"Epoch {epoch}, Loss: {epoch_loss}")

NameError: name 'np' is not defined

In [None]:
batch_size = 64   

test_loss = 0.0
num_batches = 0

for i in range(0, len(X_test), batch_size):
    X_batch = X_test[i:i+batch_size]
    y_batch = y_test[i:i+batch_size]

    preds = model_and_optimizer.model(X_batch)
    loss = jnp.mean((preds - y_batch) ** 2)

    test_loss += loss
    num_batches += 1

test_loss /= num_batches

print("Test Loss:", test_loss)

In [None]:
print("y min:", y.min())
print("y max:", y.max())