Je me suis basé sur le tutoriel de Titouan Parcollet (https://github.com/TParcollet/Tutoriel-Graph-Neural-Networks) pour le modèle, le train et le test que j'ai adapté à la tâche demandée.

Le traitement des données (Dataloader, AtomEncoder), ainsi que l'évaluation (Evaluator) sont donnés par le sujet du défi : https://ogb.stanford.edu/docs/graphprop/#ogbg-mol

In [1]:
# Je me suis basé sur le tutoriel de Titouan Parcollet (https://github.com/TParcollet/Tutoriel-Graph-Neural-Networks) pour le modèle, le train et le test

In [2]:
!pip install ogb
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cu113.html

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting ogb
  Downloading ogb-1.3.4-py3-none-any.whl (78 kB)
[K     |████████████████████████████████| 78 kB 7.7 MB/s 
Collecting outdated>=0.2.0
  Downloading outdated-0.2.1-py3-none-any.whl (7.5 kB)
Collecting littleutils
  Downloading littleutils-0.2.2.tar.gz (6.6 kB)
Building wheels for collected packages: littleutils
  Building wheel for littleutils (setup.py) ... [?25l[?25hdone
  Created wheel for littleutils: filename=littleutils-0.2.2-py3-none-any.whl size=7048 sha256=e6496e4da96f5a461a32c1bf168816cccbec88f644b3b5cbd6456dc24514868d
  Stored in directory: /root/.cache/pip/wheels/d6/64/cd/32819b511a488e4993f2fab909a95330289c3f4e0f6ef4676d
Successfully built littleutils
Installing collected packages: littleutils, outdated, ogb
Successfully installed littleutils-0.2.2 ogb-1.3.4 outdated-0.2.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/pub

In [3]:
from ogb.graphproppred import PygGraphPropPredDataset
from ogb.graphproppred import Evaluator

from ogb.graphproppred.mol_encoder import AtomEncoder

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.loader import DataLoader
import torch_geometric.nn as pyg_nn

In [14]:
class GNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, criterion, conv_type="GCN", num_layer=3, dropout=0.25):
        super(GNN, self).__init__()
        self.dropout = dropout
        self.num_layers = num_layer
        self.conv_type = conv_type
        self.criterion = criterion
        self.conv_bn = pyg_nn.BatchNorm(hidden_dim)
        self.convs = nn.ModuleList()
        self.convs.append(self.build_conv_model(input_dim, hidden_dim))
        
        for l in range(num_layer-1):
            self.convs.append(self.build_conv_model(hidden_dim, hidden_dim))

        self.post_mp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim), nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim))

    def build_conv_model(self, input_dim, hidden_dim):
        if self.conv_type == "SAGE":
            return pyg_nn.SAGEConv(input_dim, hidden_dim)
        elif self.conv_type == "GIN":
            return pyg_nn.GINConv(nn.Sequential(nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)))
        elif self.conv_type == "GATConv":
            return pyg_nn.GATConv(input_dim, hidden_dim)
        elif self.conv_type == "GATv2Conv":
            return pyg_nn.GATv2Conv(input_dim, hidden_dim)
        else:
            return pyg_nn.GCNConv(input_dim, hidden_dim)

    def forward(self, x, edge_index, batch):
        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv_bn(x)
        x = pyg_nn.global_mean_pool(x, batch)
        x = self.post_mp(x)
        return F.log_softmax(x, dim=1)

    def loss(self, pred, label):
        return self.criterion(pred, label)

In [15]:
def train(model, dataset, epochs, print_steps, batch_size, optimizer, atom_encoder, evaluator):

    # Traitement des données
    split_idx = dataset.get_idx_split()

    train_loader = DataLoader(dataset[split_idx["train"]], batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(dataset[split_idx["test"]], batch_size=batch_size, shuffle=False)

    # Apprentissage
    print("Learning...")
    best_valid_rocauc = 0
    best_epoch = 0
    best_model = model
    best_loss = 0

    for epoch in range(1, epochs + 1):
        total_loss = 0
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            x, edge_index, num_batch = batch.x, batch.edge_index, batch.batch
            atom_emb = atom_encoder(x)
            out = model(atom_emb, edge_index, num_batch)
            label = batch.y
            loss = model.loss(out, label.squeeze(1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * batch.num_graphs
        total_loss /= len(train_loader.dataset)

        valid_rocauc = test(valid_loader, model, evaluator, atom_encoder)

        if valid_rocauc > best_valid_rocauc:
            best_valid_rocauc = valid_rocauc
            torch.save(model, 'model.pth')
            best_epoch = epoch
            best_loss = total_loss

        if epoch % print_steps == 0:
            print("Itération {}. Loss: {:.4f}. Validation: {:.4f}".format(
                epoch, total_loss, valid_rocauc))

    print()
    print()
    print("Evaluating...")
    print()
    best_model = torch.load('model.pth')
    test_rocauc = test(test_loader, best_model, evaluator, atom_encoder)
    print(f'Best model at epoch: {best_epoch:02d}')
    print(f'Loss: {best_loss:.4f}, '
          f'Test: {100 * test_rocauc:.2f}%')

    return model

In [16]:
def test(loader, model, evaluator, atom_encoder):
    model.eval()

    preds = []
    labels = []

    with torch.no_grad():
      for data in loader:
            x, edge_index, num_batch = data.x, data.edge_index, data.batch
            atom_emb = atom_encoder(x)
            out = model(atom_emb, edge_index, num_batch)[:, 1]
            label = data.y
            preds.append(out.detach().cpu())
            labels.append(label.view(label.shape).detach().cpu())

    preds = torch.cat(preds, dim=0).unsqueeze(1).numpy()
    labels = torch.cat(labels, dim=0).numpy()

    input_dict = {"y_true": labels, "y_pred": preds}
    return evaluator.eval(input_dict)['rocauc']

In [17]:
def run():
    d_name = "ogbg-molhiv"

    dataset = PygGraphPropPredDataset(name=d_name)
    dataset.data.to(device)
    evaluator = Evaluator(name=d_name)

    epochs = 20
    print_steps = 1

    lr = 0.001
    batch_size = 32

    input_dim = 100
    hidden_dim = 128
    conv_type = "GATv2Conv"
    num_conv_layer = 3
    dropout = 0.25

    atom_encoder = AtomEncoder(emb_dim=input_dim).to(device)

    criterion = nn.CrossEntropyLoss()
    model = GNN(input_dim, hidden_dim, dataset.num_classes, criterion, conv_type=conv_type, num_layer=num_conv_layer, dropout=dropout).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    model = train(model, dataset, epochs, print_steps, batch_size, optimizer, atom_encoder, evaluator)

In [18]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)
print("Device : ", device)

run()

Device :  cuda
Learning...
Itération 1. Loss: 0.1740. Validation: 0.7133
Itération 2. Loss: 0.1553. Validation: 0.7242
Itération 3. Loss: 0.1534. Validation: 0.7213
Itération 4. Loss: 0.1503. Validation: 0.6021
Itération 5. Loss: 0.1498. Validation: 0.7268
Itération 6. Loss: 0.1486. Validation: 0.6696
Itération 7. Loss: 0.1479. Validation: 0.7597
Itération 8. Loss: 0.1481. Validation: 0.7436
Itération 9. Loss: 0.1463. Validation: 0.7529
Itération 10. Loss: 0.1462. Validation: 0.7420
Itération 11. Loss: 0.1457. Validation: 0.7571
Itération 12. Loss: 0.1447. Validation: 0.7437
Itération 13. Loss: 0.1437. Validation: 0.7431
Itération 14. Loss: 0.1433. Validation: 0.7108
Itération 15. Loss: 0.1432. Validation: 0.7251
Itération 16. Loss: 0.1430. Validation: 0.7538
Itération 17. Loss: 0.1422. Validation: 0.7152
Itération 18. Loss: 0.1413. Validation: 0.7510
Itération 19. Loss: 0.1414. Validation: 0.7597
Itération 20. Loss: 0.1409. Validation: 0.7983


Evaluating...

Best model at epoch: 20
L