In [1]:
import tensorflow as tf
import numpy as np
import tensorflow_datasets as tfds
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.applications.vgg16 import preprocess_input
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping
import torchvision.transforms as transforms

## Loading images and labels

In [2]:
(train_ds, train_labels), (test_ds, test_labels) = tfds.load(
    "caltech101",
    split=["train[:70%]", "train[:30%]"], ## Train test split
    batch_size=-1,
    as_supervised=True,  # Include labels
)

[1mDownloading and preparing dataset 125.64 MiB (download: 125.64 MiB, generated: 132.86 MiB, total: 258.50 MiB) to ~/tensorflow_datasets/caltech101/3.0.1...[0m


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/2 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/3060 [00:00<?, ? examples/s]

Shuffling ~/tensorflow_datasets/caltech101/3.0.1.incompleteNCXU64/caltech101-train.tfrecord*...:   0%|        …

Generating test examples...:   0%|          | 0/6084 [00:00<?, ? examples/s]

Shuffling ~/tensorflow_datasets/caltech101/3.0.1.incompleteNCXU64/caltech101-test.tfrecord*...:   0%|         …

[1mDataset caltech101 downloaded and prepared to ~/tensorflow_datasets/caltech101/3.0.1. Subsequent calls will reuse this data.[0m


In [3]:
train_ds.shape

TensorShape([2142, 919, 969, 3])

## Resizing images

In [4]:
size = (224,224)

train_ds = tf.image.resize(train_ds, (224, 224))
test_ds = tf.image.resize(test_ds, (224, 224))

In [5]:
train_ds.shape

TensorShape([2142, 224, 224, 3])

## Transforming labels to correct format

In [6]:
train_labels = to_categorical(train_labels, num_classes=102)
test_labels = to_categorical(test_labels, num_classes=102)

In [7]:
train_ds.shape

TensorShape([2142, 224, 224, 3])

## Preprocessing input

In [8]:
train_ds = preprocess_input(train_ds) 
test_ds = preprocess_input(test_ds)

## Loading VGG16 model

In [9]:
base_model = VGG16(weights="imagenet", include_top=False, input_shape=train_ds[0].shape)
base_model.trainable = False ## Not trainable weights

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5


In [10]:
base_model.summary()

Model: "vgg16"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 block1_conv1 (Conv2D)       (None, 224, 224, 64)      1792      
                                                                 
 block1_conv2 (Conv2D)       (None, 224, 224, 64)      36928     
                                                                 
 block1_pool (MaxPooling2D)  (None, 112, 112, 64)      0         
                                                                 
 block2_conv1 (Conv2D)       (None, 112, 112, 128)     73856     
                                                                 
 block2_conv2 (Conv2D)       (None, 112, 112, 128)     147584    
                                                                 
 block2_pool (MaxPooling2D)  (None, 56, 56, 128)       0     

## Adding layers

In [11]:
flatten_layer = layers.Flatten()
dense_layer_1 = layers.Dense(500, activation='relu')
dense_layer_2 = layers.Dense(300, activation='relu')
prediction_layer = layers.Dense(102, activation='softmax')


model = models.Sequential([
    base_model,
    flatten_layer,
    dense_layer_1,
    dense_layer_2,
    prediction_layer
])

## Training model

In [12]:
model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy'],
)

es = EarlyStopping(monitor='val_accuracy', mode='max', patience=2,  restore_best_weights=True)

In [13]:
model.fit(train_ds, train_labels, epochs=1, validation_split=0.2, batch_size=32, callbacks=[es])




<keras.callbacks.History at 0x7fe6551d06d0>