In [1]:
import itertools
import os

os.environ["DGLBACKEND"] = "pytorch"

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from dgl.nn.pytorch import edge_softmax, GraphConv

In [2]:
glist, label_dict = dgl.load_graphs("./hetero_graph.bin")
g = glist[0]

In [21]:
class HeteroGraphConvLayer(nn.Module):
    def __init__(self, in_size_species, in_size_molecule, out_size_species, out_size_mol):
        super(HeteroGraphConvLayer, self).__init__()
        self.conv = dgl.nn.HeteroGraphConv({
            'is_present_in': dgl.nn.GraphConv(in_size_molecule, out_size_species),
            'has': dgl.nn.GraphConv(in_size_species, out_size_mol),
            'similar_to': dgl.nn.GraphConv(in_size_molecule, out_size_mol)
        })

    def forward(self, g, h_dict):
        h_dict = self.conv(g, h_dict)
        return h_dict

In [22]:
# Initialization
in_size_species = 439
in_size_molecules = 378
hidden_size = 128

out_size_species = 885  # assuming you have 885 different species to predict
out_size_mol =4872

model = HeteroGraphConvLayer(in_size_species, in_size_molecules, out_size_species, out_size_mol)


In [33]:
def dot_score(x, y):
    return (x * y).sum(1)

def train(model, g, etype, epochs=100):
    optimizer = torch.optim.Adam(model.parameters())
    best_val_loss = float('inf')
    best_model = None
    
    for epoch in range(epochs):
        model.train()

        # Get embeddings for all nodes
        h_dict = model(g, g.ndata['h'])
        print(h_dict['molecule'].shape)
        
        # Compute scores for all existing and non-existing links
        pos_score = dot_score(h_dict['molecule'][g.edges(etype=etype)[0]],
                              h_dict['molecule'][g.edges(etype=etype)[1]])
        neg_score = dot_score(h_dict['species'][torch.randint(0, g.num_nodes('species'), (g.num_edges(),))],
                              h_dict['species'][torch.randint(0, g.num_nodes('species'), (g.num_edges(),))])
        
        # Compute the loss
        pos_loss = F.binary_cross_entropy_with_logits(pos_score, torch.ones(g.num_edges()))
        neg_loss = F.binary_cross_entropy_with_logits(neg_score, torch.zeros(g.num_edges()))
        loss = pos_loss + neg_loss

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print('Epoch', epoch, 'Loss', loss.item())

        # Save the model with the best validation loss
        if loss < best_val_loss:
            best_val_loss = loss
            best_model = copy.deepcopy(model)

    return best_model

In [34]:
best_model = train(model, g, etype='is_present_in')

torch.Size([4872, 4872])


ValueError: Target size (torch.Size([34212])) must be the same as input size (torch.Size([10000]))

In [20]:
g

Graph(num_nodes={'molecule': 4872, 'species': 885},
      num_edges={('molecule', 'is_present_in', 'species'): 10000, ('molecule', 'similar_to', 'molecule'): 14212, ('species', 'has', 'molecule'): 10000},
      metagraph=[('molecule', 'species', 'is_present_in'), ('molecule', 'molecule', 'similar_to'), ('species', 'molecule', 'has')])