<a href="https://colab.research.google.com/github/aldrich1221/recommendation-system/blob/main/GNN/GraphSAGE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

def sampling(sourceNodes,samplingNum,nodeTable):
  results=[]
  for sid in sourceNodes:
    sampleNeighbor=np.random.choice(nodeTable[sid],size=(samplingNum,))
    results.append(sampleNeighbor)
  return np.asarray(results).flatten()

def multiSampling(sourceNodes,samplingNum,nodeTable):
  multiSamplingResult=[sourceNodes]
  for i,samplenum in enumerate(samplingNum):
    hopKResult=sampling(multiSamplingResult[i],samplenum,nodeTable)
    multiSamplingResult.append(hopKResult)
  return multiSamplingResult



In [None]:
class Aggregator(nn.Module):
  def __init__(self,inputDim,outputDim,aggMethod='mean'):
    super(Aggregator,self).__init__()
    self.inputDim=inputDim
    self.outputDim=outputDim
    self.useBias=useBias
    self.aggMethod=aggMethod
    self.weight=nn.Parameter(torch.Tensor(inputDim,outputDim))
   
    self.bias=nn.Parameter(torch.Tensor(outputDim))
    self.reset()

  def reset(self):
    init.kaiming_uniform_(self.weight)
    if self.useBias:
      init.zeros_(self.bias)
  def forward(self,neighborFeature):
    if self.aggMethod=='mean':
      aggNeighbor=neighborFeature.mean(dim=1)
    elif self.aggMethod=='max':
      aggNeighbor=neighborFeature.max(dim=1)
    else:
      raise ValueError("please  ues mean aggregation")
    neighborHidden=torch.matmul(aggNeighbor,self.weight)+self.bias
   
    return neighborHidden




In [None]:
class SageGraphCovNet(nn.Module):
  def __init__(self,inputDim,hiddenDim,aggNeighborMethod="mean",aggHiddenMethod="sum"):
    super(SageGraphCovNet,self).__init__()
    self.aggNeighborMethod=aggNeighborMethod
    self.aggHiddenMethod=aggHiddenMethod
    self.activation=F.relu
    self.aggregator=Aggregator(inputDim,hiddenDim,aggMethod=aggNeighborMethod)
    self.weight=nn.Parameter(torch.Tensor(inputDim,hiddenDim))
  
  def reset(self):
    init.kaiming_uniform_(self.weight)
  def forward(self,sourceNodeFeatures,neighborNodeFeature):
    neighborHidden=self.aggregator(neighborNodeFeature)
    currentNodeHidden=torch.matmul(sourceNodeFeatures,self.weight)
    if self.aggHiddenMethod=="sum":
      hidden=currentNodeHidden+neighborHidden
    elif self.aggHiddenMethod=="concat":
      hidden=torch.cat([currentNodeHidden,neighborHidden])
    else:
      raise ValueError("Please use sum/concat aggregation")
    return self.activation(hidden)

In [None]:
class GraphSage(nn.Module):
  def __init__(self,inputDim,hiddenDim=[64,64],numNeighbors=[10,10]):
    super(GraphSage,self).__init__()
    self.inputDim=inputDim
    self.numNeighbors=numNeighbors
    self.numLayers=len(numNeighbors)
    self.gcn=[]
    self.gcn.append(SageGraphCovNet(inputDim,hiddenDim[0]))
    self.gcn.append(SageGraphCovNet(hiddenDim[0],hiddenDim[1]))
  def forward(self,nodeFeatures):
    hidden=nodeFeatures
    for layer in range(self.numLayers):
      nextHidden=[]
      gcn=self.gcn[layer]
      for hop in range(self.numLayers-1):
        sourceNodeFeatures=hidden[hop]
        sourceNodeNum=len(sourceNodeFeatures):
        neighborNodeFeatures=hidden[hop+1].view(sourceNodeNum,self.numNeighbors[hop],-1)
        h=gcn(sourceNodeFeatures,neighborNodeFeatures)
        nextHidden.append(h)
     hidden=nextHidden
     return hidden[0]


