Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup in insert_variants #593

Merged
merged 4 commits into from
Aug 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ disable=too-few-public-methods,
import-error,
missing-docstring,
abstract-method,
bad-continuation,
invalid-name,
too-few-public-methods,
keyword-arg-before-vararg,
Expand Down
179 changes: 101 additions & 78 deletions scripts/importer/data_importer/raw_data_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,74 @@ def _parse_manta(self): # pylint: disable=too-many-locals,too-many-branches,too
finished=True)
self.log_insertion(counter, "breakend", start)

def _estimate_variant_lastid(self): # pylint: disable=no-self-use
"""
Return the id of the variant with the highest id.

Returns 0 if table is empty.

Returns:
int: id of the variant with highest id or 0

"""
try:
return (db.Variant.select(db.Variant.id)
.order_by(db.Variant.id.desc())
.limit(1)
.get().id)
except db.Variant.DoesNotExist:
return 0

def _add_variants_to_db(self, batch: list, genes: list, transcripts: list, references: dict):
"""
Add variants to db.

Args:
batch (list): variant data (dict)
genes (list): genes for the variants
transcripts(list): transcripts for the variants
references (dict): reference genes and transcripts
"""
if not self.settings.beacon_only:
curr_id = self._estimate_variant_lastid()

db.Variant.insert_many(batch).execute()

# check if the variant dbid estimate is correct, otherwise must check manually
if not self.settings.beacon_only:
last_id = self._estimate_variant_lastid()
if last_id and last_id-curr_id == len(batch):
indexes = list(range(curr_id+1, last_id+1))
else:
logging.warning("Bad match between ids - slow check")
indexes = []
for entry in batch:
indexes.append(db.Variant.select(db.Variant.id)
.where(db.Variant.variant_id == entry['variant_id'])
.get().id)
self._add_variant_genes(indexes, genes, references['genes'])
self._add_variant_transcripts(indexes, transcripts, references['transcripts'])

def _get_genes_transcripts(self):
"""
Retrieve the genes and transcripts for the current dataset version in the form
`{entity: dbid}`.

Returns:
tuple: (genes, transcripts)

"""
ref_set = self.dataset_version.reference_set
genes = {gene.gene_id: gene.id
for gene in (db.Gene.select(db.Gene.id, db.Gene.gene_id)
.where(db.Gene.reference_set == ref_set))}
transcripts = {tran.transcript_id: tran.id
for tran in (db.Transcript.select(db.Transcript.id,
db.Transcript.transcript_id)
.join(db.Gene)
.where(db.Gene.reference_set == ref_set))}
return genes, transcripts

def _insert_variants(self): # pylint: disable=too-many-locals,too-many-branches,too-many-statements
"""Import variants from a VCF file."""
logging.info(f"Inserting variants{' (dry run)' if self.settings.dry_run else ''}")
Expand All @@ -314,18 +382,7 @@ def _insert_variants(self): # pylint: disable=too-many-locals,too-many-branches
vep_field_names = None
with db.database.atomic():
for filename in self.settings.variant_file: # pylint: disable=too-many-nested-blocks
# Get reference set for the variant
ref_set = self.dataset_version.reference_set

# Get all genes and transcripts for foreign keys
ref_genes = {gene.gene_id: gene.id
for gene in (db.Gene.select(db.Gene.id, db.Gene.gene_id)
.where(db.Gene.reference_set == ref_set))}
ref_transcripts = {tran.transcript_id: tran.id
for tran in (db.Transcript.select(db.Transcript.id,
db.Transcript.transcript_id)
.join(db.Gene)
.where(db.Gene.reference_set == ref_set))}
references = dict(zip(('genes', 'transcripts'), self._get_genes_transcripts()))
for line in self._open(filename, binary=False):
line = line.strip()

Expand All @@ -343,10 +400,8 @@ def _insert_variants(self): # pylint: disable=too-many-locals,too-many-branches
"Make sure VCF header is present.")
sys.exit(1)

base = {}
base = {'dataset_version': self.dataset_version}
for i, item in enumerate(line.strip().split("\t")):
if i == 0:
base['dataset_version'] = self.dataset_version
if i < 7:
base[header[i][0]] = header[i][1](item)
elif i == 7 or not self.settings.beacon_only:
Expand Down Expand Up @@ -375,6 +430,7 @@ def _insert_variants(self): # pylint: disable=too-many-locals,too-many-branches
except KeyError:
hom_counts = None # null is better than 0, as 0 has a meaning
except ValueError:
# multiple variants on same row
hom_counts = [int(count) for count in info['AC_Hom'].split(',')]

fmt_alleles = [f'{base["chrom"]}-{base["pos"]}-{base["ref"]}-{x}'
Expand Down Expand Up @@ -435,37 +491,10 @@ def _insert_variants(self): # pylint: disable=too-many-locals,too-many-branches

if len(batch) >= self.settings.batch_size:
if not self.settings.dry_run:
if not self.settings.beacon_only:
try:
curr_id = (db.Variant.select(db.Variant.id)
.order_by(db.Variant.id.desc())
.limit(1)
.get().id)
except db.Variant.DoesNotExist:
# assumes next id will be 1 if table is empty
curr_id = 0

db.Variant.insert_many(batch).execute()

if not self.settings.beacon_only:
last_id = (db.Variant.select(db.Variant.id)
.order_by(db.Variant.id.desc())
.limit(1)
.get().id)
if last_id-curr_id == len(batch):
indexes = list(range(curr_id+1, last_id+1))
else:
indexes = []
for entry in batch:
indexes.append(db.Variant.select(db.Variant.id)
.where(db.Variant.variant_id == \
entry['variant_id'])
.get().id)
self.add_variant_genes(indexes, genes, ref_genes)
self.add_variant_transcripts(indexes,
transcripts,
ref_transcripts)

self._add_variants_to_db(batch,
genes,
transcripts,
references)
genes = []
transcripts = []
batch = []
Expand All @@ -476,33 +505,10 @@ def _insert_variants(self): # pylint: disable=too-many-locals,too-many-branches
last_progress)

if batch and not self.settings.dry_run:
if not self.settings.beacon_only:
try:
curr_id = (db.Variant.select(db.Variant.id)
.order_by(db.Variant.id.desc())
.limit(1)
.get().id)
except db.Variant.DoesNotExist:
# assumes next id will be 1 if table is empty
curr_id = 0

db.Variant.insert_many(batch).execute()

if not self.settings.beacon_only:
last_id = (db.Variant.select(db.Variant.id)
.order_by(db.Variant.id.desc())
.limit(1)
.get().id)
if last_id-curr_id == len(batch):
indexes = list(range(curr_id+1, last_id+1))
else:
indexes = []
for entry in batch:
indexes.append(db.Variant.select(db.Variant.id)
.where(db.Variant.variant_id == entry['variant_id'])
.get().id)
self.add_variant_genes(indexes, genes, ref_genes)
self.add_variant_transcripts(indexes, transcripts, ref_transcripts)
self._add_variants_to_db(batch,
genes,
transcripts,
references)

if self.settings.set_vcf_sampleset_size and samples:
self.sampleset.sample_size = samples
Expand Down Expand Up @@ -584,8 +590,17 @@ def start_import(self):
if not self.settings.beacon_only and self.settings.coverage_file:
self._insert_coverage()

def add_variant_genes(self, variant_indexes: list, genes_to_add: list, ref_genes: dict):
"""Add genes associated with the provided variants."""
def _add_variant_genes(self, variant_indexes: list,
genes_to_add: list,
ref_genes: dict):
"""
Add genes associated with the provided variants.

Args:
variant_indexes (list): dbids of the variants
genes_to_add (list): the genes for each variant (str)
ref_genes (dict): genename: dbid
"""
batch = []
for i in range(len(variant_indexes)):
connected_genes = [{'variant': variant_indexes[i], 'gene': ref_genes[gene]}
Expand All @@ -595,9 +610,17 @@ def add_variant_genes(self, variant_indexes: list, genes_to_add: list, ref_genes
if not self.settings.dry_run:
db.VariantGenes.insert_many(batch).execute()

def add_variant_transcripts(self, variant_indexes: list,
transcripts_to_add: list, ref_transcripts: dict):
"""Add genes associated with the provided variants."""
def _add_variant_transcripts(self, variant_indexes: list,
transcripts_to_add: list,
ref_transcripts: dict):
"""
Add transcripts associated with the provided variants.

Args:
variant_indexes (list): dbids of the variants
transcripts_to_add (list): the transcripts for each variant (str)
ref_transcripts (dict): genename: dbid
"""
batch = []
for i in range(len(variant_indexes)):
connected_transcripts = [{'variant': variant_indexes[i],
Expand Down