Skip to content

Commit

Permalink
Refactoring of _add_variants_to_db.
Browse files Browse the repository at this point in the history
  • Loading branch information
talavis committed Aug 19, 2019
1 parent 17bd3c6 commit 0d25699
Showing 1 changed file with 36 additions and 24 deletions.
60 changes: 36 additions & 24 deletions scripts/importer/data_importer/raw_data_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,29 +298,43 @@ def _parse_manta(self): # pylint: disable=too-many-locals,too-many-branches,too
finished=True)
self.log_insertion(counter, "breakend", start)

def _add_variants_to_db(self, batch: list, genes: list, transcripts: list, # pylint: disable=too-many-arguments
ref_genes: dict, ref_transcripts: dict):
"""Add variants to db."""
def _estimate_variant_lastid(self): # pylint: disable=no-self-use
"""
Return the id of the variant with the highest id.
Returns 0 if table is empty.
Returns:
int: id of the variant with highest id or 0
"""
try:
return (db.Variant.select(db.Variant.id)
.order_by(db.Variant.id.desc())
.limit(1)
.get().id)
except db.Variant.DoesNotExist:
return 0

def _add_variants_to_db(self, batch: list, genes: list, transcripts: list, references: dict):
"""
Add variants to db.
Args:
batch (list): variant data (dict)
genes (list): genes for the variants
transcripts(list): transcripts for the variants
references (dict): reference genes and transcripts
"""
if not self.settings.beacon_only:
# estimate variant dbid start
try:
curr_id = (db.Variant.select(db.Variant.id)
.order_by(db.Variant.id.desc())
.limit(1)
.get().id)
except db.Variant.DoesNotExist:
# assumes next id will be 1 if table is empty
curr_id = 0
curr_id = self._estimate_variant_lastid()

db.Variant.insert_many(batch).execute()

# check if the variant dbid estimate is correct, otherwise must check manually
if not self.settings.beacon_only:
last_id = (db.Variant.select(db.Variant.id)
.order_by(db.Variant.id.desc())
.limit(1)
.get().id)
if last_id-curr_id == len(batch):
last_id = self._estimate_variant_lastid()
if last_id and last_id-curr_id == len(batch):
indexes = list(range(curr_id+1, last_id+1))
else:
logging.warning("Bad match between ids - slow check")
Expand All @@ -329,8 +343,8 @@ def _add_variants_to_db(self, batch: list, genes: list, transcripts: list, # py
indexes.append(db.Variant.select(db.Variant.id)
.where(db.Variant.variant_id == entry['variant_id'])
.get().id)
self._add_variant_genes(indexes, genes, ref_genes)
self._add_variant_transcripts(indexes, transcripts, ref_transcripts)
self._add_variant_genes(indexes, genes, references['genes'])
self._add_variant_transcripts(indexes, transcripts, references['transcripts'])

def _get_genes_transcripts(self):
"""
Expand Down Expand Up @@ -368,7 +382,7 @@ def _insert_variants(self): # pylint: disable=too-many-locals,too-many-branches
vep_field_names = None
with db.database.atomic():
for filename in self.settings.variant_file: # pylint: disable=too-many-nested-blocks
ref_genes, ref_transcripts = self._get_genes_transcripts()
references = dict(zip(('genes', 'transcripts'), self._get_genes_transcripts()))
for line in self._open(filename, binary=False):
line = line.strip()

Expand Down Expand Up @@ -480,8 +494,7 @@ def _insert_variants(self): # pylint: disable=too-many-locals,too-many-branches
self._add_variants_to_db(batch,
genes,
transcripts,
ref_genes,
ref_transcripts)
references)
genes = []
transcripts = []
batch = []
Expand All @@ -495,8 +508,7 @@ def _insert_variants(self): # pylint: disable=too-many-locals,too-many-branches
self._add_variants_to_db(batch,
genes,
transcripts,
ref_genes,
ref_transcripts)
references)

if self.settings.set_vcf_sampleset_size and samples:
self.sampleset.sample_size = samples
Expand Down

0 comments on commit 0d25699

Please sign in to comment.