# Hair Type Classification – Clean Local Notebook

Hybrid version rebuilt from the original Colab notebook.

This notebook trains a convolutional neural network (MobileNetV2 backbone) to classify hair types
from images stored in a local folder structure.

Expected folder layout (relative to this notebook):

```text
Hairdata/
    Straight/
    Wavy/
    curly/
    kinky/
    dreadlocks/
```

Each subfolder should contain images for that class.


## 1. Imports and dataset configuration

In [None]:
import os

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.metrics import classification_report, confusion_matrix

# Base directory where the dataset is stored (relative path)
DATA_DIR = os.path.join(os.getcwd(), "Hairdata")

print("Dataset directory:", DATA_DIR)
print("Classes available:", os.listdir(DATA_DIR))


## 2. Load images into TensorFlow datasets

In [None]:
img_height = 224
img_width = 224
batch_size = 32
seed = 42

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    DATA_DIR,
    validation_split=0.2,  # 80% training, 20% validation
    subset="training",
    seed=seed,
    image_size=(img_height, img_width),
    batch_size=batch_size,
)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    DATA_DIR,
    validation_split=0.2,
    subset="validation",
    seed=seed,
    image_size=(img_height, img_width),
    batch_size=batch_size,
)

class_names = train_ds.class_names
num_classes = len(class_names)

print("Detected classes:", class_names)


### Optional: diagnose problematic image files

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

# Cache, shuffle and prefetch for performance
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

print("Starting TensorFlow-based image diagnostic...")
problematic_files_tf = []

def can_decode_tf(filepath):
    """Return True if TensorFlow can decode the image, False otherwise."""
    try:
        img_bytes = tf.io.read_file(filepath)
        tf.image.decode_image(img_bytes, channels=3, expand_animations=False)
        return True
    except (tf.errors.InvalidArgumentError, tf.errors.OpError) as e:
        print(f"TensorFlow decoding error for {filepath}: {e}")
        return False
    except Exception as e:
        print(f"Unexpected error during TensorFlow decoding for {filepath}: {e}")
        return False

for class_name in class_names:
    class_path = os.path.join(DATA_DIR, class_name)
    if os.path.isdir(class_path):
        for filename in os.listdir(class_path):
            file_path = os.path.join(class_path, filename)
            if os.path.isfile(file_path) and not can_decode_tf(file_path):
                problematic_files_tf.append(file_path)
    else:
        print(f"Warning: directory not found: {class_path}")

if problematic_files_tf:
    print("\nProblematic files detected (TensorFlow could not decode these):")
    for p in problematic_files_tf:
        print(" -", p)
    print("\nPlease remove or fix these files and re-run the dataset loading step.")
else:
    print("\nAll files appear to be valid images for TensorFlow.")


### Visualize a few training images

In [None]:
plt.figure(figsize=(8, 8))
for images, labels in train_ds.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")
plt.show()


## 3. Define the CNN model (MobileNetV2 backbone)

In [None]:
data_augmentation = keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.05),
        layers.RandomZoom(0.1),
    ]
)

base_model = keras.applications.MobileNetV2(
    input_shape=(img_height, img_width, 3),
    include_top=False,
    weights="imagenet",
)

# Freeze backbone initially
base_model.trainable = False

inputs = keras.Input(shape=(img_height, img_width, 3))
x = data_augmentation(inputs)
x = keras.applications.mobilenet_v2.preprocess_input(x)
x = base_model(x, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.2)(x)
outputs = layers.Dense(num_classes, activation="softmax")(x)

model = keras.Model(inputs, outputs)

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)

model.summary()


## 4. Train the model

In [None]:
epochs = 10

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs,
)


## 5. Detailed evaluation (classification report + confusion matrix)

In [None]:
y_true = []
y_pred = []

for images, labels in val_ds:
    preds = model.predict(images, verbose=0)
    y_true.extend(labels.numpy())
    y_pred.extend(np.argmax(preds, axis=1))

y_true = np.array(y_true)
y_pred = np.array(y_pred)

print("Classification Report:")
print(classification_report(y_true, y_pred, target_names=class_names))

cm = confusion_matrix(y_true, y_pred)
print("Confusion Matrix:\n", cm)

plt.figure(figsize=(6, 5))
sns.heatmap(
    cm,
    annot=True,
    fmt="d",
    xticklabels=class_names,
    yticklabels=class_names,
)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix - Hair Type Classifier")
plt.show()


## 6. Save the trained model locally

In [None]:
MODEL_PATH = os.path.join(os.getcwd(), "hair_type_model.keras")
model.save(MODEL_PATH)
print("Model saved at:", MODEL_PATH)


## 7. Predict hair type for a single image

In [None]:
from PIL import Image

def predict_image(path):
    img = Image.open(path).convert("RGB")
    img_resized = img.resize((img_width, img_height))
    img_array = np.array(img_resized)
    img_array = np.expand_dims(img_array, axis=0)  # shape (1, h, w, 3)

    predictions = model.predict(img_array)
    pred_idx = np.argmax(predictions[0])
    pred_class = class_names[pred_idx]
    confidence = float(predictions[0][pred_idx])

    plt.imshow(img)
    plt.axis("off")
    plt.title(f"Predicted: {pred_class} ({confidence:.4f})")
    plt.show()

    return pred_class, confidence

# Example usage (uncomment and set an image path):
# predict_image("./example_hair.jpg")


## 8. (Optional) Save additional information as pickle files

In [None]:
import pickle

PKL_DIR = os.path.join(os.getcwd(), "pkl")
os.makedirs(PKL_DIR, exist_ok=True)
print("PKL directory:", PKL_DIR)

# Save class names
class_names_path = os.path.join(PKL_DIR, "hair_class_names.pkl")
with open(class_names_path, "wb") as f:
    pickle.dump(class_names, f)
print("Saved:", class_names_path)

# Save classification report as dict for later analysis
report_dict = classification_report(
    y_true,
    y_pred,
    target_names=class_names,
    output_dict=True,
)

report_path = os.path.join(PKL_DIR, "classification_report.pkl")
with open(report_path, "wb") as f:
    pickle.dump(report_dict, f)
print("Saved:", report_path)
