In [1]:
from dotenv import load_dotenv
import os
from typing import List, Dict, Any, Optional, Union

import pandas as pd
import numpy as np
from pathlib import Path
from tqdm.notebook import tqdm

load_dotenv()


# Import from our Classes module
from Classes.model_classes import SQLLineageExtractor, SQLLineageResult, create_sql_lineage_extractor
from Classes.validation_classes import SQLLineageValidator
from Classes.regexp_extractor import RegexSQLExtractor


MODEL = "Qwen/Qwen3-Coder-30B-A3B-Instruct"
PROVIDER = "scaleway"
HF_TOKEN = os.environ.get("HF_TOKEN")


# Get the current working directory
current_dir = Path.cwd()
# Construct the path relative to current directory
file_path = current_dir / 'data' / 'views.csv'

PROMPT = """Please extract source-to-target lineage from the SQL query with the following requirements:

### SQL Lineage Extraction Task
Extract source-to-target lineage from the SQL statement below. Return ONLY valid JSON containing:
- "target": The main object being created or modified (fully qualified name)
- "sources": List of DISTINCT base tables/views (fully qualified names)"""



In [2]:
# Create extractor using factory function
extractor = create_sql_lineage_extractor(
    model=MODEL,
    provider=PROVIDER,
    hf_token=HF_TOKEN,
    max_new_tokens=2048,
    do_sample=False,
    max_retries=3,
    use_pydantic_parser=True, 
    human_prompt_template = PROMPT
)

# Create Regexp extractor 
re_extractor = RegexSQLExtractor()

In [3]:
validation = SQLLineageValidator()

In [4]:
df_data = pd.read_csv(file_path)
df_data['ddl'] = "INSERT INTO s_grnplm_vd_t_bvd_db_dmslcl." + df_data['table_name'] + " " + df_data['view_def']
lines = '\n'.join(df_data['ddl'].head(10).values)

In [5]:
data_lineage = []
f1_scores = {}
results = []

for index, row in tqdm(df_data.iterrows(), total=len(df_data), desc="ðŸŽ¨ Extracting S2T"):

    target_ = re_extractor.extract(row['ddl'])

    res_ = validation.run_comprehensive_validation(
        extractor, 
        row['ddl'],
        expected_result=target_)
    results.append(res_)
    try:
        f1_scores[row['table_name']] = res_['metrics']['f1_score']
    except:
        f1_scores[row['table_name']] = 0
        
    data_lineage.append(res_['result'])

ðŸŽ¨ Extracting S2T:   0%|          | 0/127 [00:00<?, ?it/s]

In [6]:
np.mean(list(f1_scores.values()))

np.float64(0.9589912810554935)

In [7]:
df = pd.DataFrame(results)
df.pivot_table(index='status', columns='validation_type', values='message', aggfunc='count').fillna(0.0)

validation_type,comprehensive,uniqueness
status,Unnamed: 1_level_1,Unnamed: 2_level_1
FAILED,0.0,4.0
SUCCESS,123.0,0.0


In [8]:
df_success = df[df['status'] == 'SUCCESS'].join(pd.DataFrame(
    df[df['status'] == 'SUCCESS'].pop('metrics').tolist(),
    index=df[df['status'] == 'SUCCESS'].index,
    columns=['precision', 'recall', 'f1_score']
))

df_success['f1_score'].mean()

np.float64(0.9901779893825015)