In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import cv2
from glob import glob
from patchify import patchify
from sklearn.utils import shuffle
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import *
from sklearn.model_selection import train_test_split
from sklearn import metrics

In [2]:
def get_class_name(data_path):
    names = os.listdir(data_path)
    return names


In [3]:
path = "E:\\ML_test\\Data\\bird-groups"

In [4]:
names = get_class_name(path)

In [5]:
#Hyperparameter
hp = {}
hp['image_size'] = 200
hp['num_channel'] = 3
hp['patch_size'] = 25
hp['num_patches'] = (hp['image_size']**2) // (hp['patch_size']**2)
hp['flat_patches_shape'] = (hp['num_patches'], hp['patch_size']*hp['patch_size']*hp['num_channel'])

hp['batch_size'] = 16
hp['lr'] = 1e-4
hp['num_epochs'] = 100 
hp['num_classes'] = 9
hp['class_names'] = names

hp["num_layers"] = 12
hp["hidden_dim"] = 500 
hp["mlp_dim"] = 3072
hp["num_heads"] = 12
hp["dropout_rate"] = 0.1

# Data Pipeline

In [6]:
#function for load dataset
def load_data(path, split=0.1):
    images = shuffle(glob(os.path.join(path, "*", "*.jpg")))
    
    split_size = int(len(images) * split)
    #split the data
    train_data, valid_data = train_test_split(images, test_size=split_size, random_state=42)
    train_data, test_data = train_test_split(train_data, test_size=split_size, random_state=42)

    return train_data, valid_data, test_data 

In [7]:
train_data, valid_data, test_data  = load_data(path)
print(f"Train: {len(train_data)} - Valid: {len(valid_data)} - Test: {len(test_data)}")

Train: 2238 - Valid: 279 - Test: 279


In [8]:
def process_image(path):
    #decode the path
    path = path.decode()
    #read image
    image = cv2.imread(path, cv2.IMREAD_COLOR)
    #resize the image
    image = cv2.resize(image, (hp['image_size'], hp['image_size']))
    #scale the image
    image = image / 255.0

    #convert image into patch
    patch_shape = (hp['patch_size'], hp['patch_size'], hp['num_channel'])
    patches = patchify(image, patch_shape, hp['patch_size'])

    #labeling the image
    class_name = path.split("\\")[-2]
    class_idx = hp['class_names'].index(class_name)
    class_idx = np.array(class_idx, dtype=np.int32)

    return patches, class_idx

In [9]:
def parse(path):
    patches, labels = tf.numpy_function(process_image, [path], (tf.float32, tf.int32))
    labels = tf.one_hot(labels, hp['num_classes'])

    patches.set_shape(hp['flat_patches_shape'])
    labels.set_shape(hp['num_classes'])

    return patches, labels

In [10]:
#tensorflow dataset
def tf_dataset(images, batch=8):
    dataset = tf.data.Dataset.from_tensor_slices((images))
    dataset = dataset.map(parse)
    dataset = dataset.batch(batch)
    dataset = dataset.prefetch(8)
    return dataset

In [11]:
train_dataset = tf_dataset(train_data, batch=hp['batch_size'])
valid_dataset = tf_dataset(valid_data, batch=hp['batch_size'])

# Model Creation

### Transformer Encoder

In [12]:
#Configuratin parameters
config = {}
config['num_layers'] = 12
config['hidden_dim'] = 768
config['mlp_dim'] = 3072
config['num_heads'] = 12
config['dropout_rate'] = 0.1
config['num_patches'] = 256
config['patch_size'] = 32
config['num_channels'] = 3
config["num_classes"] = 9

In [13]:
def mlp(input, config):
    inputs = Dense(config['mlp_dim'], activation='gelu')(inputs)
    inputs = Dropout(config['dropout_rate'])(inputs)
    inputs = Dense(config['hidden_dim'])(inputs)
    inputs = Dropout(config['dropout_rate'])(inputs)
    return inputs

In [14]:
def transformer_encoder(inputs, config):
    skip_connection_1 = inputs
    inputs = LayerNormalization()(inputs)
    inputs = MultiHeadAttention(
        num_head=config['num_heads'],
        key_dim=config['hidden_dim']
    )(inputs, inputs)
    inputs = Add()([inputs, skip_connection_1])

    skip_connection_2 = inputs
    inputs = LayerNormalization()(inputs)
    inputs = mlp(inputs, config)
    inputs = Add()([inputs, skip_connection_2])

    return inputs

### Embedding