# Vision Transformers
The following networks will be used for musical genre classification. This is because the task of classify all the new music that is released nowadays it is impossible to be done by a human being.

## Libraries

In [ ]:
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa

from keras.utils import np_utils
from keras.utils import image_dataset_from_directory

from vit_keras import vit, utils, visualize

# Import function to plot the results
import plots

## Data and parameters

### Data parameters and paths

In [ ]:
# Randomize the initial network weights
random_seed = True

# Paths to where training, testing, and validation images are
database_dir = 'dataset'
train_dir = f'{database_dir}/training'
val_dir = f'{database_dir}/val'
test_dir = f'{database_dir}/test'

# Directory where to store weights of the model and results
root_dir = "results"
# Create root directory for results if it does not exist
if not os.path.exists(root_dir):
    os.makedirs(root_dir)

# Input dimension (number of subjects in our problem)
num_classes = 6

# Name of each gesture of the database
CLASSES = [x for x in os.listdir(train_dir) if os.path.isdir(os.path.join(train_dir, x))]
print(f'The classes to classify are: {CLASSES}')
#CLASSES = ['Alternative', 'Pop', 'Rock', 'Dance', 'Classical', 'Techno']

# Parameters that characterize the images
img_height = 369
img_width = 496
resize_size = 400
img_channels = 3 # although some images could be rgb, we work with grayscale images
color_mode = 'rgb'

### Configuration Training Parameters

In [ ]:
# Parameters that configures the training process
batch_size = 1  # Batch size
epochs = 5  # Number of epochs
initial_lr = 1e-5   # Learning rate
seed = 42  # Random number
modelRNN = 'ViT'  # RNN model which will be used
version = f'{modelRNN}_BS{batch_size}_E{epochs}_LR{initial_lr}'
experiment_dir = f'{root_dir}/{modelRNN}'

# Create experiment directory if it does not exist
if not os.path.exists(experiment_dir):
    os.makedirs(experiment_dir)

# Set random seed
if random_seed:
    seed = np.random.randint(0,2*31-1)
else:
    seed = 5
np.random.seed(seed)
tf.random.set_seed(seed)

### Data Load

In [ ]:
# 1. Generate train dataset (ds) from directory of samples
train_ds = image_dataset_from_directory(directory=train_dir,
                                        label_mode = 'categorical',
                                        class_names=CLASSES,
                                        batch_size=batch_size,
                                        color_mode=color_mode,
                                        image_size=(img_width,img_height), shuffle=True)

# 2. Generate validation dataset (ds) from directory of samples
val_ds  = image_dataset_from_directory(directory=val_dir,
                                       label_mode = 'categorical',
                                       class_names=CLASSES,
                                       batch_size=batch_size,
                                       color_mode=color_mode,
                                       image_size=(img_width,img_height))

# 3. Generate test dataset (ds) from directory of samples
test_ds = image_dataset_from_directory(directory=test_dir,
                                       label_mode = 'categorical',
                                       class_names=CLASSES,
                                       batch_size=batch_size,
                                       color_mode=color_mode,
                                       image_size=(img_width,img_height),
                                       shuffle = False)

### Data Preprocessing
Because Transformers divide the images, it is better to work with square shaped images
1. First the function with the needed transformations is defined
2. The transformations are applied

In [ ]:
def preprocess_image(image, resize_size):
    # Transformations which will be applied to each image in the dataset
    resized_image = cv2.resize(image.numpy(), (resize_size, resize_size))
    repeated_image = np.repeat(resized_image.reshape(resize_size, resize_size, 1), 3, axis=2)
    return repeated_image

In [ ]:
# 1. Training dataset
X_train = train_ds.map(lambda x, y: (preprocess_image(x, resize_size), y))
X_train_images = np.asarray([x for x, y in X_train])
y_train_labels = np.asarray([y for x, y in X_train])
y_train = np_utils.to_categorical(y_train_labels, num_classes)

# Validation dataset
X_val = val_ds.map(lambda x, y: (preprocess_image(x, resize_size), y))
X_val= np.asarray([x for x, y in X_val])
y_val_labels = np.asarray([y for x, y in X_val])
y_val = np_utils.to_categorical(y_val_labels, num_classes)

# Test dataset
X_test = test_ds.map(lambda x, y: (preprocess_image(x, resize_size), y))
X_test= np.asarray([x for x, y in X_test])
y_test_labels = np.asarray([y for x, y in X_test])
y_test = np_utils.to_categorical(y_test_labels, num_classes)

## Pre-trained model
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

### Load the transformer

In [ ]:
vit_model = vit.vit_b16(
    image_size=resize_size,
    activation='softmax',
    pretrained=True,
    include_top=True,
    pretrained_top=True
)

vit_model.summary()

### Example of Activation Map

In [ ]:
image = X_train[300]

attention_map = visualize.attention_map(model=vit_model, image=image)

# Plot results
fig, (ax1, ax2) = plt.subplots(ncols=2)
ax1.axis('off')
ax2.axis('off')
ax1.set_title('Original')
ax2.set_title('Attention Map')
_ = ax1.imshow(image, interpolation = 'none')
_ = ax2.imshow(attention_map, interpolation='none')

### Model training
1. Model definition
2. Model callbacks
3. Train

In [ ]:
for layer in vit_model.layers[:]:
    layer.trainable = False

model = tf.keras.Sequential([
    vit_model,
    tf.keras.layers.Flatten(),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(512, activation = tfa.activations.gelu),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(num_classes, 'softmax')
],
    name = 'vision_transformer')

model.summary()

In [ ]:
optimizer = tfa.optimizers.RectifiedAdam(learning_rate = initial_lr)

model.compile(optimizer = optimizer,
              loss = tf.keras.losses.CategoricalCrossentropy(label_smoothing = 0.2),
              metrics = ['accuracy'])


reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor = 'val_accuracy',
                                                 factor = 0.3,
                                                 patience = 5,
                                                 verbose = 1,
                                                 min_delta = 1e-4,
                                                 min_lr = 1e-6,
                                                 mode = 'max')

earlystopping = tf.keras.callbacks.EarlyStopping(monitor = 'val_accuracy',
                                                 min_delta = 1e-4,
                                                 patience = 20,
                                                 mode = 'max',
                                                 restore_best_weights = True,
                                                 verbose = 1)

checkpointer = tf.keras.callbacks.ModelCheckpoint(filepath = './model.hdf5',
                                                  monitor = 'val_accuracy',
                                                  verbose = 1,
                                                  save_best_only = True,
                                                  save_weights_only = True,
                                                  mode = 'max')

callbacks = [earlystopping, reduce_lr, checkpointer]

In [ ]:
history = model.fit(x = X_train, y = y_train, validation_data=(X_val, y_val), batch_size=batch_size, epochs=epochs,verbose=1)

model.save(f'{experiment_dir}/{version}.h5')

## Training Results
Accuracy and Loss obtained along the training process

In [ ]:
plots.accloss(history, modelRNN, experiment_dir, version)


## Testing
### Model Testing
1. Compute the loss function and accuracy for the test data
2. Confusion Matrix obtained from testing results

In [ ]:
# Evaluate model
scores = model.evaluate(test_ds, verbose=0)
print("Accuracy: %.2f%%" % (scores[1]*100))
print("Loss: %.2f" % scores[0])

# Obtain results to present the confusion matrix
prob_class = model.predict(X_test, batch_size=batch_size)
# Classified labels
y_pred = tf.argmax(prob_class, axis=-1)
# Visualize confusion matrix                                           
plots.cm(y_test, y_pred, modelRNN, CLASSES, experiment_dir, version)