Skip to content

Commit

Permalink
Adds some tests for mongo adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
Måns Magnusson committed Oct 24, 2018
1 parent 6268cb3 commit 0ad8d61
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 62 deletions.
2 changes: 1 addition & 1 deletion chanjo/store/mongo/calculate.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,6 @@ def mean(self, sample_ids=None, gene_ids=None):

def gene_metrics(self, *genes):
"""Calculate gene statistics."""
query = self.mean(gene_ids=genes))
query = self.mean(gene_ids=genes)

return query
157 changes: 97 additions & 60 deletions chanjo/store/mongo/mongo_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import logging
import os

from datetime import datetime

from .calculate import CalculateMixin
from chanjo.store.models import (Transcript, TranscriptStat, Sample, Exon)
from sqlalchemy.exc import IntegrityError
from pymongo.errors import BulkWriteError
from pymongo.errors import (BulkWriteError, DuplicateKeyError)

from mongo_adapter import (MongoAdapter, get_client)

Expand Down Expand Up @@ -127,6 +129,9 @@ def setup(self, db_name='chanjo'):

if self.db is None:
self.db = self.client[db_name]

print(self.db)
print(type(self.db))
self.db_name = db_name
self.session = Session(self.db)

Expand Down Expand Up @@ -160,9 +165,9 @@ def dialect(self):
NOT APPLICABLE FOR MONGODB
Returns:
None
dialect(str): 'mongodb'
"""
return None
return 'mongodb'

def set_up(self):
"""Initialize a new database with the default tables and columns.
Expand All @@ -182,59 +187,64 @@ def tear_down(self):
Store: self
"""
# drop/delete the tables
self.db.drop_database()
self.client.drop_database(self.db_name)
return self

def add(self, obj):
"""Add objects to mocked session"""
if isinstance(obj, Transcript):
tx_obj = dict(
_id = obj.id,
gene_id = obj.gene_id,
gene_name = obj.gene_name,
chromosome = obj.chromosome,
length = obj.length,
)

self.transcripts_bulk.append(tx_obj)
def add(self, *objs):
"""Add objects to mocked session
elif isinstance(obj, Sample):
sample_obj = dict(
_id = obj.id,
group_id = obj.group_id,
source = obj.source,
created_at = obj.created_at,
name = obj.name,
group_name = obj.group_name,
)

self.sample_bulk.append(sample_obj)
self.session.add_sample(sample_obj)

elif isinstance(obj, TranscriptStat):
if self.tx_dict is None:
self.tx_dict = self.transcripts_genes()
Args:
objs(iterable): Could be different types of objects
"""
for obj in objs:
if isinstance(obj, Transcript):
tx_obj = dict(
_id = obj.id,
gene_id = obj.gene_id,
gene_name = obj.gene_name,
chromosome = obj.chromosome,
length = obj.length,
)

self.transcripts_bulk.append(tx_obj)

tx_info = self.tx_dict[obj.transcript_id]

tx_stats_obj = dict(
mean_coverage = obj.mean_coverage,
completeness_10 = obj.completeness_10,
completeness_15 = obj.completeness_15,
completeness_20 = obj.completeness_20,
completeness_50 = obj.completeness_50,
completeness_100 = obj.completeness_100,
threshold = obj.threshold,
sample_id = obj.sample_id,
transcript_id = obj.transcript_id,
gene_id = tx_info['gene_id'],
gene_name = tx_info['gene_name'],
)
if obj._incomplete_exons is not None:
tx_stats_obj['_incomplete_exons'] = obj._incomplete_exons.split(',')

self.tx_stats_bulk.append(tx_stats_obj)
self.session.add_transcript_stat(tx_stats_obj)
elif isinstance(obj, Sample):
sample_obj = dict(
_id = obj.id,
group_id = obj.group_id,
source = obj.source,
created_at = datetime.now(),
name = obj.name,
group_name = obj.group_name,
)

self.sample_bulk.append(sample_obj)
self.session.add_sample(sample_obj)

elif isinstance(obj, TranscriptStat):
if self.tx_dict is None:
self.tx_dict = self.transcripts_genes()

tx_info = self.tx_dict[obj.transcript_id]

tx_stats_obj = dict(
mean_coverage = obj.mean_coverage,
completeness_10 = obj.completeness_10,
completeness_15 = obj.completeness_15,
completeness_20 = obj.completeness_20,
completeness_50 = obj.completeness_50,
completeness_100 = obj.completeness_100,
threshold = obj.threshold,
sample_id = obj.sample_id,
transcript_id = obj.transcript_id,
gene_id = tx_info['gene_id'],
gene_name = tx_info['gene_name'],
)
if obj._incomplete_exons is not None:
tx_stats_obj['_incomplete_exons'] = obj._incomplete_exons.split(',')

self.tx_stats_bulk.append(tx_stats_obj)
self.session.add_transcript_stat(tx_stats_obj)

def clean(self):
"""Clean the bulks"""
Expand All @@ -255,13 +265,15 @@ def save(self):
self.transcripts_collection.insert_many(self.transcripts_bulk)

if self.sample_bulk:
try:
self.sample_collection.insert_many(self.sample_bulk)
except BulkWriteError as err:
# This means that the sample already exists so we do not want to remove
# All the previously inserted data
self.session.clean()
raise err
for sample in self.sample_bulk:
try:
self.sample_collection.insert_one(sample)
except Exception as err:
# This means that the sample already exists so we do not want to remove
# All the previously inserted data
self.session.clean()
self.clean()
raise DuplicateKeyError('E11000 Duplicate key error', 11000)

if self.tx_stats_bulk:
self.transcript_stat_collection.insert_many(self.tx_stats_bulk)
Expand All @@ -285,7 +297,32 @@ def sample(self, sample_id):
if not sample_obj:
return None

return Sample(id=sample_obj['_id'], group_id=sample_obj.get('group_id'), source=sample_obj.get('source'))
return Sample(id=sample_obj['_id'], group_id=sample_obj.get('group_id'),
source=sample_obj.get('source'), created_at=sample_obj.get('created_at'))

def samples(self):
"""Return all samples from database
Args:
sample_id(str)
Returns:
sample_objs(list(models.Sample))
"""
LOG.info("Fetch all samples")
sample_objs = []
res = self.sample_collection.find()

for sample in res:
sample_objs.append(Sample(
id=sample['_id'],
group_id=sample.get('group_id'),
source=sample.get('source'),
created_at=sample.get('created_at'))
)

return sample_objs


def transcripts(self):
"""Return all transcripts
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ mock
ipdb
pytest-cov
coveralls
mongomock
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ click
path.py
toolz
ruamel.yaml
mongo_adapter
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from chanjo.cli import root
from chanjo.store.api import ChanjoDB
from chanjo.store.mongo import ChanjoMongoDB
from chanjo.load.parse import bed, sambamba
from chanjo.load.sambamba import load_transcripts
from chanjo.load.link import link_elements
Expand All @@ -29,6 +30,13 @@ def chanjo_db():
yield _chanjo_db
_chanjo_db.tear_down()

@pytest.yield_fixture(scope='function')
def chanjo_mongo_db():
_chanjo_db = ChanjoMongoDB('mongomock://')
_chanjo_db.set_up()
yield _chanjo_db
_chanjo_db.tear_down()


@pytest.yield_fixture(scope='function')
def existing_db(tmpdir):
Expand Down
50 changes: 49 additions & 1 deletion tests/store/test_store_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
from sqlalchemy.orm.exc import FlushError
from pymongo.errors import DuplicateKeyError

from chanjo.store.api import ChanjoDB
from chanjo.store.models import Sample
Expand All @@ -12,7 +13,6 @@ def test_dialect(chanjo_db):
assert chanjo_db.dialect == 'sqlite'
assert hasattr(chanjo_db, 'query')


def test_no_dialect():
# GIVEN not explicity specifying SQLite dialect
# WHEN setting up the API
Expand Down Expand Up @@ -55,3 +55,51 @@ def test_add_many(chanjo_db):
chanjo_db.save()
# THEN all samples should be added
assert Sample.query.all() == new_samples

### Mongo DB tests

def test_mongo_dialect(chanjo_mongo_db):
## GIVEN a mongo database connection
## WHEN setting up the api with a connection to mongomock
chanjo_db = chanjo_mongo_db
## THEN assert that the dialect is mongodb
assert chanjo_db.dialect == 'mongodb'
## THEN assert that the uri is mongomock
assert chanjo_db.uri == 'mongomock://'

def test_mongo_save(chanjo_mongo_db):
chanjo_db = chanjo_mongo_db
# GIVEN a new sample
sample_id = 'ADM12'
new_sample = Sample(id=sample_id, group_id='ADMG1', source='alignment.bam')
# WHEN added and saved to the database
chanjo_db.add(new_sample)
chanjo_db.save()
db_sample = chanjo_db.sample(sample_id)
# THEN is should exist in the database
assert db_sample.id == sample_id
assert isinstance(db_sample.created_at, datetime)

# GIVEN sample already exists
conflict_sample = Sample(id=sample_id, group_id='ADMG2')
# WHEN saving it again with same id
# THEN error is raised _after_ rollback
with pytest.raises(DuplicateKeyError):
chanjo_db.add(conflict_sample)
chanjo_db.save()

new_sampleid = 'ADM13'
chanjo_db.add(Sample(id=new_sampleid))
chanjo_db.save()
assert chanjo_db.sample(new_sampleid)

def test_mongo_add_many(chanjo_mongo_db):
chanjo_db = chanjo_mongo_db
# GIVEN multiple new samples
new_samples = [Sample(id='ADM12'), Sample(id='ADM13')]
# WHEN added to the session
chanjo_db.add(*new_samples)
chanjo_db.save()
# THEN all samples should be added
res = chanjo_db.samples()
assert len(res) == len(new_samples)

0 comments on commit 0ad8d61

Please sign in to comment.