Skip to content

Commit

Permalink
Merge 1d9f08f into 2fed733
Browse files Browse the repository at this point in the history
  • Loading branch information
northwestwitch committed May 20, 2024
2 parents 2fed733 + 1d9f08f commit eaabc79
Show file tree
Hide file tree
Showing 15 changed files with 266 additions and 301 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ This project adheres to [Semantic Versioning](http://semver.org/).
### Changed
- Replace ruamel.yaml with pyyaml lib
- Updated GitHub actions
- Replaced alchy lib with sqlservice
- Unfreeze SQLAlchemy
### Fixed
- Add missing brew path to GitHub action. It has been removed from PATH variable in Ubuntu
- Badges on README page
Expand Down
62 changes: 32 additions & 30 deletions chanjo/calculate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,19 @@ class CalculateMixin:

def mean(self, sample_ids=None):
"""Calculate the mean values of all metrics per sample."""
sql_query = self.query(
TranscriptStat.sample_id,
func.avg(TranscriptStat.mean_coverage),
func.avg(TranscriptStat.completeness_10),
func.avg(TranscriptStat.completeness_15),
func.avg(TranscriptStat.completeness_20),
func.avg(TranscriptStat.completeness_50),
func.avg(TranscriptStat.completeness_100),
).group_by(TranscriptStat.sample_id)
if sample_ids:
sql_query = sql_query.filter(TranscriptStat.sample_id.in_(sample_ids))
return sql_query
with self.begin() as session:
sql_query = session.query(
TranscriptStat.sample_id,
func.avg(TranscriptStat.mean_coverage),
func.avg(TranscriptStat.completeness_10),
func.avg(TranscriptStat.completeness_15),
func.avg(TranscriptStat.completeness_20),
func.avg(TranscriptStat.completeness_50),
func.avg(TranscriptStat.completeness_100),
).group_by(TranscriptStat.sample_id)
if sample_ids:
sql_query = sql_query.filter(TranscriptStat.sample_id.in_(sample_ids))
return sql_query

def gene_metrics(self, *genes):
"""Calculate gene statistics."""
Expand All @@ -37,22 +38,23 @@ def gene_metrics(self, *genes):

def sample_coverage(self, sample_ids: list, genes: list) -> dict:
"""Calculate coverage for samples."""
query = self.query(
TranscriptStat.sample_id.label('sample_id'),
func.avg(TranscriptStat.mean_coverage).label('mean_coverage'),
func.avg(TranscriptStat.completeness_10).label('mean_completeness'),
).join(
Transcript,
).filter(
Transcript.gene_id.in_(genes),
TranscriptStat.sample_id.in_(sample_ids),
).group_by(TranscriptStat.sample_id)

data = {
result.sample_id: {
"mean_coverage": result.mean_coverage,
"mean_completeness": result.mean_completeness,
with self.begin() as session:
query = session.query(
TranscriptStat.sample_id.label('sample_id'),
func.avg(TranscriptStat.mean_coverage).label('mean_coverage'),
func.avg(TranscriptStat.completeness_10).label('mean_completeness'),
).join(
Transcript,
).filter(
Transcript.gene_id.in_(genes),
TranscriptStat.sample_id.in_(sample_ids),
).group_by(TranscriptStat.sample_id)

data = {
result.sample_id: {
"mean_coverage": result.mean_coverage,
"mean_completeness": result.mean_completeness,
}
for result in query
}
for result in query
}
return data
return data
16 changes: 8 additions & 8 deletions chanjo/cli/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def db_cmd(context):
@click.option("--reset", is_flag=True, help="tear down existing db")
@click.pass_context
def setup(context, reset):
"""Initialize a new datbase from scratch."""
"""Initialize a new database from scratch."""
if reset:
LOG.info("tearing down existing database")
context.obj["db"].tear_down()
Expand All @@ -37,13 +37,13 @@ def remove(context, sample_id):
"""Remove all traces of a sample from the database."""
store = context.obj["db"]
LOG.debug("find sample in database with id: %s", sample_id)
sample_obj = Sample.query.get(sample_id)
if sample_obj is None:
LOG.warning("sample (%s) not found in database", sample_id)
context.abort()
LOG.info("delete sample (%s) from database", sample_id)
store.session.delete(sample_obj)
store.save()
with store.begin() as session:
sample_obj = session.first(Sample.select().where(Sample.id == sample_id))
if sample_obj is None:
LOG.warning("sample (%s) not found in database", sample_id)
context.abort()
LOG.info("delete sample (%s) from database", sample_id)
session.delete(sample_obj)


@db_cmd.command()
Expand Down
32 changes: 16 additions & 16 deletions chanjo/cli/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,16 @@ def load(context, sample, group, name, group_name, threshold, bed_stream):
result.sample.name = name
result.sample.group_name = group_name
try:
chanjo_db.add(result.sample)
with click.progressbar(result.models, length=result.count,
label='loading transcripts') as bar:
for tx_model in bar:
chanjo_db.add(tx_model)
chanjo_db.save()
with chanjo_db.begin() as session:
session.add(result.sample)
with click.progressbar(result.models, length=result.count,
label='loading transcripts') as bar:
for tx_model in bar:
session.add(tx_model)

except IntegrityError as error:
LOG.error('sample already loaded, rolling back')
LOG.debug(error.args[0])
chanjo_db.session.rollback()
context.abort()


Expand All @@ -68,14 +68,14 @@ def link(context, bed_stream):
"""Link related genomic elements."""
chanjo_db = ChanjoDB(uri=context.obj['database'])
result = link_elements(bed_stream)
with click.progressbar(result.models, length=result.count,
label='adding transcripts') as bar:
for tx_model in bar:
chanjo_db.add(tx_model)
try:
chanjo_db.save()
with chanjo_db.begin() as session:
with click.progressbar(result.models, length=result.count,
label='adding transcripts') as bar:
for tx_model in bar:
session.add(tx_model)

except IntegrityError:
LOG.exception('elements already linked?')
chanjo_db.session.rollback()
click.echo("use 'chanjo db setup --reset' to re-build")
context.abort()
LOG.exception('elements already linked?')
click.echo("use 'chanjo db setup --reset' to re-build")
context.abort()
39 changes: 8 additions & 31 deletions chanjo/store/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
import logging
import os

from alchy import Manager
from sqlservice import Database
from chanjo.calculate import CalculateMixin
from .models import BASE
from .fetch import FetchMixin
from .delete import DeleteMixin

log = logging.getLogger(__name__)
LOG = logging.getLogger(__name__)


class ChanjoDB(Manager, CalculateMixin, DeleteMixin, FetchMixin):
class ChanjoDB(Database, CalculateMixin, DeleteMixin, FetchMixin):
"""SQLAlchemy-based database object.
Bundles functionality required to setup and interact with various
Expand Down Expand Up @@ -45,16 +45,14 @@ class ChanjoDB(Manager, CalculateMixin, DeleteMixin, FetchMixin):
"""

def __init__(self, uri=None, debug=False, base=BASE):
self.Model = base
self.uri = uri
self.model_class = base
if uri:
self.connect(uri, debug=debug)


def connect(self, db_uri, debug=False):
"""Configure connection to a SQL database.
.. versionadded:: 2.1.0
Args:
db_uri (str): path/URI to the database to connect to
debug (Optional[bool]): whether to output logging information
Expand All @@ -66,12 +64,9 @@ def connect(self, db_uri, debug=False):
# expect only a path to a sqlite database
db_path = os.path.abspath(os.path.expanduser(db_uri))
db_uri = "sqlite:///{}".format(db_path)
self.uri = db_uri

config['SQLALCHEMY_DATABASE_URI'] = db_uri

# connect to the SQL database
super(ChanjoDB, self).__init__(config=config, Model=self.Model)
super(ChanjoDB, self).__init__(db_uri, model_class=BASE)

@property
def dialect(self):
Expand All @@ -92,8 +87,8 @@ def set_up(self):
"""
# create the tables
self.create_all()
tables = self.Model.metadata.tables.keys()
log.info("created tables: %s", ', '.join(tables))
tables = self.model_class.metadata.tables.keys()
LOG.info("created tables: %s", ', '.join(tables))
return self

def tear_down(self):
Expand All @@ -105,21 +100,3 @@ def tear_down(self):
# drop/delete the tables
self.drop_all()
return self

def save(self):
"""Manually persist changes made to various elements. Chainable.
.. versionchanged:: 2.1.2
Flush session before commit.
Returns:
Store: ``self`` for chainability
"""
try:
# commit/persist dirty changes to the database
self.commit()
except Exception as error:
log.debug('rolling back failed transaction')
self.session.rollback()
raise error
return self
17 changes: 10 additions & 7 deletions chanjo/store/delete.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module for deleting from database"""

import logging
from chanjo.store.models import Sample

LOG = logging.getLogger(__name__)

Expand All @@ -11,15 +12,17 @@ class DeleteMixin:

def delete_sample(self, sample_id):
"""Delete single sample from database"""
sample = list(self.fetch_samples(sample_id=sample_id))
if len(sample) > 0:
LOG.info("Deleting sample %s from database", sample[0].id)
self.delete_commit(sample)
LOG.info(f"Deleting sample {sample_id} from database")
with self.begin() as session:
sample = session.get(Sample, sample_id)
if sample:
session.delete(sample)

def delete_group(self, group_id):
"""Delete entire group from database"""
LOG.info("Deleting entire group %s from database", group_id)
samples = self.fetch_samples(group_id=group_id)
for sample in samples:
LOG.info("Deleting sample %s from database", sample.id)
self.delete_commit(sample)
with self.begin() as session:
for sample in samples:
LOG.info("Deleting sample %s from database", sample.id)
session.execute(sample.delete())
22 changes: 12 additions & 10 deletions chanjo/store/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,20 @@ def fetch_samples(self, sample_id=None, group_id=None):
"""
Fetch samples from database
"""
query = self.query(Sample)
if sample_id:
query = query.filter(Sample.id == sample_id)
if group_id:
query = query.filter(Sample.group_id == group_id)
return query
with self.begin() as session:
query = session.query(Sample)
if sample_id:
query = query.filter(Sample.id == sample_id)
if group_id:
query = query.filter(Sample.group_id == group_id)
return query

def fetch_transcripts(self, sample_id):
"""
Fetch transcripts from database
"""
query = self.query(TranscriptStat).filter(
TranscriptStat.sample_id == sample_id
)
return query
with self.begin() as session:
query = session.query(TranscriptStat).filter(
TranscriptStat.sample_id == sample_id
)
return query
8 changes: 3 additions & 5 deletions chanjo/store/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
from collections import namedtuple
from datetime import datetime

from alchy import ModelBase, make_declarative_base
from sqlservice import declarative_base
from sqlalchemy import Column, types, ForeignKey, UniqueConstraint, orm

Exon = namedtuple('Exon', ['chrom', 'start', 'end', 'completeness'])

# base for declaring a mapping
BASE = make_declarative_base(Base=ModelBase)

BASE = declarative_base()

class Transcript(BASE):

Expand Down Expand Up @@ -59,7 +58,6 @@ class Sample(BASE):
sample = orm.relationship('TranscriptStat', cascade='all,delete',
backref='sample')


class TranscriptStat(BASE):

"""Statistics on transcript level, related to sample and transcript.
Expand Down Expand Up @@ -89,7 +87,7 @@ class TranscriptStat(BASE):
threshold = Column(types.Integer)
_incomplete_exons = Column(types.Text)

sample_id = Column(types.String(32), ForeignKey('sample.id'),
sample_id = Column(types.String(32), ForeignKey('sample.id', ondelete='CASCADE'),
nullable=False)
transcript_id = Column(types.String(32), ForeignKey('transcript.id'),
nullable=False)
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
coloredlogs
alchy
click
path.py
toolz
pyyaml
importlib_metadata
pymysql
sqlalchemy==1.3.*
sqlalchemy
sqlservice
Loading

0 comments on commit eaabc79

Please sign in to comment.