In [None]:
import tensorflow as tf 
from tensorflow import keras
from keras import layers,optimizers,regularizers,losses 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from collections import defaultdict
import collections
from functools import reduce
from itertools import islice
import time
from Data import load_ppi
from BatchLoader import Build_batch_from_edges


In [None]:
PATH = "Data\\ppi"

SAMPLE_SIZES = [25, 10]
INTERNAL_DIM = 128
NEG_WEIGHT = 1.0
BATCH_SIZE = 512
NEG_SIZE = 20
TRAINING_STEPS = 100
LR = 0.001

In [None]:


def generate_batch(adj_lists, batch_size, sample_sizes, neg_size):
    edges = np.array([(k, v) for k in adj_lists for v in adj_lists[k]])
    nodes = np.array(list(adj_lists.keys()))
    while True :
        batch_edges = edges[np.random.randint(edges.shape[0],size=batch_size),:]
        
        batch = Build_batch_from_edges(batch_edges,nodes,adj_lists,sample_sizes,neg_size)
        yield batch





# Build Model

In [None]:


class Features(tf.keras.layers.Layer):
    def __init__(self,features_data):
        super(Features,self).__init__()
        self.features = tf.constant(features_data,dtype=tf.float32)
    def call(self,nodes):
        return tf.gather(self.features,nodes)


class MeanAgg(tf.keras.layers.Layer):
    def __init__(self,hu_shape,hv_shape,name,active = True):
        super(MeanAgg,self).__init__()
        self.w = self.add_weight(name=name+"_w",
                                 shape=(hu_shape*2,hv_shape),
                                 dtype=tf.float32,
                                 trainable=True)
    
        self.active = active

    def call(self,nodes_features,hu,hv,dif_mat):
        hu_features = tf.gather(nodes_features,hu)
        hv_features = tf.gather(nodes_features,hv)
        agg = tf.matmul(dif_mat,hu_features)
        
        conct = tf.concat([agg,hv_features],1)
        out = tf.matmul(conct,self.w)
        if self.active:
            out = tf.nn.relu(out)
        return out

    

class GraphSage(tf.keras.Model):
    def __init__(self,data_features,num_layers,in_dim,active):
        super(GraphSage,self).__init__()
        self.input_layer = Features(data_features)
        self.agg_layers = []
        for i in range(1,num_layers+1):
            name = "agg_"+str(i)
            input_dim = in_dim if i>1 else  data_features.shape[-1]
            isactive = active if i == num_layers else True
            agg_layer = MeanAgg(input_dim,in_dim,name= name,active=isactive)
            self.agg_layers.append(agg_layer)
    
    def call(self,batch):
        x = self.input_layer(tf.squeeze(batch.src_nodes))
        for agg_layer in self.agg_layers:
            x = agg_layer(x,
                          batch.dst_src_src.pop(),
                          batch.dst_src_dst.pop(),
                          batch.dif_mats.pop(),
                          )
        return x
    
class GraphSageSupervise(GraphSage):
    def __init__(self,data_feats,in_dim,num_layers,num_classes):
        super().__init__(data_feats,num_layers,in_dim,True)
        self.classifier = tf.keras.layers.Dense(num_classes,tf.nn.softmax,use_bias=False,name="classifier")

    def call(self,batch):
        return self.classifier(super().call(batch))


def unsupervise_loss(embeddingA,embeddingB,embeddingN,neg_weight):
    pos_emb = tf.reduce_sum(tf.multiply(embeddingA,embeddingB),axis=1)
    neg_emb = tf.matmul(embeddingA,tf.transpose(embeddingN))

    pos_loss = tf.nn.sigmoid_cross_entropy_with_logits(tf.ones_like(pos_emb),pos_emb)
    neg_loss = tf.nn.sigmoid_cross_entropy_with_logits(tf.zeros_like(neg_emb),neg_emb)

    neg_loss_w = tf.multiply(neg_weight,tf.reduce_sum(neg_loss))
    batch_loss = tf.add(neg_loss_w,tf.reduce_sum(pos_loss))

    return tf.divide(batch_loss,embeddingA.shape[0])


class GraphSageUnsupervise(GraphSage):
    def __init__(self,data_feats,in_dim,num_layers,neg_weight):
        super().__init__(data_feats,num_layers,in_dim,False)
        self.neg_weight = neg_weight
    def call(self,batch):
        x=super().call(batch)
        embedding = tf.math.l2_normalize(x)
        self.add_loss(
            unsupervise_loss(
                tf.gather(embedding,batch.dst2batchA),
                tf.gather(embedding,batch.dst2batchB),
                tf.boolean_mask(embedding,batch.dst2batchN),
                self.neg_weight
            )
        )
        return embedding

In [None]:
def train():
    num_nodes,feat_data,adj_lists = load_ppi() 
    batch_generator = generate_batch(adj_lists,BATCH_SIZE,SAMPLE_SIZES,NEG_SIZE)

    model = GraphSageUnsupervise(feat_data,INTERNAL_DIM,len(SAMPLE_SIZES),NEG_WEIGHT)
    optimizer = optimizers.Adam(LR)

    times = []
    i=1
    for batch in islice(batch_generator,0,TRAINING_STEPS):
        
        start_t = time.time()
        with tf.GradientTape() as tape:
            _ = model(batch)
            loss = model.losses[0]
        grads = tape.gradient(loss,model.trainable_weights)
        optimizer.apply_gradients(zip(grads,model.trainable_weights))
        end_t = time.time()
        times.append(start_t-end_t)
        print(f"Loss in step {i} is :{loss.numpy()}")
        i+=1
    print(f"average batch time is {np.mean(times)}")

train()