In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class CustomSAGEConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(CustomSAGEConv, self).__init__(aggr='mean')
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index, edge_attr):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        edge_attr = torch.cat([edge_attr, torch.zeros(x.size(0), 1, device=edge_attr.device)], dim=0)  # Add zeros for self-loops
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_attr):
        return x_j * edge_attr.view(-1, 1)

    def update(self, aggr_out):
        return self.lin(aggr_out)

In [None]:
class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
        super(GraphSAGE, self).__init__()

        self.num_layers = num_layers

        self.convs = torch.nn.ModuleList()
        self.convs.append(CustomSAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(CustomSAGEConv(hidden_channels, hidden_channels))
        self.convs.append(CustomSAGEConv(hidden_channels, out_channels))

    def forward(self, x, adjs, edge_attrs, pos_pair=None, neg_pair=None):
        for i, (edge_index, edge_attr, size) in enumerate(zip(adjs, edge_attrs)):
            x_target = x[:size[1]]  # Target nodes are always placed first.
            x = self.convs[i](x, edge_index, edge_attr)
            if i != self.num_layers - 1:
                x = x.relu()
                x = F.dropout(x, p=0.5, training=self.training)

        if pos_pair is not None and neg_pair is not None:
            pos_out = x[pos_pair[0]], x[pos_pair[1]]
            neg_out = x[neg_pair[0]], x[neg_pair[1]]
            return pos_out, neg_out

        return x

In [None]:
import torch
from torch.nn import functional as F

def nce_loss(pos_out, neg_out, neg_sample_size):
    pos_out = pos_out.view(-1).unsqueeze(0)
    neg_out = neg_out.view(-1).unsqueeze(0)
    
    out = torch.cat((pos_out, neg_out), 1)
    out = F.log_softmax(out * 100, dim=1)
    return -out[0][0]

In [None]:
def train(data, model, optimizer, neg_sample_size, device, epochs):
    model = model.to(device)
    data = data.to(device)

    for epoch in range(epochs):
        model.train()

        # Negative sampling
        total_nodes = data.x.size(0)
        neg_nodes = torch.randint(0, total_nodes, (neg_sample_size, ), dtype=torch.long, device=data.x.device)

        # Positive pair
        pos_pair = data.edge_index

        # Negative pair
        neg_pair = torch.stack([data.edge_index[0], neg_nodes])

        # Obtain the node embeddings
        pos_out, neg_out = model(data.x, data.edge_index, data.edge_attr, pos_pair, neg_pair)

        loss = nce_loss(pos_out, neg_out, neg_sample_size)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        print(f"Epoch: {epoch+1}, Loss: {loss.item()}")


In [None]:
from torch_geometric.data import NeighborSampler
from torch.optim import AdamW

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)

# Model and optimizer
model = GraphSAGE(in_channels=data.num_node_features, hidden_channels=64, out_channels=128, num_layers=2).to(device)
optimizer = AdamW(model.parameters(), lr=0.01)

# Data loader
loader = NeighborSampler(edge_index=data.edge_index, sizes=[-1]*model.num_layers, batch_size=128)

# Train the model
epochs = 20
neg_sample_size = 5
train(data, model, optimizer, neg_sample_size, device, epochs)

# Extract embeddings
model.eval()
with torch.no_grad():
    x, _ = model(data.x, data.edge_index, data.edge_attr)
