# RelaTree - Social Graph Analytics

This notebook has been created to create a social graph based on the given data about user and their groups.

In [1]:
# Import libraries 
import os
os.environ['PYSPARK_SUBMIT_ARGS'] = (
    '--jars lib/graphframes-release-0-5-0-assembly-0.5.0-spark2.1.jar pyspark-shell')
import lib.graphframes as GF
from pyspark import SparkContext
from pyspark import SparkConf
from pyspark.sql import SQLContext
from pyspark.sql.functions import col, size, lit, collect_list

In [2]:
''' 
Function: Create spark context
Parameters: app_name, executor_memory, no_of_executors
Returns: Spark SQLContext
'''
def createContext(app_name="RelaTree", executor_memory="2g", no_of_executors="4"):
    sparkConf = (SparkConf().setMaster("local").setAppName(app_name).set("spark.executor.memory", executor_memory).set("spark.executor.instances", no_of_executors))
    sparkContext = SparkContext(conf=sparkConf)
    sql_context = SQLContext(sparkContext)
    return sql_context

In [3]:
'''
Function: Create Spark dataframe from csv file
Parameters: file_path,sql_context
Returns: Pyspark Dataframe
'''
def loadData(file_path,sql_context):
    df = sql_context.read.format('com.databricks.spark.csv').options(header='true').load(file_path)
    return df

In [4]:
'''
Function: Provide stats for a Pyspark Dataframe
Parameters: Pyspark Dataframe
Returns: N/A
'''
def getStats(df):
    print("\nRow count: %d\n\nColumn Count: %d\n\nColumn headers: %s\n\nSample Data:\n" %(df.count(),len(df.columns),df.columns))
    df.show(5)  

In [5]:
'''
Function: Create vertices dataframe
Parameters: Pyspark Dataframe
Returns: Pyspark dataframe depciting vertices
'''
def createVertices(df):
    print("Creating Vertices DataFrame..")
    # Select user_Id [vertices] column from Spark dataframe
    df_users = df.select(['user_id'])
    df_users = df_users.selectExpr("user_id as id") 
    # Remove duplicate user_id entries and create vertices dataframe
    vertices = df_users.drop_duplicates()
    print("Vertices DataFrame creation complete.")
    return vertices

In [6]:
'''
Function: Create edges dataframe
Parameters: Pyspark Dataframe
Returns: Pyspark dataframe depciting edges
'''
def createEdges(df):
    # Create the edges dataframe
    print("Creating Edges DataFrame..")
    edges = df.select(col('user_id').alias('src'),col('group_id')).join(df.select(col('user_id').alias('dst'),col('group_id')), on=['group_id'], how='outer')
    #edges = edges.select(col('src'),col('dst'),col('group_id')).filter(edges.src != edges.dst)
    #edges = edges.withColumn("relationship", lit("group_member"))
    #edges = edges.select("src", "dst", "relationship","group_id").groupBy('src','dst','relationship').agg(collect_list('group_id').alias('common_groups'))
    #edges = edges.select('*', size('common_groups').alias('weight'))
    # Remove duplicate entries
    edges = edges.drop_duplicates()
    print("Edges DataFrame creation complete.")
    return edges

In [7]:
'''
Function: Create graph
Parameters: Pyspark Dataframe - vertices, edges
Returns: GraphFrame
'''
def createGraph(vertices, edges):
    print("Creating graph..")
    # Generate the graph
    graph = GF.GraphFrame(vertices, edges)
    print("Graph creation complete.")
    return graph

In [8]:
'''
Function: Save graph to file
Parameters: GraphFrame
Returns: N/A
'''
def saveGraph(graph):
    # Save the graph to a file
    print("Saving graph to file..")
    graph.vertices.write.parquet('store/gv.parquet')
    graph.edges.write.parquet('store/ge.parquet')
    print("Graph has been saved successfully.")

In [9]:
'''
Function: Load graph from file
Parameters: N/A
Returns: GraphFrame
'''
def loadGraph(context):
    # Load the graph from file
    print("\nLoading graph data..")
    vertices = context.read.parquet('store/gv.parquet')
    edges = context.read.parquet('store/ge.parquet')
    print("\nGenerating graph..")
    graph = GF.GraphFrame(vertices, edges)
    print("\nGraph load complete.")
    return graph

In [10]:
'''
Function: Obtain the first connects of a given vertex
Parameters: GraphFrame, vertex label
Returns: GraphFrame of connected vertices and their edges
'''
def firstConnects(graph, vertex):
    first_connect_motifs = graph.find("(v1)-[e]->(v2)").filter("v1.id == '"+vertex+"'")
    return first_connect_motifs.select("v2.id","e.group_id")

In [11]:
'''
Function: Delete dataframe to free memory
Parameters: List of DataFrames
Returns: N/A
'''
def cleanUp(df_list):
    for df in df_list:
        del df
    print("\nDataFrame clean up complete.")

In [12]:
'''
Function: Get users connected by the given group
Parameters: Graphframe, group_id
Returns: Set of users
'''
def getUsersOfAGroup(graph, group):
    edges = graph.edges.filter("group_id = '"+group+"'").collect()
    users = set()
    for row in edges:
        users.add(row.src)
    return users

In [13]:
'''
Function: Add first connected users to the given user set
Parameters: Graphframe, users set
Returns: Set of users
'''
def addFirstConnects(graph, users):
    newUsers = set()
    for user in users:
        vertices = firstConnects(graph, user)
        vertices = vertices.collect()
        for row in vertices:
            newUsers.add(row.id)
    users = users.union(newUsers)
    return users

In [14]:
'''
Function: Get all the groups associated with the given users
Parameters: Graphframe, users set
Returns: Dictionary of Groups with their counts frequency
'''
def getGroupsAssociatedToUsers(graph, users):
    groups = dict()
    for user in users:
        g = graph.edges.filter("src = '"+user+"'").collect()
        for row in g:
            key = row.group_id
            if key in groups:
                groups[key] += 1
            else:
                groups[key] = 1
    total = 0
    for key in groups.keys():
        total += groups[key]
    for key in groups.keys():
        groups[key] = groups[key] / total
    return groups

In [15]:
'''
Function: Get all the channels associated with the given groups
Parameters: Pyspark Dataframe, groups dict
Returns: Dictionary of Channels with their counts frequency weighted based on groups
'''
def getChannelsAssociatedToGroups(data, groups):
    channels = dict()
    for group in groups.keys():
        c = data.filter("group_id = '"+group+"'").collect()
        for row in c:
            key = row.channel_id
            if key in channels:
                channels[key] += groups[group]
            else:
                channels[key] = groups[group]
    total = 0
    for key in channels.keys():
        total += channels[key]
    for key in channels.keys():
        channels[key] = channels[key] / total
    return channels

In [16]:
'''
Function: Get all the channels associated with the given group along with their respective importance factor
Parameters: Pyspark Dataframe, Graphframe, group_id
Returns: Dictionary of Channels with their counts frequency weighted based on groups
'''
def getChannelsForGroup(data, graph, group):
    users = getUsersOfAGroup(graph, group)
    users = addFirstConnects(graph, users)
    groups = getGroupsAssociatedToUsers(graph, users)
    channels = getChannelsAssociatedToGroups(data, groups)
    return channels

In [None]:
if __name__ == "__main__":
    
    # Create Spark context
    sql_context = createContext(app_name="RelaTree", executor_memory="2g", no_of_executors="4")
    
    # Load group to member data into a Pyspark Dataframe
#     df_group_members = loadData('data/group_members.csv',sql_context)
    
    # Load group to channel data into a Pyspark Dataframe
    df_group_channels = loadData('data/group_channel.csv',sql_context)
    
    # Get stats on group members and group channels dataframes
#     getStats(df_group_members)
#     getStats(df_group_channels)
    
    # Create vertices dataframe
#     vertices = createVertices(df_group_members)
    
    # Create edges dataframe
#     edges = createEdges(df_group_members)
    
    # Get stats on vertices and edges dataframes
#     getStats(vertices)
#     getStats(edges)
    
    # Create Graph
#     duta_graph = createGraph(vertices,edges)
    
    # Clean up memory
#     cleanUp([vertices, edges, df_group_members])
    
    # Save Graph to file 
    #saveGraph(duta_graph)
    
    # Load graph from file
    duta_graph = loadGraph(sql_context)
    channels = getChannelsForGroup(df_group_channels, duta_graph, 'd2c8410bb78af46155aaa96b50b082598ca69306')