diff --git a/scripts/importer/data_importer/raw_data_importer.py b/scripts/importer/data_importer/raw_data_importer.py index 54becd36b..ec754f597 100644 --- a/scripts/importer/data_importer/raw_data_importer.py +++ b/scripts/importer/data_importer/raw_data_importer.py @@ -298,29 +298,43 @@ def _parse_manta(self): # pylint: disable=too-many-locals,too-many-branches,too finished=True) self.log_insertion(counter, "breakend", start) - def _add_variants_to_db(self, batch: list, genes: list, transcripts: list, # pylint: disable=too-many-arguments - ref_genes: dict, ref_transcripts: dict): - """Add variants to db.""" + def _estimate_variant_lastid(self): # pylint: disable=no-self-use + """ + Return the id of the variant with the highest id. + + Returns 0 if table is empty. + + Returns: + int: id of the variant with highest id or 0 + + """ + try: + return (db.Variant.select(db.Variant.id) + .order_by(db.Variant.id.desc()) + .limit(1) + .get().id) + except db.Variant.DoesNotExist: + return 0 + + def _add_variants_to_db(self, batch: list, genes: list, transcripts: list, references: dict): + """ + Add variants to db. + + Args: + batch (list): variant data (dict) + genes (list): genes for the variants + transcripts(list): transcripts for the variants + references (dict): reference genes and transcripts + """ if not self.settings.beacon_only: - # estimate variant dbid start - try: - curr_id = (db.Variant.select(db.Variant.id) - .order_by(db.Variant.id.desc()) - .limit(1) - .get().id) - except db.Variant.DoesNotExist: - # assumes next id will be 1 if table is empty - curr_id = 0 + curr_id = self._estimate_variant_lastid() db.Variant.insert_many(batch).execute() # check if the variant dbid estimate is correct, otherwise must check manually if not self.settings.beacon_only: - last_id = (db.Variant.select(db.Variant.id) - .order_by(db.Variant.id.desc()) - .limit(1) - .get().id) - if last_id-curr_id == len(batch): + last_id = self._estimate_variant_lastid() + if last_id and last_id-curr_id == len(batch): indexes = list(range(curr_id+1, last_id+1)) else: logging.warning("Bad match between ids - slow check") @@ -329,8 +343,8 @@ def _add_variants_to_db(self, batch: list, genes: list, transcripts: list, # py indexes.append(db.Variant.select(db.Variant.id) .where(db.Variant.variant_id == entry['variant_id']) .get().id) - self._add_variant_genes(indexes, genes, ref_genes) - self._add_variant_transcripts(indexes, transcripts, ref_transcripts) + self._add_variant_genes(indexes, genes, references['genes']) + self._add_variant_transcripts(indexes, transcripts, references['transcripts']) def _get_genes_transcripts(self): """ @@ -368,7 +382,7 @@ def _insert_variants(self): # pylint: disable=too-many-locals,too-many-branches vep_field_names = None with db.database.atomic(): for filename in self.settings.variant_file: # pylint: disable=too-many-nested-blocks - ref_genes, ref_transcripts = self._get_genes_transcripts() + references = dict(zip(('genes', 'transcripts'), self._get_genes_transcripts())) for line in self._open(filename, binary=False): line = line.strip() @@ -480,8 +494,7 @@ def _insert_variants(self): # pylint: disable=too-many-locals,too-many-branches self._add_variants_to_db(batch, genes, transcripts, - ref_genes, - ref_transcripts) + references) genes = [] transcripts = [] batch = [] @@ -495,8 +508,7 @@ def _insert_variants(self): # pylint: disable=too-many-locals,too-many-branches self._add_variants_to_db(batch, genes, transcripts, - ref_genes, - ref_transcripts) + references) if self.settings.set_vcf_sampleset_size and samples: self.sampleset.sample_size = samples