In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

In [2]:
dataset = fetch_california_housing()
X, y = dataset.data, dataset.target

X_train_full, X_test, y_train_full, y_test = train_test_split(X, y)
X_train, X_val, y_train, y_val = train_test_split(X_train_full, y_train_full)

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
X_val = scaler.transform(X_val)

In [3]:
model = keras.models.Sequential([
    keras.layers.Dense(30, activation="elu", kernel_initializer="he_normal"),
    keras.layers.BatchNormalization(),
    keras.layers.Dense(1)
])

In [4]:
def random_batch(X, y, batch_size=32):
    ind = np.random.randint(len(X), size=batch_size)
    return X[ind], y[ind]

In [5]:
n_epochs = 5
batch_size = 32
n_steps = len(X_train) // batch_size
loss_fn = keras.losses.mean_squared_error
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
acc_metric = keras.metrics.MeanAbsoluteError()

In [8]:
for epoch in range(1, n_epochs + 1):
    print(f"Epoch {epoch}/{n_epochs}")
    for step in range(1, n_steps + 1):
        X_batch, y_batch = random_batch(X_train, y_train)
        with tf.GradientTape() as g:
            y_pred = model(X_batch, training=True)
            loss = loss_fn(y_batch, y_pred)
        gradients = g.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        
        acc_metric.update_state(y_batch, y_pred)
        print(f"{acc_metric.name} - {acc_metric.result()}")
        
        acc_metric.reset_states()
    

Epoch 1/5
mean_absolute_error - 0.853930652141571
mean_absolute_error - 0.8784062266349792
mean_absolute_error - 0.9771437644958496
mean_absolute_error - 1.029927134513855
mean_absolute_error - 0.9090728759765625
mean_absolute_error - 0.8816479444503784
mean_absolute_error - 0.7363301515579224
mean_absolute_error - 1.0765409469604492
mean_absolute_error - 0.93986976146698
mean_absolute_error - 0.8094229698181152
mean_absolute_error - 1.110020637512207
mean_absolute_error - 0.8939132690429688
mean_absolute_error - 0.9605505466461182
mean_absolute_error - 1.0748507976531982
mean_absolute_error - 0.6595859527587891
mean_absolute_error - 0.9725304841995239
mean_absolute_error - 1.0728195905685425
mean_absolute_error - 1.0419197082519531
mean_absolute_error - 0.9744284749031067
mean_absolute_error - 0.9067421555519104
mean_absolute_error - 0.8653457760810852
mean_absolute_error - 0.6944453716278076
mean_absolute_error - 0.882390558719635
mean_absolute_error - 0.9224207997322083
mean_absolut

mean_absolute_error - 0.7599613666534424
mean_absolute_error - 0.8027165532112122
mean_absolute_error - 0.9925816655158997
mean_absolute_error - 1.0831916332244873
mean_absolute_error - 0.957379162311554
mean_absolute_error - 1.0184533596038818
mean_absolute_error - 0.6982077956199646
mean_absolute_error - 0.7675094604492188
mean_absolute_error - 1.0036479234695435
mean_absolute_error - 0.6863205432891846
mean_absolute_error - 0.9144772291183472
mean_absolute_error - 0.9506895542144775
mean_absolute_error - 0.8134213089942932
mean_absolute_error - 0.9465526342391968
mean_absolute_error - 0.9547934532165527
mean_absolute_error - 0.8975192904472351
mean_absolute_error - 1.0815799236297607
mean_absolute_error - 0.827939510345459
mean_absolute_error - 0.7408430576324463
mean_absolute_error - 0.9514978528022766
mean_absolute_error - 0.7667974829673767
mean_absolute_error - 1.0247960090637207
mean_absolute_error - 0.913394570350647
mean_absolute_error - 0.9483819603919983
mean_absolute_error

mean_absolute_error - 1.0894478559494019
mean_absolute_error - 0.7610623240470886
mean_absolute_error - 0.7196141481399536
mean_absolute_error - 0.8087957501411438
mean_absolute_error - 0.7577946782112122
mean_absolute_error - 0.8042945861816406
mean_absolute_error - 0.8411910533905029
mean_absolute_error - 0.7634831070899963
mean_absolute_error - 0.8160803318023682
mean_absolute_error - 0.8309726715087891
mean_absolute_error - 1.0197370052337646
mean_absolute_error - 1.0319328308105469
mean_absolute_error - 0.7364125847816467
mean_absolute_error - 0.8440775871276855
mean_absolute_error - 1.0198465585708618
mean_absolute_error - 0.9343161582946777
mean_absolute_error - 0.8238666653633118
mean_absolute_error - 0.9976544380187988
mean_absolute_error - 0.9289703965187073
mean_absolute_error - 1.1155295372009277
mean_absolute_error - 1.235798954963684
mean_absolute_error - 1.150693655014038
mean_absolute_error - 0.9350751638412476
mean_absolute_error - 0.9954854249954224
mean_absolute_erro

mean_absolute_error - 0.8729735612869263
mean_absolute_error - 0.8798640370368958
mean_absolute_error - 1.1442341804504395
mean_absolute_error - 0.8535826206207275
mean_absolute_error - 1.163739800453186
mean_absolute_error - 0.9925625324249268
mean_absolute_error - 0.6521564722061157
mean_absolute_error - 1.0348668098449707
mean_absolute_error - 0.8761246204376221
mean_absolute_error - 0.7760699987411499
mean_absolute_error - 1.031869888305664
mean_absolute_error - 0.8116822242736816
mean_absolute_error - 1.051245093345642
mean_absolute_error - 1.2380216121673584
mean_absolute_error - 1.0205483436584473
mean_absolute_error - 0.9876201152801514
mean_absolute_error - 1.0352321863174438
mean_absolute_error - 1.2186896800994873
mean_absolute_error - 0.6808595657348633
mean_absolute_error - 0.8728709816932678
mean_absolute_error - 0.8106642961502075
mean_absolute_error - 1.1300760507583618
mean_absolute_error - 0.7934117317199707
mean_absolute_error - 1.1455366611480713
mean_absolute_error

mean_absolute_error - 0.9636489152908325
mean_absolute_error - 0.7368370890617371
mean_absolute_error - 0.8178056478500366
mean_absolute_error - 0.8780196309089661
mean_absolute_error - 0.8718830347061157
mean_absolute_error - 0.7755018472671509
mean_absolute_error - 0.9506237506866455
mean_absolute_error - 0.8958463072776794
mean_absolute_error - 0.9069450497627258
mean_absolute_error - 0.6731590032577515
mean_absolute_error - 1.1520277261734009
mean_absolute_error - 0.8923872113227844
mean_absolute_error - 0.8649791479110718
mean_absolute_error - 0.8663791418075562
mean_absolute_error - 0.8485002517700195
mean_absolute_error - 1.0489158630371094
mean_absolute_error - 0.7802408337593079
mean_absolute_error - 0.770459771156311
mean_absolute_error - 0.8492239713668823
mean_absolute_error - 0.9301426410675049
mean_absolute_error - 1.0938427448272705
mean_absolute_error - 1.0731182098388672
mean_absolute_error - 0.662129282951355
mean_absolute_error - 0.7474695444107056
mean_absolute_erro

mean_absolute_error - 0.9011262655258179
mean_absolute_error - 0.9060310125350952
mean_absolute_error - 0.7588793039321899
Epoch 4/5
mean_absolute_error - 1.1044398546218872
mean_absolute_error - 0.963528573513031
mean_absolute_error - 0.8943360447883606
mean_absolute_error - 1.0039925575256348
mean_absolute_error - 0.9191592931747437
mean_absolute_error - 1.133172869682312
mean_absolute_error - 0.8102856874465942
mean_absolute_error - 0.8138561248779297
mean_absolute_error - 0.9220722317695618
mean_absolute_error - 0.88869708776474
mean_absolute_error - 0.9747075438499451
mean_absolute_error - 0.9189482927322388
mean_absolute_error - 0.8584659099578857
mean_absolute_error - 0.7776070237159729
mean_absolute_error - 0.9768628478050232
mean_absolute_error - 0.9009659886360168
mean_absolute_error - 0.8264892101287842
mean_absolute_error - 0.7600737810134888
mean_absolute_error - 1.0344144105911255
mean_absolute_error - 0.8208959102630615
mean_absolute_error - 1.1426905393600464
mean_absol

mean_absolute_error - 0.988498866558075
mean_absolute_error - 0.8031429648399353
mean_absolute_error - 0.8094101548194885
mean_absolute_error - 0.8964700698852539
mean_absolute_error - 0.8605004549026489
mean_absolute_error - 0.7488217353820801
mean_absolute_error - 1.074514627456665
mean_absolute_error - 1.2133557796478271
mean_absolute_error - 1.0230101346969604
mean_absolute_error - 0.9488530158996582
mean_absolute_error - 0.7754124999046326
mean_absolute_error - 0.8205567598342896
mean_absolute_error - 0.879523754119873
mean_absolute_error - 0.9850105047225952
mean_absolute_error - 0.8644531965255737
mean_absolute_error - 1.021834135055542
mean_absolute_error - 1.251306414604187
mean_absolute_error - 0.9549989104270935
mean_absolute_error - 1.0070570707321167
mean_absolute_error - 1.1941057443618774
mean_absolute_error - 0.8893712759017944
mean_absolute_error - 1.0618351697921753
mean_absolute_error - 0.9584502577781677
mean_absolute_error - 0.8805280923843384
mean_absolute_error -

mean_absolute_error - 0.7680113315582275
mean_absolute_error - 0.8498488664627075
mean_absolute_error - 0.6625335812568665
mean_absolute_error - 0.5944277048110962
mean_absolute_error - 1.1726248264312744
mean_absolute_error - 0.7055603861808777
mean_absolute_error - 0.7160165309906006
mean_absolute_error - 0.8024841547012329
mean_absolute_error - 0.9058295488357544
mean_absolute_error - 0.9229192137718201
mean_absolute_error - 0.8658384084701538
mean_absolute_error - 0.9689757227897644
mean_absolute_error - 0.9092569947242737
mean_absolute_error - 0.7882813215255737
mean_absolute_error - 0.7410675287246704
mean_absolute_error - 0.842898964881897
mean_absolute_error - 0.8918235301971436
mean_absolute_error - 0.7480751872062683
mean_absolute_error - 0.9574306011199951
mean_absolute_error - 0.7475161552429199
mean_absolute_error - 0.8720283508300781
mean_absolute_error - 0.9962175488471985
mean_absolute_error - 1.0468206405639648
mean_absolute_error - 0.882051944732666
mean_absolute_erro

mean_absolute_error - 0.8258883953094482
mean_absolute_error - 0.9753727912902832
mean_absolute_error - 0.668532133102417
mean_absolute_error - 0.9970394372940063
mean_absolute_error - 0.7262908220291138
mean_absolute_error - 0.9861672520637512
mean_absolute_error - 0.6634024381637573
mean_absolute_error - 0.8335012197494507
mean_absolute_error - 0.9316094517707825
mean_absolute_error - 1.0334148406982422
mean_absolute_error - 0.9658461809158325
mean_absolute_error - 1.0633728504180908
mean_absolute_error - 1.141964316368103
mean_absolute_error - 0.971053421497345
mean_absolute_error - 0.9854966402053833
mean_absolute_error - 0.8230469226837158
mean_absolute_error - 0.8652759194374084
mean_absolute_error - 0.8979952335357666
mean_absolute_error - 0.8362104296684265
mean_absolute_error - 0.7585030794143677
mean_absolute_error - 0.883000373840332
mean_absolute_error - 0.9219729900360107
mean_absolute_error - 0.9081094264984131
mean_absolute_error - 0.9508042335510254
mean_absolute_error 