Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Latest commit

 

History

History
60 lines (42 loc) · 2.08 KB

graph_classification.rst

File metadata and controls

60 lines (42 loc) · 2.08 KB

Graph Classification

The Task

This task consist on classifying graphs. The task predicts which ‘class’ the graph belongs to. A class is a label that indicates the kind of graph. For example, a label may indicate whether one molecule interacts with another.

The ~flash.graph.classification.model.GraphClassifier and ~flash.graph.classification.data.GraphClassificationData classes internally rely on pytorch-geometric.


Example

Let's look at the task of classifying graphs from the KKI data set from TU Dortmund University.

Once we've created the TUDataset, we create the ~flash.graph.classification.data.GraphClassificationData. We then create our ~flash.graph.classification.model.GraphClassifier and train on the KKI data. Next, we use the trained ~flash.graph.classification.model.GraphClassifier for inference. Finally, we save the model. Here's the full example:

../../../flash_examples/graph_classification.py

To learn how to view the available backbones / heads for this task, see backbones_heads.


Flash Zero

The graph classifier can be used directly from the command line with zero code using flash_zero. You can run the above example with:

flash graph_classification

To view configuration options and options for running the graph classifier with your own data, use:

flash graph_classification --help