In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '7'
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

import jax.numpy as jnp
import pylab as plt
import seaborn as sns
import pandas as pd
import numpy as np
from process_data import process_data
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, StandardScaler

In [2]:
# Returns the columnn names and the data
cols, data = process_data("bld1.csv")

  data = pd.read_csv(filename)


In [3]:
# find indices at which the column names contain 'Zone Air Temperature'
# we want to make these what we're predicting.
# the rest is just inputs to our model
idx_temperature = np.array(['Zone Air Temperature' in col for col in cols])
X = data[cols[~idx_temperature]]
y = data[cols[idx_temperature]]
y1 = y.iloc[:,0:1]
y2 = y.iloc[:,1:2]
y3 = y.iloc[:,2:3]

X = np.concatenate([X, y1, y2, y3], axis = 1)[:-1]
y = np.concatenate([y1, y2, y3], axis = 1)[1:]

In [4]:
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state = 20)

scaler_x = StandardScaler()
X_train = scaler_x.fit_transform(X_train)
X_test = scaler_x.transform(X_test)

scaler_y = StandardScaler()
y_train = scaler_y.fit_transform(y_train)
y_test = scaler_y.transform(y_test)

In [5]:
X_train = jnp.array(X_train)
X_test = jnp.array(X_test)
y_train = jnp.array(y_train)
y_test = jnp.array(y_test)

In [6]:
import jax.numpy as jnp
import flax.linen as nn
import matplotlib.pyplot as pl
import itertools
import jax
import optax
import jaxopt
import pickle
import h5py
import warnings
import wandb
import process_data

warnings.filterwarnings("ignore")
from jax import random, jit, grad, jacfwd, vmap
from optax_adan import adan
from jax.nn import sigmoid, gelu, relu, softmax, softplus
from functools import partial
from tqdm import trange
from torch.utils import data
from typing import Sequence, Callable, Any, Tuple
from absl import logging

logging.set_verbosity(logging.INFO)

# from tueplots import bundles
from jax.nn.initializers import glorot_normal, normal, zeros

In [7]:
class MLP(nn.Module):
    features: Sequence[int]
    activation: Callable

    def setup(self):
        # Regular Dense layer:
        self.layers = [nn.Dense(feat, kernel_init = jax.nn.initializers.glorot_normal()) for feat in self.features]

    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = self.activation(layer(x))
        # Final Layer has no activation function applied:
        return self.layers[-1](x)

In [8]:
class DataGenerator(data.Dataset):
    def __init__(self, key, X, y, batch_size):
        self.key = key
        self.X = X
        self.y = y
        self.batch_size = batch_size

    def __getitem__(self, index):
        "Generate one batch of data"
        self.key, subkey = random.split(self.key)
        batch = self.__data_generation(subkey)
        return batch

    @partial(jit, static_argnums=(0,))
    def __data_generation(self, key):
        idx = random.choice(key, self.X.shape[0], shape=(self.batch_size,))
        batch = (self.X[idx], self.y[idx])
        return batch

In [9]:
class Regression:
    def __init__(
        self,
        key,
        layers,
    ):
        # Initialize parameters:
        model1 = MLP(layers, jnp.sin)
        model2 = MLP(layers, jnp.sin)
        model3 = MLP(layers, jnp.sin)
        
        self.init1, self.apply1 = model1.init, model1.apply
        self.init2, self.apply2 = model2.init, model2.apply
        self.init3, self.apply3 = model3.init, model3.apply
        
        keys = random.split(key, 6)
        params1 = self.init1(keys[0], random.normal(keys[1], (44,)))
        params2 = self.init2(keys[2], random.normal(keys[3], (44,)))
        params3 = self.init3(keys[4], random.normal(keys[5], (44,)))
        
        self.params = (params1, params2, params3)
        self.avg_params = self.params
        learning_schedule = optax.cosine_decay_schedule(1e-3, 200000, 1e-5)
        self.optimizer = optax.adabelief(1e-3)
        self.opt_state = self.optimizer.init(self.params)
        self.loss_log = []        

    @partial(jit, static_argnums=(0,))
    def loss(self, params, batch):
        params1, params2, params3 = params
        X, y = batch
        y1, y2, y3 = y[:, 0], y[:, 1], y[:, 2]
        y1_pred = self.apply1(params1, X)
        y2_pred = self.apply2(params2, X)
        y3_pred = self.apply3(params3, X)
        loss = jnp.mean((y1_pred[:, 0] - y1) ** 2 + (y2_pred[:, 0] - y2) ** 2 + (y3_pred[:, 0] - y3) ** 2)
        return loss
    
    @partial(jit, static_argnums=(0,))
    def step(self, params, opt_state, batch):
        grad = jax.grad(self.loss)(params, batch)
        updates, opt_state = self.optimizer.update(grad, opt_state)
        params = optax.apply_updates(params, updates)
        return params, opt_state

    @partial(jit, static_argnums=(0,))
    def ema_update(self, params, avg_params):
        return optax.incremental_update(params, avg_params, step_size=0.001)
    
    def model_predict(self, params, X):
        y1 = self.apply1(params[0], X)
        y2 = self.apply2(params[1], X)
        y3 = self.apply3(params[2], X)
        return y1, y2, y3

    # Optimize parameters in a loop
    def train(self, dataloader, nIter=10000):
        pbar = trange(nIter)
        # Main training loop
        for it in pbar:
            batch = next(dataloader)
            self.params, self.opt_state = self.step(self.params, self.opt_state, batch)
            self.avg_params = self.ema_update(self.params, self.avg_params)
            # Logger
            if it % 1000 == 0:
                params = self.avg_params
                loss = self.loss(params, batch)
                pbar.set_postfix({"loss": loss})

In [10]:
loader = iter(DataGenerator(random.PRNGKey(0), X_train, y_train, 1024))
model = Regression(random.PRNGKey(42), [256, 256, 256, 1])

In [11]:
model.train(loader, nIter=500000)

100%|██████████| 500000/500000 [08:40<00:00, 959.74it/s, loss=0.00013228813] 


In [13]:
y_test_pred = model.model_predict(model.avg_params, X_test)

In [18]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state = 20)
X_train = scaler.fit_transform(X_train)
X_new = scaler.transform(X[:432])

y_train = scaler.fit_transform(y_train)

In [None]:
pl.plot(y_new[0].ravel(), label = 'Predicted')


In [None]:
from sklearn.metrics import r2_score
acte = r2_score(y_test, y_test_pred)

In [None]:
abs(y_test_pred - y_test).min()

DeviceArray(1.9073486e-06, dtype=float32)

In [None]:
y_test

DeviceArray([[  4.7566223,  -8.610168 , -13.243989 ],
             [ 24.469007 ,  -5.446783 ,  -3.693832 ],
             [ 33.19131  ,   8.277213 ,   4.131795 ],
             ...,
             [ 25.102198 ,   2.096973 ,  -1.5878501],
             [ 24.972366 ,   6.1852317,   4.4039264],
             [ 22.779865 ,  22.935507 ,  40.540382 ]], dtype=float32)