# Dogs-vs-cats classification with ViT

In this notebook, we'll finetune a [Vision Transformer](https://arxiv.org/abs/2010.11929) (ViT) to classify images of dogs from images of cats using TensorFlow 2 / Keras and HuggingFace's [Transformers](https://github.com/huggingface/transformers). 

**Note that using a GPU with this notebook is highly recommended.**

First, the needed imports.

In [None]:
%matplotlib inline

from transformers import ViTFeatureExtractor, TFViTForImageClassification
from transformers.utils import check_min_version
from transformers import __version__ as transformers_version

import tensorflow as tf
from tensorflow.keras.utils import plot_model
from PIL import Image
import os, sys

from natsort import natsorted

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

check_min_version("4.13.0.dev0")

print('Using TensorFlow version:', tf.__version__,
      'Keras version:', tf.keras.__version__,
      'Transformers version:', transformers_version)

## Data

The training dataset consists of 2000 images of dogs and cats, split in half.  In addition, the validation set consists of 1000 images, and the test set of 22000 images.  Here are some random training images:

![title](imgs/dvc.png)

In [None]:
DATADIR = "/media/data/dogs-vs-cats/train-2000/train"

def pil_loadimg(path: str):
    with open(path, "rb") as f:
        im = Image.open(f)
        return im.convert("RGB")
    
def pil_loader(imglist: list):
    res = []
    for i in imglist:
        res.append(pil_loadimg(DATADIR+"/"+i))
    return res

In [None]:
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
model = TFViTForImageClassification.from_pretrained('google/vit-base-patch16-224', 
                                                    num_labels=1, ignore_mismatched_sizes=True)

In [None]:
image = pil_loadimg("/media/data/dogs-vs-cats/train-2000/train/cats/cat.1.jpg")
inputs = feature_extractor(images=image, return_tensors="tf")
outputs = model(**inputs)
logits = outputs.logits                                                                                     
# model predicts one of the 1000 ImageNet classes                                                           
predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]                                                    
print("Predicted class:", model.config.id2label[int(predicted_class_idx)])

In [None]:
images = pil_loader(["cats/cat.1.jpg", "dogs/dog.1.jpg"])

In [None]:
inputs = feature_extractor(images=images, return_tensors="tf")
outputs = model(**inputs)
logits = outputs.logits                                                                                     
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = tf.math.argmax(logits, axis=-1)
for i in range(len(logits)):
    print("Predicted class:", model.config.id2label[int(predicted_class_idx[i])])

In [None]:
cats, dogs = [],[]
for fn in os.listdir(DATADIR+"/cats"):
    cats.append("cats/"+fn)
for fn in os.listdir(DATADIR+"/dogs"):
    dogs.append("dogs/"+fn)
cats_sorted = natsorted(cats)
dogs_sorted = natsorted(dogs)

In [None]:
images = pil_loader(cats_sorted[:500]+dogs_sorted[:500])
labels = [0] * 500 + [1] * 500
inputs = feature_extractor(images=images, return_tensors="tf")

In [None]:
BATCH_SIZE = 32

dataset_train = tf.data.Dataset.from_tensor_slices((inputs.data, labels))
dataset_train = dataset_train.shuffle(len(dataset_train)).batch(BATCH_SIZE)

In [None]:
LR = 1e-5

optimizer = tf.keras.optimizers.Adam(learning_rate=LR) #, epsilon=1e-08, clipnorm=1.0)
loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)
metric = 'accuracy'

model.compile(optimizer=optimizer, loss=loss, metrics=[metric])

print(model.summary())

In [None]:
%%time

EPOCHS = 5

history = model.fit(dataset_train,
                    epochs=EPOCHS, verbose=2) #callbacks=callbacks)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10,3))

ax1.plot(history.epoch,history.history['loss'], label='training')
#ax1.plot(history.epoch,history.history['val_loss'], label='validation')
ax1.set_title('loss')
ax1.set_xlabel('epoch')
ax1.legend(loc='best')

ax2.plot(history.epoch,history.history['accuracy'], label='training')
#ax2.plot(history.epoch,history.history['val_accuracy'], label='validation')
ax2.set_title('accuracy')
ax2.set_xlabel('epoch')
ax2.legend(loc='best');

In [None]:
inputs.data.keys()

In [None]:
testimages = pil_loader(["cats/cat.900.jpg", "dogs/dog.900.jpg", "dogs/dog.901.jpg"])
testinputs = feature_extractor(images=testimages, return_tensors="tf")
dataset_test = tf.data.Dataset.from_tensor_slices((testinputs.data, [0, 1, 1])).batch(BATCH_SIZE)
outputs = model(**testinputs)
logits = outputs.logits                                                                                     
predicted_class_idx = tf.math.argmax(logits, axis=-1)
for i in range(len(logits)):
    print("Predicted class:", predicted_class_idx[i])

In [None]:
test_scores = model.evaluate(dataset_test, verbose=2)

In [None]:
p = model.predict(dataset_test, verbose=2)

In [None]:
p