# Squad404

In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2, VGG16
from tensorflow.keras.applications.resnet50 import ResNet50

from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from sklearn.model_selection import StratifiedKFold




### Loading the augmented images and converting to pd dataframe

In [4]:
# Reading the csv and images 
data_annotations_csv = 'resources/augmented_data/image_labels.csv'
augmented_images = 'resources/augmented_data'

data_file = pd.read_csv(data_annotations_csv, dtype={'label': str})


In [5]:
# Creating vectors of filenames and labels matched by index 
img_filenames_vector = data_file['filename'].values
labels_vector = data_file['label'].values

### Creating our model

Add prototypical network to our model before running it through the training

In [6]:
# Function to create a model from the pretrained CNN base you pass in
# Example usage: model = create_model(MobileNetv2)
def create_model(base):
    base_model = base(input_shape=(224, 224, 3), include_top=False)
    base_model.trainable = False # Freeze the base_model
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(1024, activation='relu')(x)
    predictions = Dense(1, activation='sigmoid')(x)
    model = Model(inputs=base_model.input, outputs=predictions)
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model


### Setting up k-fold

In [7]:
k = 5
kf = StratifiedKFold(n_splits=k, shuffle=True)

In [8]:
for train_index, val_index in kf.split(img_filenames_vector, labels_vector):
    print(f"train: {train_index}, val: {val_index}")

train: [  0   1   2   3   4   6   7   8   9  10  11  12  13  14  16  17  18  20
  21  23  24  25  27  28  31  32  33  34  35  36  38  39  40  41  42  44
  45  48  49  51  52  53  54  55  56  57  58  59  60  61  63  65  66  67
  68  69  70  72  74  76  77  78  79  80  83  84  85  87  88  89  90  91
  92  93  94  95  97  98  99 101 102 103 104 105 106 107 108 110 111 112
 113 114 115 116 117 118 119 120 121 122 123 125 126 127 128 129 131 132
 133 134 135 137 139 140 141 144 145 146 148 149 152 153 154 155 156 157
 158 159], val: [  5  15  19  22  26  29  30  37  43  46  47  50  62  64  71  73  75  81
  82  86  96 100 109 124 130 136 138 142 143 147 150 151]
train: [  0   2   4   5   6   7   8   9  10  11  12  13  15  16  17  18  19  20
  22  23  24  25  26  27  29  30  31  32  33  35  37  39  41  42  43  44
  45  46  47  48  49  50  51  52  53  54  55  56  60  61  62  63  64  65
  66  67  69  70  71  73  75  77  78  79  81  82  83  84  86  87  88  89
  90  91  96  97  99 100 101 102 103

In [9]:
CNN_scores = []
CNN_models_to_test = [MobileNetV2, ResNet50]

In [10]:
for CNN_base_model in CNN_models_to_test:
    fold_scores_for_CNN = []
    fold_no = 1
    print(f"TRAINING FOR {CNN_base_model}")
    for train_index, val_index in kf.split(img_filenames_vector, labels_vector):
        print(f"Fold number: {fold_no}/5")
        train_filenames, val_filenames = img_filenames_vector[train_index], img_filenames_vector[val_index]
        train_labels, val_labels = labels_vector[train_index], labels_vector[val_index]
        
        # Create ImageDataGenerator for train and validation
        train_datagen = ImageDataGenerator(rescale=1./255)
        val_datagen = ImageDataGenerator(rescale=1./255)
        
        # Create generators
        train_generator = train_datagen.flow_from_dataframe(
            dataframe=data_file.iloc[train_index],
            directory=augmented_images,
            x_col='filename',
            y_col='label',
            target_size=(224, 224),
            batch_size=16,
            class_mode='binary'
        )
        
        val_generator = val_datagen.flow_from_dataframe(
            dataframe=data_file.iloc[val_index],
            directory=augmented_images,
            x_col='filename',
            y_col='label',
            target_size=(224, 224),
            batch_size=16,
            class_mode='binary'
        )
        model = create_model(CNN_base_model)
        history = model.fit(train_generator, validation_data=val_generator, epochs=5)
        fold_scores_for_CNN.append(model.evaluate(val_generator))
        fold_no += 1
    CNN_scores.append(fold_scores_for_CNN)

TRAINING FOR <function MobileNetV2 at 0x30da8fce0>
Fold number: 1/5
Found 128 validated image filenames belonging to 2 classes.
Found 32 validated image filenames belonging to 2 classes.
Epoch 1/5


  self._warn_if_super_not_called()


[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 190ms/step - accuracy: 0.5425 - loss: 1.4929 - val_accuracy: 0.8125 - val_loss: 0.3889
Epoch 2/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 126ms/step - accuracy: 0.7074 - loss: 0.6151 - val_accuracy: 0.8750 - val_loss: 0.2888
Epoch 3/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 127ms/step - accuracy: 0.8817 - loss: 0.2435 - val_accuracy: 1.0000 - val_loss: 0.0842
Epoch 4/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 135ms/step - accuracy: 1.0000 - loss: 0.0779 - val_accuracy: 1.0000 - val_loss: 0.0574
Epoch 5/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 140ms/step - accuracy: 1.0000 - loss: 0.0369 - val_accuracy: 1.0000 - val_loss: 0.0411
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 104ms/step - accuracy: 1.0000 - loss: 0.0418
Fold number: 2/5
Found 128 validated image filenames belonging to 2 classes.
Found 32 validated image filena

  self._warn_if_super_not_called()


[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 194ms/step - accuracy: 0.6350 - loss: 0.7876 - val_accuracy: 0.9375 - val_loss: 0.2402
Epoch 2/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 128ms/step - accuracy: 0.9198 - loss: 0.1696 - val_accuracy: 0.8438 - val_loss: 0.2105
Epoch 3/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 129ms/step - accuracy: 0.9477 - loss: 0.1021 - val_accuracy: 0.9688 - val_loss: 0.1210
Epoch 4/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 130ms/step - accuracy: 1.0000 - loss: 0.0120 - val_accuracy: 1.0000 - val_loss: 0.0194
Epoch 5/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 128ms/step - accuracy: 1.0000 - loss: 0.0057 - val_accuracy: 1.0000 - val_loss: 0.0091
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 95ms/step - accuracy: 1.0000 - loss: 0.0092
Fold number: 3/5
Found 128 validated image filenames belonging to 2 classes.
Found 32 validated image filenam

  self._warn_if_super_not_called()


[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 190ms/step - accuracy: 0.6161 - loss: 1.2506 - val_accuracy: 0.5312 - val_loss: 0.8616
Epoch 2/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 127ms/step - accuracy: 0.8349 - loss: 0.3586 - val_accuracy: 0.7500 - val_loss: 0.4513
Epoch 3/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 127ms/step - accuracy: 0.8729 - loss: 0.2813 - val_accuracy: 0.9062 - val_loss: 0.1783
Epoch 4/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 127ms/step - accuracy: 0.9675 - loss: 0.0758 - val_accuracy: 1.0000 - val_loss: 0.0392
Epoch 5/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 131ms/step - accuracy: 0.9930 - loss: 0.0205 - val_accuracy: 1.0000 - val_loss: 0.0304
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 94ms/step - accuracy: 1.0000 - loss: 0.0367
Fold number: 4/5
Found 128 validated image filenames belonging to 2 classes.
Found 32 validated image filenam

  self._warn_if_super_not_called()


[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 197ms/step - accuracy: 0.5460 - loss: 0.8338 - val_accuracy: 0.9688 - val_loss: 0.1697
Epoch 2/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 128ms/step - accuracy: 0.9192 - loss: 0.1929 - val_accuracy: 0.9375 - val_loss: 0.1330
Epoch 3/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 129ms/step - accuracy: 0.9907 - loss: 0.0516 - val_accuracy: 1.0000 - val_loss: 0.0447
Epoch 4/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 128ms/step - accuracy: 1.0000 - loss: 0.0370 - val_accuracy: 1.0000 - val_loss: 0.0101
Epoch 5/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 128ms/step - accuracy: 1.0000 - loss: 0.0071 - val_accuracy: 1.0000 - val_loss: 0.0173
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 94ms/step - accuracy: 1.0000 - loss: 0.0168
Fold number: 5/5
Found 128 validated image filenames belonging to 2 classes.
Found 32 validated image filenam

  self._warn_if_super_not_called()


[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 190ms/step - accuracy: 0.5266 - loss: 0.8379 - val_accuracy: 0.7188 - val_loss: 0.5610
Epoch 2/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 127ms/step - accuracy: 0.9685 - loss: 0.1224 - val_accuracy: 0.7812 - val_loss: 0.3629
Epoch 3/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 128ms/step - accuracy: 0.9534 - loss: 0.1118 - val_accuracy: 0.9062 - val_loss: 0.2848
Epoch 4/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 128ms/step - accuracy: 0.9930 - loss: 0.0271 - val_accuracy: 0.9375 - val_loss: 0.2614
Epoch 5/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 142ms/step - accuracy: 1.0000 - loss: 0.0074 - val_accuracy: 0.9688 - val_loss: 0.1189
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 106ms/step - accuracy: 0.9792 - loss: 0.1011
TRAINING FOR <function ResNet50 at 0x30dab1bc0>
Fold number: 1/5
Found 128 validated image filenames belongi

  self._warn_if_super_not_called()


[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 623ms/step - accuracy: 0.5605 - loss: 1.1773 - val_accuracy: 0.5000 - val_loss: 0.8339
Epoch 2/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 540ms/step - accuracy: 0.4564 - loss: 0.9329 - val_accuracy: 0.5000 - val_loss: 0.9961
Epoch 3/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 560ms/step - accuracy: 0.5074 - loss: 0.8869 - val_accuracy: 0.5000 - val_loss: 0.8930
Epoch 4/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 572ms/step - accuracy: 0.6519 - loss: 0.6679 - val_accuracy: 0.5000 - val_loss: 0.9834
Epoch 5/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 557ms/step - accuracy: 0.4632 - loss: 0.8947 - val_accuracy: 0.5000 - val_loss: 0.7555
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 441ms/step - accuracy: 0.5208 - loss: 0.7372
Fold number: 2/5
Found 128 validated image filenames belonging to 2 classes.
Found 32 validated image filena

  self._warn_if_super_not_called()


[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 641ms/step - accuracy: 0.4941 - loss: 1.1352 - val_accuracy: 0.5000 - val_loss: 0.6965
Epoch 2/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 558ms/step - accuracy: 0.3608 - loss: 1.1008 - val_accuracy: 0.5000 - val_loss: 0.8873
Epoch 3/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 568ms/step - accuracy: 0.5075 - loss: 0.8145 - val_accuracy: 0.5000 - val_loss: 0.7757
Epoch 4/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 569ms/step - accuracy: 0.4894 - loss: 0.7116 - val_accuracy: 0.5000 - val_loss: 0.7108
Epoch 5/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 575ms/step - accuracy: 0.5149 - loss: 0.6804 - val_accuracy: 0.4375 - val_loss: 0.7077
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 447ms/step - accuracy: 0.4167 - loss: 0.7107
Fold number: 3/5
Found 128 validated image filenames belonging to 2 classes.
Found 32 validated image filena

  self._warn_if_super_not_called()


[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 664ms/step - accuracy: 0.5129 - loss: 1.3299 - val_accuracy: 0.5000 - val_loss: 0.6726
Epoch 2/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 565ms/step - accuracy: 0.5240 - loss: 1.0399 - val_accuracy: 0.5000 - val_loss: 0.8822
Epoch 3/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 570ms/step - accuracy: 0.5531 - loss: 0.8728 - val_accuracy: 0.5000 - val_loss: 0.9533
Epoch 4/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 573ms/step - accuracy: 0.4234 - loss: 0.9833 - val_accuracy: 0.5000 - val_loss: 0.7122
Epoch 5/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 594ms/step - accuracy: 0.6534 - loss: 0.6148 - val_accuracy: 0.5625 - val_loss: 0.6916
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 446ms/step - accuracy: 0.5625 - loss: 0.6818
Fold number: 4/5
Found 128 validated image filenames belonging to 2 classes.
Found 32 validated image filena

  self._warn_if_super_not_called()


[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 660ms/step - accuracy: 0.4926 - loss: 1.0881 - val_accuracy: 0.5000 - val_loss: 0.7877
Epoch 2/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 579ms/step - accuracy: 0.5623 - loss: 0.8944 - val_accuracy: 0.5000 - val_loss: 1.1497
Epoch 3/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 580ms/step - accuracy: 0.4741 - loss: 0.9906 - val_accuracy: 0.5000 - val_loss: 0.6855
Epoch 4/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 584ms/step - accuracy: 0.5810 - loss: 0.6902 - val_accuracy: 0.7500 - val_loss: 0.6381
Epoch 5/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 597ms/step - accuracy: 0.5499 - loss: 0.8038 - val_accuracy: 0.5000 - val_loss: 0.8432
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 474ms/step - accuracy: 0.5417 - loss: 0.7821
Fold number: 5/5
Found 128 validated image filenames belonging to 2 classes.
Found 32 validated image filena

  self._warn_if_super_not_called()


[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 687ms/step - accuracy: 0.4636 - loss: 1.2493 - val_accuracy: 0.5000 - val_loss: 0.9965
Epoch 2/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 594ms/step - accuracy: 0.4607 - loss: 0.9064 - val_accuracy: 0.5312 - val_loss: 0.6929
Epoch 3/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 588ms/step - accuracy: 0.6018 - loss: 0.6579 - val_accuracy: 0.5000 - val_loss: 0.6960
Epoch 4/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 586ms/step - accuracy: 0.5742 - loss: 0.6691 - val_accuracy: 0.5000 - val_loss: 0.6864
Epoch 5/5
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 601ms/step - accuracy: 0.5550 - loss: 0.6473 - val_accuracy: 0.5312 - val_loss: 0.7883
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 450ms/step - accuracy: 0.6042 - loss: 0.7179


In [11]:
for score in CNN_scores:
    print(np.mean(score, axis=0))

[0.04338196 0.99375   ]
[0.7572655 0.50625  ]
