[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/18lcAtxvFn51-newA-r3ZW1wcimq3PsOT?usp=sharing)

# Advanced AI: Transformers for Computer Vision

## Setup

In [None]:
!pip install transformers datasets evaluate gradio

In [None]:
!wget https://github.com/jonfernandes/flowers-dataset/raw/main/flower_photos.tgz
!tar -xvf flower_photos.tgz


In [None]:
!ls flower_photos

In [None]:
from datasets import load_dataset

ds = load_dataset('imagefolder', data_files='https://github.com/jonfernandes/flowers-dataset/raw/main/flower_photos.tgz')
ds

In [None]:
for i in range(5):
  display(ds['train'][i]['image'])

In [None]:
labels = ds['train'].features['label'].names
labels

In [None]:
ds_train_validation = ds['train'].train_test_split(test_size=0.1, seed=1, shuffle=True)
ds_train_validation

In [None]:
ds_train_validation['validation'] = ds_train_validation.pop('test')
ds_train_validation

In [None]:
ds.update(ds_train_validation)
ds

In [None]:
ds_train_test = ds['train'].train_test_split(test_size=0.1, seed=1, shuffle=True)
ds_train_test

In [None]:
ds.update(ds_train_test)
ds

## Using a pre-trained model without fine-tuning

In [None]:
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import torch
model_id = 'google/vit-base-patch16-224'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AutoModelForImageClassification.from_pretrained(model_id).to(device)
model.eval()


In [None]:
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
feature_extractor

In [None]:
train_image_id = 3
one_image = ds['train'][train_image_id]['image']
one_image

In [None]:
inp = feature_extractor(images=one_image, return_tensors='pt')
inp

In [None]:
#This will not work if you are using a GPU
outp = model(**inp)
outp

In [None]:
inp = feature_extractor(images=one_image, return_tensors='pt').to(device)
outp = model(**inp)
outp

In [None]:
outp.logits.shape

In [None]:
torch.argmax(outp.logits, dim=1)

In [None]:
pred = torch.argmax(outp.logits, dim=1).item()
pred

In [None]:
model.config

In [None]:
model.config.id2label[pred]

In [None]:
'daisy' in model.config.label2id

## Defining a model

In [None]:
id2label = {key: value for key, value in enumerate(labels)}
id2label

In [None]:
label2id = {value:key for key, value in enumerate(labels)}
label2id

In [None]:
model = AutoModelForImageClassification.from_pretrained(model_id,
                                                        num_labels=len(labels),
                                                        id2label=id2label,
                                                        label2id=label2id,
                                                        ignore_mismatched_sizes=True
                                                        )

## Pre-processing images

In [None]:
import torchvision

from torchvision.transforms import (
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    ToTensor,
    Resize,
    CenterCrop
)

In [None]:
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)

In [None]:
train_transform = Compose(
    [
     RandomResizedCrop(feature_extractor.size),
     RandomHorizontalFlip(),
     ToTensor(),
     normalize
    ]
)

validation_transform = Compose(
        [
            Resize(feature_extractor.size),
            CenterCrop(feature_extractor.size),
            ToTensor(),
            normalize,
        ]
    )

def train_transform_images(images):
  images["pixel_values"] = [train_transform(image.convert("RGB")) for image in images["image"]]
  return images

def validation_transform_images(images):
  images["pixel_values"] = [validation_transform(image.convert("RGB")) for image in images["image"]]
  return images

In [None]:
transformed_ds = ds.with_transform(train_transform_images)
transformed_ds['train'] = ds['train'].with_transform(train_transform_images)
transformed_ds['validation'] = ds['validation'].with_transform(validation_transform_images)
transformed_ds['test'] = ds['test'].with_transform(validation_transform_images)

## A transformed image

In [None]:
sample_image = ds['train'][train_image_id]['image']
sample_image

In [None]:
# Re-run cell multiple times
import matplotlib.pyplot as plt
transformed_sample_image = train_transform(sample_image)
plt.imshow(transformed_sample_image.permute(1, 2, 0))

### Getting images in the correct format

**4-images**

In [None]:
four_images = [transformed_ds['train'][i] for i in range(4)]
four_images

In [None]:
print(four_images[0]['pixel_values'].shape, four_images[1]['pixel_values'].shape, four_images[2]['pixel_values'].shape, four_images[3]['pixel_values'].shape)

In [None]:
four_images_labels = [image['label'] for image in four_images]
four_images_labels

In [None]:
import torch
four_images_labels = torch.tensor([image['label'] for image in four_images])
four_images_labels

In [None]:
#should get an error
four_images_pixel_values = torch.tensor([image['pixel_values'] for image in four_images])
four_images_pixel_values

In [None]:
four_images_pixel_values = torch.cat([image['pixel_values'] for image in four_images])
four_images_pixel_values

In [None]:
four_images_pixel_values.shape

In [None]:
four_images_pixel_values = torch.stack([image['pixel_values'] for image in four_images])
four_images_pixel_values.shape

In [None]:
from torch.utils.data import DataLoader

def collate_fn(images):
  labels = torch.tensor([image['label'] for image in images])
  pixel_values = torch.stack([image['pixel_values'] for image in images])
  return {'pixel_values': pixel_values, 'labels': labels}

train_dataloader = DataLoader(transformed_ds['train'], batch_size=4, collate_fn=collate_fn, shuffle=True)
validation_dataloader = DataLoader(transformed_ds['validation'], batch_size=4, collate_fn=collate_fn, shuffle=False)
test_dataloader = DataLoader(transformed_ds['test'], batch_size=4, collate_fn=collate_fn, shuffle=False)

In [None]:
batch = next(iter(train_dataloader))

for key, value in batch.items():
  print(key, value.shape)

## Training arguments

In [None]:
from transformers import TrainingArguments, Trainer

batch_size=32
metric_name = "accuracy"
model_name = 'vit-base-patch16-224-finetuned-flower'

args = TrainingArguments(
    model_name,
    evaluation_strategy="steps",
    learning_rate=5e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    remove_unused_columns=False,
    logging_dir='./logs', 
    push_to_hub=True
)

In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
!git config --global credential.helper store

## Model Training

From the [evaluate documentation](https://huggingface.co/docs/evaluate/a_quick_tour#compute):

```
metric.compute(
          references=..., 
          predictions=...)
```

In [None]:
import evaluate
import numpy as np

metric = evaluate.load('accuracy')

def compute_metrics(batch):
  return metric.compute(
      references=batch.label_ids,
      predictions=np.argmax(batch.predictions, axis=1))

In [None]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=transformed_ds['train'],
    eval_dataset=transformed_ds['validation'],
    tokenizer=feature_extractor,
    data_collator=collate_fn,
    compute_metrics=compute_metrics
)

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs/

In [None]:
trainer.train()
trainer.save_model()

In [None]:
trainer.evaluate(transformed_ds['train'])

In [None]:
trainer.evaluate(transformed_ds['validation'])

In [None]:
trainer.evaluate(transformed_ds['test'])

## Inference in notebook

In [None]:
test_image = ds['test'][-2]['image']
test_image

In [None]:
import torch
from transformers import AutoModelForImageClassification, AutoFeatureExtractor

model_id = f'jonathanfernandes/vit-base-patch16-224-finetuned-flower'

def classify_image(image):
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  model = AutoModelForImageClassification.from_pretrained(model_id).to(device)
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)  
  inp = feature_extractor(image, return_tensors='pt').to(device)
  outp = model(**inp)
  pred = torch.argmax(outp.logits, dim=-1).item()
  return model.config.id2label[pred]

classify_image(test_image)

In [None]:
import torch

model_id = f'jonathanfernandes/vit-base-patch16-224-finetuned-flower'

def classify_image(image):
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  model = AutoModelForImageClassification.from_pretrained(model_id).to(device)
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
  inp = feature_extractor(image, return_tensors='pt').to(device)
  outp = model(**inp)
  pred = torch.nn.functional.softmax(outp.logits, dim=-1)
  preds = pred[0].cpu().detach().numpy()
  confidence = {label: float(preds[i]) for i, label in enumerate(labels)}
  return confidence

classify_image(test_image)

In [None]:
from transformers import pipeline

model_id = f'jonathanfernandes/vit-base-patch16-224-finetuned-flower'
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)

image_classifier = pipeline('image-classification', model=model_id, feature_extractor=feature_extractor)
image_classifier(test_image)

In [None]:
help(pipeline)

## Inference on your phone using Gradio

In [None]:
!wget https://github.com/jonfernandes/Advanced_AI_Transformers_for_Computer_Vision/raw/main/flower-1.jpg
!wget https://github.com/jonfernandes/Advanced_AI_Transformers_for_Computer_Vision/raw/main/flower-2.jpeg

--2022-11-15 07:24:27--  https://github.com/jonfernandes/Advanced_AI_Transformers_for_Computer_Vision/raw/main/flower-1.jpg
Resolving github.com (github.com)... 20.205.243.166
Connecting to github.com (github.com)|20.205.243.166|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/jonfernandes/Advanced_AI_Transformers_for_Computer_Vision/main/flower-1.jpg [following]
--2022-11-15 07:24:27--  https://raw.githubusercontent.com/jonfernandes/Advanced_AI_Transformers_for_Computer_Vision/main/flower-1.jpg
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 712963 (696K) [image/jpeg]
Saving to: ‘flower-1.jpg’


2022-11-15 07:24:27 (100 MB/s) - ‘flower-1.jpg’ saved [712963/712963]

--2022-11-15 07:24:27--  http

In [None]:
!ls -l

total 448164
drwxr-xr-x 2 root   root      4096 Nov 15 07:17 flagged
-rw-r--r-- 1 root   root    712963 Nov 15 07:24 flower-1.jpg
-rw-r--r-- 1 root   root    560399 Nov 15 07:24 flower-2.jpeg
drwxr-x--- 7 270850 5000      4096 Feb 10  2016 flower_photos
-rw-r--r-- 1 root   root 228813984 Nov 15 05:46 flower_photos.tgz
-rw-r--r-- 1 root   root 228813984 Nov 15 06:19 flower_photos.tgz.1
drwxr-xr-x 1 root   root      4096 Nov 11 14:32 sample_data


In [None]:
import torch
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import gradio as gr

model_id = f'jonathanfernandes/vit-base-patch16-224-finetuned-flower'

def classify_image(image):
  model = AutoModelForImageClassification.from_pretrained(model_id)
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
  inp = feature_extractor(image, return_tensors='pt')
  outp = model(**inp)
  pred = torch.nn.functional.softmax(outp.logits, dim=-1)
  preds = pred[0].cpu().detach().numpy()
  confidence = {label: float(preds[i]) for i, label in enumerate(labels)}
  return confidence

interface = gr.Interface(fn=classify_image, 
                         inputs='image', 
                         examples=['flower-1.jpg', 'flower-2.jpeg'],
                         outputs='label').launch(debug=True, share=True)

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Running on public URL: https://c269db8febe09b89.gradio.app

This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces


Keyboard interruption in main thread... closing server.
