In [2]:
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
from scipy.spatial import distance_matrix

from src.gnn import GNN

In [5]:
from src.conv import GNNLayer

# Définition du modèle GNN pour la classification de graphes
class GNNClassifier(torch.nn.Module):
    def __init__(self, in_features, embedding_dim, out_features, num_layers=2, conv_type="gcn", pool_type="mean"):
        super(GNNClassifier, self).__init__()
        self.num_layers = num_layers
        
        # List to hold the layers of the GNN
        self.layers = torch.nn.ModuleList()

        # Input layer
        self.layers.append(GNNLayer(in_features, embedding_dim, conv_type=conv_type))

        # Hidden layers
        for _ in range(num_layers - 2):
            self.layers.append(GNNLayer(embedding_dim, embedding_dim, conv_type=conv_type))

        # Output layer
        self.layers.append(GNNLayer(embedding_dim, out_features, conv_type=conv_type))

        self.pooling = PoolingLayer(pool_type=pool_type)
        self.final_layer = torch.nn.Linear(out_features, 10)  # 10 classes pour MNIST

    def forward(self, graphs):
        """Prend une liste de graphes et applique les convolutions GNN."""
        batch_embeddings = []
        for G in graphs:
            x = torch.tensor([G.nodes[n]['feature'] for n in G.nodes()], dtype=torch.float).unsqueeze(1)
            edge_index = torch.tensor(list(G.edges()), dtype=torch.long).T
            
            h = x
            for layer in self.layers:
                h_neighbors = self.aggregate_neighbors(h, edge_index)
                h = layer(h, h_neighbors)
            
            graph_embedding = self.pooling(h)  # Pooling global pour classification
            batch_embeddings.append(graph_embedding)
        
        batch_embeddings = torch.stack(batch_embeddings)
        output = self.final_layer(batch_embeddings)  # Classification en 10 classes
        return output

    def aggregate_neighbors(self, h, edge_index):
        """Agrégation des features des voisins."""
        neighbors = torch.matmul(edge_index, h)
        return neighbors

In [6]:
# Routine d'entraînement pour la classification de graphes
def train_gnn(model, dataset, epochs=10, lr=0.001, batch_size=32):
    """Entraîne un modèle GNN sur le dataset de graphes MNIST."""
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: list(zip(*x)))
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        total_loss = 0
        for graphs, labels in dataloader:
            labels = torch.tensor(labels, dtype=torch.long)
            optimizer.zero_grad()
            output = model(graphs)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader):.4f}")

In [7]:
# Exemple d'utilisation
dataset = MNISTGraphDataset()
# dataset.visualize_graph(0)  # Visualiser un graphe MNIST

KeyboardInterrupt: 

In [None]:
# Définition des dimensions
in_features = 1  # Chaque pixel a une seule valeur d'intensité
embedding_dim = 32  # Taille de l'embedding
out_features = 64  # Taille finale avant classification
num_layers = 3  # Nombre de couches GNN

# Initialisation du modèle
model = GNNClassifier(in_features, embedding_dim, out_features, num_layers=num_layers, conv_type="gcn", pool_type="mean")

# Vérification du modèle
print(model)

In [None]:
train_gnn(model, dataset)