In [24]:
import networkx as nx
import os
import matplotlib.pyplot as plt
import pandas as pd
import requests
from collections import defaultdict

In [66]:

def generate_gql_payload(dbt_cloud_environment_id):

    # Write Discovery API query
    gql_query = """
    query Lineage($environmentId: Int!, $first: Int!) {
        environment(id: $environmentId) {
          definition {

            sources(first: $first) {
              edges {
                node {
                  uniqueId
                  name
                  resourceType
                  children{
                    uniqueId
                    name
                    resourceType
                  }
                tags
                }
              }
            }

            seeds(first: $first) {
              edges {
                node {
                  uniqueId
                  name
                  resourceType
                  children {
                    uniqueId
                    name
                    resourceType
                  }
                }
              }
            }
            snapshots(first: $first) {
              edges {
                node {
                  uniqueId
                  name
                  resourceType
                  parents {
                    uniqueId
                    name
                    resourceType
                  }
                  children {
                    uniqueId
                    name
                    resourceType
                  }
                }
              }
            }

            models(first:$first){
              edges {
                node {
                  uniqueId
                  name
                  resourceType
                  parents{
                    uniqueId
                    name
                    resourceType
                  }
                  children{
                    uniqueId
                    name
                    resourceType
                  }
                  tags
                }
              }
            }

          }
        }
    }
    """

    # Define query variables
    variables = {
        "environmentId": dbt_cloud_environment_id,
        "first": 500
    }
    
    # return outputs
    return {"gql_query": gql_query, "variables": variables}


# Query the API
def query_discovery_api(auth_token, gql_query, variables):
    response = requests.post('https://metadata.cloud.getdbt.com/beta/graphql', 
        headers={"authorization": "Bearer "+auth_token, "content-type": "application/json"},
        json={"query": gql_query, "variables": variables})
    data = response.json()['data']['environment']
    
    return data

# extract info
def extract_node_definitions(api_response):
    nodes = []
    node_types = ["models", "sources", "seeds", "snapshots", "metrics", "exposures"]
    for node_type in node_types:
        if node_type in api_response["definition"]:
            for node_edge in api_response["definition"][node_type]["edges"]:
                node_edge["node"]["type"] = node_type
                nodes.append(node_edge["node"])
    nodes_df = pd.DataFrame(nodes)
    return nodes_df


# create a network graph
def create_generic_lineage_graph(nodes_df):
    G = nx.DiGraph()
    for _, node in nodes_df.iterrows():
        G.add_node(node["uniqueId"], name=node["name"], type=node["type"], tags=node["tags"])
    for _, node in nodes_df.iterrows():
        if node["type"] not in ["sources", "seeds"]:
          for parent in node["parents"]:
              G.add_edge(parent["uniqueId"], node["uniqueId"])
    return G

# a function to pull all down stream nodes from a given model
def calculate_downstream_nodes(G, starting_dbt_cloud_model_id):
    return [n for n in nx.traversal.bfs_tree(G, starting_dbt_cloud_model_id)]

# get all downstream nodes of the given model
def get_downstream_nodes(G, starting_dbt_cloud_model_id):
    
    downstream_nodes = calculate_downstream_nodes(G, starting_dbt_cloud_model_id)
    
    def filter_get_downstream_nodes(n1):
        return n1 in downstream_nodes
    
    return nx.subgraph_view(G, filter_node=filter_get_downstream_nodes)


# return a dataframe of downstream models, their tags, and whether or not they are missing the given tag
def analyze_downstream_node_tags(G, tag_to_analyze):
    
    downstream_model_tags_list = []

    for model_name in G.nodes(data=True):

        downstream_model_tags_list.append({"model_name" : model_name[1]['name'], "model_tags": model_name[1]['tags']})


    for model_name in downstream_model_tags_list:

        model_name['has_specified_tag'] = "FALSE" if tag_to_analyze not in model_name['model_tags'] else "TRUE"

    return pd.DataFrame.from_records(downstream_model_tags_list)


# create a function to do it all
def analyze_dbt_models_downstream_nodes_tags(auth_token, dbt_cloud_environment_id, starting_dbt_cloud_model_id, tag_to_analyze):

    # generate the gql payload
    gql_pl = generate_gql_payload(dbt_cloud_environment_id)

    # hit the discovery API
    disco_api_resp = query_discovery_api(auth_token=auth_token, gql_query=gql_pl["gql_query"], variables=gql_pl["variables"])

    # parse the API response 
    parsed_api_resp = extract_node_definitions(disco_api_resp)

    # create a networkx graph
    lineage_graph = create_generic_lineage_graph(parsed_api_resp)

    # filter down the lineage graph to just downstream nodes of the given model
    model_downstream_nodes = get_downstream_nodes(lineage_graph, starting_dbt_cloud_model_id)

    # analyze if the downstream nodes are missing the specfied tag
    analysis_df = analyze_downstream_node_tags(model_downstream_nodes, tag_to_analyze)
    
    # add some styling for fun
    def _color_red_or_green(val):
        color = 'red' if val == "FALSE" else 'green'
        return 'color: %s' % color
    
    analysis_df_styled = (analysis_df.
                          style.
                          applymap(_color_red_or_green, subset=pd.IndexSlice[:, ['has_specified_tag']])
                          .set_caption(f'Analyzing models downstream of <span style="color:blue;font-weight:bold;"> "{starting_dbt_cloud_model_id}"</span> to see if they have the tag <span style="color:blue;font-weight:bold;">"{tag_to_analyze}"</span>')
                          .set_properties(subset=['model_name'], **{'width': '300px'})
                          .set_properties(subset=['model_tags'], **{'width': '200px'})
                          .set_table_styles([{
                                                'selector': 'caption',
                                                'props': [
                                                    ('font-size', '16px'),
                                                    ('color', 'black'),
                                                    ('border', '3px solid !important'),
                                                    ('padding', '10px'),
                                                    ('margin-bottom', '25px'),
                                                    ('background-color', 'white')
                                                    
                                                ]
                                            }]))
    
    # return the df
    return analysis_df_styled
    

In [79]:
# ----------------------------------------------------------
# inputs
# ----------------------------------------------------------

# define the needed inputs
dbt_cloud_environment_id = 157050

# define the dbt Cloud Auth Token
auth_token = "<<fill in with token>>"

# this follows the syntax of "model.<<project_name>>.<<model_name>>"
starting_dbt_cloud_model_id = "model.sample_demo.stg_tpch_orders"

# the tag you want to analyze
tag_to_analyze = "finance"

# ----------------------------------------------------------

# analyze if the downstream nodes are missing the specified tag
analysis_df = analyze_dbt_models_downstream_nodes_tags(auth_token, dbt_cloud_environment_id, starting_dbt_cloud_model_id, tag_to_analyze)

# display results
display(analysis_df)

Unnamed: 0,model_name,model_tags,has_specified_tag
0,stg_tpch_orders,[],False
1,order_items,[],False
2,monthly_counts_hundred_row_sample,[],False
3,monthly_gross_revenue,"['finance', 'RevOps']",True
4,fact_orders,['finance'],True
5,fct_order_items,['finance'],True
6,mothly_counts_reporting_table,"['finance', 'RevOps']",True
