In [None]:
"""
Galaxy Morphology Classification Project (Colab-Friendly)

**Version 4: Better Training Parameters**
This version
- Increases training time (EPOCHS = 30)
- Decreases the initial LEARNING_RATE (1e-4)
- Decreases the FINE_TUNE_LR (1e-6)
This should allow the model to learn properly.
"""

import os
import zipfile
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

# --- 0. Configuration & Constants ---
# PLEASE UPLOAD 'training_solutions_rev1.csv' and 'images_training_rev1.zip'
# to the main Colab directory BEFORE running this script.

print("--- Running Colab Setup ---")

# Define file paths
# In Colab, the root directory is /content/
CSV_UPLOAD_PATH = 'training_solutions_rev1.csv'
ZIP_UPLOAD_PATH = 'images_training_rev1.zip'

DATA_DIR = 'data'
CSV_PATH = os.path.join(DATA_DIR, 'training_solutions_rev1.csv')
IMAGE_DIR = os.path.join(DATA_DIR, 'images_training_rev1')

# Model and training parameters
IMG_SIZE = 224
BATCH_SIZE = 32
SHUFFLE_BUFFER = 1000

# --- PARAMETER CHANGES FOR V4 ---
EPOCHS = 30             # MORE PATIENCE: Was 10
FINE_TUNE_EPOCHS = 15   # MORE PATIENCE: Was 5
LEARNING_RATE = 1e-4    # MORE CAREFUL: Was 0.001
FINE_TUNE_LR = 1e-6     # MORE CAREFUL: Was 1e-5
# --- END OF CHANGES ---


# --- 1. Colab File Setup ---
# Create the 'data' directory if it doesn't exist
os.makedirs(DATA_DIR, exist_ok=True)
print(f"Directory '{DATA_DIR}' ensured.")

# 1a. Move the CSV file
if os.path.exists(CSV_UPLOAD_PATH):
    # Only move if it's not already in the target location
    if not os.path.exists(CSV_PATH):
        os.rename(CSV_UPLOAD_PATH, CSV_PATH)
        print(f"Moved '{CSV_UPLOAD_PATH}' to '{CSV_PATH}'")
    else:
        print(f"'{CSV_UPLOAD_PATH}' already in root, but '{CSV_PATH}' also exists. Using existing.")
elif not os.path.exists(CSV_PATH):
    print(f"ERROR: '{CSV_UPLOAD_PATH}' not found in the root directory.")
    print("Please upload 'training_solutions_rev1.csv' and try again.")
    # Use exit() to stop the script if files are missing
    exit()
else:
    print(f"CSV file already in place at '{CSV_PATH}'.")


# 1b. Unzip the images
# We only unzip if the target directory doesn't already exist or is empty
if os.path.exists(ZIP_UPLOAD_PATH):
    if not os.path.exists(IMAGE_DIR) or not os.listdir(IMAGE_DIR):
        print(f"Unzipping '{ZIP_UPLOAD_PATH}' to '{DATA_DIR}'...")
        print("This will take several minutes. Please wait.")
        with zipfile.ZipFile(ZIP_UPLOAD_PATH, 'r') as zip_ref:
            zip_ref.extractall(DATA_DIR)
        print("Unzipping complete!")
        # Optional: Remove the zip file after to save space
        # os.remove(ZIP_UPLOAD_PATH)
        print(f"Successfully unzipped images to '{IMAGE_DIR}'")
    else:
        print(f"Image directory '{IMAGE_DIR}' already exists and is not empty. Skipping unzip.")
elif not os.path.exists(IMAGE_DIR) or not os.listdir(IMAGE_DIR):
    print(f"ERROR: '{ZIP_UPLOAD_PATH}' not found and '{IMAGE_DIR}' is also missing/empty.")
    print("Please upload 'images_training_rev1.zip' and try again.")
    exit()
else:
     print(f"Image directory '{IMAGE_DIR}' already exists and is populated. Skipping unzip.")

print("--- Colab Setup Complete ---")


# --- 2. Load and Prepare Label Data ---

print("\nStep 2: Loading and preparing label data...")

# Load the CSV file
try:
    labels_df = pd.read_csv(CSV_PATH)
except FileNotFoundError:
    print(f"CRITICAL ERROR: CSV file not found at {CSV_PATH} even after setup.")
    print("Please check the 'data' folder in the Colab file browser.")
    exit()

# Format the 'GalaxyID' to match the image filenames
labels_df['GalaxyID'] = labels_df['GalaxyID'].astype(str) + '.jpg'

# Select the target columns
target_columns = ['Class1.1', 'Class1.2', 'Class1.3']
labels_df = labels_df[['GalaxyID'] + target_columns]

print(f"Loaded {len(labels_df)} labels.")
print(labels_df.head())


# --- 3. Create File Paths and Split the Data ---

print("\nStep 3: Creating file paths and splitting data...")

print("Verifying image files exist...")
labels_df['filepath'] = labels_df['GalaxyID'].apply(
    lambda x: os.path.join(IMAGE_DIR, x)
)
labels_df['file_exists'] = labels_df['filepath'].apply(lambda x: os.path.isfile(x))

initial_count = len(labels_df)
labels_df = labels_df[labels_df['file_exists']].copy()
final_count = len(labels_df)

if final_count == 0:
    print(f"ERROR: No image files were found in '{IMAGE_DIR}'.")
    exit()

print(f"Verified files: {final_count} images found (filtered out {initial_count - final_count} missing images).")


# Split the data
train_df, val_df = train_test_split(
    labels_df,
    test_size=0.1,
    random_state=42
)
val_df, test_df = train_test_split(
    val_df,
    test_size=0.5,
    random_state=42
)

print(f"Training samples: {len(train_df)}")
print(f"Validation samples: {len(val_df)}")
print(f"Test samples: {len(test_df)}")


# --- 4. Create tf.data Input Pipeline ---

print("\nStep 4: Building tf.data input pipeline...")

def load_and_preprocess_image(filepath, label):
    """
    Loads and resizes the image.
    CRITICAL FIX: EfficientNet expects pixels in range [0, 255].
    We do NOT divide by 255.0 here.
    """
    # Read the file from disk
    img = tf.io.read_file(filepath)
    # Decode as a 3-channel (RGB) JPEG
    img = tf.image.decode_jpeg(img, channels=3)
    # Resize to the model's expected input size
    img = tf.image.resize(img, [IMG_SIZE, IMG_SIZE])

    # --- DELETED THE DIVISION LINE ---
    # img = img / 255.0  <-- This was the killer. It's gone now.

    return img, label

def create_dataset(df):
    if df.empty:
        return None
    filepaths = df['filepath'].values
    labels = df[target_columns].values

    ds = tf.data.Dataset.from_tensor_slices((filepaths, labels))
    ds = ds.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
    return ds

# Create the training dataset
train_ds = create_dataset(train_df)
if train_ds:
    train_ds = (
        train_ds
        .shuffle(buffer_size=SHUFFLE_BUFFER)
        .batch(BATCH_SIZE)
        .prefetch(buffer_size=tf.data.AUTOTUNE)
    )

# Create the validation dataset
val_ds = create_dataset(val_df)
if val_ds:
    val_ds = (
        val_ds
        .batch(BATCH_SIZE)
        .prefetch(buffer_size=tf.data.AUTOTUNE)
    )

# Create the test dataset
test_ds = create_dataset(test_df)
if test_ds:
    test_ds = (
        test_ds
        .batch(BATCH_SIZE)
        .prefetch(buffer_size=tf.data.AUTOTUNE)
    )

if not train_ds or not val_ds or not test_ds:
    print("ERROR: One or more datasets are empty. Cannot proceed.")
    exit()

print("Data pipelines built (train, validation, and test).")


# --- 5. Build the Model (Transfer Learning) ---

print("\nStep 5: Building the transfer learning model...")

IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)

# Load the base model
base_model = tf.keras.applications.EfficientNetB0(
    input_shape=IMG_SHAPE,
    include_top=False,
    weights='imagenet'
)

# Freeze the base model
base_model.trainable = False

# Build our new "head"
model = tf.keras.Sequential([
    tf.keras.Input(shape=IMG_SHAPE),
    # Add data augmentation layers
    layers.RandomFlip("horizontal_and_vertical"),
    layers.RandomRotation(0.2),

    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dropout(0.3),
    layers.Dense(128, activation='relu'),
    layers.Dense(3, activation='softmax') # 3 output neurons
])

model.summary()


# --- 6. Compile and Train the Model ---

print("\nStep 6: Compiling and starting initial training...")

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Train the model (only the new "head" layers)
history = model.fit(
    train_ds,
    epochs=EPOCHS,
    validation_data=val_ds
)

print("Initial training complete.")


# --- 7. Evaluate Initial Training ---

print("\nStep 7: Plotting initial training history...")

# Save plots
try:
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    loss = history.history['loss']
    val_loss = history.history['val_loss']

    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.plot(acc, label='Training Accuracy')
    plt.plot(val_acc, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.title('Training and Validation Accuracy')

    plt.subplot(1, 2, 2)
    plt.plot(loss, label='Training Loss')
    plt.plot(val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.title('Training and Validation Loss')
    plt.savefig('initial_training_history.png')
    print("Saved initial training plot to 'initial_training_history.png'")
    plt.show()

except Exception as e:
    print(f"Error plotting history: {e}")


# --- 8. Fine-Tuning the Model ---

print("\nStep 8: Unfreezing base model for fine-tuning...")

base_model.trainable = True

# Unfreeze the top 20 layers
print(f"Total layers in base model: {len(base_model.layers)}")
fine_tune_at = -20

for layer in base_model.layers[:fine_tune_at]:
    layer.trainable = False

# We MUST re-compile the model
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=FINE_TUNE_LR),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

model.summary()

print("Starting fine-tuning...")

# Continue training from where we left off
total_epochs = EPOCHS + FINE_TUNE_EPOCHS
initial_epoch_num = EPOCHS

fine_tune_history = model.fit(
    train_ds,
    epochs=total_epochs,
    initial_epoch=initial_epoch_num,
    validation_data=val_ds
)

print("Fine-tuning complete.")


# --- 9. Final Evaluation and Prediction ---

print("\nStep 9: Plotting combined training history and final evaluation...")

# Combine the initial and fine-tuning history
try:
    acc += fine_tune_history.history['accuracy']
    val_acc += fine_tune_history.history['val_accuracy']
    loss += fine_tune_history.history['loss']
    val_loss += fine_tune_history.history['val_loss']

    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.plot(acc, label='Training Accuracy')
    plt.plot(val_acc, label='Validation Accuracy')
    plt.ylim([min(plt.ylim()), 1])
    plt.plot([EPOCHS-1, EPOCHS-1], plt.ylim(), label='Start Fine-Tuning', linestyle='--')
    plt.legend(loc='lower right')
    plt.title('Combined Training and Validation Accuracy')

    plt.subplot(1, 2, 2)
    plt.plot(loss, label='Training Loss')
    plt.plot(val_loss, label='Validation Loss')
    plt.plot([EPOCHS-1, EPOCHS-1], plt.ylim(), label='Start Fine-Tuning', linestyle='--')
    plt.legend(loc='upper right')
    plt.title('Combined Training and Validation Loss')
    plt.savefig('combined_training_history.png')
    print("Saved combined training plot to 'combined_training_history.png'")
    plt.show()

except Exception as e:
    print(f"Error plotting combined history: {e}")


# Evaluate the model on the TEST set
print("\nEvaluating model on the test set...")
test_loss, test_accuracy = model.evaluate(test_ds)
print(f"Test Loss: {test_loss}")
print(f"Test Accuracy: {test_accuracy}")


print("\n--- Model Demonstration ---")
print("Running predictions on a batch of test data...")

# Get a batch of images and labels from the test set
image_batch, label_batch = next(iter(test_ds))

# Make predictions
predictions = model.predict(image_batch)

# Show the first 5 examples
for i in range(5):
    plt.figure(figsize=(6, 3))
    plt.imshow(image_batch[i])

    true_label = label_batch[i].numpy()
    pred_label = predictions[i]

    true_class_index = np.argmax(true_label)
    pred_class_index = np.argmax(pred_label)

    title = f"True: {target_columns[true_class_index]} ({true_label[true_class_index]:.2f})\n"
    title += f"Pred: {target_columns[pred_class_index]} ({pred_label[pred_class_index]:.2f})"

    plt.title(title)
    plt.axis('off')
    plt.savefig(f'prediction_example_{i}.png')
    plt.show()

    print(f"\nExample {i}:")
    print(f"  File: {test_df.iloc[i]['GalaxyID']}")
    print(f"  True Label (Smooth, Disk, Star): {np.around(true_label, 2)}")
    print(f"  Pred Label (Smooth, Disk, Star): {np.around(pred_label, 2)}")

print("\n--- Project Complete ---")