In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
import pandas as pd
import numpy as np

from sklearn.preprocessing import StandardScaler

In [2]:
PREDICTORS = ["tmax", "tmin", "rain"]
TARGET = "tmax_tomorrow"

data = pd.read_csv("clean_weather.csv", index_col=0)
data = data.ffill()

scaler = StandardScaler()
data[PREDICTORS] = scaler.fit_transform(data[PREDICTORS])

split_data = np.split(data, [int(.7*len(data)), int(.85*len(data))])
(train_x, train_y), (valid_x, valid_y), (test_x, test_y) = [[d[PREDICTORS].to_numpy(), d[[TARGET]].to_numpy()] for d in split_data]

In [3]:
def init_layers(inputs):
    layers = []
    for i in range(1, len(inputs)):
        layers.append([
            np.random.rand(inputs[i-1], inputs[i]) / 5 - .1,
            np.ones((1,inputs[i]))
        ])
    return layers

In [4]:
def forward(layers, x):
    for i in range(len(layers)):
        x = jnp.matmul(x, layers[i][0]) + layers[i][1]
        if i < len(layers) - 1:
            x = jnp.maximum(x, 0)
    return x

In [5]:
def mse(y, preds):
    return jnp.mean((y - preds)**2)

def loss(layers, x, y):
    preds = forward(layers, x) 
    return mse(y, preds)

In [6]:
@jit
def backward(layers, x, y):
    grads = grad(loss)(layers, x, y)
    for layer, g in zip(layers, grads):
        layer[0] -= (g[0] + layer[0] * .01) * lr 
        layer[1] -= g[1] * lr
    return layers

In [10]:
layer_conf = [3,10,10,1]
lr = 1e-6
batch_size = 64
epochs=15

layers = init_layers(layer_conf)

for epoch in range(epochs+1):

    for i in range(0, train_x.shape[0] - batch_size, batch_size):
        batch_ind = range(i, min(train_x.shape[0]-1, i + batch_size))
        layers = backward(layers, train_x[batch_ind,:].copy(), train_y[batch_ind,:])
    
    valid_preds = forward(layers, valid_x)
    print(f"Epoch: {epoch} Valid MSE: {mse(valid_y, valid_preds)}")

Epoch: 0 Valid MSE: 4447.9287109375
Epoch: 1 Valid MSE: 4447.90625
Epoch: 2 Valid MSE: 4447.88525390625
Epoch: 3 Valid MSE: 4447.86376953125
Epoch: 4 Valid MSE: 4447.84130859375
Epoch: 5 Valid MSE: 4447.81982421875
Epoch: 6 Valid MSE: 4447.7978515625
Epoch: 7 Valid MSE: 4447.7763671875
Epoch: 8 Valid MSE: 4447.75439453125
Epoch: 9 Valid MSE: 4447.73291015625
Epoch: 10 Valid MSE: 4447.7109375
Epoch: 11 Valid MSE: 4447.68896484375
Epoch: 12 Valid MSE: 4447.6669921875
Epoch: 13 Valid MSE: 4447.64501953125
Epoch: 14 Valid MSE: 4447.623046875
Epoch: 15 Valid MSE: 4447.6015625


In [8]:
test_preds = forward(layers, test_x)
mse(test_y, test_preds)

DeviceArray(4586.5703, dtype=float32)