In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

#MOST OF THIS HAS ALREADY BEEN DEFINED AND EXPLAINED IN THE 1-Preprocessing file, but I am redoing it here for simplicity
#define the animal supergroups to use for training and testing
animal_supergroups = [
    'aquatic_mammals', 'fish', 'insects', 'large_carnivores', 
    'large_omnivores_and_herbivores', 'medium_mammals', 'non-insect_invertebrates', 
    'reptiles', 'small_mammals'
]

#helper function to load CIFAR-100 data
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        data = pickle.load(fo, encoding='bytes')
    return data

#load train and meta data
train_data = unpickle('Data/train')
meta_data = unpickle('Data/meta')

#decode superclass names and find indices of animal supergroups
coarse_label_names = [label.decode('utf-8') for label in meta_data[b'coarse_label_names']]
animal_indices = [i for i, label in enumerate(coarse_label_names) if label in animal_supergroups]

#this is a super function that does a lot of the preprocessing i defined in the last page 
def preprocess_data(data, animal_indices):
    images = []
    labels = []
    for i in range(len(data[b'coarse_labels'])):
        if data[b'coarse_labels'][i] in animal_indices:
            images.append(data[b'data'][i])
            labels.append(data[b'coarse_labels'][i])
    #convert to numpy arrays
    images = np.array(images)
    labels = np.array(labels)
    #reshape and normalize image data
    images = images.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1).astype('float32') / 255.0
    #map labels to categorical indices
    labels = np.array([animal_indices.index(label) for label in labels])
    labels = to_categorical(labels, num_classes=len(animal_supergroups))
    return images, labels

#preprocess train data
filtered_images, filtered_labels = preprocess_data(train_data, animal_indices)
X_train, X_val, y_train, y_val = train_test_split(filtered_images, filtered_labels, test_size=0.2, random_state=42)

#define the CNN model
model = Sequential([
    Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same', input_shape=(32, 32, 3)),
    BatchNormalization(),
    Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same',),
    BatchNormalization(),
    MaxPooling2D((2, 2)),
    Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same',),
    BatchNormalization(),
    MaxPooling2D((2, 2)),
    Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same',),
    BatchNormalization(),
    MaxPooling2D((2, 2)),
    Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same',),
    BatchNormalization(),
    MaxPooling2D((2, 2)),
    Flatten(),
    Dense(128, activation='relu', kernel_initializer='he_uniform'),
    BatchNormalization(),
    Dropout(0.5),
    Dense(len(animal_supergroups), activation='softmax')
])

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
from tensorflow.keras.preprocessing.image import ImageDataGenerator

#data augmentation, this helps reduce overfitting by "recreating" the same image in several distorted ways, such as shifting it up, down, flipping it and so on. This helps fight againt memorization by the model. 
datagen = ImageDataGenerator(
    rotation_range=15,          
    width_shift_range=0.1,      
    height_shift_range=0.1,     
    horizontal_flip=True        
)

datagen.fit(X_train)

#fit the model
history = model.fit(
    datagen.flow(X_train, y_train, batch_size=64),  
    validation_data=(X_val, y_val),               
    epochs=100
)

#plot training history
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.show()

#load and preprocess test data
test_data = unpickle('Data/test')
test_images, test_labels = preprocess_data(test_data, animal_indices)

#evaluate the model on test data
test_loss, test_accuracy = model.evaluate(test_images, test_labels, verbose=2)
print(f"Test Loss: {test_loss}")
print(f"Test Accuracy: {test_accuracy}")