In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
import pandas as pd
import numpy as np

from sklearn.preprocessing import StandardScaler

In [2]:
PREDICTORS = ["tmax", "tmin", "rain"]
TARGET = "tmax_tomorrow"

data = pd.read_csv("../../data/clean_weather.csv", index_col=0)
data = data.ffill()

scaler = StandardScaler()
data[PREDICTORS] = scaler.fit_transform(data[PREDICTORS])

split_data = np.split(data, [int(.7*len(data)), int(.85*len(data))])
(train_x, train_y), (valid_x, valid_y), (test_x, test_y) = [[d[PREDICTORS].to_numpy(), d[[TARGET]].to_numpy()] for d in split_data]

In [3]:
def init_layers(inputs):
    layers = []
    for i in range(1, len(inputs)):
        layers.append([
            np.random.rand(inputs[i-1], inputs[i]) / 5 - .1,
            np.ones((1,inputs[i]))
        ])
    return layers

In [4]:
def forward(layers, x):
    for i in range(len(layers)):
        x = jnp.matmul(x, layers[i][0]) + layers[i][1]
        if i < len(layers) - 1:
            x = jnp.maximum(x, 0)
    return x

In [5]:
def mse(y, preds):
    return jnp.mean((y - preds)**2)

def loss(layers, x, y):
    preds = forward(layers, x) 
    return mse(y, preds)

In [6]:
@jit
def backward(layers, x, y):
    grads = grad(loss)(layers, x, y)
    for layer, g in zip(layers, grads):
        layer[0] -= (g[0] + layer[0] * .01) * lr 
        layer[1] -= g[1] * lr
    return layers

In [7]:
layer_conf = [3,10,10,1]
lr = 5e-7
batch_size = 32
epochs=100

layers = init_layers(layer_conf)

for epoch in range(epochs+1):

    for i in range(0, train_x.shape[0] - batch_size, batch_size):
        batch_ind = range(i, min(train_x.shape[0]-1, i + batch_size))
        layers = backward(layers, train_x[batch_ind,:].copy(), train_y[batch_ind,:])
    
    if epoch % 10 == 0:
        valid_preds = forward(layers, valid_x)
        print(f"Epoch: {epoch} Valid MSE: {mse(valid_y, valid_preds)}")

Epoch: 0 Valid MSE: 4421.083984375
Epoch: 10 Valid MSE: 4088.173095703125
Epoch: 20 Valid MSE: 3300.42919921875
Epoch: 30 Valid MSE: 1643.950439453125
Epoch: 40 Valid MSE: 258.70050048828125
Epoch: 50 Valid MSE: 40.592655181884766
Epoch: 60 Valid MSE: 25.406206130981445
Epoch: 70 Valid MSE: 23.1761417388916
Epoch: 80 Valid MSE: 22.51946258544922
Epoch: 90 Valid MSE: 22.194660186767578
Epoch: 100 Valid MSE: 21.970237731933594


In [9]:
test_preds = forward(layers, test_x)
mse(test_y, test_preds)

DeviceArray(24.077864, dtype=float32)