In [None]:
import os
import numpy as np
import cv2
from skimage.feature import graycomatrix, graycoprops
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import preprocess_input, ResNet50
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout, Input, Concatenate, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator

base_dir = "/content/drive/MyDrive/dataset"
train_dir = os.path.join(base_dir, "train")
val_dir = os.path.join(base_dir, "val")
image_size = (224, 224)
batch_size = 32
epochs = 50

color_ranges = {
    'ulcer': (np.array([0, 0, 200]), np.array([50, 50, 255])),
    'bleeding': (np.array([0, 50, 50]), np.array([10, 255, 255])),
    'erythema': (np.array([160, 50, 50]), np.array([180, 255, 255])),
    'foreign body': (np.array([25, 52, 72]), np.array([102, 255, 255])),
    'lymphangiectasia': (np.array([20, 40, 150]), np.array([80, 255, 255])),
    'polyp': (np.array([0, 100, 100]), np.array([10, 255, 255])),
    'angioectasia': (np.array([160, 100, 50]), np.array([180, 255, 255])),
    'erosion': (np.array([0, 50, 100]), np.array([50, 255, 255])),
    'worms': (np.array([50, 100, 100]), np.array([80, 255, 255])),
    'normal': (np.array([0, 0, 0]), np.array([180, 255, 255]))
}

def color_detection(img, color_range):
    hsv_img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
    mask = cv2.inRange(hsv_img, color_range[0], color_range[1])
    result = cv2.bitwise_and(img, img, mask=mask)
    return result

def contour_detection(img):
    gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    gray = (gray * 255).astype(np.uint8)
    _, thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY)
    contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    return len(contours)

def texture_analysis(img):
    gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    gray = (gray * 255).astype(np.uint8)
    glcm = graycomatrix(gray, distances=[1], angles=[0], levels=256, symmetric=True, normed=True)
    contrast = graycoprops(glcm, 'contrast')[0, 0]
    dissimilarity = graycoprops(glcm, 'dissimilarity')[0, 0]
    homogeneity = graycoprops(glcm, 'homogeneity')[0, 0]
    return np.array([contrast, dissimilarity, homogeneity])

def preprocess_image(img_path, label):
    img = tf.io.read_file(img_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, image_size)
    img = img / 255.0
    img_np = img.numpy()

    ulcer_result = color_detection(img_np, color_ranges['ulcer'])
    bleeding_result = color_detection(img_np, color_ranges['bleeding'])
    erythema_result = color_detection(img_np, color_ranges['erythema'])

    contours = contour_detection(img_np)
    texture = texture_analysis(img_np)

    extra_features = np.concatenate([
        np.mean(ulcer_result, axis=(0, 1)),
        np.mean(bleeding_result, axis=(0, 1)),
        np.mean(erythema_result, axis=(0, 1)),
        [contours],
        texture
    ])

    return img.numpy(), extra_features, label

def ensure_shape(image, extra_features, label):
    image = tf.ensure_shape(image, (224, 224, 3))
    extra_features = tf.ensure_shape(extra_features, (13,))
    label = tf.ensure_shape(label, ())
    return (image, extra_features), label

def create_dataset(directory):
    class_names = sorted(os.listdir(directory))
    class_dict = {name: i for i, name in enumerate(class_names)}

    image_paths = []
    labels = []

    for class_name in class_names:
        class_dir = os.path.join(directory, class_name)
        for img_name in os.listdir(class_dir):
            img_path = os.path.join(class_dir, img_name)
            image_paths.append(img_path)
            labels.append(class_dict[class_name])

    dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels))
    dataset = dataset.map(lambda x, y: tf.py_function(preprocess_image, [x, y], [tf.float32, tf.float32, tf.int32]),
                          num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.map(ensure_shape)
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

    return dataset

train_dataset = create_dataset(train_dir)
val_dataset = create_dataset(val_dir)

train_datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

base_model = ResNet50(weights='imagenet', include_top=False, input_shape=image_size + (3,))

for layer in base_model.layers[:-30]:
    layer.trainable = False

image_input = Input(shape=image_size + (3,))
extra_input = Input(shape=(13,))

x = base_model(image_input)
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)

concatenated = Concatenate()([x, extra_input])

x = Dense(512, activation='relu')(concatenated)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)
predictions = Dense(10, activation='softmax')(x)

model = Model(inputs=[image_input, extra_input], outputs=predictions)

optimizer = Adam(learning_rate=1e-4)
model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])

lr_scheduler = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-7)
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)

history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=epochs,
    callbacks=[lr_scheduler, early_stopping]
)

model.save(os.path.join(base_dir, 'advanced_ulcer_classification_model.h5'))

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()
