In [13]:
import os
import pandas as pd
import jax.numpy as np

# 2.2.1. Reading the Dataset

In [6]:
def mkdir_if_not_exist(path):
    if not isinstance(path, str):
        path = os.path.join(*path)
    if not os.path.exists(path):
        os.makedirs(path)

In [7]:
data_file = '../data/house_tiny.csv'
mkdir_if_not_exist('../data')
with open(data_file, 'w') as f:
    f.write('NumRooms,Alley,Price\n')  # Column names
    f.write('NA,Pave,127500\n')  # Each row is a data point
    f.write('2,NA,106000\n')
    f.write('4,NA,178100\n')
    f.write('NA,NA,140000\n')

In [10]:
data = pd.read_csv(data_file)
print(data)

   NumRooms Alley   Price
0       NaN  Pave  127500
1       2.0   NaN  106000
2       4.0   NaN  178100
3       NaN   NaN  140000


# 2.2.2. Handling Missing Data

In [11]:
inputs, outputs = data.iloc[:, 0:2], data.iloc[:, 2]
inputs = inputs.fillna(inputs.mean())
print(inputs)

   NumRooms Alley
0       3.0  Pave
1       2.0   NaN
2       4.0   NaN
3       3.0   NaN


In [12]:
inputs = pd.get_dummies(inputs, dummy_na=True)
print(inputs)

   NumRooms  Alley_Pave  Alley_nan
0       3.0           1          0
1       2.0           0          1
2       4.0           0          1
3       3.0           0          1


# 2.2.3. Conversion to the ndarray Format

In [14]:
X, y = np.array(inputs.values), np.array(outputs.values)
X, y



(DeviceArray([[3., 1., 0.],
              [2., 0., 1.],
              [4., 0., 1.],
              [3., 0., 1.]], dtype=float32),
 DeviceArray([127500, 106000, 178100, 140000], dtype=int32))