In [2]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Assume genomaps are in a folder: 'genomaps_data' with subfolders for each category
data_dir = 'organized_genomaps'

# Data generator for loading and preprocessing images
datagen = ImageDataGenerator(rescale=1.0/255.0, validation_split=0.2)

# Training and validation data
train_data = datagen.flow_from_directory(data_dir, target_size=(224, 224),
                                         batch_size=32, class_mode='categorical', subset='training')

val_data = datagen.flow_from_directory(data_dir, target_size=(224, 224),
                                       batch_size=32, class_mode='categorical', subset='validation')


Found 8887 images belonging to 10 classes.
Found 2215 images belonging to 10 classes.


In [3]:
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Model

# Load pretrained ResNet50 model
base_model = ResNet50(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
base_model.trainable = False  # Freeze the pretrained layers

# Add pooling and flattening layers for feature extraction
feature_extractor = tf.keras.Sequential([
    base_model,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Flatten()
])

# Extract features for training and validation data
X_train_features = feature_extractor.predict(train_data, verbose=1)
X_val_features = feature_extractor.predict(val_data, verbose=1)
y_train = train_data.classes
y_val = val_data.classes


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m94765736/94765736[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 0us/step


  self._warn_if_super_not_called()


[1m278/278[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m429s[0m 2s/step
[1m70/70[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m97s[0m 1s/step


In [4]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

# Initialize Random Forest Classifier
rf = RandomForestClassifier(n_estimators=100, random_state=42)

# Train the model on deep learning extracted features
rf.fit(X_train_features, y_train)

# Predict on validation features
rf_predictions = rf.predict(X_val_features)

# Accuracy of Random Forest
rf_accuracy = accuracy_score(y_val, rf_predictions)
print(f"Random Forest Validation Accuracy: {rf_accuracy}")


Random Forest Validation Accuracy: 0.09255079006772009


In [5]:
# Predictions from Deep Learning model
dl_predictions = feature_extractor.predict(val_data, verbose=1)

# Convert Deep Learning predictions to class labels
dl_pred_labels = np.argmax(dl_predictions, axis=1)

# Combine predictions using a weighted ensemble
ensemble_predictions = (0.6 * dl_pred_labels + 0.4 * rf_predictions).round().astype(int)


[1m70/70[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m89s[0m 1s/step


In [9]:
import matplotlib.pyplot as plt
import os
import numpy as np

output_dir = 'annotated_genomaps'
os.makedirs(output_dir, exist_ok=True)

# Reverse the mapping of class indices to get the class names
class_indices_reverse = {v: k for k, v in train_data.class_indices.items()}

# Validate and fix predictions
print("Unique predictions:", np.unique(ensemble_predictions))
print("Available class indices:", list(class_indices_reverse.keys()))

# Clip predictions to valid range
valid_classes = list(class_indices_reverse.keys())
ensemble_predictions = np.clip(ensemble_predictions, min(valid_classes), max(valid_classes)).astype(int)

# Annotate and save each genomap
for idx, (image, label) in enumerate(zip(val_data.filepaths, ensemble_predictions)):
    img = plt.imread(image)
    plt.figure(figsize=(6, 6))
    try:
        # Use the reversed mapping to get the class name
        class_name = class_indices_reverse[label]
    except KeyError:
        print(f"Warning: Label {label} not found in class_indices_reverse. Skipping image {idx}.")
        continue
    
    plt.imshow(img)
    plt.title(f'Annotation: {class_name}')
    plt.axis('off')
    plt.savefig(os.path.join(output_dir, f'genomap_{idx + 1}.png'))
    plt.close()


Unique predictions: [61 62 63 64]
Available class indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
