In [None]:
from torch_geometric.nn import GCNConv, SAGEConv, Linear, to_hetero
import torch.nn.functional as F
from torch import Tensor

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.conv1 = GCNConv(hidden_channels, hidden_channels, add_self_loops = False, cached = True)
        self.conv2 = GCNConv(hidden_channels, hidden_channels, add_self_loops = False, cached = True)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x : Tensor, edge_index : Tensor):
        x = self.sigmoid(self.conv1(x, edge_index))
        x = self.sigmoid(self.conv2(x, edge_index))
        return x
    
class Classifier(torch.nn.Module):
    def __init__(self, input_channels):
        super().__init__()
        self.linear = Linear(input_channels, 1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, node, edge) -> Tensor:
        edge_feat_oer_before = node["OER"][edge["before"][0]]
        edge_feat_oer_after = node["OER"][edge["before"][1]]
        edge_vec = torch.cat((edge_feat_oer_before, edge_feat_oer_after), dim = 1)
        prod = self.linear(edge_vec)
        return torch.squeeze(prod)
    
class Model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.oer_emb = torch.nn.Embedding(data["OER"].num_nodes, hidden_channels)
        self.oer_x = data["OER"].x
        self.concept_emb = torch.nn.Embedding(data["Concept"].num_nodes, hidden_channels)
        self.concept_x = data["Concept"].x
        self.class_emb = torch.nn.Embedding(data["Class"].num_nodes, hidden_channels)
        self.class_x = data["Class"].x
        self.gnn = GNN(hidden_channels)
        self.gnn = to_hetero(self.gnn, metadata = data.metadata())
        self.classifier = Classifier(hidden_channels * 2)

    def forward(self, data : HeteroData) -> Tensor:
        node_dict = {
            "OER" : self.oer_emb(data["OER"].node_id),
            "Concept" : self.concept_emb(data["Concept"].node_id),
            "Class" : self.concept_emb(data["Class"].node_id)
        }
        node_dict = self.gnn(node_dict, data.edge_index_dict)
        
        edge_dict = {
            "before" : data["OER", "before", "OER"].edge_label_index,
            "covers" : data['OER', 'covers', 'Concept'].edge_label_index,
            "belongs" : data['Concept', 'belongs', 'Class'].edge_label_index
        }
        
        pred = self.classifier(
            node_dict,
            edge_dict
        )

        return pred
   
model = Model(hidden_channels = 128)

In [None]:
from torch_geometric.nn import HeteroConv, GCNConv, GATConv, Linear
import torch.nn.functional as F
from torch import Tensor

class HeteroGNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_layers):
        super().__init__()
        self.convs = torch.nn.ModuleList()

        for _ in range(num_layers):
            conv = HeteroConv({
                ('OER', 'before', 'OER') : GCNConv(-1, hidden_channels, add_self_loops = False, cached = True),
                ('OER', 'covers', 'Concept') : GATConv((-1, -1), hidden_channels, add_self_loops = False, cached = True),
                ('Concept', 'belongs', 'Class') : GCNConv(-1, hidden_channels, add_self_loops = False, cached = True)
            }, aggr = 'mean')#experiment with cat for aggr instead of mean
            self.convs.append(conv) 

        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        x_dict = {
            node_type: self.convs[node_type](x).relu_()
            for node_type, x in x_dict.items()
        }
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            #x_dict = {key: x.relu() for key, x in x_dict.items()}
        return self.lin(x_dict['OER'])


class Classifier(torch.nn.Module):
    def __init__(self, input_channels):
        super().__init__()
        self.linear = Linear(input_channels, 1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, node, edge) -> Tensor:
        edge_feat_oer_before = node["OER"][edge["before"][0]]
        edge_feat_oer_after = node["OER"][edge["before"][1]]
        edge_vec = torch.cat((edge_feat_oer_before, edge_feat_oer_after), dim = 1)
        prod = self.linear(edge_vec)
        return torch.squeeze(prod)


class Model(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_layers):
        super().__init__()
        self.gnn = HeteroGNN(hidden_channels, out_channels, num_layers)
        self.gnn = self.gnn.float()
        self.classifier = Classifier(hidden_channels * 2)

    def forward(self, data : HeteroData) -> Tensor:
        node_dict = {
            "OER" : data["OER"].x,
            "Concept" : data["Concept"].x,
            "Class" : data["Class"].x
        }
        node_dict = self.gnn(node_dict, data.edge_index_dict)
        
        edge_dict = {
            "before" : data["OER", "before", "OER"].edge_label_index,
            "covers" : data['OER', 'covers', 'Concept'].edge_label_index,
            "belongs" : data['Concept', 'belongs', 'Class'].edge_label_index
        }
        
        pred = self.classifier(
            node_dict,
            edge_dict
        )

        return pred


model = Model(hidden_channels = 64, out_channels = 1, num_layers = 2)