In [None]:
!pip install wandb

In [2]:
cd /content/drive/MyDrive/DeepLearningTasks/FaceRecognition-ExpertMode

/content/drive/MyDrive/DeepLearningTasks/FaceRecognition-ExpertMode


### Settig Parameters and Imports

In [35]:
%%writefile config.py

width = height = 112
batch_size = 16
epochs = 30

Overwriting config.py


In [45]:
import os
import numpy as np
import cv2
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Dense, Flatten, MaxPooling2D, Dropout
from keras import Input
import wandb
from wandb.keras import WandbCallback
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tqdm import tqdm
from config import *

### Preparing Dataset

In [4]:
data_path = '/content/drive/MyDrive/7-7 dataset'

data_generator = ImageDataGenerator(
    rescale = 1./255,
    validation_split=0.2
)

train_data = data_generator.flow_from_directory(
    data_path,
    target_size=(width, height),
    batch_size = batch_size,
    class_mode='categorical',
    shuffle=True,
    subset='training'
)


val_data = data_generator.flow_from_directory(
    data_path,
    target_size=(width, height),
    batch_size = batch_size,
    class_mode='categorical',
    shuffle=False,
    subset='validation'
)


label_map = (train_data.class_indices)
np.save('label_map.npy', label_map) 
num_classes = len(np.bincount(train_data.labels))

print('num_classes: {}'.format(num_classes))
print(np.bincount(train_data.labels))
print(np.bincount(val_data.labels))

Found 1091 images belonging to 14 classes.
Found 268 images belonging to 14 classes.
num_classes: 14
[73 81 82 80 84 71 80 60 80 77 78 90 74 81]
[18 20 20 20 21 17 19 15 20 19 19 22 18 20]


### Define Model and wandb Initializing

1-1- My Model

In [46]:
class MyModel(tf.keras.Model):
  def __init__(self, num_class, input_shape):
    super().__init__()
    self.conv1 = Conv2D(32, (3, 3), activation='relu', input_shape = input_shape)
    self.conv2 = Conv2D(64, (3, 3), activation='relu')
    self.conv3 = Conv2D(128, (3, 3), activation='relu')
    self.conv4 = Conv2D(256, (3, 3), activation='relu')
    self.maxpool = MaxPooling2D()
    self.flatten = Flatten()
    self.dropout = Dropout(0.2)
    self.fc1 = Dense(128, activation='relu')
    self.fc2 = Dense(64, activation='relu')
    self.fc3= Dense(num_class, activation='softmax')
    self.dim = input_shape


  def call(self, x):
    conv1 = self.conv1(x)
    maxpool1 = self.maxpool(conv1)

    conv2 = self.conv2(maxpool1)
    maxpool2 = self.maxpool(conv2)

    conv3 = self.conv3(maxpool2)
    maxpool3 = self.maxpool(conv3)

    conv4 = self.conv4(maxpool3)
    maxpool4 = self.maxpool(conv4)

    flatten = self.flatten(maxpool4)
    dropout = self.dropout(flatten)
    
    fc1 = self.fc1(dropout)
    dropout = self.dropout(fc1)
    
    fc2 = self.fc2(dropout)
    dropout = self.dropout(fc2)
    
    output = self.fc3(dropout)

    return output

  def build_graph(self):
      x = Input(shape=(self.dim))
      return tf.keras.Model(inputs=[x], outputs=self.call(x))


input_shape = (width ,height ,3)
model = MyModel(num_classes, (input_shape))
model.build((None, *input_shape))
model.build_graph().summary()

Model: "model_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_6 (InputLayer)            [(None, 112, 112, 3) 0                                            
__________________________________________________________________________________________________
conv2d_40 (Conv2D)              (None, 110, 110, 32) 896         input_6[0][0]                    
__________________________________________________________________________________________________
max_pooling2d_8 (MaxPooling2D)  multiple             0           conv2d_40[0][0]                  
                                                                 conv2d_41[0][0]                  
                                                                 conv2d_42[0][0]                  
                                                                 conv2d_43[0][0]            

In [None]:
wandb.init(project='FaceRecognition-ExpertMode')
config = wandb.config
config.learning_rate=0.0001

### Define Loss and Training Loop

In [48]:
loss_function = tf.keras.losses.CategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam(learning_rate=config.learning_rate)

In [49]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_acc = tf.keras.metrics.CategoricalAccuracy(name='train_acc')

val_loss = tf.keras.metrics.Mean(name='val_loss')
val_acc = tf.keras.metrics.CategoricalAccuracy(name='val_acc')

In [50]:
def train_step(images, y):
  with tf.GradientTape() as tape:
    y_pred =  model(images, training = True)
    loss = loss_function(y, y_pred)

  # calculate gradient
  gradient = tape.gradient(loss, model.trainable_variables)
  
  # update
  optimizer.apply_gradients(zip(gradient, model.trainable_variables))

  train_loss(loss)
  train_acc(y, y_pred)

In [51]:
def val_step(images, y):
  y_pred = model(images, training=False)
  loss = loss_function(y, y_pred)

  val_loss(loss)
  val_acc(y, y_pred)

In [52]:
def train():
  
  train_step_per_epoch = train_data.samples // batch_size
  val_step_per_epoch = val_data.samples // batch_size

  for epoch in range(epochs):
    train_loss.reset_state()
    train_acc.reset_state()
    val_loss.reset_state()
    val_acc.reset_state()

    
    for _ in tqdm(range(train_step_per_epoch)):
      images, labels = next(train_data)
      train_step(images, labels)
    
    for _ in tqdm(range(val_step_per_epoch)):
      images, labels = next(val_data)
      val_step(images, labels)

    print('epoch: {}'.format(epoch + 1))
    print('loss: {}'.format(train_loss.result()))
    print('accuracy: {}'.format(train_acc.result()))
    print('val_loss: {}'.format(val_loss.result()))
    print('val_accuracy: {}'.format(val_acc.result()))


    # log metrics using wandb.log
    wandb.log({'epochs':  epoch + 1,
                'loss': np.mean(train_loss.result()),
                'acc': float(train_acc.result()), 
                'val_loss': np.mean(val_loss.result()),
                'val_acc':float(val_acc.result())})

In [53]:
train()

100%|██████████| 68/68 [00:19<00:00,  3.49it/s]
100%|██████████| 16/16 [00:04<00:00,  3.83it/s]


epoch: 1
loss: 2.636568069458008
accuracy: 0.07534883916378021
val_loss: 2.611361503601074
val_accuracy: 0.0833333358168602


100%|██████████| 68/68 [00:19<00:00,  3.58it/s]
100%|██████████| 16/16 [00:04<00:00,  3.90it/s]


epoch: 2
loss: 2.592245101928711
accuracy: 0.13209302723407745
val_loss: 2.482846975326538
val_accuracy: 0.2976190447807312


100%|██████████| 68/68 [00:19<00:00,  3.57it/s]
100%|██████████| 16/16 [00:04<00:00,  3.90it/s]


epoch: 3
loss: 2.3394315242767334
accuracy: 0.23720930516719818
val_loss: 2.0370912551879883
val_accuracy: 0.4126984179019928


100%|██████████| 68/68 [00:19<00:00,  3.55it/s]
100%|██████████| 16/16 [00:04<00:00,  3.91it/s]


epoch: 4
loss: 1.9643079042434692
accuracy: 0.3488371968269348
val_loss: 1.6303080320358276
val_accuracy: 0.4722222089767456


100%|██████████| 68/68 [00:19<00:00,  3.54it/s]
100%|██████████| 16/16 [00:04<00:00,  3.83it/s]


epoch: 5
loss: 1.6475915908813477
accuracy: 0.4502325654029846
val_loss: 1.3865840435028076
val_accuracy: 0.5714285969734192


100%|██████████| 68/68 [00:19<00:00,  3.52it/s]
100%|██████████| 16/16 [00:04<00:00,  3.90it/s]


epoch: 6
loss: 1.540524959564209
accuracy: 0.5041860342025757
val_loss: 1.3172893524169922
val_accuracy: 0.6111111044883728


100%|██████████| 68/68 [00:19<00:00,  3.51it/s]
100%|██████████| 16/16 [00:04<00:00,  3.90it/s]


epoch: 7
loss: 1.3113093376159668
accuracy: 0.5832558274269104
val_loss: 1.2582142353057861
val_accuracy: 0.6150793433189392


100%|██████████| 68/68 [00:19<00:00,  3.53it/s]
100%|██████████| 16/16 [00:04<00:00,  3.89it/s]


epoch: 8
loss: 1.1835156679153442
accuracy: 0.6139534711837769
val_loss: 1.1214003562927246
val_accuracy: 0.6507936716079712


100%|██████████| 68/68 [00:19<00:00,  3.54it/s]
100%|██████████| 16/16 [00:04<00:00,  3.89it/s]


epoch: 9
loss: 1.1117478609085083
accuracy: 0.6353488564491272
val_loss: 1.059070348739624
val_accuracy: 0.7142857313156128


100%|██████████| 68/68 [00:19<00:00,  3.54it/s]
100%|██████████| 16/16 [00:04<00:00,  3.92it/s]


epoch: 10
loss: 1.0067936182022095
accuracy: 0.6790697574615479
val_loss: 0.8845313787460327
val_accuracy: 0.7341269850730896


100%|██████████| 68/68 [00:19<00:00,  3.51it/s]
100%|██████████| 16/16 [00:04<00:00,  3.88it/s]


epoch: 11
loss: 0.9133316874504089
accuracy: 0.70790696144104
val_loss: 0.9504942893981934
val_accuracy: 0.7222222089767456


100%|██████████| 68/68 [00:19<00:00,  3.53it/s]
100%|██████████| 16/16 [00:04<00:00,  3.94it/s]


epoch: 12
loss: 0.830324113368988
accuracy: 0.7339534759521484
val_loss: 0.9632441997528076
val_accuracy: 0.7142857313156128


100%|██████████| 68/68 [00:19<00:00,  3.56it/s]
100%|██████████| 16/16 [00:04<00:00,  3.82it/s]


epoch: 13
loss: 0.748000979423523
accuracy: 0.7525581121444702
val_loss: 0.8080567717552185
val_accuracy: 0.76171875


100%|██████████| 68/68 [00:19<00:00,  3.53it/s]
100%|██████████| 16/16 [00:04<00:00,  3.91it/s]


epoch: 14
loss: 0.6928069591522217
accuracy: 0.7674418687820435
val_loss: 0.7819676995277405
val_accuracy: 0.7817460298538208


100%|██████████| 68/68 [00:19<00:00,  3.54it/s]
100%|██████████| 16/16 [00:04<00:00,  3.88it/s]


epoch: 15
loss: 0.6234845519065857
accuracy: 0.806511640548706
val_loss: 0.8339870572090149
val_accuracy: 0.7579365372657776


100%|██████████| 68/68 [00:19<00:00,  3.54it/s]
100%|██████████| 16/16 [00:04<00:00,  3.85it/s]


epoch: 16
loss: 0.5811430811882019
accuracy: 0.8176743984222412
val_loss: 0.8494935035705566
val_accuracy: 0.7698412537574768


100%|██████████| 68/68 [00:19<00:00,  3.53it/s]
100%|██████████| 16/16 [00:04<00:00,  3.94it/s]


epoch: 17
loss: 0.5330876111984253
accuracy: 0.8241860270500183
val_loss: 0.7843586206436157
val_accuracy: 0.7658730149269104


100%|██████████| 68/68 [00:19<00:00,  3.54it/s]
100%|██████████| 16/16 [00:04<00:00,  3.92it/s]


epoch: 18
loss: 0.5095170736312866
accuracy: 0.8316279053688049
val_loss: 0.7897982597351074
val_accuracy: 0.761904776096344


100%|██████████| 68/68 [00:19<00:00,  3.53it/s]
100%|██████████| 16/16 [00:04<00:00,  3.90it/s]


epoch: 19
loss: 0.4013980031013489
accuracy: 0.8762790560722351
val_loss: 0.8090662360191345
val_accuracy: 0.7817460298538208


100%|██████████| 68/68 [00:19<00:00,  3.53it/s]
100%|██████████| 16/16 [00:04<00:00,  3.85it/s]


epoch: 20
loss: 0.41479194164276123
accuracy: 0.8651162981987
val_loss: 0.8082689046859741
val_accuracy: 0.7777777910232544


100%|██████████| 68/68 [00:19<00:00,  3.54it/s]
100%|██████████| 16/16 [00:04<00:00,  3.84it/s]


epoch: 21
loss: 0.3756784200668335
accuracy: 0.8865116238594055
val_loss: 0.6631747484207153
val_accuracy: 0.8015872836112976


100%|██████████| 68/68 [00:19<00:00,  3.41it/s]
100%|██████████| 16/16 [00:04<00:00,  3.72it/s]


epoch: 22
loss: 0.36110758781433105
accuracy: 0.8706976771354675
val_loss: 0.7182731032371521
val_accuracy: 0.8055555820465088


100%|██████████| 68/68 [00:20<00:00,  3.40it/s]
100%|██████████| 16/16 [00:04<00:00,  3.69it/s]


epoch: 23
loss: 0.33967792987823486
accuracy: 0.8799999952316284
val_loss: 0.7838874459266663
val_accuracy: 0.7857142686843872


100%|██████████| 68/68 [00:19<00:00,  3.43it/s]
100%|██████████| 16/16 [00:04<00:00,  3.82it/s]


epoch: 24
loss: 0.3033677637577057
accuracy: 0.8976744413375854
val_loss: 0.6791784763336182
val_accuracy: 0.8333333134651184


100%|██████████| 68/68 [00:19<00:00,  3.53it/s]
100%|██████████| 16/16 [00:04<00:00,  3.98it/s]


epoch: 25
loss: 0.25388339161872864
accuracy: 0.9181395173072815
val_loss: 0.7957773208618164
val_accuracy: 0.817460298538208


100%|██████████| 68/68 [00:18<00:00,  3.60it/s]
100%|██████████| 16/16 [00:04<00:00,  3.92it/s]


epoch: 26
loss: 0.2476063072681427
accuracy: 0.9116278886795044
val_loss: 0.8154959678649902
val_accuracy: 0.8015872836112976


100%|██████████| 68/68 [00:18<00:00,  3.60it/s]
100%|██████████| 16/16 [00:03<00:00,  4.04it/s]


epoch: 27
loss: 0.2381187379360199
accuracy: 0.9265116453170776
val_loss: 0.6258754730224609
val_accuracy: 0.841269850730896


100%|██████████| 68/68 [00:19<00:00,  3.50it/s]
100%|██████████| 16/16 [00:04<00:00,  3.67it/s]


epoch: 28
loss: 0.22210989892482758
accuracy: 0.9311627745628357
val_loss: 0.7616292834281921
val_accuracy: 0.8095238208770752


100%|██████████| 68/68 [00:20<00:00,  3.35it/s]
100%|██████████| 16/16 [00:04<00:00,  3.68it/s]


epoch: 29
loss: 0.18624596297740936
accuracy: 0.9376744031906128
val_loss: 0.7430930733680725
val_accuracy: 0.817460298538208


100%|██████████| 68/68 [00:20<00:00,  3.37it/s]
100%|██████████| 16/16 [00:04<00:00,  3.59it/s]


epoch: 30
loss: 0.18340212106704712
accuracy: 0.9413953423500061
val_loss: 0.7341812252998352
val_accuracy: 0.81640625


100%|██████████| 68/68 [00:20<00:00,  3.32it/s]
100%|██████████| 16/16 [00:04<00:00,  3.94it/s]


epoch: 31
loss: 0.15785783529281616
accuracy: 0.9488372206687927
val_loss: 0.7166642546653748
val_accuracy: 0.8333333134651184


100%|██████████| 68/68 [00:18<00:00,  3.61it/s]
100%|██████████| 16/16 [00:03<00:00,  4.01it/s]


epoch: 32
loss: 0.19436943531036377
accuracy: 0.9358139634132385
val_loss: 0.8168303966522217
val_accuracy: 0.8134920597076416


100%|██████████| 68/68 [00:18<00:00,  3.64it/s]
100%|██████████| 16/16 [00:04<00:00,  4.00it/s]


epoch: 33
loss: 0.14992006123065948
accuracy: 0.952558159828186
val_loss: 0.889971911907196
val_accuracy: 0.8095238208770752


100%|██████████| 68/68 [00:18<00:00,  3.65it/s]
100%|██████████| 16/16 [00:04<00:00,  3.94it/s]


epoch: 34
loss: 0.18822398781776428
accuracy: 0.934883713722229
val_loss: 0.7917313575744629
val_accuracy: 0.8015872836112976


100%|██████████| 68/68 [00:19<00:00,  3.57it/s]
100%|██████████| 16/16 [00:03<00:00,  4.03it/s]


epoch: 35
loss: 0.14502890408039093
accuracy: 0.950697660446167
val_loss: 0.744623064994812
val_accuracy: 0.817460298538208


100%|██████████| 68/68 [00:18<00:00,  3.66it/s]
100%|██████████| 16/16 [00:03<00:00,  4.03it/s]


epoch: 36
loss: 0.16980038583278656
accuracy: 0.9516279101371765
val_loss: 0.9215313196182251
val_accuracy: 0.7857142686843872


100%|██████████| 68/68 [00:18<00:00,  3.71it/s]
100%|██████████| 16/16 [00:03<00:00,  4.05it/s]


epoch: 37
loss: 0.15086960792541504
accuracy: 0.9627906680107117
val_loss: 0.7807472944259644
val_accuracy: 0.8095238208770752


100%|██████████| 68/68 [00:18<00:00,  3.61it/s]
100%|██████████| 16/16 [00:03<00:00,  4.02it/s]


epoch: 38
loss: 0.14101479947566986
accuracy: 0.9609302282333374
val_loss: 0.6482131481170654
val_accuracy: 0.8253968358039856


100%|██████████| 68/68 [00:18<00:00,  3.66it/s]
100%|██████████| 16/16 [00:03<00:00,  4.06it/s]


epoch: 39
loss: 0.1397770792245865
accuracy: 0.9534883499145508
val_loss: 0.711073637008667
val_accuracy: 0.829365074634552


100%|██████████| 68/68 [00:18<00:00,  3.64it/s]
100%|██████████| 16/16 [00:04<00:00,  3.92it/s]


epoch: 40
loss: 0.09143207222223282
accuracy: 0.9776744246482849
val_loss: 0.7403008341789246
val_accuracy: 0.841269850730896


100%|██████████| 68/68 [00:18<00:00,  3.62it/s]
100%|██████████| 16/16 [00:04<00:00,  3.97it/s]


epoch: 41
loss: 0.08669579774141312
accuracy: 0.9758139252662659
val_loss: 0.8374333381652832
val_accuracy: 0.8214285969734192


100%|██████████| 68/68 [00:18<00:00,  3.69it/s]
100%|██████████| 16/16 [00:03<00:00,  4.11it/s]


epoch: 42
loss: 0.11547188460826874
accuracy: 0.9590697884559631
val_loss: 0.8481805920600891
val_accuracy: 0.8333333134651184


100%|██████████| 68/68 [00:18<00:00,  3.69it/s]
100%|██████████| 16/16 [00:03<00:00,  4.06it/s]


epoch: 43
loss: 0.1034543439745903
accuracy: 0.9693022966384888
val_loss: 0.788997232913971
val_accuracy: 0.8214285969734192


100%|██████████| 68/68 [00:18<00:00,  3.66it/s]
100%|██████████| 16/16 [00:03<00:00,  4.06it/s]


epoch: 44
loss: 0.07476228475570679
accuracy: 0.9804651141166687
val_loss: 0.7074779272079468
val_accuracy: 0.8492063283920288


100%|██████████| 68/68 [00:18<00:00,  3.63it/s]
100%|██████████| 16/16 [00:03<00:00,  4.06it/s]


epoch: 45
loss: 0.07153719663619995
accuracy: 0.9702325463294983
val_loss: 0.7347460389137268
val_accuracy: 0.8690476417541504


100%|██████████| 68/68 [00:18<00:00,  3.67it/s]
100%|██████████| 16/16 [00:03<00:00,  4.02it/s]


epoch: 46
loss: 0.08088342100381851
accuracy: 0.9776744246482849
val_loss: 0.9370545744895935
val_accuracy: 0.8015872836112976


100%|██████████| 68/68 [00:18<00:00,  3.63it/s]
100%|██████████| 16/16 [00:04<00:00,  3.96it/s]


epoch: 47
loss: 0.08347026258707047
accuracy: 0.9720930457115173
val_loss: 0.7446317076683044
val_accuracy: 0.84765625


100%|██████████| 68/68 [00:18<00:00,  3.64it/s]
100%|██████████| 16/16 [00:03<00:00,  4.04it/s]


epoch: 48
loss: 0.07747302204370499
accuracy: 0.968372106552124
val_loss: 0.8426114916801453
val_accuracy: 0.85317462682724


100%|██████████| 68/68 [00:18<00:00,  3.59it/s]
100%|██████████| 16/16 [00:03<00:00,  4.05it/s]


epoch: 49
loss: 0.0809091255068779
accuracy: 0.9770220518112183
val_loss: 0.8185986876487732
val_accuracy: 0.8492063283920288


100%|██████████| 68/68 [00:18<00:00,  3.65it/s]
100%|██████████| 16/16 [00:03<00:00,  4.03it/s]

epoch: 50
loss: 0.06746424734592438
accuracy: 0.9720930457115173
val_loss: 0.9579612612724304
val_accuracy: 0.8134920597076416





In [24]:
model.save_weights(filepath='checkpoint',save_format='HDF5')