In [3]:
import tensorflow as tf
import numpy as np
import os
import scipy.io
import datetime

In [9]:
url = 'https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/static/wiki_crop.tar'

data = tf.keras.utils.get_file("wiki_crop",
                              url,untar=True,cache_dir='.',cache_subdir='')

dataset_dir = os.path.join(os.path.dirname(data),"wiki_crop/")
mat = scipy.io.loadmat(os.path.join(dataset_dir,'wiki.mat'))

dob = np.vectorize(lambda x : datetime.datetime.fromordinal(x).year)(mat["wiki"]["dob"][0][0][0])
photo_taken = mat["wiki"]["photo_taken"][0][0][0]
age = (photo_taken - dob).astype(np.float32)

file_path = np.vectorize(lambda x : os.path.join(dataset_dir,x[0]))(mat["wiki"]["full_path"][0][0][0])

file_age_ds = tf.data.Dataset.from_tensor_slices((file_path,age))

def parse_function(filename,label):
    img_string = tf.io.read_file(filename)
    img_decoded = tf.io.decode_jpeg(img_string,channels=1)
    img = tf.image.resize(img_decoded,[256,256])
    return img,tf.expand_dims(label,0)

image_age_ds = file_age_ds.map(parse_function).shuffle(buffer_size=64,seed=2)

data_size = image_age_ds.cardinality().numpy()
AUTOTUNE = tf.data.AUTOTUNE

train_ds = image_age_ds.take(data_size*0.6).batch(32).prefetch(AUTOTUNE)
val_ds = image_age_ds.skip(data_size * 0.6).batch(32).prefetch(AUTOTUNE)
test_ds= image_age_ds.skip(data_size*0.8).batch(32).prefetch(AUTOTUNE)

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32,(7,7),activation="relu",input_shape = (256,256,1)),
    tf.keras.layers.MaxPool2D((4,4),strides=4),
    tf.keras.layers.Conv2D(64,(3,3),activation="relu"),
    tf.keras.layers.MaxPool2D((4,4),strides=4),
    tf.keras.layers.Conv2D(128,(3,3),activation="relu"),
    tf.keras.layers.MaxPool2D((3,3),strides=3),
    tf.keras.layers.Conv2D(256,(1,1),activation="relu"),
    tf.keras.layers.MaxPool2D((2,2),strides=2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64,activation="relu"),
    tf.keras.layers.Dense(1)
])

model.compile(optimizer = "adam",loss = tf.keras.losses.MeanAbsoluteError(),metrics=['MAE'])

model.fit(train_ds,validation_data=val_ds,epochs=10)

loss,accuracy = model.evaluate(test_ds)
print("\n")
print(f"Loss : {loss} ; Accuracy : {accuracy}")

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


Loss : 11.635746955871582 ; Accuracy : 11.635746955871582
