# 1. Importing dependencies, setting environment

In [None]:
import os
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

In [None]:
import scipy
import multiprocessing as mp

from PIL import Image

import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img, array_to_img
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Dropout, GlobalAveragePooling2D
from tensorflow.keras.optimizers import AdamW
from tensorflow.keras.regularizers import l2

In [None]:
random_seed = 997
# seed added for results reproducibility
tf.random.set_seed(random_seed)
np.random.seed(random_seed)

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

In [None]:
! rocm-smi

In [None]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
  # restriction for TensorFlow to only use the first GPU
  try:
    tf.config.set_visible_devices(gpus[0], 'GPU')
    logical_gpus = tf.config.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU")
  except RuntimeError as e:
    print(e)

# 2. Image generation by augmentation

In [None]:
# GLOBALS
IMG_X, IMG_Y = 320, 213 # proportion: 3:2

In [None]:
def augment_image_for_country(
        country, directories, datagen, num_augmented_images=5
    ) -> None:
    country_input_path = os.path.join(directories['source_dir'], country)
    country_output_path = os.path.join(directories['augmented_dir'], country)
    os.makedirs(country_output_path, exist_ok=True)

    for image_name in os.listdir(country_input_path):
        img_path = os.path.join(country_input_path, image_name)
        img = load_img(img_path)
        img = img.resize((IMG_X, IMG_Y))
        x = img_to_array(img)
        x = x.reshape((1,) + x.shape)
        # to size: (1, width, height, channels)
        
        i = 0
        for batch in datagen.flow(
            x, batch_size=1, save_to_dir=country_output_path, save_prefix='aug', save_format='jpeg'
        ):
            i += 1
            if i >= num_augmented_images:
                break

def augment_images(directories, data_generator, num_augmented_images=5) -> None:
    countries = os.listdir(directories['source_dir'])
    pool = mp.Pool(mp.cpu_count())
    pool.starmap(
        augment_image_for_country,
        [(country, directories, data_generator, num_augmented_images) for country in countries]
    )
    pool.close()
    pool.join()

In [None]:
# obj for augmentation:
datagen = ImageDataGenerator(
    rotation_range=90.0,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.4,
    zoom_range=[0.8, 1.2],
    channel_shift_range=51.0, # coding 0-255
    horizontal_flip=True,
    fill_mode='nearest'
)

In [None]:
dir_dict = {
    'source_dir': 'data/country_flag',
    'augmented_dir': 'data/augmented_flags',
    'tfrecords': 'data/tfrecords',
    'train_tfrecord': 'data/tfrecords/train.tfrecord',
    'val_tfrecord': 'data/tfrecords/val.tfrecord',
    'test_tfrecord': 'data/tfrecords/test.tfrecord',
    'model': 'models/flag_classifier_model.h5',
    'checkpoints': 'models/checkpoints/ckpt.weights.keras'
}

In [None]:
if not os.path.exists(dir_dict['augmented_dir']):
    os.makedirs(dir_dict['augmented_dir'])
    augment_images(dir_dict, datagen, num_augmented_images=200)
    print("Augmentation completed.")

# 3. Creation of TFRecords

In [None]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def image_example(image_string, label):
    feature = {
        'label': _int64_feature(label),
        'image_raw': _bytes_feature(image_string),
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))

def write_tfrecord(images, labels, output_path):
    """Function to write images to TFRecord."""
    with tf.io.TFRecordWriter(output_path) as writer:
        for img_path, label in zip(images, labels):
            img = Image.open(img_path)
            img = img.resize((IMG_X, IMG_Y))
            img = img.convert('RGB')
            img_byte_arr = img.tobytes()

            tf_example = image_example(img_byte_arr, label)
            writer.write(tf_example.SerializeToString())

def convert_to_tfrecord(directories, test_size=0.2, val_size=0.2):
    countries = os.listdir(directories['augmented_dir'])
    label_map = {country: idx for idx, country in enumerate(countries)}

    all_images = []
    all_labels = []

    for country, label in label_map.items():
        country_input_path = os.path.join(directories['augmented_dir'], country)
        for image_name in os.listdir(country_input_path):
            img_path = os.path.join(country_input_path, image_name)
            all_images.append(img_path)
            all_labels.append(label)
        
    train_images, test_images, train_labels, test_labels = train_test_split(
        all_images, all_labels, test_size=test_size, stratify=all_labels
    )
    train_images, val_images, train_labels, val_labels = train_test_split(
        train_images, train_labels, test_size=val_size, stratify=train_labels
    )

    write_tfrecord(train_images, train_labels, directories['train_tfrecord'])
    write_tfrecord(val_images, val_labels, directories['val_tfrecord'])
    write_tfrecord(test_images, test_labels, directories['test_tfrecord'])

In [None]:
if not os.path.exists(dir_dict['tfrecords']):
    os.makedirs(dir_dict['tfrecords'])
    convert_to_tfrecord(dir_dict, test_size=0.1, val_size=0.1)
    print("TFRecord conversion completed.")

In [None]:
print(os.listdir(dir_dict['tfrecords']))

# 4. Reading TFRecords

In [None]:
# GLOBALS
BATCH_SIZE = 32
NUM_CLASSES = len(os.listdir(dir_dict['source_dir']))
f'Number of flag classes: {NUM_CLASSES}'

In [None]:
def _parse_function(proto):
    """ Function is used to parse the example. Returns respectively image and its label. """
    keys_to_features = {
        'label': tf.io.FixedLenFeature([], tf.int64),
        'image_raw': tf.io.FixedLenFeature([], tf.string),
    }
    parsed_features = tf.io.parse_single_example(proto, keys_to_features)
    
    image = tf.io.decode_raw(parsed_features['image_raw'], tf.uint8)
    image = tf.reshape(image, [IMG_Y, IMG_X, 3])
    image = tf.cast(image, tf.float32) / 255.0  # normalization of image
    label = tf.cast(parsed_features['label'], tf.int32)
    
    return image, label

def load_tfrecord_dataset(tfrecord_path, batch_size=32):
    """ Creates a dataset from the TFRecord files. """
    dataset = tf.data.TFRecordDataset(tfrecord_path)
    dataset = dataset.map(_parse_function, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    return dataset

In [None]:
train_dataset = load_tfrecord_dataset(dir_dict['train_tfrecord'])
val_dataset = load_tfrecord_dataset(dir_dict['val_tfrecord'])
test_dataset = load_tfrecord_dataset(dir_dict['test_tfrecord'])

In [None]:
for images, labels in train_dataset.take(1):
    print(images.shape, '\n', labels)

In [None]:
plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(np.array(images[i] * 255).astype("uint8"))
        plt.title(int(labels[i]))
        plt.axis("off")

# 5. Model definition

In [None]:
# GLOBALS
NUM_EPOCHS = 15
PATIENCE = 8

In [None]:
base_model = ResNet50(
    weights='imagenet', include_top=False, input_shape=(IMG_Y, IMG_X, 3)
)
io_x = base_model.output
io_x = GlobalAveragePooling2D()(io_x)
io_x = Dense(1024, activation='selu', kernel_regularizer=l2(5e-5))(io_x)
io_x = Dropout(0.2)(io_x)
predictions = Dense(NUM_CLASSES, activation='softmax')(io_x)

In [None]:
model = Model(inputs=base_model.input, outputs=predictions)

for layer in base_model.layers:
    layer.trainable = False

In [None]:
OPTIMIZER = AdamW(learning_rate=2e-5)
LOSS = 'sparse_categorical_crossentropy'

In [None]:
metrics = [
    'accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()
]

In [None]:
model.compile(
    optimizer=OPTIMIZER, 
    loss=LOSS, 
    metrics=metrics
)

In [None]:
# Callbacks
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    dir_dict['checkpoints'], monitor='val_accuracy',
    mode='max',save_best_only=True
)
early_stopping_cb = tf.keras.callbacks.EarlyStopping(
    patience=PATIENCE, restore_best_weights=True
)
callbacks = [checkpoint_cb, early_stopping_cb]

# 6. Model training

In [None]:
history = model.fit(
    train_dataset, validation_data=val_dataset, epochs=NUM_EPOCHS, 
    verbose=1, callbacks=callbacks
)

In [None]:
model.summary()

In [None]:
# model.save_weights(dir_dict['checkpoints'])

In [None]:
model = model.load_weights(dir_dict['checkpoints'])

# 7. Fine - tuning

In [None]:
# unfreezing some layers for training continue
# https://medium.com/@kenneth.ca95/a-guide-to-transfer-learning-with-keras-using-resnet50-a81a4a28084b
UNFREEZE_FROM = 143
FT_EPOCHS = 10

In [None]:
for layer in base_model.layers[:UNFREEZE_FROM]:
    layer.trainable = False
for layer in base_model.layers[UNFREEZE_FROM:]:
    layer.trainable = True

model.compile(
    optimizer=OPTIMIZER, 
    loss=LOSS, 
    metrics=metrics
)

In [None]:
history = model.fit(
    train_dataset, validation_data=val_dataset, epochs=FT_EPOCHS, 
    verbose=1, callbacks=callbacks
)

In [None]:
model.summary()

# 8. Evaluation

In [None]:
test_loss, test_acc, test_precision, test_recall = model.evaluate(test_dataset)

In [None]:
print('Test accuracy:', test_acc)
print('Test precision:', test_precision)
print('Test recall:', test_recall)

# 9. Model saving

In [None]:
#model.save(dir_dict['model'])