In [None]:
import requests
import pandas as pd
import os
import json
from datetime import datetime
from requests.auth import HTTPBasicAuth
from dotenv import load_dotenv

# --- LOAD ENV FILE ---
load_dotenv()

# --- SNOWFLAKE IMPORTS ---
import snowflake.connector
from snowflake.connector.pandas_tools import write_pandas

# --- CONFIGURATION ---
API_USER = os.getenv('RATER8_USER', '$UpperlineHealth')
API_PASS = os.getenv('RATER8_PASS', 'Inept#crater$5') 
BASE_URL = "https://rapi.rater8.com/api"

# FULL YEAR 2025 WINDOW
FROM_DATE = "2025-12-31" 
TO_DATE = "2026-01-09"

auth = HTTPBasicAuth(API_USER, API_PASS)

# Snowflake Config
SF_USER = os.getenv('SNOWFLAKE_USER')
SF_PASSWORD = os.getenv('SNOWFLAKE_PASSWORD')
SF_ACCOUNT = os.getenv('SNOWFLAKE_ACCOUNT')
SF_WAREHOUSE = os.getenv('SNOWFLAKE_WAREHOUSE')
SF_DATABASE = os.getenv('SNOWFLAKE_DATABASE')
SF_SCHEMA = os.getenv('SNOWFLAKE_SCHEMA')
SF_ROLE = os.getenv('SNOWFLAKE_ROLE')

# DEBUGGING
DEBUG_PRINT_RAW_JSON = True 
CUSTOM_Q_NPS_ID = -1 

# --- UTILITY: API Fetcher ---
def get_permissions():
    url = f"{BASE_URL}/permissions"
    response = requests.get(url, auth=auth)
    response.raise_for_status()
    return response.json()

def build_client_code_lookup(permissions):
    lookup = {}
    for practice in permissions:
        practice_id = practice.get('id')
        practice_codes = ", ".join(map(str, practice.get('clientCodes', [])))
        if practice_id:
            lookup[f"P_{practice_id}"] = practice_codes
        for employee in practice.get('employees', []):
            employee_id = employee.get('id')
            employee_codes = ", ".join(map(str, employee.get('clientCodes', [])))
            if employee_id:
                lookup[f"E_{employee_id}"] = employee_codes
    return lookup

def fetch_survey_data(practice_id, employee_id):
    # Company ID 0 is typically used for internal surveys
    url = f"{BASE_URL}/reviews/practice/{practice_id}/employee/{employee_id}"
    params = {"companyid": 0, "fromdate": FROM_DATE, "todate": TO_DATE}
    response = requests.get(url, auth=auth, params=params)
    response.raise_for_status()
    if response.text.strip().upper() == 'NULL' or not response.text.strip():
        return []
    return response.json()

# --- UTILITY: Data Flattener ---
def flatten_internal_surveys(surveys, practice_id, practice_name, employee_id, employee_name, client_code_lookup):
    rows = []
    employee_key = f"E_{employee_id}"
    practice_key = f"P_{practice_id}"
    client_code = client_code_lookup.get(employee_key, "") or client_code_lookup.get(practice_key, "")

    for survey in surveys:
        mrn_val = survey.get("patientMrn") or survey.get("mrn") or survey.get("MRN")
        
        base_fields = {
            "Practice_ID": practice_id, 
            "Practice_Name": practice_name,
            "Entity_ID": employee_id, 
            "Entity_Name": employee_name, 
            "Entity_Type": "Employee",
            "Client_Code": client_code, 
            "Survey_ID": survey.get("id"),
            "MRN": mrn_val, 
            "CompanyId": survey.get("companyId"), 
            "CompanyName": survey.get("companyName"),
            "ReviewMonth": survey.get("ReviewMonth"), 
            "ReviewDate": survey.get("ReviewDate"),
            "apptLocationCode": survey.get("apptLocationCode"), 
            "apptTypeCode": survey.get("apptTypeCode"),
            "apptDeptCode": survey.get("apptDeptCode"), 
            "atLoc_rater8Id": survey.get("atLoc_rater8Id"),
            "atLoc_Name": survey.get("atLoc_Name"),
        }
        
        record_added = False
        main_rating = survey.get("employeeRating") or survey.get("locationRating")
        main_comment = survey.get("Comment")
        
        if main_rating is not None or (main_comment and main_comment != ""):
            row = base_fields.copy()
            row.update({"Question_ID": 0, "Question_Name": "Employee/Location Main Rating", "Rating": main_rating, "Comment_Text": main_comment})
            rows.append(row)
            record_added = True

        for q in survey.get("questions", []):
            row = base_fields.copy()
            row.update({"Question_ID": q.get("id"), "Question_Name": q.get("name"), "Rating": q.get("rating"), "Comment_Text": q.get("Comment")})
            rows.append(row)
            record_added = True
            
        for cq in survey.get("customQuestions", []):
            question_type = cq.get("questionType")
            q_id = CUSTOM_Q_NPS_ID if question_type == "NPS" else question_type 
            row = base_fields.copy()
            row.update({"Question_ID": q_id, "Question_Name": question_type, "Rating": cq.get("rating"), "Comment_Text": cq.get("comment")})
            rows.append(row)
            record_added = True
        
        if not record_added:
            row = base_fields.copy()
            row.update({"Question_ID": None, "Question_Name": "No Question/Rating Data", "Rating": None, "Comment_Text": None})
            rows.append(row)
    return rows

# --- SNOWFLAKE UTILITIES ---

def get_snowflake_conn():
    return snowflake.connector.connect(
        user=SF_USER, password=SF_PASSWORD, account=SF_ACCOUNT,
        warehouse=SF_WAREHOUSE, database=SF_DATABASE, schema=SF_SCHEMA, role=SF_ROLE
    )

def map_pandas_dtype_to_snowflake(dtype):
    if pd.api.types.is_integer_dtype(dtype): return "NUMBER"
    elif pd.api.types.is_float_dtype(dtype): return "FLOAT"
    elif pd.api.types.is_datetime64_any_dtype(dtype): return "TIMESTAMP_NTZ"
    elif pd.api.types.is_bool_dtype(dtype): return "BOOLEAN"
    else: return "VARCHAR" 

def create_or_replace_table(conn, df, table_name):
    cursor = conn.cursor()
    cols_def = []
    for col_name, dtype in df.dtypes.items():
        sf_type = map_pandas_dtype_to_snowflake(dtype)
        cols_def.append(f'"{col_name}" {sf_type}')
    
    ddl = f'CREATE OR REPLACE TABLE "{table_name}" ({", ".join(cols_def)})'
    cursor.execute(ddl)
    cursor.close()

def run_merge(conn, target_table, source_table, join_keys):
    cursor = conn.cursor()
    cursor.execute(f"DESC TABLE {source_table}")
    columns = [row[0] for row in cursor.fetchall()]
    
    on_clause = " AND ".join([f'target."{k}" = source."{k}"' for k in join_keys])
    update_clause = ", ".join([f'target."{c}" = source."{c}"' for c in columns])
    col_list = ", ".join([f'"{c}"' for c in columns])
    val_list = ", ".join([f'source."{c}"' for c in columns])
    
    merge_sql = f"""
    MERGE INTO "{target_table}" AS target
    USING "{source_table}" AS source
    ON {on_clause}
    WHEN MATCHED THEN UPDATE SET {update_clause}
    WHEN NOT MATCHED THEN INSERT ({col_list}) VALUES ({val_list})
    """
    cursor.execute(merge_sql)
    cursor.close()

def upload_to_snowflake_merge(df, table_name, merge_keys):
    if df.empty:
        print(f"⚠️ Skipping Snowflake upload for {table_name}: DataFrame is empty.")
        return

    df.columns = [c.upper().replace(" ", "_").replace("/", "_").replace(".", "_") for c in df.columns]

    print(f"\n--- Starting Snowflake Merge Operation: {table_name} ---")
    try:
        conn = get_snowflake_conn()
        cursor = conn.cursor()
        temp_table_name = f"{table_name}_TEMP"
        
        create_or_replace_table(conn, df, temp_table_name)
        write_pandas(conn, df, temp_table_name, auto_create_table=False, quote_identifiers=True)
        
        try:
            cursor.execute(f"SELECT TOP 1 * FROM \"{table_name}\"")
            table_exists = True
        except:
            table_exists = False

        if not table_exists:
            cursor.execute(f'ALTER TABLE "{temp_table_name}" RENAME TO "{table_name}"')
            print(f"✅ Created {table_name} successfully.")
        else:
            run_merge(conn, table_name, temp_table_name, merge_keys)
            cursor.execute(f'DROP TABLE IF EXISTS "{temp_table_name}"')
            print(f"✅ Merge to {table_name} complete.")

        conn.close()
    except Exception as e:
        print(f"❌ Snowflake Error: {e}")

# --- MAIN EXECUTION ---

def main():
    internal_survey_rows = []
    
    print(f"--- Starting rater8 SURVEY ONLY Pull ({FROM_DATE} to {TO_DATE}) ---")
    
    try:
        permissions = get_permissions()
        client_code_lookup = build_client_code_lookup(permissions)
        total_practices = len(permissions)
        
        for p_index, practice in enumerate(permissions):
            practice_id = practice.get('id')
            practice_name = practice.get('name', 'N/A')
            print(f"Processing Practice {p_index + 1}/{total_practices}: {practice_name}...")

            for employee in practice.get('employees', []):
                try:
                    surveys = fetch_survey_data(practice_id, employee.get('id'))
                    if surveys:
                        internal_survey_rows.extend(
                            flatten_internal_surveys(surveys, practice_id, practice_name, employee.get('id'), employee.get('name'), client_code_lookup)
                        )
                except Exception as e:
                    print(f"  ❌ Error fetching surveys for {employee.get('name')}: {e}")

    except Exception as e:
        print(f"\n❌ FATAL ERROR: {e}")
        return

    if internal_survey_rows:
        df_surveys = pd.DataFrame(internal_survey_rows)
        df_surveys['MRN'] = df_surveys['MRN'].astype(str).replace({'None': None, 'nan': None})

        if 'ReviewDate' in df_surveys.columns: 
            df_surveys['ReviewDate'] = pd.to_datetime(df_surveys['ReviewDate'])
        
        upload_to_snowflake_merge(
            df_surveys, 
            "RATER8_SURVEYS", 
            merge_keys=["SURVEY_ID", "QUESTION_ID"] 
        )
    else:
        print("\n⚠️ No survey data found for the period.")

if __name__ == "__main__":
    main()