# **Graph Classification with 🤗 Transformers**

This notebook shows how to fine-tune the Graphormer model for Graph Classification on a dataset available on the hub. The idea is to add a randomly initialized classification head on top of a pre-trained encoder, and fine-tune the model altogether on a labeled dataset.

Depending on the model and the GPU you are using, you might need to adjust the batch size to avoid out-of-memory errors. Set those two parameters, then the rest of the notebook should run smoothly.

In this notebook, we'll fine-tune from the https://huggingface.co/clefourrier/pcqm4mv2-graphormer-base checkpoint.

## Fine-tuning Graphormer on an graph classification task

In this notebook, we will see how to fine-tune the Graphormer model on [🤗 Transformers](https://github.com/huggingface/transformers) on a Graph Classification dataset.

Given a graph, the goal is to predict its class.

### Loading the dataset

Loading a graph dataset from the Hub is very easy. Let's load the `ogbg-molhiv` dataset, stored in the `OGB` repository. 
*To find other graph datasets, look for the "Graph Machine Learning" tag on the hub:  [here](https://huggingface.co/datasets?task_categories=task_categories:graph-ml&sort=downloads). You'll find social graphs, molecular datasets, some artificial ones, etc!*

This dataset contains a collection of molecules (from MoleculeNet), and the goal is to predict if they to inhibit HIV or not. 


In [None]:
from datasets import load_dataset 

dataset = load_dataset("VincentPai/for-graphormer-v4")

# rename the label to y to fit the format of the input of the Graphormer
dataset['train'] = dataset['train'].rename_column('label', 'y')

dataset = dataset.shuffle(seed = 87)

Let us also load the Accuracy metric, which we'll use to evaluate our model both during and after training.

In [None]:
from datasets import load_metric

metric = load_metric("accuracy")

The `dataset` object itself is a [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key per split (in this case, "train", "validation" and "test" splits).

In [None]:
dataset
print(dataset["train"][0])

We can inspect the graph using networkx and pyplot.

In [None]:
import networkx as nx
import matplotlib.pyplot as plt
# We want to plot the first train graph
graph = dataset["train"][0] # would be a json here

# print(graph["node_feat"], type(graph["node_feat"]))
# graph["node_feat"].append(0)
# print(graph["node_feat"])

edges = graph["edge_index"]
print(graph["edge_index"])

num_edges = len(edges[0])
num_nodes = graph["num_nodes"]

# print(type(dataset), dataset["label"])

# print(graph["num_nodes"], type(graph["num_nodes"]))
# print(num_nodes)
# # num_nodes = graph["num_nodes"]
# num_nodes = 100
# graph["num_nodes"] = 5
# Conversion to networkx format
G = nx.Graph()
G.add_nodes_from(range(num_nodes))
G.add_edges_from([(edges[0][i], edges[1][i]) for i in range(num_edges)])

# Plot
# nx.draw(G)
nx.draw(G, with_labels=True)


### Preprocessing the data

Graph transformer frameworks usually apply specific preprocessing to their datasets to generate added features and properties which help the underlying learning task (classification in our case).

Here, we use Graphormer's default preprocessing, which generates in/out degree information, the shortest path between node matrices, and other properties of interest for the model. 

In [None]:
from transformers.models.graphormer.collating_graphormer import preprocess_item, GraphormerDataCollator

dataset_processed = dataset.map(preprocess_item, batched=False)

In [None]:
# split up training into training + validation

train_ds = dataset_processed['train']
val_ds = dataset_processed['validation']

Let's access an element to look at all the features we've added:

In [None]:
print(train_ds[0].keys())

### Training the model


Calling the `from_pretrained` method on our model downloads and caches the weights for us. As the number of classes (for prediction) is dataset dependent, we pass the new `num_classes` as well as `ignore_mismatched_sizes` alongside the `model_checkpoint`. This makes sure a custom classification head is created, specific to our task, hence likely different from the original decoder head. 

(When using a pretrained model, you must make sure the embeddings of your data have the same shape as the ones used to pretrain your model.)

In [None]:
from transformers import GraphormerForGraphClassification

model_checkpoint = "clefourrier/graphormer-base-pcqm4mv2" # pre-trained model from which to fine-tune

model = GraphormerForGraphClassification.from_pretrained(
    model_checkpoint, 
    num_classes=2,
    ignore_mismatched_sizes = True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)


The warning is telling us we are throwing away some weights (the weights and bias of the `classifier` layer) and randomly initializing some other (the weights and bias of a new `classifier` layer). This is expected in this case, because we are adding a new head for which we don't have pretrained weights, so the library warns us we should fine-tune this model before using it for inference, which is exactly what we are going to do.

To instantiate a `Trainer`, we will need to define the training configuration and the evaluation metric. The most important is the [`TrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html#transformers.TrainingArguments), which is a class that contains all the attributes to customize the training. It requires one folder name, which will be used to save the checkpoints of the model.

For graph datasets, it is particularly important to play around with batch sizes and gradient accumulation steps to train on enough samples while avoiding out-of-memory errors. 

In [None]:
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(
    "graph-classification",
    logging_dir="graph-classification",
    
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,

    auto_find_batch_size=True, # batch size can be changed automatically to prevent OOMs
    gradient_accumulation_steps=10,
    dataloader_num_workers=4, 

    num_train_epochs=2,

    evaluation_strategy="epoch",
    logging_strategy="epoch",
    push_to_hub=False,
)

In the `Trainer` for graph classification, it is important to pass the specific data collator for the given graph dataset, which will convert individual graphs to batches for training. 

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    # eval_dataset=val_ds,
    data_collator=GraphormerDataCollator()
)

We cna now train our model!

In [None]:
train_results = trainer.train()
# rest is optional but nice to have
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

You can now upload the result of the training to the Hub with the following:
- Need to login first(add some code in the front of the script)

In [None]:
trainer.push_to_hub()