In [None]:
import wandb
from wandb.integration.keras import WandbMetricsLogger, WandbModelCheckpoint
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow import keras

In [None]:
sweep_config = {
    'method' : 'grid',
    'metric': {
        'name' : 'val_accuracy',
        'goal' : 'maximize'
    },
    'parameters' : {
        'batch_size' : { 'values' : [8, 16, 32, 64, 128]},
        'learning_rate' : { 'values' : [0.001, 0.0001, 0.00001]},
        'hidden_nodes': {'values' : [32, 64, 128, 256]},
        'img_size' : {'values' : [16, 64, 224]},
        'epochs' : {'values': [5, 10]}
    }
}
sweep_id = wandb.sweep(sweep_config, project="5-Flower-Dataset")

In [None]:
def parse_csvline(csv_line):
  # print("csv line:", csv_line)
  # record_defaults specify the data types for each columns
  record_default = ["", ""]
  filename, label_string =tf.io.decode_csv(csv_line, record_defaults=record_default)


  #load the image
  img= read_and_decode(filename, [IMG_HEIGHT, IMG_WIDTH])
  # print("Label String:",label_string)
  label = tf.argmax(tf.math.equal(CLASS_NAMES, label_string))
  return img, label

In [None]:
def train():
  with wandb.init() as run:
    config= wandb.config
    import tensorflow as tf
    IMG_HEIGHT= config.img_size
    IMG_WIDTH= config.img_size
    IMG_CHANNELS=3
    CLASS_NAMES = ["daisy", "dandelion", "roses", "sunflowers", "tulips"]
    def read_and_decode(filename, resize_dims):
      # 1. Read the raw file
      img_bytes= tf.io.read_file(filename)
      # 2. Decode image data
      img= tf.image.decode_jpeg(img_bytes, channels=IMG_CHANNELS)
      # 3. Convert image to float values in [0, 1]
      img= tf.image.convert_image_dtype(img, tf.float32)
      # 4. Resize the image to the match the desire dimention
      img= tf.image.resize(img, resize_dims)
      return img

    def parse_csvline(csv_line):
      # print("csv line:", csv_line)
      # record_defaults specify the data types for each columns
      record_default = ["", ""]
      filename, label_string =tf.io.decode_csv(csv_line, record_defaults=record_default)


      #load the image
      img= read_and_decode(filename, [IMG_HEIGHT, IMG_WIDTH])
      # print("Label String:",label_string)
      label = tf.argmax(tf.math.equal(CLASS_NAMES, label_string))
      return img, label

    #Define dataset
    train_dataset = (
        tf.data.TextLineDataset("gs://cloud-ml-data/img/flower_photos/train_set.csv")
        #.map(parse_csvline) # it will process one by one line to the map function which is slow
        #.map(parse_csvline, num_parallel_calls=4) # it will process one by four line to the map function which is faster

        .map(parse_csvline, num_parallel_calls=tf.data.AUTOTUNE) # It will adjust the number of line to the function depends upon the cpu
        .batch(config.batch_size)
        .prefetch(tf.data.AUTOTUNE)
    )
    test_dataset = (
        tf.data.TextLineDataset("gs://cloud-ml-data/img/flower_photos/eval_set.csv")
        .map(parse_csvline, num_parallel_calls=tf.data.AUTOTUNE)
        .batch(config.batch_size)
        .prefetch(tf.data.AUTOTUNE) # When the model is training the current batch then it will prepare the next batch in the background.
    )
    # Define the  base model
    base_model= tf.keras.applications.MobileNetV2(
    input_shape=(IMG_WIDTH, IMG_HEIGHT, IMG_CHANNELS),
    include_top=False,
    weights='imagenet'
    )
    base_model.trainable = False  # Freeze the base model


    regularizer=tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01)
    model = keras.Sequential([
    base_model,
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.BatchNormalization(),
    keras.layers.Dense(len(CLASS_NAMES), kernel_regularizer=regularizer),
    keras.layers.Activation('softmax')

    ])

    model.compile(optimizer='adam', loss= keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['accuracy'])

    model.fit(
    train_dataset,
    validation_data=test_dataset,
    epochs=config.epochs,
    callbacks=[WandbMetricsLogger(),
               tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=2, restore_best_weights=True),
               ]
    )
