# Implementação de uma ViT

Author: Lucas Silva

## Introdução

Usando o ViT: [Vision Transformer (ViT)](https://arxiv.org/abs/2010.11929)
model por Alexey Dosovitskiy et al. para classificação de ervas daninhas no cultivo da soja.

Esse modelo ViT aplica a Arquitetura Transformer com auto-atenção para sequências de patches de imagens, **SEM** utilização de camadas de convolução.

Será necessário o uso do TF Addons:
[TensorFlow Addons](https://www.tensorflow.org/addons/overview),

## Setup

In [1]:
# Bibliotecas
import os, time, random, sys
os.environ['PYTHONHASHSEED']=str(1)
import numpy as np
import tensorflow as tf
import pandas as pd
import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TensorBoard, CSVLogger
from keras.optimizers import Adam
import csv
from keras.models import Model, load_model

# Preparando os Dados

In [2]:
def runSeed():
    global seed
    seed=12
    os.environ['PYTHONHASHSEED']=str(12)
    tf.random.set_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

runSeed()

# Configurando os HyperParameters

In [3]:
learning_rate = 0.001
weight_decay = 0.0001
NUM_EPOCHS = 15
MAX_EPOCH = 20
RAW_IMG_SIZE = (256, 256)
IMG_SIZE = (224, 224)
INPUT_SHAPE = (IMG_SIZE[0], IMG_SIZE[1], 3)
BATCH_SIZE = 32
FOLDS = 5
STOPPING_PATIENCE = 32
LR_PATIENCE = 16
INITIAL_LR = 0.0001

In [4]:
#------------- Parâmetros ViT-Base -------------------------
patch_size = 16  # Tamanho dos patches para serem extraidos.
num_patches = (224 // patch_size) ** 2
projection_dim = 64
num_heads = 12
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  

# Tamanho das camadas de transformação.
transformer_layers = 12
mlp_head_units = [2048, 1024]

In [5]:
IMG_DIRECTORY = '/kaggle/input/deepweeds/images/'
LABEL_DIRECTORY = '/kaggle/input/deepweeds/labels/'
OUTPUT_DIRECTORY = './'

In [6]:
CLASSES_STR = ['0', '1', '2', '3', '4', '5', '6', '7', '8']
CLASSES = [0, 1, 2, 3, 4, 5, 6, 7, 8]

CLASS_NAMES = ['Chinee Apple',
               'Lantana',
               'Parkinsonia',
               'Parthenium',
               'Prickly Acacia',
               'Rubber Vine',
               'Siam Weed',
               'Snake Weed',
               'Negatives']
NUM_CLASSES=9

In [7]:
def crop(img, size):
    (h, w, c) = img.shape
    x = int((w - size[0]) / 2)
    y = int((h - size[1]) / 2)
    return img[y:(y + size[1]), x:(x + size[0]), :]

def crop_generator(batches, size):
    while True:
        batch_x, batch_y = next(batches)
        (b, h, w, c) = batch_x.shape
        batch_crops = np.zeros((b, size[0], size[1], c))
        for i in range(b):
            batch_crops[i] = crop(batch_x[i], (size[0], size[1]))
        yield (batch_crops, batch_y)

# Setando os Dataframes

In [8]:
from keras.preprocessing.image import ImageDataGenerator

for k in range(FOLDS):
        # Prepare training, validation and testing labels for kth fold
        train_label_file = "{}train_subset{}.csv".format(LABEL_DIRECTORY, k)
        val_label_file = "{}val_subset{}.csv".format(LABEL_DIRECTORY, k)
        test_label_file = "{}test_subset{}.csv".format(LABEL_DIRECTORY, k)
        train_dataframe = pd.read_csv(train_label_file)
        val_dataframe = pd.read_csv(val_label_file)
        test_dataframe = pd.read_csv(test_label_file)
        train_image_count = train_dataframe.shape[0]
        val_image_count = train_dataframe.shape[0]
        test_image_count = test_dataframe.shape[0]
        train_dataframe['Label'] = train_dataframe.Label.astype(str)
        val_dataframe['Label'] = val_dataframe.Label.astype(str)
        test_dataframe['Label'] = test_dataframe.Label.astype(str)

        # Training image augmentation
        train_data_generator = ImageDataGenerator(
            rescale=1. / 255,
            fill_mode="constant",
            shear_range=0.2,
            zoom_range=(0.5, 1),
            horizontal_flip=True,
            rotation_range=360,
            channel_shift_range=25,
            brightness_range=(0.75, 1.25))

        # Validation image augmentation
        val_data_generator = ImageDataGenerator(
            rescale=1. / 255,
            fill_mode="constant",
            shear_range=0.2,
            zoom_range=(0.5, 1),
            horizontal_flip=True,
            rotation_range=360,
            channel_shift_range=25,
            brightness_range=(0.75, 1.25))

        # No testing image augmentation (except for converting pixel values to floats)
        test_data_generator = ImageDataGenerator(rescale=1. / 255)

        # Load train images in batches from directory and apply augmentations
        train_data_generator = train_data_generator.flow_from_dataframe(
            train_dataframe,
            IMG_DIRECTORY,
            x_col='Filename',
            y_col='Label',
            target_size=RAW_IMG_SIZE,
            batch_size=BATCH_SIZE,
            #classes=CLASSES,
            class_mode='categorical')

        # Load validation images in batches from directory and apply rescaling
        val_data_generator = val_data_generator.flow_from_dataframe(
            val_dataframe,
            IMG_DIRECTORY,
            x_col="Filename",
            y_col="Label",
            target_size=RAW_IMG_SIZE,
            batch_size=BATCH_SIZE,
            #classes=CLASSES,
            class_mode='categorical')

        # Load test images in batches from directory and apply rescaling
        test_data_generator = test_data_generator.flow_from_dataframe(
            test_dataframe,
            IMG_DIRECTORY,
            x_col="Filename",
            y_col="Label",
            target_size=IMG_SIZE,
            batch_size=BATCH_SIZE,
            shuffle=False,
            #classes=CLASSES,
            class_mode='categorical')
        
        # Crop augmented images from 256x256 to 224x224
        train_data_generator = crop_generator(train_data_generator, IMG_SIZE)
        val_data_generator = crop_generator(val_data_generator, IMG_SIZE)

Found 10501 validated image filenames belonging to 9 classes.
Found 3501 validated image filenames belonging to 9 classes.
Found 3507 validated image filenames belonging to 9 classes.
Found 10504 validated image filenames belonging to 9 classes.
Found 3502 validated image filenames belonging to 9 classes.
Found 3503 validated image filenames belonging to 9 classes.
Found 10506 validated image filenames belonging to 9 classes.
Found 3502 validated image filenames belonging to 9 classes.
Found 3501 validated image filenames belonging to 9 classes.
Found 10506 validated image filenames belonging to 9 classes.
Found 3503 validated image filenames belonging to 9 classes.
Found 3500 validated image filenames belonging to 9 classes.
Found 10508 validated image filenames belonging to 9 classes.
Found 3503 validated image filenames belonging to 9 classes.
Found 3498 validated image filenames belonging to 9 classes.


# Implementando a Percepção de Multi Camada - Multi Layer Perception (MLP)

In [9]:
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

# Declarando a criação de patches como uma camada da rede

In [10]:
class Patches(layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches

# Implementando a camada de Patch Encoding

A camada de `PatchEncoder` irá realizar a transformação linear de um patch, fazendo a projeção em um vetor de tamanho `projection_dim`. Junto, irá realizar a adição da posição de embedding para o vetor projetado.

In [11]:
class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

## Construindo o modelo do ViT

O modelo ViT consiste de múltiplos blocos Transformer, onde usamos o `layer.MultiHeadAttetion` como camada de self-attention, aplicando em uma sequência de patches. Os blocos Transformers produzem um tensor: `[batch_size,num_patches,projection_dim]`, que será processado via uma cabeça classificadora, com um softmax para produzir as probabilidades de saída.

In [12]:
def create_vit_classifier():
    inputs = layers.Input(shape=INPUT_SHAPE)
    # Criacao de patches
    patches = Patches(patch_size)(inputs)
    # Encode  dos patches.
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    # Camadas do block transformer (range é limite de camadas)
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Camada MLP
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_patches])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])

    # Cria um Tensor [batch_size, projection_dim].
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    # Adiciona MLP.
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    # Classifica.
    logits = layers.Dense(NUM_CLASSES,activation="sigmoid")(features)
    # Cria o modelo Keras.
    model = keras.Model(inputs=inputs, outputs=logits)
    return model


# Compilando, treinando e avaliando o modelo.

In [13]:
# Callbacks

model_checkpoint = ModelCheckpoint(OUTPUT_DIRECTORY + "lastbest-0.hdf5",save_weights_only=True, verbose=1, save_best_only=True)
early_stopping = EarlyStopping(patience=STOPPING_PATIENCE, restore_best_weights=True)
tensorboard = TensorBoard(log_dir=OUTPUT_DIRECTORY, histogram_freq=0, write_graph=True, write_images=False)
reduce_lr = ReduceLROnPlateau('val_loss', factor=0.5, patience=LR_PATIENCE, min_lr=0.000003125)
csv_logger = CSVLogger(OUTPUT_DIRECTORY + "training_metrics.csv")

In [14]:
model = create_vit_classifier()
model.compile(loss='binary_crossentropy', optimizer=Adam(lr=INITIAL_LR), metrics=['categorical_accuracy'])

global_epoch = 0
restarts = 0
last_best_losses = []
last_best_epochs = []
while global_epoch < MAX_EPOCH:
    history = model.fit(
        train_data_generator,
        steps_per_epoch=train_image_count // BATCH_SIZE,
        epochs=MAX_EPOCH - global_epoch, #alterar depois
        validation_data=val_data_generator,
        validation_steps=val_image_count // BATCH_SIZE,
        callbacks=[tensorboard, model_checkpoint, early_stopping, reduce_lr, csv_logger],
        shuffle=False)
    if early_stopping.stopped_epoch == 0:
        print("Completed training after {} epochs.".format(MAX_EPOCH))
        break
    else:
        global_epoch = global_epoch + early_stopping.stopped_epoch - STOPPING_PATIENCE + 1
        print("Early stopping triggered after local epoch {} (global epoch {}).".format(
            early_stopping.stopped_epoch, global_epoch))
        print("Restarting from last best val_loss at local epoch {} (global epoch {}).".format(
            early_stopping.stopped_epoch - STOPPING_PATIENCE, global_epoch - STOPPING_PATIENCE))
        restarts = restarts + 1
        model.compile(loss='binary_crossentropy', optimizer=Adam(lr=INITIAL_LR / 2 ** restarts),
                      metrics=['categorical_accuracy'])
        model_checkpoint = ModelCheckpoint(OUTPUT_DIRECTORY + "lastbest-{}.hdf5".format(restarts),
                                           monitor='val_loss',save_weights_only=True, verbose=1, save_best_only=True, mode='min')

# Save last best model info
# with open(OUTPUT_DIRECTORY + "last_best_models.csv", 'w', newline='') as file:
#     writer = csv.writer(file, delimiter=',')
#     writer.writerow(['Model file', 'Global epoch', 'Validation loss'])
#     for i in range(restarts + 1):
#         writer.writerow(["lastbest-{}.hdf5".format(i), last_best_epochs[i], last_best_losses[i]])

Epoch 1/20

Epoch 00001: val_loss improved from inf to 0.27752, saving model to ./lastbest-0.hdf5
Epoch 2/20

Epoch 00002: val_loss did not improve from 0.27752
Epoch 3/20

Epoch 00003: val_loss improved from 0.27752 to 0.26727, saving model to ./lastbest-0.hdf5
Epoch 4/20

Epoch 00004: val_loss did not improve from 0.26727
Epoch 5/20

Epoch 00005: val_loss improved from 0.26727 to 0.23462, saving model to ./lastbest-0.hdf5
Epoch 6/20

Epoch 00006: val_loss improved from 0.23462 to 0.22788, saving model to ./lastbest-0.hdf5
Epoch 7/20

Epoch 00007: val_loss improved from 0.22788 to 0.21836, saving model to ./lastbest-0.hdf5
Epoch 8/20

Epoch 00008: val_loss did not improve from 0.21836
Epoch 9/20

Epoch 00009: val_loss improved from 0.21836 to 0.21436, saving model to ./lastbest-0.hdf5
Epoch 10/20

Epoch 00010: val_loss improved from 0.21436 to 0.20160, saving model to ./lastbest-0.hdf5
Epoch 11/20

Epoch 00011: val_loss did not improve from 0.20160
Epoch 12/20

Epoch 00012: val_loss i

#  Trabalhando o modelo

In [15]:
model.save_weights('./checkpoints/ViT_base_weights')
model_weight_save = ModelCheckpoint(OUTPUT_DIRECTORY + "model_best_weights.hdf5",save_weights_only=True, verbose=1, save_best_only=True)

## Metricas por Classe

In [16]:
from sklearn.metrics import confusion_matrix, classification_report
from keras.models import Model, load_model

# model_load = create_vit_classifier()
# model_load.load_weights('./checkpoints/ViT_base_weights')

# # Evaluate model on test subset for kth fold
# predictions = model_load.predict(test_data_generator, test_image_count // BATCH_SIZE + 1)
# y_true = test_data_generator.classes
# y_pred = np.argmax(predictions, axis=1)
# y_pred[np.max(predictions, axis=1) < 1 / 9] = 8  # Assign predictions worse than random guess to negative class

# Evaluate model on test subset for kth fold
# predictions = model.predict(test_data_generator, test_image_count // BATCH_SIZE + 1)
# y_true = test_data_generator.classes
# y_pred = np.argmax(predictions, axis=1)
# y_pred[np.max(predictions, axis=1) < 1 / 9] = 8  # Assign predictions worse than random guess to negative class

# Métricas por Classe

In [17]:
predictions = model.predict_generator(test_data_generator, test_image_count // BATCH_SIZE + 1)
y_true = test_data_generator.classes
y_pred = np.argmax(predictions, axis=1)
y_pred[np.max(predictions, axis=1) < 1 / 9] = 8

# Generate and print classification metrics and confusion matrix
print(classification_report(y_true, y_pred, labels=CLASSES, target_names=CLASS_NAMES))
report = classification_report(y_true, y_pred, labels=CLASSES, target_names=CLASS_NAMES, output_dict=True)
with open('classification_report.csv', 'w') as f:
    for key in report.keys():
        f.write("%s,%s\n" % (key, report[key]))
conf_arr = confusion_matrix(y_true, y_pred, labels=CLASSES)
print(conf_arr)

#Get the confusion matrix
cm = conf_arr

#Now the normalize the diagonal entries
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

#The diagonal entries are the accuracies of each class
cm.diagonal()



                precision    recall  f1-score   support

  Chinee Apple       0.33      0.05      0.09       225
       Lantana       0.56      0.46      0.50       212
   Parkinsonia       0.37      0.72      0.49       206
    Parthenium       0.34      0.09      0.15       204
Prickly Acacia       0.55      0.79      0.65       212
   Rubber Vine       0.73      0.18      0.29       201
     Siam Weed       0.53      0.58      0.55       214
    Snake Weed       0.36      0.09      0.15       203
     Negatives       0.70      0.84      0.76      1821

      accuracy                           0.61      3498
     macro avg       0.50      0.42      0.40      3498
  weighted avg       0.59      0.61      0.57      3498

[[  12    9   18    1    9    0    7   29  140]
 [   4   97    1    1    0    2   33    2   72]
 [   0    0  149    0   21    0    0    0   36]
 [   0    3   45   19   39    0    1    0   97]
 [   0    0   36    0  168    0    0    0    8]
 [   8   11    6    0   20   

array([0.05333333, 0.45754717, 0.72330097, 0.09313725, 0.79245283,
       0.1840796 , 0.57943925, 0.09359606, 0.8380011 ])

In [18]:
print(y_pred)

[6 8 8 ... 8 2 8]
