In [None]:
import wandb
from wandb.keras import WandbCallback
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import pandas as pd
import tensorflow as tf
from sklearn.model_selection import train_test_split
# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

In [None]:
train_df = pd.read_csv('/kaggle/input/datadir-celeb-a/aligned_celeba.txt', sep='\t')
train_df['datadir'] = '/kaggle/input/celeba-dataset/img_align_celeba/img_align_celeba/' + train_df['datadir'].astype(str)

In [None]:
defaults = {
    'epochs': 20,
    'batch_size': 32,
    'fc1_num_neurons': 1024,
    'fc2_num_neurons': 512,
    'fc3_num_neurons': 256,
    'seed': 7,
    'learning_rate': 3e-4,
    'optimizer': 'adam',
    'hidden_activation': 'relu',
    'output_activation': 'sigmoid',
    'loss_function': 'binary_crossentropy',
    'metrics': ['accuracy'],
}

wandb.init(config=defaults, resume=True, name='No Validation', project='CelebA Runs', notes='use full dataset for training')
config = wandb.config

# Load images into keras image generator 
datagen_train = ImageDataGenerator(
    preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input,
)

train_generator = datagen_train.flow_from_dataframe(
    dataframe=train_df,
    x_col='datadir',
    y_col='gender',
    batch_size=config.batch_size,
    seed=config.seed,
    shuffle=True,
    class_mode='raw',
    target_size=(224,224),
)

mobile_net_v2 = tf.keras.applications.MobileNetV2(
    include_top=False,
    pooling='avg',
    weights='imagenet',
    input_shape=(224,224,3),
)
mobile_net_v2.trainable = True

fc1 = tf.keras.layers.Dense(
    config.fc1_num_neurons,
    activation=config.hidden_activation,
)

fc2 = tf.keras.layers.Dense(
    config.fc2_num_neurons,
    activation=config.hidden_activation,
)

fc3 = tf.keras.layers.Dense(
    config.fc2_num_neurons,
    activation=config.hidden_activation,
)

bn1 = tf.keras.layers.BatchNormalization()
bn2 = tf.keras.layers.BatchNormalization()
bn3 = tf.keras.layers.BatchNormalization()
bn4 = tf.keras.layers.BatchNormalization()

model = tf.keras.models.Sequential([
    mobile_net_v2,
    tf.keras.layers.Flatten(),
    bn1,
    fc1,
    bn2,
    fc2,
    bn3,
    fc3,
    bn4,
    tf.keras.layers.Dense(1, activation=config.output_activation),
])

model.summary()

# Compile model 
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=config.learning_rate),
    loss=config.loss_function,
    metrics=config.metrics,
)

model.fit(
    train_generator,
    shuffle=True,
    epochs=config.epochs,
    callbacks=[WandbCallback()],
)
model.save_weights('model_celeba_no_val.h5') 
# run.finish()

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


wandb: Paste an API key from your profile and hit enter: ········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: wandb version 0.10.10 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Found 202599 validated image filenames.
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
mobilenetv2_1.00_224 (Functi (None, 1280)              2257984   
_________________________________________________________________
flatten (Flatten)            (None, 1280)              0         
_________________________________________________________________
batch_normalization (BatchNo (None, 1280)              5120      
_________________________________________________________________
dense (Dense)                (None, 1024)              1311744   
_________________________________________________________________
batch_normalization_1 (Batch (None, 1024)              4096      
_________________________________________________________________
dense_1 (Dense)              (None, 512)               524800    
_________________________________________________________________
batch_normalizat

KeyboardInterrupt: 