# Setup and Mount Google Drive

In [3]:
!pip install torchsummary



NOTES
- look into student teacher learning model

- Look at confusion matrix, the classes are imbalanced causing this issue

# Import Required Libraries

In [7]:
# import all the required libraries, use tensor flow for model - watch some video about tensorflow functions

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from sklearn.metrics import classification_report, confusion_matrix
from IPython.display import display
from PIL import Image

# Set Parameters and Data Preparation


In [9]:

img_width, img_height = 224, 224  # set the images to all same size
batch_size = 32 #dataset is split into batches of 32
epochs = 5 # how many cycles of training


path_to_training_data = "/Users/kdfer/Desktop/ACM Research/GalaxyImages" # dataset


#data augmentation
train_datagen = ImageDataGenerator( #when is this used?
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest',
    validation_split=0.2)


train_generator = train_datagen.flow_from_directory(
    path_to_training_data,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='categorical',
    subset='training')

validation_generator = train_datagen.flow_from_directory(
    path_to_training_data,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='categorical',
    subset='validation')

Found 3568 images belonging to 4 classes.
Found 890 images belonging to 4 classes.


# base Model
# how to implement a basic CNN

In [11]:
def build_model():
  # model from tensorflow
    base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(img_width, img_height, 3))


    base_model.trainable = False
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(1024, activation='relu')(x) # activation function used
    predictions = Dense(4, activation='softmax')(x)
    model = Model(inputs=base_model.input, outputs=predictions)
    model.compile(optimizer=Adam(learning_rate=0.0001), # optimizer
                  loss='categorical_crossentropy', metrics=['accuracy'])
    return model

model = build_model()

# Model Training Before Transfer Learning

In [None]:
history = model.fit(  # this trains the neural net with dataset
    train_generator,
    epochs=epochs,
    validation_data=validation_generator)

initial_test_loss, initial_test_acc = model.evaluate(validation_generator, steps=validation_generator.samples // batch_size)
print(f"Test Accuracy Before Fine-tuning: {initial_test_acc*100:.2f}%")

  self._warn_if_super_not_called()


Epoch 1/5
[1m 67/112[0m [32m━━━━━━━━━━━[0m[37m━━━━━━━━━[0m [1m33s[0m 745ms/step - accuracy: 0.7085 - loss: 0.8566

# Plot Accuracy and Loss Before Fine-Tuning/ Transfer learning

In [None]:
def plot_acc_loss(history):
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    loss = history.history['loss']
    val_loss = history.history['val_loss']

    epochs_range = range(len(acc))

    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, acc, label='Training Accuracy')
    plt.plot(epochs_range, val_acc, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.title('Training and Validation Accuracy')

    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, loss, label='Training Loss')
    plt.plot(epochs_range, val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.title('Training and Validation Loss')
    plt.show()

plot_acc_loss(history)
test_loss, test_acc = model.evaluate(validation_generator, steps=validation_generator.samples // batch_size)
print(f"Test Accuracy After Fine-tuning: {test_acc*100:.2f}%")

# Fine-Tuning and Re-Training

In [None]:
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam


base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))


x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(4, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=predictions)

for layer in base_model.layers:
    layer.trainable = False

model.compile(optimizer=Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])

model.fit(train_generator, epochs=5, validation_data=validation_generator)


# Plot Accuracy and Loss After Transfer Learning

In [None]:
fine_tune_history = model.fit(
    train_generator,
    epochs=5,
    validation_data=validation_generator
)
import matplotlib.pyplot as plt

def plot_acc_loss(training_history):
    acc = training_history.history['accuracy']
    val_acc = training_history.history['val_accuracy']
    loss = training_history.history['loss']
    val_loss = training_history.history['val_loss']
    epochs = range(1, len(acc) + 1)

    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, acc, 'bo', label='Training accuracy')
    plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
    plt.title('Training and Validation Accuracy')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs, loss, 'ro', label='Training loss')
    plt.plot(epochs, val_loss, 'r', label='Validation loss')
    plt.title('Training and Validation Loss')
    plt.legend()

    plt.show()


plot_acc_loss(fine_tune_history)


# Final Evaluation and Display Predictions

In [None]:
test_loss, test_acc = model.evaluate(validation_generator, steps=validation_generator.samples // batch_size)
print(f"Test Accuracy After Fine-tuning: {test_acc*100:.2f}%")

# Display Classification Metrics and Confusion Matrix

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

def display_confusion_matrix(true_classes, predicted_classes, class_labels):
    """
    This function computes and displays a confusion matrix.
    """
    cm = confusion_matrix(true_classes, predicted_classes)
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_labels, yticklabels=class_labels)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.show()

# Get model predictions
predictions = model.predict(validation_generator)

# Convert probabilities to binary class predictions (0 or 1)
predicted_classes = (predictions > 0.5).astype(int).flatten()

# Get true labels (already 0 or 1)
true_classes = validation_generator.classes.flatten()

# Define class labels (assuming binary classification)
class_labels = list(validation_generator.class_indices.keys())

# Display the confusion matrix
display_confusion_matrix(true_classes, predicted_classes, class_labels)

In [None]:
def display_predictions(model, generator, num_images=20):
    """
    This function fetches a batch of images, predicts using the model,
    and displays the images with actual and predicted labels.
    """
    generator.reset()  # Resetting the generator to avoid shuffling issues
    x, y_true = next(generator)
    predictions = model.predict(x)

    # Convert probabilities to binary predictions (0 or 1)
    predicted_classes = (predictions > 0.5).astype(int).flatten()
    true_classes = y_true.flatten().astype(int)

    # Get class labels from generator
    class_labels = list(generator.class_indices.keys())

    plt.figure(figsize=(15, 10))
    for i in range(min(num_images, len(x))):  # Ensure we don't exceed batch size
        plt.subplot(5, 4, i + 1)  # Adjust subplot grid for the number of images
        plt.imshow(x[i])
        plt.title(f'Actual: {class_labels[true_classes[i]]}\nPredicted: {class_labels[predicted_classes[i]]}')
        plt.axis('off')

    plt.tight_layout()
    plt.show()

display_predictions(model, validation_generator, num_images=20)