In [1]:
import pickle
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from dgl.nn import GraphConv, RelGraphConv, HeteroGraphConv
import dgl.function as fn
import torchvision.models as models
import torchvision.transforms as transforms

In [2]:
# Load the pickle file containing the heterogeneous graph
with open('test_complete_hetero_graph.pkl', 'rb') as f:
    hetero_graph = pickle.load(f)

In [3]:
edge_label = hetero_graph.etypes #Collecting all edge type
#print(len(edge_label))
edge_label = ['To' if item == 'to' else item for item in edge_label] # Changing the 'to' edge types 'To'
set_edge_label = edge_label
# Convert back to set
set_edge_label = set(set_edge_label)
#print((set_edge_label))

# Model type 1 #
** To predict the type of an edge, use MuduleDict which is a dictionary that holds each type of edge. HeteroGraphConv performs graph convolution on edge type on all nodes type to get a node embbeding**

# Model type 2 #
** To predict the type of an edge, the HeteroDotProductPredictor takes in another graph with only one edge type that merges all the edge types to be predicted, and emits the score of each type for every edge. **

In [14]:
subgraphs = {}

# Iterate over each relationship type
for triplet in hetero_graph.canonical_etypes:
    # Extract the source and destination node types for the relationship
    #print(triplet)
    src_type, _, dst_type = triplet
    # Create a subgraph for the relationship type
    subgraph = hetero_graph[triplet]
    user_feats = hetero_graph.nodes[src_type].data['features']
    item_feats = hetero_graph.nodes[dst_type].data['features']
    node_features = {src_type: user_feats, dst_type: item_feats}
    #print(node_features)
    # Add the subgraph to the dictionary with the relationship type as key
    subgraphs[(src_type, dst_type)] = (subgraph,node_features)

print( subgraphs)

{('1', 'jersey'): (Graph(num_nodes={'1': 386, 'jersey': 385},
      num_edges={('1', 'ON', 'jersey'): 1},
      metagraph=[('1', 'jersey', 'ON')]), {'1': tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.4258, 0.6753],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.4258, 0.6753],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.4258, 0.6753],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.4258, 0.6753],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.4258, 0.6753],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.4258, 0.6753]]), 'jersey': tensor([[0.3769, 0.0000, 0.0000,  ..., 0.0000, 0.8917, 0.0000],
        [0.3769, 0.0000, 0.0000,  ..., 0.0000, 0.8917, 0.0000],
        [0.3769, 0.0000, 0.0000,  ..., 0.0000, 0.8917, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 2.4999, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 2.4999, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 2.4999, 0.0000]])}), ('103', 'car'): (Graph(num_nodes={'1

In [5]:
# performs a separate graph convolution on each edge type, then sums the message aggregations on each edge type as the final result for all node types
class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()

        self.conv1 = HeteroGraphConv({
            rel: GraphConv(in_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        print(inputs.keys())
        h = self.conv1(graph, inputs)
        print(h.keys())
        h = {k: F.relu(v) for k, v in h.items()}
        print(h.keys())
        return h

In [17]:
class HeteroEdgePredictor(nn.Module):
    def __init__(self, in_dims, n_classes):
        super().__init__()
        self.W = nn.Linear(in_dims*2, n_classes)

    def apply_edges(self, edges):
        h_u = edges.src['h']
        h_v = edges.dst['h']
        y = self.W(torch.cat([h_u, h_v], 1))
        
        #x = torch.cat([edges.src['h'], edges.dst['h']], 1)
        #print(x.shape) 
        #y = self.W(x) 
        return {'score': y}

    def forward(self, graph, h):
        with graph.local_scope():
            #first_key = next(iter(h.keys()))  # Get the first key in the dictionary
            #(h ) 
            for ntype, features in h.items():
                graph.nodes[ntype].data['h'] = features
            print(graph)    
            graph.apply_edges(self.apply_edges)
            return graph.edata['score']


In [18]:
class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, rel_names):
        super().__init__()
        self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
        self.pred = HeteroEdgePredictor(out_features, len(rel_names))
    def forward(self, g, x, dec_graph):
        h = self.sage(g, x)
        return self.pred(dec_graph, h)

In [19]:
model = Model(4096, 128, len(set_edge_label), set_edge_label) #set_edge_label length = 3196


# Define your training loop
def train_model(model, optimizer, hetero_graph, node_features, dec_graph, set_edge_label, num_epochs=10):
    for epoch in range(num_epochs):
        logits = model(hetero_graph, node_features, dec_graph)
        loss = F.cross_entropy(logits, edge_label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")


optimizer = torch.optim.Adam(model.parameters())

# Iterate over each relationship type
for relationship, (subgraph, node_features) in subgraphs.items():
    print(f"Training model for relationship: {relationship}")
        
    # Train the model on the current subgraph and its node features
    train_model(model, optimizer, hetero_graph, node_features, subgraph, edge_label)


Training model for relationship: ('1', 'jersey')
dict_keys(['1', 'jersey'])
dict_keys(['jersey'])
dict_keys(['jersey'])
Graph(num_nodes={'1': 386, 'jersey': 385},
      num_edges={('1', 'ON', 'jersey'): 1},
      metagraph=[('1', 'jersey', 'ON')])


KeyError: 'h'

# Model 3 #

In [None]:
class TestRGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, canonical_etypes):
        super(TestRGCN, self).__init__()

        self.conv1 = dglnn.HeteroGraphConv({
                etype : dglnn.GraphConv(in_feats[utype], hid_feats, norm='right')
                for utype, etype, vtype in canonical_etypes
                })
        self.conv2 = dglnn.HeteroGraphConv({
                etype : dglnn.GraphConv(hid_feats, out_feats, norm='right')
                for _, etype, _ in canonical_etypes
                })

    def forward(self, blocks, inputs):
        x = self.conv1(blocks[0], inputs)
        x = self.conv2(blocks[1], x)

        return x

class HeteroScorePredictor(nn.Module):
    def forward(self, edge_subgraph, x):
        with edge_subgraph.local_scope():
            edge_subgraph.ndata['h'] = x
            for etype in edge_subgraph.canonical_etypes:
                edge_subgraph.apply_edges(dgl.function.u_dot_v('h', 'h', 'score'), etype=etype)
                # edge_subgraph.apply_edges(self.apply_edges, etype=etype)
            return edge_subgraph.edata['score']

class TestModel(nn.Module):
    # here we have a model that first computes the representation and then predicts the scores for the edges
    def __init__(self, in_features, hidden_features, out_features, canonical_etypes):
        super().__init__()
        self.sage = TestRGCN(in_features, hidden_features, out_features, canonical_etypes)
        self.pred = HeteroScorePredictor()
    def forward(self, g, neg_g, blocks, x):
        x = self.sage(blocks, x)
        pos_score = self.pred(g, x)
        neg_score = self.pred(neg_g, x)
        return pos_score, neg_score 

def compute_loss(pos_score, neg_score, canonical_etypes):
    # Margin loss
    all_losses = []
    for given_type in canonical_etypes:
        n_edges = pos_score[given_type].shape[0]
        if n_edges == 0:
            continue
        all_losses.append((1 - neg_score[given_type].view(n_edges, -1) + pos_score[given_type].unsqueeze(1)).clamp(min=0).mean())
    return torch.stack(all_losses, dim=0).mean()

model = TestModel(in_features={'source':700, 'user':800}, hidden_features=512, out_features=256, canonical_etypes=g.canonical_etypes)
...
for epoch in range(args.n_epochs):
        for input_nodes, positive_graph, negative_graph, blocks in dataloader:
            model.train()
            blocks = [b.to(torch.device('cuda')) for b in blocks]

            positive_graph = positive_graph.to(torch.device('cuda'))
            negative_graph = negative_graph.to(torch.device('cuda'))

            node_features = {'source': blocks[0].srcdata['source_embedding']['source'], 'user': blocks[0].srcdata['user_embedding']['user']}
            pos_score, neg_score = model(positive_graph, negative_graph, blocks, node_features)
            loss = compute_loss(pos_score, neg_score, g.canonical_etypes)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()