# Multi-Label Classification on Protein-Protein Interaction

In [2]:
import torch
from sklearn.metrics import f1_score

from torch_geometric.datasets import PPI
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader, NeighborLoader
from torch_geometric.nn import GraphSAGE

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

In [3]:
# Load training, evaluation, and test sets
train_dataset = PPI(root='./datasets/', split='train')
val_dataset = PPI(root='./datasets/', split='val')
test_dataset = PPI(root='./datasets/', split='test')

In [4]:
# Unify the training graphs and apply neighbor sampling
train_data = Batch.from_data_list(train_dataset)
train_loader = NeighborLoader(
    train_data,
    batch_size=2048,
    shuffle=True,
    num_neighbors=[20, 10],
    num_workers=2,
    persistent_workers=True,
)

In [5]:
print(train_data)
print('Num classes:', train_dataset.num_classes)
print('Num node features:', train_dataset.num_features)

DataBatch(x=[44906, 50], edge_index=[2, 1226368], y=[44906, 121], batch=[44906], ptr=[21])
Num classes: 121
Num node features: 50


- Each node is a protein, with 50 features (x)
- Each node belongs to as much as 121 classes (y)

In [6]:
# Evaluation loaders (one datapoint corresponds to a graph)
val_loader = DataLoader(val_dataset, batch_size=2)
test_loader = DataLoader(test_dataset, batch_size=2)

In [7]:
model = GraphSAGE(
    in_channels=train_dataset.num_features,
    hidden_channels=64,
    num_layers=2,
    out_channels=train_dataset.num_classes,
).to(device)

criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

In [8]:
def fit(loader):
    model.train()
    total_loss = 0

    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()

        out = model(data.x, data.edge_index)
        loss = criterion(out, data.y)
        # we have 2 graphs per batch, so we multiply the loss by 2
        total_loss += loss.item() * data.num_graphs  

        loss.backward()
        optimizer.step()

    return total_loss / len(loader.dataset)

In [13]:
@torch.no_grad()
def test(loader):
    model.eval()

    data = next(iter(loader))
    out = model(data.x.to(device), data.edge_index.to(device))
    preds = (out > 0).float().cpu()
    print(out)
    print(preds)

    y, pred = data.y.numpy(), preds.numpy()
    return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0

In [14]:
# test(val_loader)

tensor([[ 0.0813, -0.1173,  0.1798,  ...,  0.2478, -0.0782,  0.3677],
        [ 0.5665, -0.2947, -0.2368,  ...,  0.0677,  0.3263,  0.1181],
        [ 0.1805, -0.1533, -0.1491,  ..., -0.0139,  0.0287,  0.1809],
        ...,
        [ 0.4943, -0.3411, -0.4811,  ...,  0.1616, -0.1372,  0.0780],
        [ 0.1127,  0.0255, -0.0986,  ..., -0.0507, -0.0936, -0.0121],
        [ 0.0092, -0.1319, -0.1448,  ..., -0.1282,  0.0401,  0.1612]],
       device='mps:0')
tensor([[1., 0., 1.,  ..., 1., 0., 1.],
        [1., 0., 0.,  ..., 1., 1., 1.],
        [1., 0., 0.,  ..., 0., 1., 1.],
        ...,
        [1., 0., 0.,  ..., 1., 0., 1.],
        [1., 1., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 1., 1.]])


0.3652859757197721

In [None]:
for epoch in range(301):
    loss = fit(train_loader)
    val_f1 = test(val_loader)
    if epoch % 50 == 0:
        print(f'Epoch {epoch:>3} | Train Loss: {loss:.3f} | Val F1-score: {val_f1:.4f}')

Epoch   0 | Train Loss: 0.006 | Val F1-score: 0.4921
Epoch  50 | Train Loss: 0.004 | Val F1-score: 0.7926
Epoch 100 | Train Loss: 0.004 | Val F1-score: 0.8093
Epoch 150 | Train Loss: 0.004 | Val F1-score: 0.8188
Epoch 200 | Train Loss: 0.004 | Val F1-score: 0.8226
Epoch 250 | Train Loss: 0.004 | Val F1-score: 0.8224
Epoch 300 | Train Loss: 0.004 | Val F1-score: 0.8261

In [None]:
print(f'Test F1-score: {test(test_loader):.4f}')

Test F1-score: 0.6370
