# CNN Training for Evolved Epithelial Modulation Detection
In this notebook, we'll demonstrate how to build, train, and evaluate a Convolutional Neural Network (CNN) for detecting epithelial modulation. We'll cover:
- Building the CNN model
- Training the model
- Evaluating the model
- Visualizing training performance

In [None]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt


## Building the CNN Model
We'll start by defining a CNN architecture that is suitable for image classification tasks.

In [None]:
def build_cnn(input_shape=(128, 128, 3), num_classes=2):
    model = Sequential([
        Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
        MaxPooling2D(pool_size=(2, 2)),
        Conv2D(64, (3, 3), activation='relu'),
        MaxPooling2D(pool_size=(2, 2)),
        Conv2D(128, (3, 3), activation='relu'),
        MaxPooling2D(pool_size=(2, 2)),
        Flatten(),
        Dense(128, activation='relu'),
        Dropout(0.5),
        Dense(num_classes, activation='softmax')
    ])
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model

# Build the model
model = build_cnn()
model.summary()

## Training the CNN Model
Next, we train the CNN model using an image dataset. We will use data augmentation to increase the variability of the training data.

In [None]:
def train_model(model, train_dir, val_dir, batch_size=32, epochs=50):
    train_datagen = ImageDataGenerator(rescale=1./255, rotation_range=20, zoom_range=0.15,
                                       width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15,
                                       horizontal_flip=True, fill_mode="nearest")

    val_datagen = ImageDataGenerator(rescale=1./255)

    train_generator = train_datagen.flow_from_directory(train_dir, target_size=(128, 128),
                                                        batch_size=batch_size, class_mode='categorical')

    val_generator = val_datagen.flow_from_directory(val_dir, target_size=(128, 128),
                                                    batch_size=batch_size, class_mode='categorical')

    history = model.fit(train_generator, epochs=epochs, validation_data=val_generator)
    return history

# Specify the training and validation directories
train_dir = 'path/to/train_data'
val_dir = 'path/to/val_data'

# Train the model
history = train_model(model, train_dir, val_dir)

## Evaluating the CNN Model
Once the model is trained, we evaluate its performance on a test dataset.

In [None]:
def evaluate_model(model, test_dir, batch_size=32):
    test_datagen = ImageDataGenerator(rescale=1./255)
    test_generator = test_datagen.flow_from_directory(test_dir, target_size=(128, 128),
                                                    batch_size=batch_size, class_mode='categorical')
    results = model.evaluate(test_generator)
    return dict(zip(model.metrics_names, results))

# Evaluate the model on the test data
test_dir = 'path/to/test_data'
evaluation_results = evaluate_model(model, test_dir)
print(evaluation_results)

## Visualizing Training Performance
We plot the training accuracy and loss to observe how the model performed over time.

In [None]:
def visualize_training_performance(history):
    plt.figure(figsize=(12, 5))

    # Plot training & validation accuracy
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Train')
    plt.plot(history.history['val_accuracy'], label='Validation')
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    # Plot training & validation loss
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Train')
    plt.plot(history.history['val_loss'], label='Validation')
    plt.title('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.show()

# Visualize the training performance
visualize_training_performance(history)