In [None]:
import os
from pprint import pprint
from importlib import import_module

from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import joinedload, subqueryload, Load, load_only
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.dialects import postgresql

from dataservice.extensions import db
from dataservice import create_app
from dataservice.api.investigator.models import Investigator
from dataservice.api.study.models import Study
from dataservice.api.participant.models import Participant
from dataservice.api.family.models import Family
from dataservice.api.family_relationship.models import FamilyRelationship
from dataservice.api.diagnosis.models import Diagnosis
from dataservice.api.outcome.models import Outcome
from dataservice.api.phenotype.models import Phenotype
from dataservice.api.biospecimen.models import Biospecimen
from dataservice.api.genomic_file.models import GenomicFile
from dataservice.api.workflow.models import Workflow, WorkflowGenomicFile
from dataservice.api.study_file.models import StudyFile

from dataservice.util.data_import.utils import to_camel_case
from dataservice.util.data_import.etl.defaults import DEFAULT_ENTITY_TYPES

study_id = 'SD_FN5YSGZE'

class BaseLoader(object):

    def __init__(self, config_name=None):
        if not config_name:
            config_name = 'testing'
        self.setup(config_name)
        self.entity_id_map = {}

    def setup(self, config_name):
        """
        Creates tables in database
        """
        self.app = create_app(config_name)
        self.app.config['SQLALCHEMY_ECHO'] = True
        self.app_context = self.app.app_context()
        self.app_context.push()
        db.create_all()
        self.import_models()

    def teardown(self):
        """
        Clean up
        """
        db.session.close()
        db.drop_all()

    def drop_all(self, study_external_id):
        """
        Delete all data related to a study
        """
        from dataservice.api.study.models import Study
        from dataservice.api.investigator.models import Investigator

        try:
            study = Study.query.filter_by(external_id=study_external_id).one()
        except NoResultFound:
            print("Study {} not found. Aborting drop all for this dataset"
                  .format(study_external_id))
        else:
            # Save investigator id
            investigator_id = study.investigator_id

            # Delete study
            db.session.delete(study)

            # Delete investigator
            if investigator_id:
                investigator = Investigator.query.get(investigator_id)
                db.session.delete(investigator)

            db.session.commit()

    def import_models(self, skip_entities=[]):
        """
        Load all entities into db
        """
        # For each entity type
        for entity_type in DEFAULT_ENTITY_TYPES:
            # Skip some entities
            if entity_type in skip_entities:
                continue
            # Dynamically import entity model class
            model_name = to_camel_case(entity_type)
            model_module_path = 'dataservice.api.{}.models'.format(
                entity_type)
            models_module = import_module(model_module_path)
            model = getattr(models_module, model_name)

In [None]:
loader = BaseLoader()

In [None]:
q = (Diagnosis.query.options(joinedload(Diagnosis.participant, innerjoin=True).load_only('kf_id'))
# .options(Load(Participant).load_only('kf_id', 'study_id'))
.filter(Participant.study_id == study_id))
print(q.statement.compile(dialect=postgresql.dialect()))

In [None]:
q = (Diagnosis.query
     .join(Participant.diagnoses)
     .options(Load(Participant).load_only('kf_id', 'study_id'))
     .filter(Participant.study_id == study_id))
print(q.statement.compile(dialect=postgresql.dialect()))

In [None]:
# q = (Diagnosis.query
#      .options(joinedload(Participant.diagnoses).load_only('kf_id')))
#      .options(Load(Participant).load_only('kf_id', 'study_id'))
#      .filter(Participant.study_id == study_id))
# print(q.statement.compile(dialect=postgresql.dialect()))

### Wrong way to load children through joins

In [None]:
q = (GenomicFile.query.options(
     joinedload(GenomicFile.sequencing_experiment).load_only("kf_id")
     .joinedload(SequencingExperiment.aliquot).load_only("kf_id")
     .joinedload(Aliquot.sample).load_only("kf_id"))
     .join(Sample.participant).options(Load(Participant).load_only("kf_id", "study_id"))
     .filter(Participant.study_id==study_id))
print(q.statement.compile(dialect=postgresql.dialect()))

### Correct way to load children through joins

In [None]:
# Genomic files
q = (GenomicFile.query
     .join(SequencingExperiment.genomic_files)
     .join(Aliquot.sequencing_experiments)
     .join(Sample.aliquots)
     .join(Participant.samples)
     .filter(Participant.study_id==study_id)
    )
print(q.statement.compile(dialect=postgresql.dialect()))

In [None]:
# Participants
q = (Participant.query
                .options(joinedload(Participant.diagnoses)
                        .load_only('kf_id'))
                .options(joinedload(Participant.samples)
                        .load_only('kf_id'))
                .options(joinedload(Participant.phenotypes)
                        .load_only('kf_id'))
                .options(joinedload(Participant.outcomes)
                        .load_only('kf_id')))
print(q.statement.compile(dialect=postgresql.dialect()))

In [None]:
# Family 
q = (Family.query
     .join(Family.participants)
    .options(Load(Participant).load_only('kf_id', 'study_id'))
    .filter(Participant.study_id==study_id)
    .distinct(Family.kf_id)
    .order_by(Family.kf_id))

In [None]:
print(q.statement.compile(dialect=postgresql.dialect()))
results = q.all()

In [None]:
# Family relationship
q = (FamilyRelationship.query
     .join(FamilyRelationship.participant)
    .options(Load(Participant).load_only('kf_id', 'study_id'))
    .filter(Participant.study_id==study_id))
print(q.statement.compile(dialect=postgresql.dialect()))
results = q.all()

In [None]:
# Study File
q = StudyFile.query.filter(StudyFile.study_id == study_id)
print(q.statement.compile(dialect=postgresql.dialect()))
results = q.all()

In [None]:
# Investigator
q = (Investigator.query
     .join(Investigator.studies)
     .options(Load(Study).load_only('kf_id'))
     .filter(Study.kf_id==study_id))
print(q.statement.compile(dialect=postgresql.dialect()))
q.count()

In [None]:
# Sequencing experiments
q = (SequencingExperiment.query
     .join(GenomicFile.sequencing_experiment)
     .join(Biospecimen.genomic_files)
     .join(Participant.biospecimens)
     .filter(Participant.study_id==study_id)
    )