In [10]:
import pandas as pd
import os, shutil
import time
import tensorflow as tf
import matplotlib as plt
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam, Nadam, SGD
from tensorflow.keras import regularizers
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten, BatchNormalization
from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from tensorflow.keras import losses

# Define constants
img_size = 64
BATCH_SIZE = 64
NUM_CLASSES = 200
EPOCHS = 10 # Increase the number of epochs for better convergence

# Data augmentation
train_datagen = ImageDataGenerator(
    rescale=1./255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    rotation_range=30,
    width_shift_range=0.1,
    height_shift_range=0.1,
    brightness_range=[0.8, 1.2],
    fill_mode='nearest'
)

test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    'tiny-imagenet-200/train',
    target_size=(img_size, img_size),
    batch_size=BATCH_SIZE,
    class_mode='categorical'
)

val_datagen = ImageDataGenerator(
    rescale=1./255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    rotation_range=30,
    width_shift_range=0.1,
    height_shift_range=0.1,
    brightness_range=[0.8, 1.2],
    fill_mode='nearest'
)

validation_generator = val_datagen.flow_from_directory(
    'tiny-imagenet-200/val',
    target_size=(img_size, img_size),
    batch_size=BATCH_SIZE,
    class_mode='categorical'
)

Found 100001 images belonging to 200 classes.
Found 9950 images belonging to 200 classes.


In [11]:
resnet50_model = tf.keras.applications.ResNet50(
    weights='imagenet', 
    include_top=False, 
    input_shape=(img_size, img_size, 3)
)

In [12]:
i = resnet50_model.output
i = layers.GlobalAveragePooling2D()(i)
i = layers.Dense(1024, activation='relu')(i)
i = layers.Dropout(0.5)(i)
predictions = layers.Dense(NUM_CLASSES, activation='softmax')(i)

M5 = models.Model(inputs = resnet50_model.input, outputs=predictions)

In [6]:
M5.compile(optimizer=Adam(learning_rate=0.0001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])
print(M5.summary())


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 64, 64, 3)]          0         []                            
                                                                                                  
 conv1_pad (ZeroPadding2D)   (None, 70, 70, 3)            0         ['input_1[0][0]']             
                                                                                                  
 conv1_conv (Conv2D)         (None, 32, 32, 64)           9472      ['conv1_pad[0][0]']           
                                                                                                  
 conv1_bn (BatchNormalizati  (None, 32, 32, 64)           256       ['conv1_conv[0][0]']          
 on)                                                                                          

In [8]:
# Train the model
with tf.device('/device:GPU:0'):
    history = M5.fit(train_generator,
                        epochs=15,
                        validation_data=validation_generator,
)

Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15


In [None]:
import matplotlib.pyplot as plt

# Train the model
history = M5.fit(
    train_generator,
    validation_data=validation_generator,
    epochs=15,
)

# Plot training and validation accuracy
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.show()
