# California House Prices with MLP

### $\textbf{Task: Create a MLP regressor for predicting California home prices}$

In [16]:
from typing import Dict, Tuple, List

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import flax.linen as nn
from jax.nn.initializers import lecun_normal, lecun_uniform
import optax
import jax
from jax.nn import one_hot
import jax.numpy as jnp

In [17]:
df = pd.read_csv("datasets/housing.csv")
df.head(10)

Unnamed: 0,longitude,latitude,housing_median_age,total_rooms,total_bedrooms,population,households,median_income,median_house_value,ocean_proximity
0,-122.23,37.88,41.0,880.0,129.0,322.0,126.0,8.3252,452600.0,NEAR BAY
1,-122.22,37.86,21.0,7099.0,1106.0,2401.0,1138.0,8.3014,358500.0,NEAR BAY
2,-122.24,37.85,52.0,1467.0,190.0,496.0,177.0,7.2574,352100.0,NEAR BAY
3,-122.25,37.85,52.0,1274.0,235.0,558.0,219.0,5.6431,341300.0,NEAR BAY
4,-122.25,37.85,52.0,1627.0,280.0,565.0,259.0,3.8462,342200.0,NEAR BAY
5,-122.25,37.85,52.0,919.0,213.0,413.0,193.0,4.0368,269700.0,NEAR BAY
6,-122.25,37.84,52.0,2535.0,489.0,1094.0,514.0,3.6591,299200.0,NEAR BAY
7,-122.25,37.84,52.0,3104.0,687.0,1157.0,647.0,3.12,241400.0,NEAR BAY
8,-122.26,37.84,42.0,2555.0,665.0,1206.0,595.0,2.0804,226700.0,NEAR BAY
9,-122.25,37.84,52.0,3549.0,707.0,1551.0,714.0,3.6912,261100.0,NEAR BAY


In [18]:
df.describe()

Unnamed: 0,longitude,latitude,housing_median_age,total_rooms,total_bedrooms,population,households,median_income,median_house_value
count,20640.0,20640.0,20640.0,20640.0,20433.0,20640.0,20640.0,20640.0,20640.0
mean,-119.569704,35.631861,28.639486,2635.763081,537.870553,1425.476744,499.53968,3.870671,206855.816909
std,2.003532,2.135952,12.585558,2181.615252,421.38507,1132.462122,382.329753,1.899822,115395.615874
min,-124.35,32.54,1.0,2.0,1.0,3.0,1.0,0.4999,14999.0
25%,-121.8,33.93,18.0,1447.75,296.0,787.0,280.0,2.5634,119600.0
50%,-118.49,34.26,29.0,2127.0,435.0,1166.0,409.0,3.5348,179700.0
75%,-118.01,37.71,37.0,3148.0,647.0,1725.0,605.0,4.74325,264725.0
max,-114.31,41.95,52.0,39320.0,6445.0,35682.0,6082.0,15.0001,500001.0


In [19]:
nan_indexes = df.isna().sum(axis=0)
print(f"Nan element counts:\n{nan_indexes}")
df = df.dropna()

Nan element counts:
longitude               0
latitude                0
housing_median_age      0
total_rooms             0
total_bedrooms        207
population              0
households              0
median_income           0
median_house_value      0
ocean_proximity         0
dtype: int64


In [20]:
numerical_cols = df.select_dtypes(include=['number']).columns
df['ocean_proximity'], _ = pd.factorize(df['ocean_proximity'])
train_df, test_df = train_test_split(df, test_size = 0.1)

In [21]:
def get_moments(df : pd.DataFrame, cols : List[str]) -> Dict[str, Tuple[float, float]]:
    col_moment_map : Dict[str, Tuple[float, float]] = dict()
    for col in cols:
        col_mean = df.loc[:, col].mean()
        col_std = df.loc[:, col].std(ddof=1)
        col_moment_map[col] = (col_mean, col_std)
    return col_moment_map

def standardise(df : pd.DataFrame, moment_map) -> pd.DataFrame:
    for col in moment_map.keys():
        df[col] = (df[col] - moment_map[col][0]) / moment_map[col][1]
    return df

In [22]:
train_moments = get_moments(train_df, numerical_cols)
train_df = standardise(train_df, train_moments)
test_df = standardise(test_df, train_moments)
train_df.describe()

Unnamed: 0,longitude,latitude,housing_median_age,total_rooms,total_bedrooms,population,households,median_income,median_house_value,ocean_proximity
count,18389.0,18389.0,18389.0,18389.0,18389.0,18389.0,18389.0,18389.0,18389.0,18389.0
mean,-4.854673e-15,-1.339247e-15,5.583415e-17,-4.675386e-17,7.186957e-17,2.9366060000000004e-17,5.023142e-17,-1.275105e-16,7.843829000000001e-17,1.466801
std,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.853501
min,-2.391925,-1.447872,-2.189491,-1.204508,-1.273583,-1.248638,-1.304326,-1.775834,-1.664761,0.0
25%,-1.111558,-0.796758,-0.8416473,-0.5424267,-0.5746356,-0.5619618,-0.5725872,-0.6878501,-0.7573493,1.0
50%,0.5389158,-0.6421769,0.03048691,-0.2333032,-0.245301,-0.229134,-0.2354645,-0.1759478,-0.2351156,1.0
75%,0.7789846,0.9738985,0.6647664,0.2324369,0.2569936,0.2622249,0.2741395,0.4586259,0.5022577,2.0
max,2.629515,2.960032,1.85404,16.7484,13.99428,30.00126,14.58748,5.858319,2.542621,4.0


In [23]:
test_df.describe()

Unnamed: 0,longitude,latitude,housing_median_age,total_rooms,total_bedrooms,population,households,median_income,median_house_value,ocean_proximity
count,2044.0,2044.0,2044.0,2044.0,2044.0,2044.0,2044.0,2044.0,2044.0,2044.0
mean,-0.01583,0.010776,0.013963,-0.015722,-0.015678,-0.032068,-0.017426,-0.009111,-0.003322,1.446673
std,1.020705,1.007411,0.983436,0.978027,0.983952,0.922391,0.990902,0.999711,1.014171,0.860621
min,-2.346912,-1.438504,-2.110206,-1.200399,-1.266475,-1.239879,-1.293873,-1.775834,-1.556315,0.0
25%,-1.137815,-0.796758,-0.762362,-0.5527,-0.581744,-0.556926,-0.583041,-0.697393,-0.759084,1.0
50%,0.513909,-0.642177,0.030487,-0.246088,-0.25004,-0.246651,-0.252451,-0.181529,-0.247694,1.0
75%,0.783986,0.983267,0.664766,0.213145,0.241001,0.210549,0.253233,0.457836,0.499872,2.0
max,2.50448,2.927242,1.85404,11.52024,10.634589,9.220547,11.380894,5.858319,2.542621,3.0


In [24]:
print(train_df.isna().any())

longitude             False
latitude              False
housing_median_age    False
total_rooms           False
total_bedrooms        False
population            False
households            False
median_income         False
median_house_value    False
ocean_proximity       False
dtype: bool


## Modelling

In [25]:
def dropout(rng, x, rate):
    mask = jax.random.bernoulli(rng, shape=x.shape, p=rate)
    return x * mask / rate

In [26]:
class MLP(nn.Module):
    hidden_layer_sizes : list

    @nn.compact
    def __call__(self, x, rng, rate = 1.0):
        for hidden_dim in self.hidden_layer_sizes[:-1]:
            rng, rng_dropout = jax.random.split(rng, 2)
            x = nn.Dense(features=hidden_dim, kernel_init=lecun_normal())(x)
            x = nn.sigmoid(x)
            x = dropout(rng, x, rate)
        x = nn.Dense(features=self.hidden_layer_sizes[-1], kernel_init = lecun_normal())(x)
        return nn.sigmoid(x)

model = MLP(hidden_layer_sizes=[10, 10, 1])

In [27]:
def get_batch(df, batch_size):
    batch_samples = df.sample(batch_size)
    batch_inputs = batch_samples.loc[:, df.columns != "median_house_value"]
    batch_outputs = batch_samples.loc[:, "median_house_value"]
    encoded_proximities = one_hot(batch_inputs.loc[:, "ocean_proximity"].values, num_classes = 5)
    batch_inputs = batch_inputs.loc[:, batch_inputs.columns != "ocean_proximity"]
    batch_inputs = jnp.concatenate([batch_inputs.values, encoded_proximities], axis=1)
    return jnp.array(batch_inputs), jnp.array(batch_outputs)

In [28]:
rng_key = jax.random.PRNGKey(42)
rng_key, init_rng, dropout_rng = jax.random.split(rng_key, 3)

num_epochs = 1
batch_size = 128
N = train_df.shape[0]

example_batch = jnp.ones((batch_size, 13))
params = model.init(init_rng, example_batch, dropout_rng)

optimiser = optax.chain(
    optax.clip(1.0),
    optax.adam(learning_rate = 5e-3)
)

optimiser_state = optimiser.init(params)

def mse_loss(params, batch_in, batch_out, rng):
    return jnp.mean(jnp.square(model.apply(params, batch_in, rng, rate=0.1) - batch_out))

In [29]:
for epoch in range(num_epochs):
    batch_losses = []
    for batch in range(N // batch_size):
        dropout_rng, batch_rng = jax.random.split(dropout_rng, 2)
        batch_inputs, batch_outputs = get_batch(train_df, batch_size)
        loss, grads = jax.value_and_grad(mse_loss, argnums=(0))(params, batch_inputs, batch_outputs, batch_rng)
        #grads = optax.clip_by_global_norm(grads, 1.0)
        updates, optimiser_state = optimiser.update(grads, optimiser_state)
        params = optax.apply_updates(params, updates)
        batch_losses.append(loss)        
    if ((epoch + 1) % 1 == 0):
        print(f"Epoch {epoch + 1} average training loss: {jnp.array(batch_losses).mean()}")

Epoch 1 average training loss: 1.0709253549575806
