In [1]:
import cv2
import tensorflow as tf
import pandas as pd
import math
import numpy as np
import os

from tqdm import tqdm

from random import shuffle

from tensorflow.keras import layers, Model
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.applications import VGG16
from tensorflow.keras.optimizers import Adam

from tensorflow.keras.utils import Sequence, to_categorical

In [2]:
IMAGES_DIR = os.path.join("..", "datasets", "CASIA-PalmprintV1")
num_classes = 300

def train_test_split():
    labels = sorted(os.listdir(IMAGES_DIR))[:num_classes]
    train_files = []
    eval_files = []
    for label in labels:
        folder = os.path.join(IMAGES_DIR, label)
        files = [os.path.join(folder, file) for file in os.listdir(folder) if file.endswith(".jpg")]
        shuffle(files)
        split_idx = int(len(files) * 0.7)
        train_files.extend(files[:split_idx])
        eval_files.extend(files[split_idx:])
    return train_files, eval_files, labels

train_files, eval_files, labels = train_test_split()
print("Train files: {}".format(len(train_files)))
print("Eval files: {}".format(len(eval_files)))
print("Labels: {}".format(len(labels)))

Train files: 3584
Eval files: 1615
Labels: 300


In [3]:
image_by_filename = {}
files = train_files + eval_files
for file in tqdm(files):
    image = cv2.imread(file)
    image = cv2.resize(image, (224, 224)).astype(np.float32)
    image /= 255.
    image_by_filename[file] = image

100%|█████████████████████████████████████████████████████████████████████████████| 5199/5199 [00:14<00:00, 354.31it/s]


In [4]:
class PersonIDSequence(Sequence):

    def __init__(self, files, labels, batch_size, extract_palm=False):
        self.files = files
        self.labels = labels
        shuffle(self.files)
        self.num_labels = len(self.labels)
        self.batch_size = batch_size
        self.extract_palm = extract_palm

    def __len__(self):
        return math.ceil(len(self.files) / self.batch_size)

    def load_palm_print(self, image_file):
        if INPUT_SHAPE[2] == 1:
            image = cv2.imread(image_file, 0)
        else:
            image = cv2.imread(image_file)
        if self.extract_palm:
            image = extract_palm_from_img(image)
        try:
            image = cv2.resize(image, INPUT_SHAPE[:2])
        except:
            print("image_file:", image_file)
        image = image * 1./255
        return image

    def __getitem__(self, idx):
        X = self.files[idx * self.batch_size:(idx + 1) * self.batch_size]
        y = [os.path.basename(os.path.dirname(file)) for file in X ]
        palm_prints = np.array([image_by_filename[image] for image in X])
        y_indices = [to_categorical(self.labels.index(i), num_classes=self.num_labels)
                     for i in y]
        return palm_prints, np.array(y_indices)

    def on_epoch_end(self):
        shuffle(self.files)


In [5]:
INPUT_SHAPE = (224, 224, 3)

def palm_model():
    vgg16 = VGG16(include_top=False, weights='imagenet', input_shape=INPUT_SHAPE)
    for layer in vgg16.layers:
        layer.trainable = False
    x = vgg16.output
    x = layers.Flatten()(x)                                # Flatten dimensions to for use in FC layers
    x = layers.Dense(4096, activation='relu')(x)
    x = layers.Dense(4096, activation='relu')(x)
    x = layers.Dropout(0.2)(x)                             # Dropout layer to reduce overfitting
    x = layers.Dense(len(labels), name="last_dense")(x) 
    x = layers.Softmax()(x)                                # Softmax for multiclass
    return Model(inputs=vgg16.input, outputs=x)

model = palm_model()
# model.summary()

In [6]:
train_ds = PersonIDSequence(train_files, labels, batch_size=64)
eval_ds = PersonIDSequence(eval_files, labels, batch_size=64)

In [7]:
epochs = 10
lr = 0.0001

model.compile(loss='categorical_crossentropy', metrics=['accuracy'], optimizer=Adam(lr))
history = model.fit(train_ds, epochs=epochs, validation_data=eval_ds)

Train for 56 steps, validate for 26 steps
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [8]:
model.evaluate(eval_ds)



[0.37008239386173397, 0.9331269]

In [9]:
model.save("palm_model_e{}_lr{}.h5".format(epochs, lr))