<a href="https://colab.research.google.com/github/Cassini-chris/Transfer_Learning_Image_Classification_Overwatch/blob/main/Transfer_Learning_Overwatch_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Import packages

In [None]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import matplotlib.pyplot as plt
import os
import zipfile
import random
from shutil import copyfile

In [None]:
#Check current directory
!pwd
#Go to directory
os.chdir('/tmp')
!pwd

#Remove Folder
!rm -rf overwatch
!rm -rf SOURCE_DATA

In [None]:
#PARAMETERS
IMG_HEIGHT = 224
IMG_WIDTH = 224

## Load data

In [None]:
#Mount Google Drive
from google.colab import drive
drive.mount('/content/gdrive')

#--- Input --- Location of ZIP File
ZIP_FILE = '/content/gdrive/My Drive/__TECH/_My Flask Apps/Overwatch_data/Overwatch_data.zip'
FOLDER_NAME = 'SOURCE_DATA/'
UNZIP_DIR = '/tmp/'+ FOLDER_NAME

#Read & Unzip .zip file in the directory in UNZIP_DIR
zip_ref = zipfile.ZipFile(ZIP_FILE, 'r')
zip_ref.extractall(UNZIP_DIR)
zip_ref.close()

#Declare path__
PATH = UNZIP_DIR
print(PATH)

In [None]:
overwatch_heros = ['tracer', 'reaper', 'widowmaker', 'pharah', 'reinhardt', 'mercy', 'torbjörn', 'hanzo', 'winston', 'zenyatta', 'bastion', 'symmetra', 'zarya', 'mccree', 'soldier76', 'lucio', 'roadhog', 'junkrat', 'dva', 'mei', 'genji', 'ana', 'sombra', 'orisa', 'doomfist', 'moira', 'brigitte', 'wreckingball', 'ashe', 'baptiste', 'sigma', 'echo']

try:
    os.mkdir('/tmp/overwatch')
    os.mkdir('/tmp/overwatch/training')
    os.mkdir('/tmp/overwatch/testing')
except OSError:
    pass

for hero in overwatch_heros:
  try:
    os.mkdir('/tmp/overwatch/training/'+hero)
    os.mkdir('/tmp/overwatch/testing/'+hero)
  except OSError:
    pass

In [None]:
 def split_data(SOURCE, TRAINING, TESTING, SPLIT_SIZE):
    files = []
    for filename in os.listdir(SOURCE):
        file = SOURCE + filename
        if os.path.getsize(file) > 0:
            files.append(filename)
        else:
            print(filename + " is zero length, so ignoring.")

    training_length = int(len(files) * SPLIT_SIZE)
    testing_length = int(len(files) - training_length)
    shuffled_set = random.sample(files, len(files))
    training_set = shuffled_set[0:training_length]
    testing_set = shuffled_set[-testing_length:]

    for filename in training_set:
        this_file = SOURCE + filename
        destination = TRAINING + filename
        copyfile(this_file, destination)

    for filename in testing_set:
        this_file = SOURCE + filename
        destination = TESTING + filename
        copyfile(this_file, destination)

split_size = 0.9

for hero in overwatch_heros:
  split_data(PATH + hero + "/", '/tmp/overwatch/training/'+hero+"/",  '/tmp/overwatch/testing/'+hero+"/", split_size)
  print("TRAINING: " + hero +': '+ str(len(os.listdir('/tmp/overwatch/training/'+hero))))
  print("TESTING: " + hero +': '+ str(len(os.listdir('/tmp/overwatch/testing/'+hero))) + '\n')

In [None]:
train_dir = os.path.join('/tmp/overwatch/', 'training')
validation_dir = os.path.join('/tmp/overwatch/', 'testing')

## Data preparation

In [None]:
train_image_generator = ImageDataGenerator( rescale=1./255,
                                            rotation_range=40,
                                            width_shift_range=0.2,
                                            height_shift_range=0.2,
                                            shear_range=0.2,
                                            zoom_range=0.2,
                                            horizontal_flip=True,
                                            fill_mode='nearest') # Generator for our training data

validation_image_generator = ImageDataGenerator(rescale=1./255,
                                                rotation_range=40,
                                                width_shift_range=0.2,
                                                height_shift_range=0.2,
                                                shear_range=0.2,
                                                zoom_range=0.2,
                                                horizontal_flip=True,
                                                fill_mode='nearest') # Generator for our validation data

In [None]:
train_data_gen = train_image_generator.flow_from_directory(directory=train_dir,
                                                              batch_size= 10,
                                                              shuffle=True,
                                                              target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                              class_mode='categorical')

In [None]:
val_data_gen = validation_image_generator.flow_from_directory(directory=validation_dir,
                                                              batch_size= 10,
                                                              shuffle=True,
                                                              target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                              class_mode='categorical')

In [None]:
labels = (train_data_gen.class_indices)
labels = dict((v,k) for k,v in labels.items())
print(labels)

### Visualize training images

In [None]:
sample_training_images, _ = next(train_data_gen)

In [None]:
# This function will plot images in the form of a grid with 1 row and 5 columns where images are placed in each column.
def plotImages(images_arr):
    fig, axes = plt.subplots(1, 5, figsize=(20,20))
    axes = axes.flatten()
    for img, ax in zip( images_arr, axes):
        ax.imshow(img)
        ax.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
plotImages(sample_training_images[:5])

## Create the model

In [None]:
IMG_SHAPE = (224, 224, 3)

# Create the base model from the pre-trained model MobileNet V2
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')

In [None]:
base_model.trainable = False
#base_model.summary()

In [None]:
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
prediction_layer = tf.keras.layers.Dense(units = 32, input_shape = (520,), activation='softmax')

model = tf.keras.Sequential([
  base_model,
  global_average_layer,
  prediction_layer
])

model.summary()

In [None]:
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

## Train  model

In [None]:
history = model.fit(
    train_data_gen,
    #steps_per_epoch=4,

    batch_size=500,
    epochs=30,

    validation_data=val_data_gen,
    #validation_steps=2
    
    verbose=1,
)

### Visualize training results

Now visualize the results after training the network.

In [None]:
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss=history.history['loss']
val_loss=history.history['val_loss']

epochs_range = range(30)

plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

In [None]:
#model.save('disney_model_2.h5')

In [None]:
#Test Random Image 
image_path = "/content/gdrive/My Drive/__TECH/_My Flask Apps/Overwatch_data/RANDOM"

def loadImages(path):
    image_file = sorted([os.path.join(path, file)
                          for file in os.listdir(path )
                          if file.endswith(('.jpg','.png'))])
    return image_file

image_list = loadImages(image_path)

path = np.array(image_list)
path_string = (path[2])

img = tf.io.read_file(path_string)
img = tf.image.decode_jpeg(img, channels=3)
img = tf.image.convert_image_dtype(img, tf.float32)
final_img = tf.image.resize(img, [IMG_WIDTH, IMG_HEIGHT])

plt.subplot(121), plt.imshow(final_img)

In [None]:
#Expand Tensor for Model (Input shape)
y = np.expand_dims(final_img, axis=0)

#Predict Image Tensor with model
prediction = model.predict(y)
prediction_squeeze = np.squeeze(prediction, axis=0)

label_array = np.array(labels)

#print(type(label))
for key, value in labels.items():
    real_label = prediction_squeeze[key]
    
    print ("{0:.0%}".format(real_label), value)

In [None]:
#predictions = [labels[k] for k in predicted_class_indices]