# Vision Transformer (ViT)

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [12]:
from utils.transformer import TransformerEncoderLayer, TransformerEncoder, ClassEmbedding, Patches
from utils.visualize import plotImages, plotHistory
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import cv2
from tqdm.notebook import tqdm

In [3]:
# set some paths
mode_dir = Path('bin')

# set some variables
input_size = (224, 224, 3)
patch_size = 16

In [4]:
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

# 1.0 Import the Dataset

## 1.1 Download a dataset (Cats_vs_Dogs)

In [5]:
ds_train, ds_info = tfds.load(
    'cats_vs_dogs',
    shuffle_files=True,
    as_supervised=True,
    with_info=True)

In [6]:
ds_info

tfds.core.DatasetInfo(
    name='cats_vs_dogs',
    version=4.0.0,
    description='A large set of images of cats and dogs.There are 1738 corrupted images that are dropped.',
    homepage='https://www.microsoft.com/en-us/download/details.aspx?id=54765',
    features=FeaturesDict({
        'image': Image(shape=(None, None, 3), dtype=tf.uint8),
        'image/filename': Text(shape=(), dtype=tf.string),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=2),
    }),
    total_num_examples=23262,
    splits={
        'train': 23262,
    },
    supervised_keys=('image', 'label'),
    citation="""@Inproceedings (Conference){asirra-a-captcha-that-exploits-interest-aligned-manual-image-categorization,
    author = {Elson, Jeremy and Douceur, John (JD) and Howell, Jon and Saul, Jared},
    title = {Asirra: A CAPTCHA that Exploits Interest-Aligned Manual Image Categorization},
    booktitle = {Proceedings of 14th ACM Conference on Computer and Communications Security (CCS)},
    ye

In [7]:
ds_info.features['label'].names

['cat', 'dog']

In [8]:
n_images = ds_info.splits['train'].num_examples
print(n_images)

23262


# 2.0 Pre-Process the Dataset

In [9]:
def pre_process(ds, n_images):
    """Create a numpy array resizing all images"""
    X = np.empty((n_images, input_size[0], input_size[1], input_size[2]))
    y = np.empty((n_images))
    for i, data in tqdm(enumerate(ds['train'])):
        img = cv2.resize(data[0].numpy(), (input_size[1],input_size[0]))
        X[i] = img
        y[i] = data[1]
    return X, y

In [10]:
X, y = pre_process(ds_train, n_images)

HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


