In [6]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import GlobalAveragePooling2D
from tensorflow.keras.models import Sequential
import numpy as np

In [7]:
# Load dataset
def load_and_preprocess_data():
    ds, info = tfds.load('malaria', split='train', with_info=True, as_supervised=True)
    
    # Preprocess images
    def preprocess_images(image, label):
        image = tf.image.resize(image, (128, 128))
        image = tf.keras.applications.mobilenet_v2.preprocess_input(image)
        return image, tf.cast(label, tf.int32)

    ds = ds.map(preprocess_images).batch(32)
    
    # Feature extraction with MobileNetV2
    base_model = MobileNetV2(include_top=False, weights='imagenet', input_shape=(128, 128, 3))
    base_model.trainable = False
    feature_extractor = Sequential([base_model, GlobalAveragePooling2D()])

    features, labels = [], []
    for images, label in ds:
        features_batch = feature_extractor(images)
        features.append(features_batch)
        labels.append(label)

    features = np.vstack(features)
    labels = np.concatenate(labels)
    
    return features, labels

features, labels = load_and_preprocess_data()


In [8]:
from spektral.data import Graph

def create_graph(features, labels):
    num_nodes = len(labels)
    adjacency = np.random.randint(0, 2, (num_nodes, num_nodes))
    adjacency = np.maximum(adjacency, adjacency.T)
    np.fill_diagonal(adjacency, 0)
    return Graph(x=features, a=adjacency, y=labels.reshape(-1, 1))

graph = create_graph(features, labels)


In [9]:
from sklearn.model_selection import train_test_split

def create_subgraph(graph, indices):
    x_sub = graph.x[indices]
    y_sub = graph.y[indices]
    a_sub = graph.a[indices][:, indices]
    return Graph(x=x_sub, a=a_sub, y=y_sub)

def split_graph(graph):
    indices = np.arange(graph.n_nodes)
    train_indices, test_indices = train_test_split(indices, test_size=0.2, random_state=42)
    train_indices, val_indices = train_test_split(train_indices, test_size=0.25, random_state=42)
    
    train_graph = create_subgraph(graph, train_indices)
    val_graph = create_subgraph(graph, val_indices)
    test_graph = create_subgraph(graph, test_indices)

    return train_graph, val_graph, test_graph

train_graph, val_graph, test_graph = split_graph(graph)


In [10]:
from spektral.layers import GCNConv
from tensorflow.keras.layers import Input, Dense, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from spektral.data import SingleLoader

def create_gnn_model():
    inputs = Input(shape=(features.shape[1],))
    adj_input = Input(shape=(None,))
    x = GCNConv(32, activation='relu')([inputs, adj_input])
    x = Dropout(0.5)(x)
    x = GCNConv(16, activation='relu')([x, adj_input])
    x = Dropout(0.5)(x)
    x = Dense(1, activation='sigmoid')(x)
    return Model(inputs=[inputs, adj_input], outputs=x)

model = create_gnn_model()
model.compile(optimizer=Adam(learning_rate=0.01), loss='binary_crossentropy', metrics=['accuracy'])

train_loader = SingleLoader(train_graph)
val_loader = SingleLoader(val_graph)

model.fit(train_loader.load(), steps_per_epoch=train_loader.steps_per_epoch,
          validation_data=val_loader.load(), validation_steps=val_loader.steps_per_epoch,
          epochs=10)




TypeError: object of type 'Graph' has no len()

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Making predictions with the model
val_images, val_labels = next(iter(val_ds.unbatch().batch(len(ds_validation))))
val_predictions = model.predict(val_images)
val_pred_classes = (val_predictions > 0.5).astype(int)  # Binary classification threshold

# Computing the confusion matrix
cm = confusion_matrix(val_labels, val_pred_classes)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=ds_info.features['label'].names)

# Plotting the confusion matrix
plt.figure(figsize=(10,10))
disp.plot(cmap=plt.cm.Blues)
plt.show()