In [None]:
import pandas as pd
import re
import yaml
import sqlparse
import os
import pandas as pd
import requests

In [None]:
pd.set_option('display.max_rows', 20)
pd.set_option('display.max_columns', 30)
#pd.set_option('display.width', None)
#pd.set_option('display.max_colwidth', 10) 

#### Extract repo elements

In [None]:
def extract_owner_and_repo(github_url):
    try:
        # Remove the base URL and split the rest
        parts = github_url.replace("https://github.com/", "").split("/")
        # Validate structure
        if len(parts) >= 2:
            owner = parts[0]
            repo = parts[1]
            return owner, repo
        else:
            raise ValueError("Invalid GitHub URL structure.")
    except Exception as e:
        print(f"Error: {e}")
        return None, None

def list_local_repo_structure(repo_path):
    paths = []
    for root, dirs, files in os.walk(repo_path):
        rel_dir = os.path.relpath(root, repo_path)
        if rel_dir == '.':
            rel_dir = ''
        if rel_dir:
            paths.append(rel_dir + '/')
        for f in files:
            file_path = f"{rel_dir}/{f}" if rel_dir else f
            paths.append(file_path)
    return paths

def list_online_repo_structure(owner, repo):
    url = f"https://api.github.com/repos/{owner}/{repo}/contents/"
    stack = [(url, '')]
    paths = []
    while stack:
        current_url, current_path = stack.pop()
        response = requests.get(current_url)
        if response.status_code == 200:
            items = response.json()
            for item in items:
                if item['type'] == 'dir':
                    paths.append(current_path + item['name'] + '/')
                    stack.append((item['url'], current_path + item['name'] + '/'))
                else:
                    paths.append(current_path + item['name'])
    return paths

In [None]:
def is_online_repo(path):
    return path.startswith("http://") or path.startswith("https://")

local_dbt_repo = ''
online_dbt_repo = 'https://github.com/dbt-labs/jaffle-shop'

# Use local repo?
if False:
    repo_path = local_dbt_repo
else:
    repo_path = online_dbt_repo

is_online = is_online_repo(repo_path)
if is_online:
    owner, repo = extract_owner_and_repo(online_dbt_repo)
    repo_elements = list_online_repo_structure(owner,repo)
else:
    repo_elements = list_local_repo_structure(local_dbt_repo)

print(repo_elements)

### dbt models knowledge db

#### Select dbt elements

In [None]:
dbt_extensions = ['.sql', '.yml', '.yaml', '.csv']

def select_dbt_elements_by_extension(dbt_extensions, repo_elements):
    # Filter elements with relevant extensions
    return [element for element in repo_elements if any(element.endswith(ext) for ext in dbt_extensions)]

repo_dbt_elements = select_dbt_elements_by_extension(dbt_extensions, repo_elements)
print(repo_dbt_elements)

def select_dbt_models(dbt_extensions, repo_dbt_elements):
    return [
        element for element in repo_dbt_elements
        if element.startswith('models/') and any(element.endswith(ext) for ext in dbt_extensions)
    ]

repo_dbt_models = select_dbt_models(dbt_extensions, repo_dbt_elements)
print(repo_dbt_models)

In [None]:
dbt_config_elements = ['packages.yml', 'dbt_project.yml']

In [None]:
def generate_dbt_models_df(repo_dbt_models):
    data = []
    for path in repo_dbt_models:
        name = os.path.basename(path)
        extension = os.path.splitext(name)[1]
        data.append({'path': path, 'name': name, 'extension': extension})
    return pd.DataFrame(data)

dbt_models_df = generate_dbt_models_df(repo_dbt_models)
display(dbt_models_df)

#### Add sql code

In [None]:
if False:
    def get_base_url(repo_url):
        if repo_url.startswith("https://github.com"):
            parts = repo_url.replace("https://github.com/", "").split("/")
            owner, repo = parts[0], parts[1]
            return f"https://raw.githubusercontent.com/{owner}/{repo}/main"
        else:
            raise ValueError("URL not valid.")

    def extract_file_content(path, is_online = False, repo_base_url = None):
        try:
            if is_online:
                # Build complete URL
                file_url = f"{repo_base_url}/{path}" if repo_base_url else path
                response = requests.get(file_url)
                if response.status_code == 200:
                    return response.text
                else:
                    return f"Error: {response.status_code} {response.reason}"
            else:
                # Read content
                with open(path, 'r', encoding='utf-8') as file:
                    return file.read()
        except Exception as e:
            return f"Error: {e}"

    def add_code_column(df, is_online = False, repo_url = None):
        if is_online:
            repo_base_url = get_base_url(repo_url)
        else:
            repo_base_url = ''

        df['sql_code'] = df['path'].apply(lambda path: extract_file_content(path, is_online, repo_base_url))
        return df

    dbt_models_df = add_code_column(dbt_models_df, is_online, online_dbt_repo)
    dbt_models_df.head(3)


In [None]:
def get_base_url(repo_url):
    if repo_url.startswith("https://github.com"):
        parts = repo_url.replace("https://github.com/", "").split("/")
        owner, repo = parts[0], parts[1]
        return f"https://raw.githubusercontent.com/{owner}/{repo}/main"
    else:
        raise ValueError("URL not valid.")

def extract_file_content(path, is_online=False, repo_base_url=None):
    try:
        if is_online:
            # Build complete URL
            file_url = f"{repo_base_url}/{path}" if repo_base_url else path
            response = requests.get(file_url)
            if response.status_code == 200:
                content = response.text
            else:
                return f"Error: {response.status_code} {response.reason}"
        else:
            # Read content locally
            with open(path, 'r', encoding='utf-8') as file:
                content = file.read()

        # Process content based on file type
        if path.endswith(('.yml', '.yaml')):
            try:
                return yaml.safe_load(content)  # Parse YAML and return as dictionary
            except yaml.YAMLError as e:
                return f"Error parsing YAML: {e}"
        elif path.endswith('.sql'):
            try:
                return sqlparse.format(content, reindent=True, keyword_case='upper')  # Format SQL
            except Exception as e:
                return f"Error parsing SQL: {e}"
        else:
            return content  # Return plain text for other types

    except Exception as e:
        return f"Error: {e}"

def add_code_column(df, is_online=False, repo_url=None):
    if is_online:
        repo_base_url = get_base_url(repo_url)
    else:
        repo_base_url = ''

    # Extract content for each file and process it based on type
    df['sql_code'] = df['path'].apply(lambda path: extract_file_content(path, is_online, repo_base_url))
    return df

dbt_models_df = add_code_column(dbt_models_df, is_online, online_dbt_repo)
dbt_models_df.head(3)


#### Add config block

In [None]:
def extract_config_block(sql_code):
    pattern = r"{{\s*config\((.*?)\)\s*}}"
    match = re.search(pattern, sql_code, re.DOTALL)
    return match.group(0) if match else None

def add_config_column(df):
    df['config'] = df.apply(
        lambda row: extract_config_block(row['sql_code']) if row['extension'] == '.sql' else None,
        axis=1
    )
    return df

dbt_models_df = add_config_column(dbt_models_df)
dbt_models_df.head(3)

In [None]:
test = """
{{
    config(
        materialized="table"
    )
}}
"""

dbt_models_df.at[0, 'config'] = test

#### Add model metadata

In [None]:
def extract_materialized_value(config_text):
    if config_text:
        match = re.search(r"materialized\s*=\s*[\"']([^\"']+)[\"']", config_text)
        return match.group(1) if match else None
    return None

def check_is_snapshot(config_text):
    if config_text:
        return 'strategy' in config_text
    return False

dbt_models_df['materialized'] = dbt_models_df['config'].apply(extract_materialized_value)
dbt_models_df['is_snapshot'] = dbt_models_df['config'].apply(check_is_snapshot)
dbt_models_df['materialized'] = dbt_models_df.apply(lambda row: 'snapshot' if row['is_snapshot'] else row['materialized'] ,1)

def contains_jinja_code(code_text):
    if isinstance(code_text, str):
        return bool(re.search(r"{%|{#", code_text))
    return False

dbt_models_df['has_jinja_code'] = dbt_models_df['sql_code'].apply(contains_jinja_code)


def categorize_model(name):
    if name.startswith("base"):
        return "base"
    elif name.startswith("stg"):
        return "stg"
    elif name.startswith("int"):
        return "int"
    elif name.startswith("test"):
        return "test"
    elif name.startswith("snap"):
        return "snap"
    elif name.startswith("__sources"):
        return "sources"
    else:
        return "other"

dbt_models_df['model_category'] = dbt_models_df['name'].apply(categorize_model)

def get_vertical(name, model_category):
    base_name = re.sub(r'\.[^.]+$', '', name)
    
    if model_category == 'sources':
        return 'sources'
    
    known_categories = ['stg', 'int']
    if model_category not in known_categories:
        # Para model_category = other u otras no conocidas, devolver base_name sin extensión
        return base_name
    
    # Para stg o int, extraer vertical antes de "__" o "."
    pattern = rf'^{re.escape(model_category)}_([a-z0-9_]+?)(?:__|\.|$)'
    match = re.search(pattern, base_name)
    return match.group(1) if match else base_name

dbt_models_df['vertical'] = dbt_models_df.apply(lambda row: get_vertical(row['name'], row['model_category']), axis=1)

#### Zip the dataframe by models

In [None]:
def assign_yml_rows_to_each_model(dbt_models_df):
    dbt_models_df['yml_code'] = None

    yml_df = dbt_models_df[dbt_models_df['extension'] == '.yml'].copy()
    yml_df['delete'] = False

    for idx, row in yml_df.iterrows():
        base_name = row['name'].rsplit('.', 1)[0]

        sql_match = dbt_models_df[(dbt_models_df['name'] == base_name + '.sql')]

        if not sql_match.empty:
            dbt_models_df.at[sql_match.index[0], 'yml_code'] = row['sql_code']
            yml_df.at[idx, 'delete'] = True
        else:
            yml_df.at[idx, 'yml_code'] = row['sql_code']
            yml_df.at[idx, 'sql_code'] = None

    yml_df = yml_df[~yml_df['delete']]

    dbt_models_df = dbt_models_df[dbt_models_df['extension'] != '.yml']

    yml_df = yml_df.drop(columns=['delete'])
    dbt_models_df = pd.concat([dbt_models_df, yml_df], ignore_index=True)

    return dbt_models_df

dbt_models_df = assign_yml_rows_to_each_model(dbt_models_df)


#### Extract sql code info

In [None]:
def extract_tests(yml_code):
    if not isinstance(yml_code, dict):
        return None

    tests_dict = {'columns': {}, 'unit_tests': []}

    # Extract tests from all models
    for model in yml_code.get('models', []):
        for column in model.get('columns', []):
            column_name = column.get('name')
            if column_name:
                # Combine 'tests' and 'data_tests' if present
                tests = column.get('tests', []) + column.get('data_tests', [])
                if tests:
                    tests_dict['columns'][column_name] = tests

    # Extract unit tests
    if 'unit_tests' in yml_code:
        unit_test_names = [test.get('name') for test in yml_code['unit_tests'] if test.get('name')]
        if unit_test_names:
            tests_dict['unit_tests'] = unit_test_names

    return tests_dict if tests_dict['columns'] or tests_dict['unit_tests'] else None

dbt_models_df['tests'] = dbt_models_df['yml_code'].apply(extract_tests)
dbt_models_df['has_tests'] = dbt_models_df['tests'].apply(lambda x: x is not None)


In [None]:
def extract_ids_from_query(code):
    if not isinstance(code, str):
        return None
    
    # Parse the SQL query
    parsed = sqlparse.parse(code)
    if not parsed:
        return None
    
    # Regular expression to find columns ending in '_id'
    id_pattern = re.compile(r'\b(\w+_id)\b')
    
    cte_ids = set()
    output_ids = set()
    
    for statement in parsed:
        # Flatten tokens to handle nested structures
        token_list = sqlparse.sql.TokenList(statement.tokens).flatten()
        inside_cte = False
        
        for token in token_list:
            # Detect CTE start (with keyword 'WITH')
            if token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'WITH':
                inside_cte = True
            
            # Detect SELECT after a WITH block ends
            if inside_cte and token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'SELECT':
                inside_cte = False
            
            if token.ttype is sqlparse.tokens.Name or token.ttype is None:
                match = id_pattern.search(token.value)
                if match:
                    if inside_cte:
                        cte_ids.add(match.group(1))
                    else:
                        output_ids.add(match.group(1))
    ids = {
        'cte_ids': list(cte_ids),
        'output_ids': list(output_ids)
    }
    return ids['output_ids'] if ids['output_ids'] != [] else None

dbt_models_df['sql_ids'] = dbt_models_df['sql_code'].apply(extract_ids_from_query)


In [None]:
def has_select_all_in_last_select(code):
    if not isinstance(code, str):
        return False

    parsed = sqlparse.parse(code)
    if not parsed:
        return False

    select_statements = [stmt for stmt in parsed if stmt.get_type() == 'SELECT']
    if not select_statements:
        return False
    last_select = select_statements[-1]

    for token in last_select.tokens:
        if token.ttype is sqlparse.tokens.Wildcard and token.value == '*':
            return True

    return False

dbt_models_df['has_select_all_in_last_select'] = dbt_models_df['sql_code'].apply(has_select_all_in_last_select)


In [None]:
def has_group_by(code):
    if not isinstance(code, str):
        return False

    parsed = sqlparse.parse(code)
    if not parsed:
        return False
    return 'group by' in code.lower()


dbt_models_df['has_group_by'] = dbt_models_df['sql_code'].apply(has_group_by)

In [None]:
def find_primary_key(tests_dict):
    if not isinstance(tests_dict, dict) or 'columns' not in tests_dict:
        return None

    for column, tests in tests_dict.get('columns', {}).items():
        # Check if the column has the required tests for a primary key
        if tests == ['not_null', 'unique'] or 'dbt_constraints.primary_key' in tests:
            return column
    
    return None

dbt_models_df['primary_key'] = dbt_models_df['tests'].apply(find_primary_key)

In [None]:
def extract_sql_filters(sql_query):
    if not isinstance(sql_query, str) or not sql_query.strip():
        return None

    sql_query_clean = ' '.join(sql_query.split()).lower()

    filters_patterns = [
        (r'\bwhere\b\s+(.*?)(?=\bgroup\b|\border\b|\blimit\b|\bhaving\b|;|$)', 'where'),
        (r'\bon\b\s+(.*?)(?=\bleft\b|\bright\b|\binner\b|\bouter\b|\bjoin\b|\bselect\b|\bwhere\b|\bgroup\b|\border\b|\blimit\b|;|$)', 'join'),
        (r'\bhaving\b\s+(.*?)(?=\bgroup\b|\border\b|\blimit\b|;|$)', 'having')
    ]

    filters = []
    joins = []

    for pattern, clause_type in filters_patterns:
        matches = re.findall(pattern, sql_query_clean, re.DOTALL)
        for match in matches:
            sub_conditions = re.split(r'\band\b|\bor\b', match)
            for condition in sub_conditions:
                cleaned = condition.strip().strip('()')
                if cleaned:
                    if clause_type == 'join':
                        joins.append(cleaned)
                    else:
                        filters.append(cleaned)
    all_filters = filters + joins
    return all_filters if all_filters != [] else None

dbt_models_df['filters'] = dbt_models_df['sql_code'].apply(extract_sql_filters)
dbt_models_df['is_filtered'] = dbt_models_df['filters'].apply(lambda x: x is not None)

In [None]:
def extract_dbt_macros(sql_query):

    if not isinstance(sql_query, str) or not sql_query.strip():
        return None
    
    macro_pattern = r"\{\{\s*([\w\.]+)\s*\(.*?\)\s*\}\}"
    matches = re.findall(macro_pattern, sql_query)
    filtered_macros = sorted(set(m for m in matches if m not in ('ref', 'source')))
    
    return filtered_macros if filtered_macros != [] else None

dbt_models_df['macros'] = dbt_models_df['sql_code'].apply(extract_dbt_macros)
dbt_models_df['has_macros'] = dbt_models_df['macros'].apply(lambda x: x is not None)

#### Calculate models structure

In [None]:
def extract_source_details(code, source_pattern):
    if not isinstance(code, str):
        return False, None
    sources = re.findall(source_pattern, code)
    if sources:
        return True, [f"{source[0]}.{source[1]}" for source in sources]
    return False, None

def enrich_dbt_models(dbt_models_df):
    # Helper regex patterns
    source_pattern = r"\{\{\s*source\(['\"](.*?)['\"],\s*['\"](.*?)['\"]\)\s*\}\}"
    ref_pattern = r"\{\{\s*ref\(['\"](.*?)['\"]\)\s*\}\}"
    
    # Add 'parent_models' - extract all models referenced using 'ref'
    dbt_models_df['parent_models'] = dbt_models_df['sql_code'].apply(
        lambda code: re.findall(ref_pattern, code) if isinstance(code, str) else []
    )
    
    dbt_models_df[['is_source_model', 'source']] = dbt_models_df['sql_code'].apply(
        lambda code: pd.Series(extract_source_details(code, source_pattern))
    )
    
    # Build a dictionary to track children relationships
    model_children = {}
    for idx, row in dbt_models_df.iterrows():
        for parent in row['parent_models']:
            model_children.setdefault(parent, []).append(row['name'].replace('.sql', ''))

    # Add 'children_models' - list all models that depend on this model
    dbt_models_df['children_models'] = dbt_models_df['name'].apply(
        lambda name: model_children.get(name.replace('.sql', ''), [])
    )
    
    # Add 'is_end_model' - True if there are no children
    dbt_models_df['is_end_model'] = dbt_models_df['children_models'].apply(lambda children: len(children) == 0)
    
    return dbt_models_df

dbt_models_enriched_df = enrich_dbt_models(dbt_models_df)

In [None]:
display(dbt_models_enriched_df)

### Add descriptions using LLM

In [None]:
import os
import sys

repo_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if repo_root not in sys.path:
    sys.path.append(repo_root)

import openai_setup

OPENAI_API_KEY = openai_setup.conf['key']
OPENAI_PROJECT = openai_setup.conf['project']
OPENAI_ORGANIZATION = openai_setup.conf['organization']
DEFAULT_LLM_MODEL = "gpt-4o-mini"

In [None]:
from langchain_openai import ChatOpenAI  # Importa el modelo de OpenAI
from langchain.schema import HumanMessage  # Para interactuar con mensajes

In [None]:
llm = ChatOpenAI(model=DEFAULT_LLM_MODEL, temperature=0.1, openai_api_key=OPENAI_API_KEY, openai_organization = OPENAI_ORGANIZATION,  )

In [None]:
def generate_query_description(llm, query, documentation = None):
    # Context and prompt
    prompt = f"""
    Given the following SQL query and additional documentation, describe in a few sentences what the query does.

    SQL Query:
    {query}

    Documentation:
    {documentation if documentation else "No additional documentation provided."}

    Please provide a concise and clear description.
    """

    # Interact
    response = llm([HumanMessage(content=prompt)])
    return response.content

In [None]:
example_query = dbt_models_enriched_df.iloc[1].sql_code
print(example_query)
example_doc = dbt_models_enriched_df.iloc[1].yml_code
print(example_doc)

In [None]:
generate_query_description(llm, example_query, example_doc)

In [None]:
llm:
- resumen escrito de lo que hace la query
- jinja: para que es el jinja