<a href="https://colab.research.google.com/github/MorningStarTM/Transformers-in-Vision/blob/main/ViT_for_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [41]:
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
import random
from glob import glob
import tensorflow as tf
from sklearn.utils import shuffle
from patchify import patchify
from tensorflow.keras.layers import *
from tensorflow.keras.model import Model
from tensorflow.keras.callbacks import Callback
from sklearn.model_selection import train_test_split
from sklearn import metrics

ModuleNotFoundError: ignored

In [1]:
#Hyperparameter
hp = {}
hp['image_size'] = 200
hp['num_channel'] = 3
hp['patch_size'] = 25
hp['num_patches'] = (hp['image_size']**2) // (hp['patch_size']**2)
hp['flat_patches_shape'] = (hp['num_patches'], hp['patch_size']*hp['patch_size']*hp['num_channel'])

hp['batch_size'] = 32
hp['lr'] = 1e-4
hp['num_epochs'] = 500
hp['num_classes'] = 9
hp['class_names'] = ["Ant", "Butterfly", "Cockroach", "Frog", "Grasshopper", "Honey bee", "Spider", "dragonfly", "lizard"]

In [None]:
!unzip "/content/drive/MyDrive/archive podiwije.zip" -d "/content/drive/MyDrive/DataSet/Insects/"

In [3]:
dataset_path = "/content/drive/MyDrive/DataSet/Insects/Reptiles-Insects"

In [4]:
def create_dir(path):
  if not os.path.exists(path):
    os.makedirs(path)

In [33]:
#function for load the data file
def load_data(path, split=0.1):
  images = shuffle(glob(os.path.join(path, "*", "*.jpg")))
  
  split_size = int(len(images) * split)
  #split the data 
  train_x, valid_x = train_test_split(images, test_size=split_size, random_state=42)
  train_x, test_x = train_test_split(train_x, test_size=split_size, random_state=42)

  return train_x, valid_x, test_x

In [34]:
def process_image_label(path):
  path = path.decode()
  #read image
  image = cv2.imread(path, cv2.IMREAD_COLOR)
  #resize the image
  image = cv2.resize(image, (hp['image_size'], hp['image_size']))
  #scale the image
  image = image/255.0
  print(image.shape)

  #image into patch
  patch_shape = (hp['patch_size'], hp['patch_size'], hp['num_channel'])
  patches = patchify(image, patch_shape, hp['patch_size'])

  patches = np.reshape(patches, hp['flat_patches_shape'])
  patches = patches.astype(np.float32)
  print(path)

  #labeling
  class_name = path.split('/')[-2]
  class_idx = hp['class_names'].index(class_name)
  class_idx = np.array(class_idx, dtype=np.int32)
  print(class_idx)

  return patches, class_idx

In [35]:
#we used opencv to read images not tensorflow. so we need to use tf.numpy_function to use these function in tensorflow
def parse(path):
  patches, labels = tf.numpy_function(process_image_label, [path], [tf.float32, tf.int32])
  labels = tf.one_hot(labels, hp['num_classes'])

  patches.set_shape(hp['flat_patches_shape'])
  labels.set_shape(hp['num_classes'])

  return patches, labels

In [36]:
def tf_dataset(images, batch=32):
  dataset = tf.data.Dataset.from_tensor_slices((images))
  dataset = dataset.map(parse).batch(batch).prefetch(8)
  return dataset

In [37]:
train_x, valid_x, test_x = load_data(dataset_path)

In [38]:
print(f"Train: {len(train_x)} Valid: {len(valid_x)} Test: {len(test_x)}")

Train: 680 Valid: 85 Test: 85


In [39]:
train_dataset = tf_dataset(train_x, batch=hp['batch_size'])
valid_dataset = tf_dataset(valid_x, batch=hp['batch_size'])

#Model

In [42]:
#Configuratin parameters
config = {}
config['num_layers'] = 12
config['hidden_dim'] = 768
config['mlp_dim'] = 3072
config['num_heads'] = 12
config['dropout_rate'] = 0.1
config['num_patches'] = 256
config['patch_size'] = 32
config['num_channels'] = 3

In [47]:
def ViT(config):
  #input layer
  input_shape = (config['num_patches'], config['patch_size']*config['patch_size']*config['num_channels'])
  inputs = Input(input_shape)
  
  #patch and position embedding
  patch_embedding = Dense(config['hidden_dim'])(inputs)
  
  positions = tf.range(start=0, limit=config['num_patches'], delta=1)
  position_embedding = Embedding(input_dim=config['num_patches'], output_dim=config['hidden_dim'])(positions)
  print(position_embedding.shape)

In [48]:
ViT(config)

(256, 768)
