# __Original notebook can be found here:__ https://www.kaggle.com/code/moh3we5/traffic-sign-dataset-resnet-classification

# Import Libraries

In [None]:
import os
import glob
import pandas as pd
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.utils import plot_model
from tensorflow.keras.models import Sequential,Model
from tensorflow.keras.layers import Dense, Conv2D, MaxPool2D, Flatten
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.layers import Dense, Dropout, Flatten, BatchNormalization
from tensorflow.keras.preprocessing import image_dataset_from_directory

tf.random.set_seed(42)
np.random.seed(42) 

# Load the Data

### __NOTE__: ImageDataGenerator is __depricated__

## Train Data

In [None]:
train_datagen = ImageDataGenerator(zoom_range=0.5, shear_range=0.8, horizontal_flip=True, rescale=1/255, validation_split=0.2)

In [None]:
# base_path = os.path.basename('images')
train_path = os.path.join('images', 'traffic_Data', 'TRAIN')

data_train_gen = train_datagen.flow_from_directory(
    train_path,
    target_size=(224,224),
    batch_size=32,
    class_mode='categorical',
    color_mode='rgb',
    seed = 1234,
    shuffle = True,
    subset='training') # set as training data

data_val_gen = train_datagen.flow_from_directory(
    train_path,
    target_size=(224,224),
    batch_size=32,
    class_mode='categorical',
    color_mode='rgb',
    seed = 1234,
    shuffle = True,
    subset='validation') # set as validation data


# Load the Data V2

### __NOTE__: image_dataset_from_directory is more up to data data loader

In [None]:
data_train_gen = image_dataset_from_directory(
    train_path,
    labels='inferred',
    label_mode='int',
    color_mode='rgb',
    batch_size=32,
    image_size=(224,224),
    shuffle=True,
    seed=42,
    validation_split=0.2,
    subset='training',
    interpolation='bilinear'
)


data_val_gen = image_dataset_from_directory(
    train_path,
    labels='inferred',
    label_mode='int',
    color_mode='rgb',
    batch_size=32,
    image_size=(224,224),
    shuffle=True,
    seed=42,
    validation_split=0.2,
    subset='validation',
    interpolation='bilinear'
)

In [None]:
data_train_gen.class_indices

## Test Data

In [None]:
all_images = []
for img_path in glob.glob("images/traffic_Data/TEST/*"):
    img = Image.open(img_path)
    img = img.resize((224, 224))
    img_array = np.array(img)
    img_array = img_array / 255.0  # divide by 255.0 to get float values between 0 and 1 (Rescale)
    all_images.append(img_array)

In [None]:
plt.imshow(all_images[40])

# Import the Model

In [None]:
ResNet50_model = ResNet50(weights='imagenet',
                  include_top=False,
                  input_shape=(224,224,3)
                  )

# Model Architecture

In [None]:
print(ResNet50_model.trainable)
ResNet50_model.trainable = False
print(ResNet50_model.trainable)

In [None]:
# ResNet50_model.summary()

In [None]:
# plot_model(ResNet50_model, to_file= 'ResNet50_model.png', show_shapes = True, show_layer_names=True)

In [None]:
# Transfer Learning
flatten_layer1 = Flatten()(ResNet50_model.output)
final_layer = Dense(len(data_train_gen.class_indices),activation='Softmax')(flatten_layer1)

In [None]:
model=Model(inputs=ResNet50_model.input,outputs=final_layer)
# model.summary()

In [None]:
# plot_model(model, to_file= 'model.png', show_shapes = True, show_layer_names=True)

In [None]:
for layer in model.layers:
    print(f'{layer} is trainable: {layer.trainable}')

In [None]:
len(model.trainable_weights)

# Model training

In [None]:
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

In [None]:
model_training = model.fit(data_train_gen, epochs=10, validation_data=data_val_gen)

# History

In [None]:
# History
pd.DataFrame(model_training.history)

In [None]:
# Plot the Losses
plt.plot(model_training.history['loss'])
plt.plot(model_training.history['val_loss'])

In [None]:
# Plot the Losses
plt.plot(model_training.history['accuracy'])
plt.plot(model_training.history['val_accuracy'])

# Save the Model

In [None]:
# model.save('Traffic_ResNet50_94%.h5')