<a href="https://colab.research.google.com/github/Ahmed-A-A-Elhag/GraphSAGE/blob/main/GraphSAGE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
!pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
!pip install -q torch-geometric

[K     |████████████████████████████████| 3.0 MB 6.4 MB/s 
[K     |████████████████████████████████| 1.6 MB 7.2 MB/s 
[K     |████████████████████████████████| 222 kB 8.1 MB/s 
[K     |████████████████████████████████| 376 kB 43.2 MB/s 
[K     |████████████████████████████████| 45 kB 3.4 MB/s 
[?25h  Building wheel for torch-geometric (setup.py) ... [?25l[?25hdone


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
import math

import torch_geometric
from torch_geometric.utils import to_dense_adj


In [None]:
seed = 150
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
device

device(type='cpu')

In [None]:
def train(model, data, num_epochs, use_edge_index=False):
    if not use_edge_index:

        # Create the adjacency matrix
        adj = to_dense_adj(data.edge_index)[0]

    else:

        # Directly use edge_index, ignore this branch for now
        adj = data.edge_index
        
    model.to(device)
    data.to(device)
    # Set up the optimizer
    
    optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

    # A utility function to compute the accuracy
    def get_acc(outs, y, mask):
        return (outs[mask].argmax(dim=1) == y[mask]).sum().float() / mask.sum()

    best_acc_val = -1
    for epoch in range(num_epochs):

        # Zero grads -> forward pass -> compute loss -> backprop
        
        # train mode
        model.train()

        optimizer.zero_grad()
        outs = model(data.x.to(device), adj.to(device))

        # null_loss 

        loss = torch.nn.functional.nll_loss(outs[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()

        # Compute accuracies, print only if this is the best result so far

        # evaluation mode
        model.eval()

        # data.x = the features of the dataset

        outs = model(data.x, adj)

        # validation accuracy 
        acc_val = get_acc(outs, data.y, data.val_mask)

        # test accuracy 
        acc_test = get_acc(outs, data.y, data.test_mask)

        # print the accuracy if it’s incresed
        if acc_val > best_acc_val:
            best_acc_val = acc_val
            print(f'[Epoch {epoch+1}/{num_epochs}] Loss: {loss} | Val: {acc_val:.3f} | Test: {acc_test:.3f}')

    print(f'[Epoch {epoch+1}/{num_epochs}] Loss: {loss} | Val: {acc_val:.3f} | Test: {acc_test:.3f}')

In [None]:
class GraphSAGE_Mean(torch.nn.Module):
    """
    GraphSAGE_Mean layer
    """

    def __init__(self, in_features, out_features, normalize = True, bias = False):  
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.normalize = normalize
        self.bias = bias

        # linear transformation that apply to embedding for central node
        self.linear_l = torch.nn.Linear(self.in_features, self.out_features, bias = self.bias)
        
        #linear transformation that you apply to aggregated message from neighbors
        self.linear_r = torch.nn.Linear(self.in_features, self.out_features, bias = self.bias)




    def forward(self, fts, edge_index):

        out = None
        u, v = edge_index
        aggregate = scatter(fts[v].to(device), u.to(device), dim = 0, reduce='mean')
        
        fts = self.linear_l(fts)
        aggregate = self.linear_r(aggregate)

        out = fts + aggregate


        if self.normalize:
            out = out/torch.norm(out, dim=1).unsqueeze(-1)

        return out.log_softmax(dim=-1).to(device)



In [None]:
class GraphSAGE_MaxPooling(torch.nn.Module):
    """
    GraphSAGE_MaxPooling layer
    """

    def __init__(self, in_features, out_features, normalize = True, bias = False):  
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.normalize = normalize
        self.bias = bias

        # linear transformation that apply after concatenation 
        self.linear_l = torch.nn.Linear(2*self.in_features, self.out_features, bias = self.bias)
        
        #linear transformation that you apply to neighbors features before max pooling
        self.linear_r = torch.nn.Linear(self.in_features, self.in_features, bias = self.bias)

        # non-linearity before pooling
        self.relu = torch.nn.ReLU()


    def forward(self, fts, edge_index):

        
        out = None
        u, v = edge_index
        
        aggregate = scatter(self.relu(fts[v].to(device)), u.to(device), dim=0, reduce="max")
        

        # aggregate = self.linear_r(fts.to(device))
        # aggregate = self.relu(aggregate)
        # aggregate = scatter(aggregate[v], u.to(device), dim = 0, reduce='max')


        out = torch.cat([fts, aggregate], dim= 1)

   
        out = self.linear_l(out)

        if self.normalize:
            out = out/torch.norm(out, dim=1).unsqueeze(-1)

        return out.to(device)



In [None]:

class GraphSAGE(torch.nn.Module):
    def __init__(self, nfeat, nhid, nclass, aggregator = 'Mean'):
        super().__init__()
        if(aggregator == 'Mean'):
            self.gc1 = GraphSAGE_Mean(nfeat, nhid)
            self.gc2 = GraphSAGE_Mean(nhid, nclass)

        elif(aggregator == 'MaxPooling'):
            self.gc1 = GraphSAGE_MaxPooling(nfeat, nhid)
            self.gc2 = GraphSAGE_MaxPooling(nhid, nclass)



        self.relu = torch.nn.ReLU()

    def forward(self, fts, adj):
        fts = self.relu(self.gc1(fts, adj))
        fts = self.gc2(fts, adj)
        return fts

In [None]:
Cora = torch_geometric.datasets.Planetoid(root='/', name='Cora')

In [None]:
train(GraphSAGE(Cora.num_features, 1024, Cora.num_classes, aggregator = 'MaxPooling'), Cora[0], num_epochs=300, use_edge_index=True)