In [4]:
import os
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, UpSampling2D, GlobalAveragePooling2D, Dense, Input
from tensorflow.keras.utils import to_categorical
import keras

def load_images_with_labels(base_directory):
    images = []
    labels = []
    processed_counts = {}
    label_map = {'up': 0, 'straight': 1, 'left': 2, 'right': 3}
    
    for root, dirs, files in os.walk(base_directory):
        folder_name = os.path.basename(root)
        processed_counts[folder_name] = {'total': 0, 'processed': 0, 'skipped': 0}
        
        for filename in files:
            if filename.endswith(('.png', '.pgm')):
                file_path = os.path.join(root, filename)
                try:
                    processed_counts[folder_name]['total'] += 1
                    
                    img = Image.open(file_path).convert('RGB')
                    img = img.resize((64, 64))
                    img_array = np.array(img) / 255.0
                    
                    parts = filename.split('_')
                    if len(parts) > 1:
                        orientation = parts[1]
                        if orientation in label_map:
                            images.append(img_array)
                            labels.append(label_map[orientation])
                            processed_counts[folder_name]['processed'] += 1
                        else:
                            processed_counts[folder_name]['skipped'] += 1
                            print(f"Skipping file with unexpected orientation: {filename} in {folder_name}")
                    else:
                        processed_counts[folder_name]['skipped'] += 1
                        print(f"Skipping file with missing orientation: {filename} in {folder_name}")
                        
                except Exception as e:
                    processed_counts[folder_name]['skipped'] += 1
                    print(f"Error processing {filename} in {folder_name}: {e}")
    
    # # Print processing summary
    # print("\nProcessing Summary:")
    # for folder, counts in processed_counts.items():
    #     if counts['total'] > 0: 
    #         print(f"\n{folder}:")
    #         print(f"  Total files: {counts['total']}")
    #         print(f"  Successfully processed: {counts['processed']}")
    #         print(f"  Skipped: {counts['skipped']}")
    #         if counts['processed'] > 0:
    #             success_rate = (counts['processed'] / counts['total']) * 100
    #             print(f"  Success rate: {success_rate:.2f}%")
    
    return np.array(images), np.array(labels)

def create_combined_model(input_shape, num_classes):
    inputs = Input(shape=input_shape)

    # Encoder
    x = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    x = MaxPooling2D((2, 2), padding='same')(x)
    x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = MaxPooling2D((2, 2), padding='same')(x)
    x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    encoded = MaxPooling2D((2, 2), padding='same')(x)

    # Decoder
    x = UpSampling2D((2, 2))(encoded)
    x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = UpSampling2D((2, 2))(x)
    x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = UpSampling2D((2, 2))(x)
    decoded = Conv2D(3, (3, 3), activation='sigmoid', padding='same', name='decoded')(x)

    # Classification head
    x = GlobalAveragePooling2D()(encoded)
    classification_output = Dense(num_classes, activation='softmax', name='classification')(x)

    model = Model(inputs, [decoded, classification_output])
    
    model.compile(
        optimizer='adam',
        loss={
            'decoded': 'mean_squared_error',
            'classification': 'categorical_crossentropy'
        },
        loss_weights={
            'decoded': 1.0,
            'classification': 0.5
        },
        metrics={
            'classification': ['accuracy']
        }
    )

    return model

directory = 'faces/tammo'
images, labels = load_images_with_labels(directory)
print(len(images))
if len(images) == 0 or len(labels) == 0:
    print("No data found. Ensure the directory is correct and files match the expected format.")
else:
    X_train, X_test, y_train, y_test = train_test_split(images, labels, test_size=0.2, random_state=42)
    y_train_cat = to_categorical(y_train, num_classes=4)
    y_test_cat = to_categorical(y_test, num_classes=4)

    combined_model = create_combined_model(input_shape=(64, 64, 3), num_classes=4)

    combined_model.fit(
        X_train,
        {
            'decoded': X_train,
            'classification': y_train_cat 
        },
        batch_size=64,
        epochs=100,
        validation_data=(
            X_test,
            {
                'decoded': X_test,
                'classification': y_test_cat
            }
        )
    )

    decoded_imgs, predictions = combined_model.predict(X_test)
    predicted_labels = np.argmax(predictions, axis=1)

    orientation_map = {0: 'up', 1: 'straight', 2: 'left', 3: 'right'}
    correct = 0
    wrong = 0
    for i in range(len(y_test)):
        actual = orientation_map[y_test[i]]
        predicted = orientation_map[predicted_labels[i]]
        print(f"Actual: {actual}, Predicted: {predicted}")
        if(actual == predicted):
            correct += 1
        else:
            wrong += 1
    print(f"Total Correct: {correct}")
    print(f"Total Wrong: {wrong}")


1872
Epoch 1/100
[1m24/24[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 386ms/step - classification_accuracy: 0.2544 - loss: 0.7834 - val_classification_accuracy: 0.2213 - val_loss: 0.7221
Epoch 2/100
[1m24/24[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 360ms/step - classification_accuracy: 0.2545 - loss: 0.7168 - val_classification_accuracy: 0.2907 - val_loss: 0.7067
Epoch 3/100
[1m24/24[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 336ms/step - classification_accuracy: 0.2761 - loss: 0.7045 - val_classification_accuracy: 0.3200 - val_loss: 0.7008
Epoch 4/100
[1m24/24[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 345ms/step - classification_accuracy: 0.2898 - loss: 0.6995 - val_classification_accuracy: 0.2800 - val_loss: 0.7030
Epoch 5/100
[1m24/24[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 339ms/step - classification_accuracy: 0.3087 - loss: 0.6972 - val_classification_accuracy: 0.2347 - val_loss: 0.6971
Epoch 6/100
[1m24/24[0m [3