This task consist on classifying graphs. The task predicts which ‘class’ the graph belongs to. A class is a label that indicates the kind of graph. For example, a label may indicate whether one molecule interacts with another.
The ~flash.graph.classification.model.GraphClassifier
and ~flash.graph.classification.data.GraphClassificationData
classes internally rely on pytorch-geometric.
Let's look at the task of classifying graphs from the KKI data set from TU Dortmund University.
Once we've created the TUDataset, we create the ~flash.graph.classification.data.GraphClassificationData
. We then create our ~flash.graph.classification.model.GraphClassifier
and train on the KKI data. Next, we use the trained ~flash.graph.classification.model.GraphClassifier
for inference. Finally, we save the model. Here's the full example:
../../../flash_examples/graph_classification.py
To learn how to view the available backbones / heads for this task, see backbones_heads
.
The graph classifier can be used directly from the command line with zero code using flash_zero
. You can run the above example with:
flash graph_classification
To view configuration options and options for running the graph classifier with your own data, use:
flash graph_classification --help