In [None]:
from pathlib import Path
import pandas as pd
from dotenv import load_dotenv
load_dotenv()
config_path = Path('../dash_config')

In [None]:
subq = pd.read_csv(config_path / 'subquery.csv')
queryr = pd.read_csv(config_path / 'query_rule.csv')
qr_lookup = pd.read_csv(config_path / 'query_rule_map.csv')
phen = pd.read_csv(config_path / 'phenotype.csv')
phen_def = pd.read_csv(config_path / 'phenotype_definition.csv')
meas = pd.read_csv(config_path / 'measure.csv')
meas_rel = pd.read_csv(config_path / 'measure_relationship.csv')
report = pd.read_csv(config_path / 'report.csv')
report_cohorts = pd.read_csv(config_path / 'report_cohort_map.csv')
report_ind = pd.read_csv(config_path / 'report_indicator_map.csv')
report_version = pd.read_csv(config_path / 'report_version.csv')
dash_cohort = pd.read_csv(config_path / 'dash_cohort.csv')
dash_cohort_map = pd.read_csv(config_path / 'dash_cohort_def_map.csv')
dash_cohort_def = pd.read_csv(config_path / 'dash_cohort_def.csv')
indicator = pd.read_csv(config_path / 'indicator.csv')
report_indicators = pd.read_csv(config_path / 'report_indicator_map.csv')

In [None]:
from omop_constructs.semantics import registry_engine

In [None]:
from oa_cohorts.measurables import get_measurable_registry
from oa_cohorts.core import RuleTemporality, RuleTarget, RuleMatcher, ThresholdDirection, RuleCombination, ReportStatus
from oa_cohorts.query import (
    DashCohort, DashCohortDef, dash_cohort_def_map,
    Indicator, report_indicator_map,
    Measure, MeasureRelationship,
    MeasureNode, SubqueryNode, QueryNode, QueryPlan,
    Report, ReportCohortMap, ReportVersion,
    Subquery,
    QueryRule, ExactRule, HierarchyExclusionRule, HierarchyRule, AbsenceRule, ScalarRule, PhenotypeRule, SubstringRule,
    Phenotype, PhenotypeDefinition
)
import sqlalchemy.orm as so
import sqlalchemy as sa
from typing import Any

In [None]:
registry = get_measurable_registry()

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from orm_loader.helpers import Base  

Base.metadata.create_all(registry_engine)
Session = sessionmaker(bind=registry_engine, future=True)

In [None]:
def clean_dict(d, model_cls) -> dict[str, Any]:
    keys = model_cls.__table__.columns.keys()
    return {
        k: v for k, v in d.items() if not pd.isna(v) and k in keys
    }

In [None]:

subqueries = [
    Subquery(**clean_dict(d, Subquery)) for d in subq.rename(
    columns={
        'subquery_target': 'target',   
        'subquery_temporality': 'temporality',
        'subquery_name': 'name',
        'subquery_short_name': 'short_name',
    }
).to_dict(orient='records')]

queryrules = [
    QueryRule(**clean_dict(d, QueryRule)) for d in queryr.rename(
    columns={
        'query_matcher': 'matcher',
        'query_concept_id': 'concept_id',
        'query_notes': 'notes',
        'scalar_threshold': 'scalar_threshold'
 }
).to_dict(orient='records')]


for qr in queryrules:
    qr.matcher = getattr(RuleMatcher, qr.matcher)

qr_lookup_rows = qr_lookup.to_dict(orient="records")

phenotypes = [
    Phenotype(**clean_dict(d, Phenotype)) for d in phen.to_dict(orient='records')
]

phenotype_defs = [
    PhenotypeDefinition(**clean_dict(d, PhenotypeDefinition)) for d in phen_def.to_dict(orient='records')
]

measures = [
    Measure(**clean_dict(d, Measure)) for d in meas.rename(
    columns={
        'measure_name': 'name',
        'measure_combination': 'combination'
    }
).to_dict(orient='records')]

for m in measures:
    m.combination = RuleCombination(m.combination.replace('rule_', ''))
    m.person_ep_override = m.person_ep_override == 't'

measure_relationships = [
    MeasureRelationship(**clean_dict(d, MeasureRelationship)) for d in meas_rel.to_dict(orient='records')
]

In [None]:
# with Session() as session:
#     session.add_all(subqueries)
#     session.add_all(queryrules)
#     session.add_all(phenotypes)
#     session.add_all(phenotype_defs)
#     session.add_all(measures)
#     session.add_all(measure_relationships)
#     session.commit()

In [None]:
# with Session() as session:
#     session.execute(subquery_rule_map.insert(), rows)
#     session.commit()

In [None]:
cohort_def_lookup = dash_cohort_map.to_dict(orient='records')
dash_defs = [DashCohortDef(**clean_dict(d, DashCohortDef)) for d in dash_cohort_def.to_dict(orient='records')]
dash_cohorts = [DashCohort(**clean_dict(d, DashCohort)) for d in dash_cohort.to_dict(orient='records')]
indicators = [Indicator(**clean_dict(d, Indicator)) for d in indicator.to_dict(orient='records')]
report_indicator_objects = report_indicators.drop_duplicates().to_dict(orient='records')
reports = [Report(**clean_dict(d, Report)) for d in report.to_dict(orient='records')]
report_cohort_maps = [ReportCohortMap(**clean_dict(d, ReportCohortMap)) for d in report_cohorts.to_dict(orient='records')]
report_versions = [ReportVersion(**clean_dict(d, ReportVersion)) for d in report_version.to_dict(orient='records')]

In [None]:
for r in report_versions:
    r.report_status = ReportStatus(r.report_status.replace('st_', ''))

In [None]:
for r in report_cohort_maps:
    r.primary_cohort = r.primary_cohort == 't'

In [None]:
with Session() as session:
    session.add_all(dash_defs)
    session.add_all(dash_cohorts)
    session.add_all(indicators)
    session.add_all(reports)
    session.add_all(report_cohort_maps)
    session.add_all(report_versions)
    session.commit()


In [None]:
with Session() as session:
    session.execute(dash_cohort_def_map.insert(), cohort_def_lookup)
    session.execute(report_indicator_map.insert(), report_indicator_objects)
    session.commit()