In [3]:
import json
import os
import cv2
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from tensorflow.keras.applications import MobileNet
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
from collections import Counter, defaultdict

# Step 1: Load the JSON data
with open('/Users/mjrchy/Documents/BCS-models-training-new-data/bcs_dataset_new.json', 'r') as f:
    data = json.load(f)
print(f"Total entries: {len(data)}")

# Step 2: Process and validate data
images = []
bcs_labels = []
skipped_count = 0

for i, entry in enumerate(data):
    # Check BCS value first
    bcs_raw = entry.get("BCS")
    if bcs_raw is None or bcs_raw in ['ไม่ระบุ (ไม่ทราบ)', 'ไม่ระบุ', 'ไม่ทราบ']:
        skipped_count += 1
        continue
    
    # Convert BCS to integer
    try:
        if isinstance(bcs_raw, (int, float)):
            bcs_value = int(bcs_raw)
        elif isinstance(bcs_raw, str) and bcs_raw.replace('.', '').isdigit():
            bcs_value = int(float(bcs_raw))
        else:
            print(f"Invalid BCS value: {bcs_raw}")
            skipped_count += 1
            continue
    except (ValueError, TypeError):
        print(f"Could not convert BCS: {bcs_raw}")
        skipped_count += 1
        continue
    
    # Validate BCS range (1-9)
    if bcs_value < 1 or bcs_value > 9:
        print(f"BCS out of range: {bcs_value}")
        skipped_count += 1
        continue
    
    # Collect all 4 view paths
    views = [
        entry.get("ภาพด้านบน (Top View)"),
        entry.get("ภาพด้านหลัง (Back View)"),
        entry.get("ภาพด้านขวา (Right View)"),
        entry.get("ภาพด้านซ้าย (Left View)")
    ]

    # Load and validate images
    img_list = []
    valid = True
    for j, v in enumerate(views):
        if v is None:
            print(f"Entry {i}: Missing view {j}")
            valid = False
            break
        
        path = os.path.join("bcs_dataset", v)
        if not os.path.exists(path):
            print(f"Entry {i}: Missing image file: {path}")
            valid = False
            break
        
        try:
            img = cv2.imread(path)
            if img is None:
                print(f"Entry {i}: Could not load image: {path}")
                valid = False
                break
            
            # Resize to 224x224 for MobileNet
            img = cv2.resize(img, (224, 224))
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
            img_list.append(img)
        except Exception as e:
            print(f"Entry {i}: Error processing image {path}: {e}")
            valid = False
            break
    
    if not valid:
        skipped_count += 1
        continue

    # Create 2x2 grid (448x448x3)
    try:
        top = np.concatenate([img_list[0], img_list[1]], axis=1)  # top row
        bottom = np.concatenate([img_list[2], img_list[3]], axis=1)  # bottom row
        combined = np.concatenate([top, bottom], axis=0)  # shape (448, 448, 3)
        
        images.append(combined)
        bcs_labels.append(bcs_value - 1)  # Convert to 0-8 for model training
        
    except Exception as e:
        print(f"Entry {i}: Error combining images: {e}")
        skipped_count += 1
        continue

print(f"Successfully processed: {len(images)} samples")
print(f"Skipped: {skipped_count} samples")

# Convert to numpy arrays
X = np.array(images, dtype=np.float32) / 255.0  # Normalize to [0,1]
y = np.array(bcs_labels, dtype=np.int32)

print(f"Final dataset shape: {X.shape}")
print(f"Labels shape: {y.shape}")
print(f"Label range: {y.min()} to {y.max()}")

# Check class distribution
class_counts = Counter(y)
print("\nClass distribution:")
for class_label in sorted(class_counts.keys()):
    print(f"BCS {class_label + 1}: {class_counts[class_label]} samples")

# Step 3: Ensure we have enough data
if len(X) == 0:
    raise ValueError("No valid samples found! Check your data and image paths.")

if len(np.unique(y)) < 2:
    raise ValueError("Need at least 2 different classes for training.")

# Step 4: Custom train-test split to ensure all classes are represented
class_indices = defaultdict(list)
for idx, label in enumerate(y):
    class_indices[label].append(idx)

train_indices = []
test_indices = []

for class_label, indices in class_indices.items():
    indices = shuffle(indices, random_state=42)
    n_samples = len(indices)

    if n_samples == 1:
        # Only one sample → put it in both sets
        train_indices.extend(indices)
        test_indices.extend(indices)
        print(f"Warning: Only 1 sample for BCS {class_label + 1}, duplicating for train/test")
    else:
        # At least 80% train, 20% test, but ensure at least 1 in each
        n_test = max(1, int(0.2 * n_samples))
        n_train = n_samples - n_test

        if n_train == 0:
            n_train = 1
            n_test = n_samples - 1

        train_indices.extend(indices[:n_train])
        test_indices.extend(indices[n_train:])

# Create train/test splits
X_train, X_test = X[train_indices], X[test_indices]
y_train, y_test = y[train_indices], y[test_indices]

print(f"\nTrain set: {len(X_train)} samples")
print(f"Test set: {len(X_test)} samples")
print("Train class distribution:", Counter(y_train))
print("Test class distribution:", Counter(y_test))

# Step 5: Resize images to match MobileNet input (224x224)
def resize_to_mobilenet_input(images):
    """Resize 448x448 images to 224x224 for MobileNet"""
    resized = []
    for img in images:
        resized_img = cv2.resize(img, (224, 224))
        resized.append(resized_img)
    return np.array(resized, dtype=np.float32)

X_train_resized = resize_to_mobilenet_input(X_train)
X_test_resized = resize_to_mobilenet_input(X_test)

print(f"Resized train shape: {X_train_resized.shape}")
print(f"Resized test shape: {X_test_resized.shape}")

# Step 6: Build the model using MobileNet
base_model = MobileNet(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

# Freeze base model initially
base_model.trainable = False

# Add custom classification head
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(256, activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(128, activation='relu')(x)
x = Dropout(0.3)(x)
predictions = Dense(9, activation='softmax')(x)  # 9 classes (BCS 1-9)

# Create the model
model = Model(inputs=base_model.input, outputs=predictions)

# Step 7: Compile the model
model.compile(
    optimizer=Adam(learning_rate=0.001),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

print("Model compiled successfully!")
print(f"Model input shape: {model.input_shape}")
print(f"Model output shape: {model.output_shape}")

# Step 8: Add data validation before training
print("\nValidating training data...")
print(f"X_train_resized shape: {X_train_resized.shape}")
print(f"X_train_resized dtype: {X_train_resized.dtype}")
print(f"X_train_resized range: [{X_train_resized.min():.3f}, {X_train_resized.max():.3f}]")
print(f"y_train shape: {y_train.shape}")
print(f"y_train dtype: {y_train.dtype}")
print(f"y_train range: [{y_train.min()}, {y_train.max()}]")

# Check for NaN or infinite values
if np.isnan(X_train_resized).any():
    print("WARNING: NaN values found in X_train_resized")
    X_train_resized = np.nan_to_num(X_train_resized)

if np.isinf(X_train_resized).any():
    print("WARNING: Infinite values found in X_train_resized")
    X_train_resized = np.nan_to_num(X_train_resized)

# Step 9: Train the model with callbacks for better monitoring
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau

callbacks = [
    EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-7)
]

try:
    print("\nStarting training...")
    history = model.fit(
        X_train_resized, y_train,
        epochs=30,
        batch_size=4,
        validation_split=0.2,
        callbacks=callbacks,
        verbose=1
    )
    print("Training completed successfully!")
    
except Exception as e:
    print(f"Training failed with error: {e}")
    print("Trying with smaller batch size...")
    
    # Try with batch size 1 if batch size 4 fails
    try:
        history = model.fit(
            X_train_resized, y_train,
            epochs=30,
            batch_size=1,
            validation_split=0.2,
            callbacks=callbacks,
            verbose=1
        )
        print("Training completed with batch size 1!")
    except Exception as e2:
        print(f"Training failed even with batch size 1: {e2}")
        raise e2

# Step 10: Evaluate the model
print("\nEvaluating model...")
loss, accuracy = model.evaluate(X_test_resized, y_test, verbose=0)
print(f'Test Accuracy: {accuracy * 100:.2f}%')

# Step 11: Save the model
model.save('bcs_prediction_model.h5')
print("Model saved successfully!")

# Step 12: Generate predictions and detailed analysis
print("\nGenerating predictions...")
y_pred_probs = model.predict(X_test_resized)
y_pred = np.argmax(y_pred_probs, axis=1)

# Create confusion matrix
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

all_labels = list(range(9))
cm = confusion_matrix(y_test, y_pred, labels=all_labels)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=[f'BCS {i}' for i in range(1, 10)],
            yticklabels=[f'BCS {i}' for i in range(1, 10)])
plt.xlabel('Predicted BCS')
plt.ylabel('Actual BCS')
plt.title('Confusion Matrix for BCS Prediction')
plt.show()

# Print classification report
print("\nClassification Report:")
print(classification_report(
    y_test, y_pred,
    labels=all_labels,
    target_names=[f'BCS {i}' for i in range(1, 10)],
    zero_division=0
))

# Plot training history
if 'history' in locals():
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history.history['accuracy'], label='Training Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

Total entries: 34
Entry 1: Missing image file: bcs_dataset/IMG_9099 - PANNAWAT SINTHUBTHONG.jpeg
Entry 2: Missing image file: bcs_dataset/IMG_8980 - Patchareenat N.jpeg
Entry 3: Missing view 0
Entry 4: Missing image file: bcs_dataset/IMG_2570 - Pasavit TAPEN.jpeg
Entry 5: Missing image file: bcs_dataset/IMG_0577 - Pakkaphan SAMAKWONGPANICH.jpeg
Entry 6: Missing image file: bcs_dataset/IMG_2718 - Pasavit TAPEN.jpeg
Entry 7: Missing image file: bcs_dataset/IMG_20250726_082932 - Khomson Satchasataporn.jpg
Entry 8: Missing image file: bcs_dataset/IMG_20250726_083110 - Khomson Satchasataporn.jpg
Entry 9: Missing image file: bcs_dataset/IMG_20250726_083323 - Khomson Satchasataporn.jpg
Entry 10: Missing image file: bcs_dataset/IMG_20250726_084003 - Khomson Satchasataporn.jpg
Entry 11: Missing image file: bcs_dataset/IMG_20250726_084722 - Khomson Satchasataporn.jpg
Entry 12: Missing image file: bcs_dataset/IMG_20250726_084915 - Khomson Satchasataporn.jpg
Entry 13: Missing image file: bcs_datas

ValueError: zero-size array to reduction operation minimum which has no identity