Inspired from this notebook https://www.kaggle.com/jessemostipak/getting-started-tpus-cassava-leaf-disease

In [1]:
import tensorflow as tf
import cv2
import pathlib
import numpy as np # linear algebra
import pandas as pd 
import os
from sklearn.model_selection import train_test_split
from functools import partial
import re

In [2]:
try:
    tpu=tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy=tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy=tf.distribute.get_strategy()
print(strategy.num_replicas_in_sync)
    


Device: grpc://10.0.0.2:8470
8


In [3]:
base_dir='/kaggle/input/cassava-leaf-disease-classification'

In [4]:
os.listdir(base_dir)

['train_tfrecords',
 'train.csv',
 'train_images',
 'test_tfrecords',
 'test_images',
 'label_num_to_disease_map.json',
 'sample_submission.csv']

In [5]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
user_credential = user_secrets.get_gcloud_credential()
user_secrets.set_tensorflow_credential(user_credential)

In [6]:
from kaggle_datasets import KaggleDatasets

In [7]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
GCS_PATH = KaggleDatasets().get_gcs_path()
BATCH_SIZE = 16 * strategy.num_replicas_in_sync
IMG_SIZE = [512, 512]
CLASSES = ['0', '1', '2', '3', '4']
EPOCHS = 25

In [8]:
def decode_image(img):
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.cast(img, tf.float32) / 255.0
    img = tf.reshape(img, [*IMG_SIZE, 3])
    return img

In [9]:
tr_files, val_files= train_test_split(tf.io.gfile.glob(GCS_PATH +'/train_tfrecords/ld_train*.tfrec'), test_size=.2, random_state=41)

In [10]:
def read_tfdata(example,labeled):
    if labeled:
        tf_format={'image':tf.io.FixedLenFeature([], tf.string),
                  'target':tf.io.FixedLenFeature([], tf.int64)}
    else:
        tf_format={'image':tf.io.FixedLenFeature([], tf.string)}
    example=tf.io.parse_single_example(example, tf_format)  
    img=decode_image(example['image'])
    if labeled:
        lbl=tf.cast(example['target'],tf.int32)
        return img, lbl
    
    return img
    

In [11]:
def load_dataset(filenames, labeled=True, ordered=False):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(partial(read_tfdata, labeled=labeled), num_parallel_calls=AUTOTUNE)
    return dataset

In [12]:
def data_augment(image, label):
    image = tf.image.random_flip_left_right(image)
    return image, label

In [13]:
def get_training_dataset():
    dataset = load_dataset(tr_files, labeled=True)  
    dataset = dataset.map(data_augment, num_parallel_calls=AUTOTUNE)  
    dataset = dataset.repeat()
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset

In [14]:
tr_ds=get_training_dataset()

In [15]:
def get_validation_dataset(ordered=False):
    dataset = load_dataset(val_files, labeled=True, ordered=ordered) 
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset

In [16]:
val_ds=get_validation_dataset()

In [17]:
for img, lbl in tr_ds.take(3):
    print(img.numpy().shape, lbl.numpy().shape)

(128, 512, 512, 3) (128,)
(128, 512, 512, 3) (128,)
(128, 512, 512, 3) (128,)


In [18]:
with strategy.scope():       
    img_adjust_layer = tf.keras.layers.Lambda(tf.keras.applications.resnet50.preprocess_input, input_shape=[*IMG_SIZE, 3])
    
    base_model = tf.keras.applications.ResNet50(weights='imagenet', include_top=False)
    base_model.trainable = False
    
    model = tf.keras.Sequential([
        tf.keras.layers.BatchNormalization(renorm=True),
        img_adjust_layer,
        base_model,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(8, activation='relu'),
        #tf.keras.layers.BatchNormalization(renorm=True),
        tf.keras.layers.Dense(len(CLASSES), activation='softmax')  
    ])
    
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=.01, epsilon=0.001),
        loss='sparse_categorical_crossentropy',  
        metrics=['sparse_categorical_accuracy'])
    

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5


In [19]:
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

In [20]:
tr_count=count_data_items(tr_files)
val_count=count_data_items(val_files)
print(tr_count, val_count)

16045 5352


In [21]:
step_tr=tr_count//BATCH_SIZE
step_val=val_count//BATCH_SIZE
history=model.fit(tr_ds,steps_per_epoch=step_tr, epochs=EPOCHS, validation_data=val_ds, validation_steps=step_val)

Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25


In [22]:
model.save('model.h5')
