In [None]:
mport tensorflow as tf 
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import LearningRateScheduler, EarlyStopping
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
import glob as gb 
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Force TensorFlow to use CPU only (to avoid GPU issues)
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
print("Using CPU only mode")

# Configuration
IMG_SIZE = 224
SEED = 1000
BATCH_SIZE = 16
TRAIN_DIR = '/media/rogerthattan/Toshiba/Thesis/Dataset/Train_Data'
TEST_DIR = '/media/rogerthattan/Toshiba/Thesis/Dataset/Test_Data'

# Analyze dataset
categories = []
class_count = []
train_exm = 0
for f in os.listdir(TRAIN_DIR):
    files = gb.glob(pathname=str(TRAIN_DIR + '/' + f + '/*.jpg'))
    categories.append(f)
    class_count.append(len(files))
    train_exm += len(files)

# Plot class distribution
sns.barplot(x=categories, y=class_count).set_title("Distribution of train data")
plt.show()
print(f"Total training examples: {train_exm}")

# Data generators with simpler preprocessing
train_gen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    brightness_range=[0.8, 1.2],
    shear_range=0.1,
    validation_split=0.2,
    rescale=1./255  # Simple rescaling instead of complex preprocessing
)

test_gen = ImageDataGenerator(
    rescale=1./255
)

train_batch = train_gen.flow_from_directory(
    directory=TRAIN_DIR,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='sparse',
    subset='training',
    seed=SEED
)

valid_batch = train_gen.flow_from_directory(
    directory=TRAIN_DIR,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='sparse',
    subset='validation',
    seed=SEED
)

test_batch = test_gen.flow_from_directory(
    directory=TEST_DIR,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='sparse',
    seed=SEED
)

# Calculate steps per epoch
steps_per_epoch = len(train_batch)
validation_steps = len(valid_batch)

print(f"Steps per epoch: {steps_per_epoch}")
print(f"Validation steps: {validation_steps}")

# Build a simpler ResNet model
img_shape = (IMG_SIZE, IMG_SIZE, 3)
num_classes = len(categories)

# Use a smaller ResNet model
base_model = tf.keras.applications.ResNet50V2(
    input_shape=img_shape,
    include_top=False,
    weights='imagenet'
)

# Freeze the base model layers
base_model.trainable = False

# Add custom classification head
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(256, activation='relu')(x)
predictions = Dense(num_classes, activation='softmax')(x)

# Create the complete model
model = Model(inputs=base_model.input, outputs=predictions)

# Display model summary
print("Model created successfully")

# Compile the model with a simpler optimizer
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Define callbacks
callbacks = [
    EarlyStopping(patience=5, restore_best_weights=True)
]

# Try to train with a very small number of steps first to test
try:
    print("Testing model with a small batch...")
    x_sample, y_sample = next(iter(train_batch))
    model.fit(x_sample, y_sample, epochs=1, verbose=1)
    print("Small batch test successful, proceeding with full training")
    
    # Train the model
    h = model.fit(
        train_batch,
        steps_per_epoch=steps_per_epoch,
        validation_data=valid_batch,
        validation_steps=validation_steps,
        epochs=50,
        callbacks=callbacks,
        verbose=1
    )
    
    # Plot training history
    plt.figure(figsize=(20, 10))
    plt.subplot(1, 2, 1)
    plt.plot(h.history['accuracy'], 'o-', label='train accuracy')
    plt.plot(h.history['val_accuracy'], 'o-', label='validation accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.grid(True)
    plt.legend(loc='lower right')

    plt.subplot(1, 2, 2)
    plt.plot(h.history['loss'], 'o-', label='train loss')
    plt.plot(h.history['val_loss'], 'o-', label='validation loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)
    plt.legend(loc='upper right')
    plt.show()
    
except Exception as e:
    print(f"Error during training: {e}")
    
    # Alternative approach if the above fails
    print("\nTrying alternative approach with manual batching...")
    
    # Create a simpler model
    from tensorflow.keras.models import Sequential
    from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
    
    simple_model = Sequential([
        Conv2D(32, (3, 3), activation='relu', input_shape=img_shape),
        MaxPooling2D(2, 2),
        Conv2D(64, (3, 3), activation='relu'),
        MaxPooling2D(2, 2),
        Conv2D(128, (3, 3), activation='relu'),
        MaxPooling2D(2, 2),
        Flatten(),
        Dense(512, activation='relu'),
        Dropout(0.5),
        Dense(num_classes, activation='softmax')
    ])
    
    simple_model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    
    # Train with the simpler model
    h = simple_model.fit(
        train_batch,
        steps_per_epoch=steps_per_epoch,
        validation_data=valid_batch,
        validation_steps=validation_steps,
        epochs=20,
        callbacks=[EarlyStopping(patience=3, restore_best_weights=True)],
        verbose=1
    )
    
    # Plot training history for the simple model
    plt.figure(figsize=(20, 10))
    plt.subplot(1, 2, 1)
    plt.plot(h.history['accuracy'], 'o-', label='train accuracy')
    plt.plot(h.history['val_accuracy'], 'o-', label='validation accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.grid(True)
    plt.legend(loc='lower right')

    plt.subplot(1, 2, 2)
    plt.plot(h.history['loss'], 'o-', label='train loss')
    plt.plot(h.history['val_loss'], 'o-', label='validation loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)
    plt.legend(loc='upper right')
    plt.show()

# Save the model
try:
    model.save('resnet_model.h5')
    print("Model saved successfully")
except Exception as e:
    print(f"Error saving model: {e}")
    try:
        # Try saving in TensorFlow SavedModel format instead
        model.save('resnet_model')
        print("Model saved in SavedModel format")
    except Exception as e2:
        print(f"Error saving in SavedModel format: {e2}")