<a href="https://colab.research.google.com/github/aswinaus/graphrag/blob/main/Graph_RAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install pyvis IPython cchardet datasets langchain==0.1.17 neo4j openai tiktoken langchain-community langchain-experimental json-repair

In [None]:
from getpass import getpass
import os
from google.colab import userdata
os.environ["OPENAI_API_KEY"] = userdata.get("OPENAI_API_KEY")

In [None]:
from datasets import load_dataset
import pandas as pd
dataset = load_dataset("aswinaus/tax_statistics_dataset_by_income_range", download_mode="force_redownload")
df=pd.DataFrame(dataset['train'])

In [None]:
df.head(10)

In [None]:
# @title Knowledge Graph Builder & Visualizer
# @markdown Note that rendering the graph will take a minute or two under the default 10% sample size, longer with higher sample sizes.
sample_size = 0.66 # @param {type:"number", default:0.10}
import pandas as pd
import networkx as nx
from pyvis.network import Network
from IPython.display import IFrame
from IPython.display import Markdown, HTML


colors = {
    'zipcode': 'lightblue',
    'incometaxby_state_name': 'orange',
    'No of returns': 'red',
    'No of single returns': 'green',
    'No of joint returns': 'magenta',
    'No of joint returns': 'purple',
}
sizes = {
    'zipcode': 20,
    'Size of adjusted gross income': 15,
    'No of returns': 25,
    'No of single returns': 30,
    'No of joint returns': 10
}

# Stratify based on 'rating_value', 'locality', 'trip_period', 'hotel_name' columns
stratify_cols = ['Size of adjusted gross income', 'zipcode', 'STATE', 'No of returns']
sample_size = 0.6

# Perform stratified sampling
sampled_df = df.groupby(stratify_cols, group_keys=False).apply(lambda x: x.sample(frac=sample_size))

# Initialize a directed graph
G = nx.DiGraph()

# Adding nodes with the entity type as a node attribute
for index, row in sampled_df.iterrows():
    incometaxby_state_name=f"{row['STATE']}_{row['zipcode']}_{row['Size of adjusted gross income']}"
    if row['STATE'] not in G:
      G.add_node(row['STATE'],
                 entity='STATE',
                 color=colors.get('STATE', 'blue'),
                 size=sizes.get('STATE', 5), )

    if row['Size of adjusted gross income'] not in G:
      G.add_node(row['Size of adjusted gross income'], entity='Size of adjusted gross income', color=colors.get('LOCALITY', 'gray'),
                 size=sizes.get('Size of adjusted gross income', 40))

    #G.add_node(
    #           row['zipcode'],
    #           entity='ZIPCODE',
    #           color=colors.get('ZIPCODE', 'orange'),
    #           size=sizes.get('ZIPCODE', 20))

    G.add_node(
               row['No of returns'],
               entity='No_of_returns',
               color=colors.get('No_of_returns', 'green'),
               size=sizes.get('No_of_returns', 25))

    G.add_node(
               row['No of single returns'],
               entity='No_of_single_returns',
               color=colors.get('No_of_single_returns', 'orange'),
               size=sizes.get('No_of_returns', 25))

    G.add_node(
               row['No of joint returns'],
               entity='No_of_joint_returns',
               color=colors.get('No_of_joint_returns', 'brown'),
               size=sizes.get('No_of_joint_returns',25))

    G.add_edge(row['STATE'], row['No of returns'],relationship='Size of adjusted gross income')
    G.add_edge(row['STATE'], row['No of single returns'],relationship='Size of adjusted gross income')
    G.add_edge(row['STATE'], row['No of joint returns'],relationship='Size of adjusted gross income')

    G.add_edge(row['zipcode'],row['No of returns'], relationship='Size of adjusted gross income')


# Step 4: Visualization (Optional)
# Convert to a pyvis network
nt = Network('700px', '700px', notebook=True, cdn_resources='in_line')
# if you are not in a Jupyter environment, you might need to set notebook=False
nt.from_nx(G)
nt.toggle_physics(True)  # Enable force-directed algorithm
nt.save_graph('income_tax_2019_graph.html')
nt.show('income_tax_2019_graph.html')

HTML('income_tax_2019_graph.html')