# Wood Finish Classification using LAB Color Space

This notebook builds a classification model for wood finishes (medium cherry, desert oak, and graphite walnut) by preprocessing images to LAB color space for improved accuracy.

## 1. Setup and Imports

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
import random
from pathlib import Path

# For image processing
import cv2
from PIL import Image

# For model building
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, callbacks
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical

# For evaluation
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

# Set random seeds for reproducibility
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
tf.random.set_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

# Check if GPU is available
print("TensorFlow version:", tf.__version__)
print("GPU Available: ", len(tf.config.list_physical_devices('GPU')) > 0)

## 2. Data Loading and LAB Preprocessing

### 2.1 Configure Data Paths

In [None]:
# Define constants
IMG_SIZE = 224  # Input size expected by most pretrained models
BATCH_SIZE = 32
NUM_CLASSES = 3
CLASS_NAMES = ["medium_cherry", "desert_oak", "graphite_walnut"]

# Set your data directory path here
BASE_DATA_DIR = Path("./wood_finishes_dataset")  # Update this to your dataset path

# Make sure the paths exist
for class_name in CLASS_NAMES:
    class_path = BASE_DATA_DIR / class_name
    if not class_path.exists():
        print(f"Warning: Directory {class_path} does not exist")
    else:
        num_images = len(list(class_path.glob('*.jpg')) + list(class_path.glob('*.png')))
        print(f"Found {num_images} images in {class_name}")

### 2.2 Load and Convert Images to LAB Color Space

In [None]:
def load_and_convert_to_lab(base_dir, class_names, img_size):
    """
    Load images from directories and convert to LAB color space
    
    Args:
        base_dir: Base directory containing class folders
        class_names: List of class folder names
        img_size: Target size for images (img_size x img_size)
        
    Returns:
        images: Numpy array of LAB images
        labels: Numpy array of integer labels
    """
    images = []
    labels = []
    base_dir = Path(base_dir)
    
    for i, class_name in enumerate(class_names):
        class_dir = base_dir / class_name
        print(f"Processing class {i}: {class_name}")
        
        # Check if directory exists
        if not class_dir.exists():
            print(f"Warning: Directory {class_dir} not found. Skipping.")
            continue
        
        # Get all image files
        image_files = list(class_dir.glob('*.jpg')) + list(class_dir.glob('*.jpeg')) + \
                      list(class_dir.glob('*.png')) + list(class_dir.glob('*.bmp'))
        
        print(f"Found {len(image_files)} images in {class_name}")
        
        for img_path in tqdm(image_files, desc=class_name):
            try:
                # Read image using OpenCV
                img = cv2.imread(str(img_path))
                if img is None:
                    print(f"Warning: Could not read {img_path}. Skipping.")
                    continue
                
                # Convert BGR to RGB (OpenCV loads as BGR)
                img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                
                # Resize to target dimensions
                img_resized = cv2.resize(img_rgb, (img_size, img_size))
                
                # Convert RGB to LAB
                img_lab = cv2.cvtColor(img_resized, cv2.COLOR_RGB2LAB)
                
                # Normalize LAB values to 0-1 range for the model
                # L channel is in range [0, 100], a and b channels are in range [-127, 127]
                l_channel = img_lab[:,:,0] / 100.0  # Normalize L to [0, 1]
                a_channel = (img_lab[:,:,1] + 127) / 255.0  # Normalize a to [0, 1]
                b_channel = (img_lab[:,:,2] + 127) / 255.0  # Normalize b to [0, 1]
                
                # Stack normalized channels
                normalized_lab = np.stack([l_channel, a_channel, b_channel], axis=-1)
                
                images.append(normalized_lab)
                labels.append(i)
                
            except Exception as e:
                print(f"Error processing {img_path}: {e}")
    
    if len(images) == 0:
        raise ValueError("No images were loaded. Please check the directory structure and image formats.")
    
    return np.array(images), np.array(labels)

# Load all images and convert to LAB
print("Loading images and converting to LAB color space...")
all_images, all_labels = load_and_convert_to_lab(BASE_DATA_DIR, CLASS_NAMES, IMG_SIZE)
print(f"Loaded {len(all_images)} total images with shape {all_images.shape}")

### 2.3 Visualize LAB Color Space

Let's visualize the LAB color space to understand what our model will be working with:

In [None]:
def visualize_lab_channels(images, labels, class_names, num_samples=2):
    """
    Visualize LAB channels for sample images from each class
    """
    plt.figure(figsize=(15, 5 * len(class_names)))
    
    for i, class_name in enumerate(class_names):
        # Find indices of images from this class
        indices = np.where(labels == i)[0]
        
        # Randomly sample some images
        if len(indices) > 0:
            sample_indices = np.random.choice(indices, 
                                             size=min(num_samples, len(indices)), 
                                             replace=False)
            
            for j, idx in enumerate(sample_indices):
                lab_img = images[idx]
                
                # Extract LAB channels
                l_channel = lab_img[:,:,0]  # Already normalized to [0, 1]
                a_channel = lab_img[:,:,1]  # Already normalized to [0, 1]
                b_channel = lab_img[:,:,2]  # Already normalized to [0, 1]
                
                # Convert back to RGB for display
                # First denormalize
                l_denorm = l_channel * 100
                a_denorm = a_channel * 255 - 127
                b_denorm = b_channel * 255 - 127
                # Reconstruct LAB image (OpenCV format)
                lab_denorm = np.stack([l_denorm, a_denorm, b_denorm], axis=-1).astype(np.uint8)
                # Convert back to RGB
                rgb_img = cv2.cvtColor(lab_denorm, cv2.COLOR_LAB2RGB)
                
                # Plot images
                row_base = i * num_samples + j
                
                # Original (converted back to RGB)
                plt.subplot(len(class_names) * num_samples, 4, row_base * 4 + 1)
                plt.imshow(rgb_img)
                plt.title(f"{class_name} (RGB View)")
                plt.axis('off')
                
                # L channel
                plt.subplot(len(class_names) * num_samples, 4, row_base * 4 + 2)
                plt.imshow(l_channel, cmap='gray')
                plt.title("L Channel (Lightness)")
                plt.axis('off')
                
                # A channel
                plt.subplot(len(class_names) * num_samples, 4, row_base * 4 + 3)
                plt.imshow(a_channel, cmap='RdYlGn')
                plt.title("A Channel (Green-Red)")
                plt.axis('off')
                
                # B channel
                plt.subplot(len(class_names) * num_samples, 4, row_base * 4 + 4)
                plt.imshow(b_channel, cmap='coolwarm')
                plt.title("B Channel (Blue-Yellow)")
                plt.axis('off')
    
    plt.tight_layout()
    plt.suptitle("LAB Color Space Visualization by Wood Type", y=1.02, fontsize=16)
    plt.show()

# Visualize LAB channels
visualize_lab_channels(all_images, all_labels, CLASS_NAMES)

# Calculate and visualize average LAB values per class
plt.figure(figsize=(12, 5))

# Calculate average channel values per class
avg_values = []
for i, class_name in enumerate(CLASS_NAMES):
    class_indices = np.where(all_labels == i)[0]
    class_images = all_images[class_indices]
    
    # Calculate average per channel
    avg_l = np.mean(class_images[:,:,:,0])
    avg_a = np.mean(class_images[:,:,:,1])
    avg_b = np.mean(class_images[:,:,:,2])
    
    avg_values.append([avg_l, avg_a, avg_b])

# Convert to numpy array for easier plotting
avg_values = np.array(avg_values)

# Plot average values for each channel
channels = ['L (Lightness)', 'A (Green-Red)', 'B (Blue-Yellow)']
for c in range(3):
    plt.subplot(1, 3, c+1)
    plt.bar(CLASS_NAMES, avg_values[:,c])
    plt.title(f"Average {channels[c]} Value")
    plt.ylim(0, 1)
    plt.xticks(rotation=45)

plt.tight_layout()
plt.show()

### 2.4 Split Data into Training, Validation, and Test Sets

In [None]:
# Split into train, validation, and test sets (70%/15%/15% split)
train_images, test_images, train_labels, test_labels = train_test_split(
    all_images, all_labels, test_size=0.3, random_state=RANDOM_SEED, stratify=all_labels
)

val_images, test_images, val_labels, test_labels = train_test_split(
    test_images, test_labels, test_size=0.5, random_state=RANDOM_SEED, stratify=test_labels
)

print(f"Training set: {train_images.shape[0]} images")
print(f"Validation set: {val_images.shape[0]} images")
print(f"Testing set: {test_images.shape[0]} images")

# Convert labels to categorical (one-hot encoding)
train_labels_cat = to_categorical(train_labels, NUM_CLASSES)
val_labels_cat = to_categorical(val_labels, NUM_CLASSES)
test_labels_cat = to_categorical(test_labels, NUM_CLASSES)

# Function to display class distribution
def show_class_distribution(labels, title="Class Distribution"):
    unique, counts = np.unique(labels, return_counts=True)
    distribution = dict(zip([CLASS_NAMES[int(i)] for i in unique], counts))
    
    plt.figure(figsize=(10, 6))
    plt.bar(distribution.keys(), distribution.values(), color='skyblue')
    plt.title(title)
    plt.ylabel('Number of Images')
    plt.xticks(rotation=45)
    
    # Add counts on top of bars
    for i, count in enumerate(distribution.values()):
        plt.text(i, count + 5, str(count), ha='center')
        
    plt.tight_layout()
    plt.show()
    
    # Print percentages
    print(f"\n{title}:")
    total = sum(distribution.values())
    for class_name, count in distribution.items():
        print(f"  {class_name}: {count} images ({count/total*100:.1f}%)")

# Display class distributions
show_class_distribution(train_labels, "Training Set Class Distribution")
show_class_distribution(val_labels, "Validation Set Class Distribution")
show_class_distribution(test_labels, "Test Set Class Distribution")

### 2.5 Data Augmentation

In [None]:
# Set up data augmentation for training
data_augmentation = ImageDataGenerator(
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    brightness_range=[0.9, 1.1],  # Be careful with brightness in LAB space
    zoom_range=0.1,
    fill_mode='nearest'
)

# Visualize augmented LAB images
def show_augmented_lab_images(original_image):
    plt.figure(figsize=(12, 6))
    
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        if i == 0:
            sample_image = original_image[0:1]  # Select first image and keep dimensions
            plt.title('Original')
        else:
            # Create augmented image
            batch = data_augmentation.flow(original_image[0:1], batch_size=1)
            augmented_images = batch.next()
            sample_image = augmented_images[0]
            plt.title(f'Augmented #{i}')

        # Convert LAB back to RGB for display
        img = sample_image.copy()
        # Denormalize
        l = img[:,:,0] * 100
        a = img[:,:,1] * 255 - 127
        b = img[:,:,2] * 255 - 127
        # Reconstruct LAB image (OpenCV format)
        lab_img_cv = np.stack([l, a, b], axis=-1).astype(np.uint8)
        # Convert back to RGB for display
        rgb_img = cv2.cvtColor(lab_img_cv, cv2.COLOR_LAB2RGB)
        plt.imshow(rgb_img)
        plt.axis('off')
        
    plt.tight_layout()
    plt.show()

# Select a sample image (first image from training set)
sample_img = train_images[0:1]

# Visualize augmentation
print("LAB Image Augmentation Examples:")
show_augmented_lab_images(sample_img)

## 3. Model Development and Training

### 3.1 Building the Model

In [None]:
def build_model(input_shape=(IMG_SIZE, IMG_SIZE, 3), num_classes=NUM_CLASSES):
    """
    Build a wood classification model using transfer learning with MobileNetV2
    """
    # Base model - MobileNetV2
    base_model = MobileNetV2(
        input_shape=input_shape,
        include_top=False,
        weights='imagenet'
    )
    
    # Freeze the base model layers
    base_model.trainable = False
    
    # Create the model
    model = models.Sequential([
        base_model,
        layers.GlobalAveragePooling2D(),
        layers.BatchNormalization(),
        layers.Dense(128, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(64, activation='relu'),
        layers.Dropout(0.3),
        layers.Dense(num_classes, activation='softmax')
    ])
    
    # Compile the model
    model.compile(
        optimizer=optimizers.Adam(learning_rate=0.001),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    return model

# Build the model
model = build_model()
model.summary()

### 3.2 Training the Model

In [None]:
# Define callbacks
def get_callbacks():
    return [
        callbacks.EarlyStopping(
            monitor='val_loss',
            patience=10,
            restore_best_weights=True
        ),
        callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=5,
            min_lr=1e-6
        ),
        callbacks.ModelCheckpoint(
            'wood_classifier_lab_best.h5',
            monitor='val_accuracy',
            save_best_only=True,
            mode='max'
        )
    ]

# Train the model
print("Training the model...")
history = model.fit(
    data_augmentation.flow(train_images, train_labels_cat, batch_size=BATCH_SIZE),
    validation_data=(val_images, val_labels_cat),
    epochs=30,
    callbacks=get_callbacks(),
    verbose=1
)

# Plot training history
def plot_training_history(history):
    plt.figure(figsize=(12, 5))
    
    # Plot accuracy
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Training Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    # Plot loss
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

plot_training_history(history)

### 3.3 Fine-tuning the Model

In [None]:
# Fine-tune the model
print("Fine-tuning the model...")

# Unfreeze the top layers of the base model
base_model = model.layers[0]
base_model.trainable = True

# Freeze the bottom layers and unfreeze the top layers
for layer in base_model.layers[:-30]:  # Keep the bottom layers frozen
    layer.trainable = False
for layer in base_model.layers[-30:]:  # Unfreeze the top layers
    layer.trainable = True

# Recompile with a lower learning rate
model.compile(
    optimizer=optimizers.Adam(learning_rate=1e-5),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Continue training
fine_tune_history = model.fit(
    data_augmentation.flow(train_images, train_labels_cat, batch_size=BATCH_SIZE // 2),
    validation_data=(val_images, val_labels_cat),
    epochs=20,
    callbacks=get_callbacks(),
    verbose=1
)

# Plot fine-tuning history
plot_training_history(fine_tune_history)

## 4. Model Evaluation

Let's evaluate our model on the test set:

In [None]:
# Evaluate on test set
test_loss, test_accuracy = model.evaluate(test_images, test_labels_cat)
print(f"Test accuracy: {test_accuracy:.4f}")
print(f"Test loss: {test_loss:.4f}")

# Make predictions on the test set
y_pred_prob = model.predict(test_images)
y_pred = np.argmax(y_pred_prob, axis=1)

# Print classification report
print("\nClassification Report:")
report = classification_report(test_labels, y_pred, target_names=CLASS_NAMES)
print(report)

# Create confusion matrix
cm = confusion_matrix(test_labels, y_pred)

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=CLASS_NAMES,
            yticklabels=CLASS_NAMES)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.tight_layout()
plt.show()

# Visualize predictions
def visualize_predictions(images, true_labels, pred_labels, class_names, num_samples=10):
    # Find both correct and incorrect predictions
    correct_indices = np.where(true_labels == pred_labels)[0]
    incorrect_indices = np.where(true_labels != pred_labels)[0]
    
    # Determine how many of each to show
    n_incorrect = min(num_samples // 2, len(incorrect_indices))
    n_correct = min(num_samples - n_incorrect, len(correct_indices))
    
    # Select random samples
    if len(incorrect_indices) > 0:
        selected_incorrect = np.random.choice(incorrect_indices, n_incorrect, replace=False)
    else:
        selected_incorrect = []
    
    selected_correct = np.random.choice(correct_indices, n_correct, replace=False)
    
    # Combine indices
    selected_indices = np.concatenate([selected_incorrect, selected_correct])
    
    # Create figure
    plt.figure(figsize=(15, 2 * ((len(selected_indices) + 4) // 5)))
    
    for i, idx in enumerate(selected_indices):
        img = images[idx]
        
        # Convert LAB back to RGB for display
        l = img[:,:,0] * 100
        a = img[:,:,1] * 255 - 127
        b = img[:,:,2] * 255 - 127
        lab_img_cv = np.stack([l, a, b], axis=-1).astype(np.uint8)
        rgb_img = cv2.cvtColor(lab_img_cv, cv2.COLOR_LAB2RGB)
        
        plt.subplot(((len(selected_indices) + 4) // 5), 5, i + 1)
        plt.imshow(rgb_img)
        
        true_class = class_names[true_labels[idx]]
        pred_class = class_names[pred_labels[idx]]
        
        if true_labels[idx] == pred_labels[idx]:
            color = 'green'
        else:
            color = 'red'
        
        plt.title(f"True: {true_class}\nPred: {pred_class}", color=color, fontsize=10)
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize sample predictions
visualize_predictions(test_images, test_labels, y_pred, CLASS_NAMES)

## 5. Save and Convert Model for Deployment

In [None]:
# Save the Keras model
model.save('wood_classifier_lab.h5')
print("Saved Keras model to wood_classifier_lab.h5")

# Convert to TensorFlow Lite format
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# Save the TFLite model
with open('wood_classifier_lab.tflite', 'wb') as f:
    f.write(tflite_model)
print("Saved TFLite model to wood_classifier_lab.tflite")

# Convert to optimized TFLite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_optimized_model = converter.convert()

# Save the optimized TFLite model
with open('wood_classifier_lab_optimized.tflite', 'wb') as f:
    f.write(tflite_optimized_model)
print("Saved optimized TFLite model to wood_classifier_lab_optimized.tflite")

## 6. Test the TFLite Model

In [None]:
# Load the TFLite model and allocate tensors
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

# Get input and output tensors
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print("Input details:", input_details)
print("Output details:", output_details)

# Test the TFLite model on a few images
def test_tflite_model(interpreter, images, labels, class_names, num_samples=5):
    # Get input and output details
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    
    # Create a figure for visualization
    plt.figure(figsize=(15, 3 * num_samples))
    
    # Test on random samples
    indices = np.random.choice(range(len(images)), num_samples, replace=False)
    
    for i, idx in enumerate(indices):
        # Process input
        input_image = images[idx:idx+1]
        interpreter.set_tensor(input_details[0]['index'], input_image)
        
        # Run inference
        interpreter.invoke()
        
        # Get the output
        output = interpreter.get_tensor(output_details[0]['index'])
        predicted_class = np.argmax(output[0])
        
        # Calculate confidence
        confidence = output[0][predicted_class] * 100
        
        # Display results
        plt.subplot(num_samples, 1, i+1)
        
        # Convert LAB back to RGB for display
        img = images[idx].copy()
        l = img[:,:,0] * 100
        a = img[:,:,1] * 255 - 127
        b = img[:,:,2] * 255 - 127
        lab_img_cv = np.stack([l, a, b], axis=-1).astype(np.uint8)
        rgb_img = cv2.cvtColor(lab_img_cv, cv2.COLOR_LAB2RGB)
        plt.imshow(rgb_img)
        
        title = f"True: {class_names[labels[idx]]}\n"
        title += f"Predicted: {class_names[predicted_class]} ({confidence:.1f}%)"
        
        if predicted_class == labels[idx]:
            plt.title(title, color='green')
        else:
            plt.title(title, color='red')
        
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Test the TFLite model
test_tflite_model(interpreter, test_images, test_labels, CLASS_NAMES)

## 7. Using the Model in Production

To use this model in a production environment, you'll need to:

1. Preprocess new images by converting to LAB color space
2. Normalize the LAB values appropriately
3. Run the model on the preprocessed image

Here's a function you can use to preprocess new images:

In [None]:
def preprocess_image_for_model(image_path, img_size=224):
    """
    Preprocess an image for the wood classifier model
    
    Args:
        image_path: Path to the image file
        img_size: Size to resize the image to
        
    Returns:
        preprocessed_image: LAB image normalized and ready for the model
    """
    # Read image
    img = cv2.imread(image_path)
    if img is None:
        raise ValueError(f"Could not read image: {image_path}")
    
    # Convert BGR to RGB
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    # Resize to target dimensions
    img_resized = cv2.resize(img_rgb, (img_size, img_size))
    
    # Convert RGB to LAB
    img_lab = cv2.cvtColor(img_resized, cv2.COLOR_RGB2LAB)
    
    # Normalize LAB values
    l_channel = img_lab[:,:,0] / 100.0
    a_channel = (img_lab[:,:,1] + 127) / 255.0
    b_channel = (img_lab[:,:,2] + 127) / 255.0
    
    # Stack normalized channels
    normalized_lab = np.stack([l_channel, a_channel, b_channel], axis=-1)
    
    # Add batch dimension
    return np.expand_dims(normalized_lab, axis=0)

# Example of using the model with a new image
def classify_new_image(image_path, model, class_names):
    """
    Classify a new image using the trained model
    
    Args:
        image_path: Path to the image file
        model: Trained Keras model
        class_names: List of class names
        
    Returns:
        predicted_class: The predicted class name
        confidence: Confidence score (0-100%)
    """
    # Preprocess the image
    preprocessed_img = preprocess_image_for_model(image_path)
    
    # Make prediction
    predictions = model.predict(preprocessed_img)
    
    # Get the predicted class and confidence
    predicted_class_idx = np.argmax(predictions[0])
    predicted_class = class_names[predicted_class_idx]
    confidence = predictions[0][predicted_class_idx] * 100
    
    return predicted_class, confidence

## 8. Conclusion

This notebook has demonstrated how to build a wood finish classification model using LAB color space, which is particularly well-suited for distinguishing between different wood tones. The LAB color space separates lightness (L channel) from color information (a and b channels), making it easier for the model to distinguish subtle color differences in wood finishes.

Key takeaways:
1. LAB color space provides better feature separation for wood finish classification
2. Data augmentation helps improve model robustness
3. Transfer learning with MobileNetV2 provides a strong foundation
4. Fine-tuning further improves model performance

The trained model can now be deployed on mobile or edge devices using the optimized TensorFlow Lite format.