In [None]:
import sqlalchemy as sa
import pandas as pd

from orm_loader.helpers import configure_logging, bootstrap, explain_sqlite_fk_error, bulk_load_context, configure_logging
from sqlalchemy.orm import sessionmaker
from sqlalchemy.exc import IntegrityError

from omop_alchemy import get_engine_name, load_environment, TEST_PATH, ROOT_PATH

from omop_alchemy.cdm.model.vocabulary import (
    Domain,
    Vocabulary,
    Concept_Class,
    Relationship,
    Concept,
    Concept_Ancestor,
    Concept_Relationship,
    Concept_Synonym,
    Concept_Synonym,
)

ATHENA_INITIAL_LOAD = [
    Domain,
    Vocabulary,
    Concept_Class,
    Relationship,
    Concept
]

ATHENA_SUBSEQUENT_LOAD = [
    Concept_Ancestor,
    Concept_Relationship
]

from random import randint, choice
import numpy as np

from sqlalchemy.orm import Session
from omop_alchemy.cdm.model.health_system import Location, Care_Site, Provider, Visit_Detail, Visit_Occurrence
from omop_alchemy.cdm.model.clinical import Person, Condition_Occurrence, Procedure_Occurrence, Death, Specimen, Drug_Exposure, Measurement, Observation
from omop_alchemy.cdm.model.structural import Episode, Episode_Event
from omop_alchemy.cdm.model.derived import Observation_Period
from datetime import date, timedelta

configure_logging()
load_environment()

engine_string = get_engine_name()
engine = sa.create_engine(engine_string, future=True, echo=False)
bootstrap(engine, create=True)

2026-01-12 09:22:52,338 | INFO     | sql_loader.omop_alchemy.config | Environment variables loaded from .env file
2026-01-12 09:22:52,339 | INFO     | sql_loader.omop_alchemy.config | Database engine configured


In [2]:
import os
from pathlib import Path

In [3]:
Session = sessionmaker(bind=engine, future=True)
session = Session()

In [None]:
base_path = TEST_PATH / "fixtures" / "athena_source"

# uncomment this line if you want to load the full athena source from env var
# instead of the minimal test fixture set for rapid access

# base_path = Path(os.environ['SOURCE_PATH'])

In [5]:
# Initial load of core vocabulary tables - use bulk load to ensure mutual FK constraints are handled (trusted sources only)
with bulk_load_context(session):
    for model in ATHENA_INITIAL_LOAD:
        _ = model.load_csv(
            session,
            base_path / f"{model.__tablename__.upper()}.csv",
            dedupe=True,
        )
    session.commit()

Found 1 rows with unexpected nulls in vocabulary.vocabulary_id
Found 34 rows with unexpected nulls in concept.concept_name
Found 1 rows with unexpected nulls in concept.vocabulary_id
Found 1 rows with unexpected nulls in concept.concept_code


In [None]:
# can still turn off FK checks for speed but mutual dependency is not an issue for this one - commit after each chunk
with bulk_load_context(session):
    for model in ATHENA_SUBSEQUENT_LOAD:
        _ = model.load_csv(
            session,
            base_path / f"{model.__tablename__.upper()}.csv",
            dedupe=True,
            chunksize=5000,
            commit_on_chunk=True,
        )
        session.commit()

In [5]:
# for a fresh load db dupe checks that include db are super slow for the largest tables, 
# but synonym file is very dupe prone so just for this one...

with bulk_load_context(session):
    for model in [Concept_Synonym]:#ATHENA_SUBSEQUENT_LOAD:
        _ = model.load_csv(
            session,
            base_path / f"{model.__tablename__.upper()}.csv",
            dedupe=True,
            chunksize=5000,
            commit_on_chunk=True,
            dedupe_incl_db=True
        )
        session.commit()

Dropping 5000 rows from concept_synonym that already exist in the database
Dropping 1 rows from concept_synonym that already exist in the database
Dropping 20 rows from concept_synonym that already exist in the database
Dropping 1 rows from concept_synonym that already exist in the database
Dropping 17 rows from concept_synonym that already exist in the database
Dropping 7 rows from concept_synonym that already exist in the database
Dropping 5 rows from concept_synonym that already exist in the database
Dropping 1 rows from concept_synonym that already exist in the database
Found 1 rows with unexpected nulls in concept_synonym.concept_synonym_name
Dropping 20 rows from concept_synonym that already exist in the database
Dropping 2 rows from concept_synonym that already exist in the database
Dropping 1 rows from concept_synonym that already exist in the database
Dropping 18 rows from concept_synonym that already exist in the database
Dropping 2 rows from concept_synonym that already exis

In [None]:
concept_by_domain = pd.DataFrame(
    session.query(
        *Concept.__table__.columns
    )
    .filter(
        sa.or_(
            Concept.domain_id.in_(['Gender', 'Ethnicity', 'Race', 'Visit', 'Location', 'Provider', 'Type Concept']),
            sa.and_(
                Concept.domain_id == 'Condition',
                Concept.vocabulary_id == 'ICDO3'
            )
        )
    )
)

In [8]:
avail_gender = list(concept_by_domain[concept_by_domain.domain_id=='Gender'].concept_id)
avail_ethnicity = list(concept_by_domain[concept_by_domain.domain_id=='Ethnicity'].concept_id)
avail_race = list(concept_by_domain[concept_by_domain.domain_id=='Race'].concept_id)
avail_place_of_service = list(concept_by_domain[concept_by_domain.domain_id=='Visit'].concept_id)
avail_country = list(concept_by_domain[concept_by_domain.concept_class_id=='Location'].concept_id)
avail_provider = list(concept_by_domain[concept_by_domain.domain_id=='Provider'].concept_id)
avail_types = list(concept_by_domain[concept_by_domain.domain_id=='Type Concept'].concept_id)

In [9]:
cancers = list(concept_by_domain[(concept_by_domain.domain_id=='Condition')&(concept_by_domain.vocabulary_id=='ICDO3') & (concept_by_domain.concept_code.str.contains('/3'))].concept_id)

In [10]:
staging_parents = pd.DataFrame(
    session.query(
        *Concept.__table__.columns
    )
    .join(Concept_Ancestor, Concept.concept_id==Concept_Ancestor.descendant_concept_id)
    .filter(Concept_Ancestor.ancestor_concept_id==734320)
    .filter(Concept_Ancestor.max_levels_of_separation==1)
)

staging_sets = {}

for axis in ['T', 'N', 'M', 'Stage']:
    parents = list(staging_parents[staging_parents.concept_name.str.contains(axis)].concept_id)
    s = pd.DataFrame(
        session.query(
            *Concept.__table__.columns
        )
        .join(Concept_Ancestor, Concept.concept_id==Concept_Ancestor.descendant_concept_id)
        .filter(Concept_Ancestor.ancestor_concept_id.in_(parents))
        .filter(Concept.concept_code.ilike('%8th%'))
        .filter(~Concept.concept_code.ilike('%yp%'))
    )
    staging_sets[axis] = s

In [11]:
# confirming string hack to identify staging axes does work as expected
# staging_sets['Stage'].concept_code.map(lambda x: x.split('-')[-1]).value_counts()

In [12]:
# these are super-naive and brute-force ways to populate very basic test data - good enough for now - better content coming

def populate_reference_data(session):
    
    loc_ids = Location.allocator(session)
    cs_ids = Care_Site.allocator(session)
    pro_ids = Provider.allocator(session)
    
    location_data = [{'location_id': loc_ids.next(), 'country_concept_id': choice(avail_country), 'city': f'City {idx}'} for idx in range(10)]
    locations = [Location(**row) for row in location_data]
    care_site_data = [{'care_site_id': cs_ids.next(), 'care_site_name': f'Care Site {idx}', 'location_id': choice(locations).location_id, 'place_of_service_concept_id': choice(avail_place_of_service)} for idx in range(30)]
    care_sites = [Care_Site(**row) for row in care_site_data]
    provider_data = [{'provider_id': pro_ids.next(), 'specialty_concept_id': choice(avail_provider), 'gender_concept_id': choice(avail_gender), 'care_site_id': choice(care_sites).care_site_id} for _ in range(50)]
    providers = [Provider(**row) for row in provider_data]

    session.add_all(locations)
    session.add_all(care_sites)
    session.add_all(providers)
    session.commit()

    return locations, care_sites, providers

def populate_people_and_visits(session, care_sites):
    
    person_ids = Person.allocator(session)
    visit_ids = Visit_Occurrence.allocator(session)
    
    person_data = [{'person_id': person_ids.next(), 'year_of_birth': randint(1950, 2020), 'month_of_birth': randint(1, 12), 'gender_concept_id':choice(avail_gender), 'race_concept_id':choice(avail_race), 'ethnicity_concept_id':choice(avail_ethnicity)} for idx in range(1000)]
    people = [Person(**row) for row in person_data]

    visits = []
    for person in people:
        cs = choice(care_sites)
        visit_num = randint(1, 3)
        for v in range(visit_num):
            days_delay = randint(0, 365)
            visit_date = date(2020, 1, 1) + timedelta(days_delay)
            visit = Visit_Occurrence(
                visit_occurrence_id=visit_ids.next(),
                person_id=person.person_id,
                care_site_id=cs.care_site_id,
                visit_concept_id=choice(avail_place_of_service),
                visit_start_date=visit_date,
                visit_end_date=visit_date,
            )
            visits.append(visit)
    session.add_all(people)
    session.add_all(visits)
    session.commit()
    return people, visits

def populate_observation_periods(session):
    op_ids = Observation_Period.allocator(session)
    deaths = []
    rows = (
        session.query(
            Visit_Occurrence.person_id,
            sa.func.min(Visit_Occurrence.visit_start_date).label("start"),
            sa.func.max(Visit_Occurrence.visit_end_date).label("end"),
            Death.death_date,
            Observation_Period.observation_period_id
        )
        .join(Death, Death.person_id==Visit_Occurrence.person_id, isouter=True)
        .join(Observation_Period, Observation_Period.person_id==Visit_Occurrence.person_id, isouter=True)
        .filter(Observation_Period.observation_period_id==None)
        .group_by(Visit_Occurrence.person_id)
        .all()
    )
    obs = []
    for idx, r in enumerate(rows):
        deceased = np.random.choice([True, False], p=[0.05, 0.95])
        if deceased:
            death_date = r.end + timedelta(days=randint(1, 365))
            deaths.append(
                Death(
                    person_id=r.person_id,
                    death_date=death_date,
                    death_type_concept_id=choice(avail_types),
                )
            )
            obs_end = death_date
        else:
            obs_end = r.end
        obs.append(
            Observation_Period(
                observation_period_id=op_ids.next(),
                person_id=r.person_id,
                observation_period_start_date=r.start,
                observation_period_end_date=obs_end,
                period_type_concept_id=choice(avail_types),
            )
        )
    session.add_all(deaths)
    session.add_all(obs)
    session.commit()
    return obs

def populate_conditions_and_modifiers(session):
    cond_ids = Condition_Occurrence.allocator(session)
    meas_ids = Measurement.allocator(session)
    ep_ids   = Episode.allocator(session)
    rows = (
        session.query(
            Observation_Period, Death, Condition_Occurrence
        )
        .join(Death, Observation_Period.person_id==Death.person_id, isouter=True)
        .join(Condition_Occurrence, Observation_Period.person_id==Condition_Occurrence.person_id, isouter=True)
        .all()
    )
    conditions = []
    measurements = []
    episodes = []
    episode_events = []
    for obs, death, condition in rows:
        if condition:
            continue
        t = choice(list(staging_sets['T'].concept_id))
        n = choice(list(staging_sets['N'].concept_id))
        m = choice(list(staging_sets['M'].concept_id))
        # don't worry abt overall stage for now as it should be calculated
        condition_concept = choice(cancers)
        condition = Condition_Occurrence(
            condition_occurrence_id=cond_ids.next(),
            condition_concept_id = condition_concept,
            condition_start_date = obs.observation_period_start_date,
            condition_type_concept_id = choice(avail_types),
            person_id = obs.person_id,
            condition_status_concept_id = 32902
        )
        conditions.append(condition)
        episode = Episode(
            episode_id=ep_ids.next(),
            person_id=obs.person_id,
            episode_concept_id=32533,  # Episode of care
            episode_object_concept_id=condition.condition_concept_id,
            episode_start_date=condition.condition_start_date,
            episode_end_date=(
                death.death_date if death else obs.observation_period_end_date
            ),
            episode_type_concept_id=choice(avail_types),  # EHR / registry / derived
        )
        episodes.append(episode)

        for stage in [t, n, m]:
            measurement = Measurement(
                person_id = obs.person_id,
                measurement_id = meas_ids.next(),
                measurement_concept_id = stage,
                measurement_event_id = condition.condition_occurrence_id,
                meas_event_field_concept_id = 1147127, # condition_occurrence.condition_occurrence_id
                measurement_date = condition.condition_start_date,
                measurement_type_concept_id = choice(avail_types),
                value_as_number = 1
            )
            measurements.append(measurement)
            episode_events.append(
                Episode_Event(
                    episode_id=episode.episode_id,
                    event_id=measurement.measurement_id,
                    episode_event_field_concept_id=1147138,  # measurement.measurement_id
                )
            )
        episode_events.append(
            Episode_Event(
                episode_id=episode.episode_id,
                event_id=condition.condition_occurrence_id,
                episode_event_field_concept_id=1147127,  # condition_occurrence.condition_occurrence_id
            )
        )
    session.add_all(conditions)
    session.add_all(measurements)
    session.add_all(episodes)
    session.add_all(episode_events)
    session.commit()

In [13]:
with Session() as sess:
    populate_reference_data(sess)
    sess.commit()
    care_sites = sess.query(Care_Site).all()

In [14]:
with Session() as sess:
    populate_people_and_visits(sess, care_sites)
    populate_observation_periods(sess)

In [15]:
with Session() as sess:
    populate_conditions_and_modifiers(sess)