![Practicum AI Logo image](https://github.com/PracticumAI/practicumai.github.io/blob/main/images/logo/PracticumAI_logo_250x50.png?raw=true) <img src='https://github.com/PracticumAI/deep_learning/blob/main/images/practicumai_deep_learning.png?raw=true' alt='Practicum AI: Deep Learning Foundations icon' align='right' width=50>
***

# ViT Bonus Notebook

This notebook is a bonus notebook that demonstrates how to use the Vision Transformer (ViT) model for image classification. The dataset used is the [Fruits Detection dataset](https://www.kaggle.com/datasets/lakshaytyagi01/fruit-detection/data) from Kaggle. The dataset contains images of fruits and their bounding boxes. The goal is to classify the fruits in the images.

As before, the dataset was found on. [Check out the dataset information](https://www.kaggle.com/datasets/lakshaytyagi01/fruit-detection/data)

<img src="notebook_images/fruits_detection_dataset-cover.jpg" 
        alt="Image of fruits and bounding boxes from the dataset cover image" 
        width="1000" 
        height="600" 
        style="display: block; margin: 0 auto" />

This notebook will use the `ViT` model from the `transformers` library. The `transformers` library is a library that provides state-of-the-art models for Natural Language Processing (NLP) and Computer Vision (CV). The `ViT` model is a model that uses the Vision Transformer architecture for image classification tasks. This framework is developed by Hugging Face.



## 1.  Import the libraries we will use

In [None]:
# Install torchsummary
%pip install torchsummary

In [None]:
# Importing supporting libraries
import os
import sys
import json
import random
import pathlib
import requests
import zipfile
import time

import cv2


# Importing the necessary libraries for the dataset
import numpy as np # NumPy
from PIL import Image # Python Imaging Library
import pandas as pd # Pandas
from sklearn.model_selection import train_test_split # Train-Test Split
from sklearn.preprocessing import StandardScaler # Standard Scaler
from sklearn.datasets import fetch_openml # Fetch OpenML
from sklearn.metrics import accuracy_score # Accuracy Score

# Importing the necessary libraries for the visualization
import matplotlib.pyplot as plt # Matplotlib
%matplotlib inline



# Importing the necessary libraries for ViT
import torch # PyTorch
import torch.nn as nn # Neural Network Module
import torch.nn.functional as F # Functional Module that contains activation functions
import torchvision # TorchVision
from einops import rearrange # Rearrange function that rearranges the dimensions of the input tensor
from einops.layers.torch import Rearrange # Rearrange Layer that rearranges the dimensions of the input tensor
from torch import einsum # Einsum function that performs a contraction on the input tensor, similar to a Convolutional Layer
from torchvision import transforms # Transforms that are used for data augmentation
from torchsummary import summary # Summary function that displays the architecture of the model


## 2. Getting the data

As we did in Notebook 1, we will have to download the dataset. This time the file is stored as a zip file, so we will need to extract it. 

In [None]:
def download_file(url="https://www.dropbox.com/s/x70hm8mxqhe7fa6/bee_vs_wasp.tar.gz?dl=1", filename="bee_vs_wasp.tar.gz"):

    # Download the file using requests
    response = requests.get(url, stream=True)

    # Create a file object and write the response content in chunks
    with open(filename, "wb") as f:
        for chunk in response.iter_content(chunk_size=8192):
            f.write(chunk)

    # Wait for the file to finish downloading
    while not os.path.exists(filename):
        time.sleep(1)

    # Print a success message
    print(f"Downloaded {filename} successfully.")

def extract_file(filename, data_folder):
    # Check if the file is a tar file
    if tarfile.is_tarfile(filename):
        # Open the tar file
        tar = tarfile.open(filename, "r:gz")
        # Extract all the files to the data folder
        tar.extractall(data_folder)
        # Close the tar file
        tar.close()
        # Print a success message
        print(f"Extracted {filename} to {data_folder} successfully.")
    else:
        # Print an error message
        print(f"{filename} is not a valid tar file.")
    
def manage_data(folder_name='bee_vs_wasp'):
    '''Try to find the data for the exercise and return the path'''
    
    # Check common paths of where the data might be on different systems
    likely_paths= [os.path.normpath(f'/blue/practicum-ai/share/data/{folder_name}'),
                   os.path.normpath(f'/project/scinet_workshop2/data/{folder_name}'),
                   os.path.join('data', folder_name),
                   os.path.normpath(folder_name)]
    
    for path in likely_paths:
        if os.path.exists(path):
            print(f'Found data at {path}.')
            return path

    answer = input(f'Could not find data in the common locations. Do you know the path? (yes/no): ')

    if answer.lower() == 'yes':
        path = os.path.join(os.path.normpath(input('Please enter the path to the data folder: ')),folder_name)
        if os.path.exists(path):
            print(f'Thanks! Found your data at {path}.')
            return path
        else:
            print(f'Sorry, that path does not exist.')
    
    answer = input('Do you want to download the data? (yes/no): ')

    if answer.lower() == 'yes':

        ''' Check and see if the downloaded data is inside the .gitignore file, and adds them to the list of files to ignore if not. 
        This is to prevent the data from being uploaded to the repository, as the files are too large for GitHub.'''
        
        if os.path.exists('.gitignore'):
            with open('.gitignore', 'r') as f:
                ignore = f.read().split('\n')
        # If the .gitignore file does not exist, create a new one
        elif not os.path.exists('.gitignore'):
            with open('.gitignore', 'w') as f:
                f.write('')
            ignore = []
        else:
            ignore = []

        # Check if the .gz file is in the ignore list
        if 'bee_vs_wasp.tar.gz' not in ignore:
            ignore.append('bee_vs_wasp.tar.gz')
            
        # Check if the data/ folder is in the ignore list
        if 'data/' not in ignore:
            ignore.append('data/')

        # Write the updated ignore list back to the .gitignore file
        with open('.gitignore', 'w') as f:
            f.write('\n'.join(ignore))

        print("Updated .gitignore file.")
        print('Downloading data, this may take a minute.')
        download_file()
        print('Data downloaded, unpacking')
        extract_file("bee_vs_wasp.tar.gz", "data")
        print('Data downloaded and unpacked. Now available at data/bee_vs_wasp.')
        return os.path.normpath('data/bee_vs_wasp')   

    print('Sorry, I cannot find the data. Please download it manually from https://www.dropbox.com/s/x70hm8mxqhe7fa6/bee_vs_wasp.tar.gz and unpack it to the data folder.')      


data_path = manage_data()

## 3. Explore the dataset

We will take a look at the dataset to see what it contains. We will also look at the annotations file, which contains the bounding box information for each image.

In [None]:
def check_data(data_path):
    '''Check the contents of the data folder'''
    
    # Check the contents of the data folder
    data_folder = pathlib.Path(data_path)
    # Check if the data folder exists
    if data_folder.exists():
        # Create a list of the contents of the data folder
        data_folder_contents = list(data_folder.glob('*/*'))
        # Create a list of the classes in the data folder
        classes = [item.name for item in data_folder.glob('*') if item.is_dir()]
        # Print the classes in the data folder
        print(f'The data folder contains the following classes: {classes}')
        # Create a dictionary to store the class distribution
        class_distribution = {}
        # Create a list to store the paths of the images
        image_paths = []
        # Create a list to store the labels of the images
        labels = []
        # Loop through the classes in the data folder
        for class_name in classes:
            # Create a list of the contents of the class folder
            class_folder_contents = list(data_folder.glob(f'{class_name}/*'))
            # Filter out any directories that start with ".ipynb_checkpoints"
            class_folder_contents = [item for item in class_folder_contents if not item.name.startswith('.ipynb_checkpoints')]
            # Print the number of images in the class folder
            print(f'The class {class_name} contains {len(class_folder_contents)} images.')
            # Update the class distribution dictionary with the class name and the number of images
            class_distribution[class_name] = len(class_folder_contents)
            # Add the paths of the images to the image paths list
            image_paths.extend(class_folder_contents[:4])
            # Add the labels of the images to the labels list
            labels.extend([class_name]*4)
        # Create a figure to display the images
        fig, axes = plt.subplots(4, 4, figsize=(20, 10))
        # Loop through the image paths and labels
        for i, (image_path, label) in enumerate(zip(image_paths, labels)):
            # Load the image
            image = Image.open(image_path)
            # Display the image
            axes[i//4, i%4].imshow(image)
            axes[i//4, i%4].set_title(label)
            axes[i//4, i%4].axis('off')
        # Show the figure
        plt.show()
        # Create a figure to display the class distribution
        fig, ax = plt.subplots()
        # Plot the class distribution as a histogram
        ax.bar(class_distribution.keys(), class_distribution.values())
        # Set the title of the plot
        ax.set_title('Class Distribution')
        # Set the x-label of the plot
        ax.set_xlabel('Class')

        # Set the y-label of the plot
        ax.set_ylabel('Number of Images')
        # Show the plot
        plt.show()
    else:
        # Print an error message if the data folder does not exist
        print(f'The data folder {data_folder} does not exist.')

check_data(data_path)

There's a lot to unpack here! First, depending on the images sampled, we can see that the images are of different sizes and have different numbers of fruits. Other things to note:
- Some of the images don't really contain fruits...
- The file names are the same as the image names, but with a .txt extension.
- The annotations file contains the class ID of the fruit (0 corresponds to 'Apple', etc.), and the bounding box coordinates. 
- The bounding box coordinates are in the format (x_min, y_min, x_max, y_max).
- The bounding box coordinates are normalized, meaning that they are scaled to be between 0 and 1. This is a common practice in object detection tasks.
- The bounding boxes in some of the images are not very accurate. This is a common problem in object detection tasks.
- The dataset is very imbalanced, with a lot more oranges than other fruits.

## 4. Build the model

In [None]:
# Set up our ViT model architecture

# Define the Patch Embedding Layer
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, patch_size, emb_size):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.projection = nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # Rearrange the input tensor x to have a new shape
        x = self.projection(x)
        x = rearrange(x, 'b e (h) (w) -> b (h w) e', h=x.shape[2], w=x.shape[3])
        return x
    
# Define the Multi-Head Self-Attention Layer

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size, num_heads, dropout=0.0):
        super(MultiHeadAttention, self).__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        self.fc_out = nn.Linear(emb_size, emb_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Split the input tensor x into num_heads pieces
        queries = rearrange(self.queries(x), 'b n (h d) -> b h n d', h=self.num_heads)
        keys = rearrange(self.keys(x), 'b n (h d) -> b h n d', h=self.num_heads)
        values = rearrange(self.values(x), 'b n (h d) -> b h n d', h=self.num_heads)
        
        # Perform a scaled dot-product attention
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) / (self.emb_size ** 0.5)
        attention = F.softmax(energy, dim=-1)
        attention = self.dropout(attention)
        
        # Apply the attention to the values
        out = torch.einsum('bhal, bhlv -> bhav', attention, values)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.fc_out(out)
        return out
    
# Define the Multi-Layer Perceptron (MLP) Layer

class MLP(nn.Module):
    def __init__(self, emb_size, mlp_hidden_dim, dropout=0.0):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(emb_size, mlp_hidden_dim)
        self.fc2 = nn.Linear(mlp_hidden_dim, emb_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = F.gelu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x
    
class Transformer(nn.Module):
    def __init__(self, emb_size, num_heads, mlp_hidden_dim, dropout=0.0):
        super(Transformer, self).__init__()
        self.attention = MultiHeadAttention(emb_size, num_heads, dropout)
        self.norm1 = nn.LayerNorm(emb_size)
        self.mlp = MLP(emb_size, mlp_hidden_dim, dropout)
        self.norm2 = nn.LayerNorm(emb_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = x + self.dropout(self.attention(self.norm1(x)))
        x = x + self.dropout(self.mlp(self.norm2(x)))
        return x
    
# Define the Vision Transformer (ViT) Model

class ViT(nn.Module):
    def __init__(self, in_channels, patch_size, emb_size, img_size, num_classes, num_layers, num_heads, mlp_hidden_dim, dropout=0.0):
        super(ViT, self).__init__()
        num_patches = (img_size // patch_size) ** 2
        self.patch_embedding = PatchEmbedding(in_channels, patch_size, emb_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, emb_size))
        self.transformer = nn.Sequential(*[
            Transformer(emb_size, num_heads, mlp_hidden_dim, dropout)
            for _ in range(num_layers)
        ])
        self.fc = nn.Linear(emb_size, num_classes)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.patch_embedding(x)
        b, n, _ = x.shape
        cls_tokens = self.cls_token.expand(b, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        x += self.pos_embedding
        x = self.dropout(x)
        x = self.transformer(x)
        x = x.mean(dim=1)
        x = self.fc(x)
        return x

## 5. Create the model


In [None]:
# Create a ViT model instance

# Define the model hyperparameters
in_channels = 3
patch_size = 16
emb_size = 768
img_size = 128
num_classes = 6
num_layers = 12
num_heads = 12
mlp_hidden_dim = 3072
dropout = 0.1


model = ViT(in_channels, patch_size, emb_size, img_size, num_classes, num_layers, num_heads, mlp_hidden_dim, dropout)


## 6. Train the model

In [None]:
# Fit the model to the data

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# Define the training function
def train(model, data_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for images, labels in data_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(data_loader)

# Define the validation function
def validate(model, data_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return running_loss / len(data_loader), correct / total

# Define the test function
def test(model, data_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

# Set the device to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Move the model to the device
model.to(device)

# Define the batch size
batch_size = 32

# Load the dataset
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load the training dataset
train_dataset = torchvision.datasets.ImageFolder(root=os.path.join(data_dir, 'train'), transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Load the validation dataset
val_dataset = torchvision.datasets.ImageFolder(root=os.path.join(data_dir, 'valid'), transform=transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Load the test dataset
test_dataset = torchvision.datasets.ImageFolder(root=os.path.join(data_dir, 'test'), transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [None]:
# Train the model
num_epochs = 10
for epoch in range(num_epochs):
    train_loss = train(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    print(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

## 6. Evaluate the results

Let's look at those evaluation metrics we mentioned above. YOLOv8 creates a **runs** folder that stores each training run. We'll pull them up here and examine what they mean.

In [None]:
# Plot the evaluation results

# Find the latest training run
training = sorted(os.listdir('runs/detect/.'))
latest_training = training[-1]
print(f'Latest training run: {latest_training}')

# Plot the .png files in the latest training run
for file in os.listdir(f'runs/detect/{latest_training}'):
    if file.endswith('.png'):
        # Exclude the normalized confusion matrix since it's redundant
        if 'normalized' in file:
            continue
        img = Image.open(f'runs/detect/{latest_training}/{file}')
        plt.imshow(img)
        plt.axis('off')
        plt.show()

## Breaking down the graphs
If your stats are a little rusty like Kevin's, you might need a bit of refresher to make sense of the graphs above. Expand the section below for a rundown on the terms used above.

<details>
    
<summary><h2>Click to Expand for Stats Terms!</h2></summary>

<p>

##### What is Precision?
Precision is the ratio of correctly predicted positive observations to the total predicted positives. An example in our fruit object detection task would be the ratio of correctly predicted apples to the total predicted apples. Higher precision values are better, as they indicate that the model is making more accurate predictions.

##### What is Recall?
Recall is the ratio of correctly predicted positive observations to the all observations in actual class. An example in our fruit object detection task would be the ratio of correctly predicted apples to the total actual apples. Higher recall values are better, as they indicate that the model is making more accurate predictions.

##### What is Confidence?
Confidence is the probability that a model assigns to a prediction. In our fruit object detection task, the confidence is the probability that a model assigns to a fruit being an apple, orange, or any other fruit. Higher confidence values are better, as they indicate that the model is more certain about its predictions.

##### What is a confusion matrix?
A confusion matrix is a table that is often used to describe the performance of a classification model on a set of test data for which the true values are known. It allows the visualization of the performance of an algorithm. The confusion matrix shows the ways in which your classification model is confused when it makes predictions. It gives you insight not only into the errors being made by your classifier but more importantly the types of errors that are being made.

##### What is the F1 Confidence Curve?
The F1 Confidence Curve is a plot of the F1 score against the confidence threshold. The F1 score is the harmonic mean of precision and recall, and it is a measure of a model's accuracy. The confidence threshold is the minimum confidence level that a model must have in order to make a prediction. The F1 Confidence Curve is used to determine the optimal confidence threshold for a model. The curve shows how the F1 score changes as the confidence threshold is varied. The goal is to find the confidence threshold that maximizes the F1 score. The highest point on the curve tells you the optimal confidence threshold where the model strikes the best balance between precision and recall. In our fruit object detection task, the F1 Confidence Curve would show how the F1 score changes as the confidence threshold is varied for each fruit class. Higher F1 scores are better, as they indicate that the model is making more accurate predictions.
    
##### What is mAP50?
mAP50 stands for "mean Average Precision at 50% confidence". It is a common metric used in object detection tasks to evaluate the performance of a model. The mAP50 is the average of the precision values at different recall levels, where the recall is calculated at a confidence threshold of 50%. A higher mAP50 value indicates better performance.
    
##### What is mAP50-95?
mAP50-95 is the mean average precision (mAP) over the range of intersection over union (IoU) thresholds from 0.5 to 0.95. IoU is a measure of the overlap between two bounding boxes. It is calculated as the area of the intersection of the two bounding boxes divided by the area of their union. An IoU of 0 means that the two bounding boxes do not overlap at all, while an IoU of 1 means that the two bounding boxes are identical. The mAP50-95 metric is a more comprehensive metric than the mAP50 metric, which only considers the IoU threshold of 0.5. The mAP50-95 metric provides a more complete picture of the model's performance across a range of IoU thresholds.

</p>
</details>

In addition to those graphs, we can look at the predictions themselves. 

In [None]:
# Plot the labels and predictions from the last training run

# Load the images
img1 = Image.open(f'runs/detect/{latest_training}/val_batch0_labels.jpg')
img2 = Image.open(f'runs/detect/{latest_training}/val_batch0_pred.jpg')

# Plot the images
fig, axes = plt.subplots(1, 2, figsize=(15, 10))
axes[0].imshow(img1)
axes[0].axis('off')
axes[0].set_title('Ground Truth')
axes[1].imshow(img2)
axes[1].axis('off')
axes[1].set_title('Predictions')
plt.show()


From the prediction images, you'll probably notice this model loves its citrus. That's no surprise given the histogram we looked at at the beginning of this notebook. What could we do to improve this?

## 7. Inference
How does the model fair on some test images? After you run the cell below:
1. Find your own image of fruit.
2. Upload it to this folder.
3. Add or edit the code below to run on the new image rather than images in the test folder.


In [None]:
# Run the model on a few test images
data_dir = r"datasets/fruits_detection"

# Get ten random test images
test_images = []
for folder in os.listdir(os.path.join(data_dir, 'test', 'images')):
    img_path = os.path.join(data_dir, 'test', 'images', folder)
    test_images.append(img_path)
test_images = random.sample(test_images, 10)

# Plot the test images
fig, axes = plt.subplots(2, 5, figsize=(20, 10))
for i, img_path in enumerate(test_images):
    img = Image.open(img_path)
    axes[i // 5, i % 5].imshow(img)
    axes[i // 5, i % 5].axis('off')
plt.show()

# Run the model on the test images
infer_results = model.predict(test_images)

for img_path, result in zip(test_images, infer_results):
    img = Image.open(img_path)
    img = img.resize((640, 640))

    # Create the plot for this image
    plt.imshow(img)
    plt.axis('off')

    # Access bounding boxes and class information
    boxes = result.boxes
    class_ids = boxes.cls.cpu().numpy()  # Class IDs are on the CPU
    confidences = boxes.conf.cpu().numpy()  # Confidences are on the CPU

    for box, class_id, conf in zip(boxes.xyxy, class_ids, confidences):  
        # boxes.xyxy are in [xmin, ymin, xmax, ymax] format
        x1, y1, x2, y2 = box.cpu()  # Move coordinates to CPU
        w, h = x2 - x1, y2 - y1
        rect = plt.Rectangle((x1, y1), w, h, fill=False, color='red', linewidth=2)
        plt.gca().add_patch(rect)
        plt.text(x1, y1, result.names[int(class_id)], color='red')  

    # Display the plot
    plt.show()

## 8. Explore hyperparameters

Now that you have a good baseline, consider how you might deal with this model's issues.
- How would you address issues in the dataset?
- How would you optimize training?

A good first place to start would be YOLOv8's documentation so you can understand what hyperparameters you have access to and how changing them will affect training. Make some adjustments and see how high you can get your fruits object detection F1 score!

## Bonus Exercises

- You might have noticed the *yolov8n.pt* file that is added to the folder when you load YOLO. That is the pre-trained model using the [COCO dataset](https://cocodataset.org/#home). The YOLOv8 documentation linked above provides instructions for transer learning and fine-tuning with the pre-trained model. Give it a shot!

- If you're feeling very bold and want the full Object Detection Experience, find an image labeler and build your own dataset! If you use proprietary data, be sure that the software you use to label meets your institution's guidelines for data sharing. If you go this route, you will almost certainly want to use the fine-tuning option mentioned above, otherwise you're going to be spending quite a long time labeling...