Skip to content

Commit

Permalink
Adding the possibility of seeing the used attribute metrics in the me…
Browse files Browse the repository at this point in the history
…trics file. Added 'default' key to the filering (for compatibility with attributes). All properly tested.
  • Loading branch information
lucventurini committed Mar 31, 2021
1 parent 2011ee5 commit c1c365c
Show file tree
Hide file tree
Showing 8 changed files with 263 additions and 68 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
id: cache-miniconda
uses: actions/cache@v2
env:
CACHE_NUMBER: 0
CACHE_NUMBER: 1
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
path: ~/conda_pkgs_dir
Expand Down
41 changes: 23 additions & 18 deletions Mikado/_transcripts/scoring_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,32 @@ def compiler(expression):
raise InvalidConfiguration("Invalid expression:\n{}".format(expression))


class Unique(validate.Validator):
message = "{input} is Not unique"

def __init__(self, error=None):
self.error = error

def _repr_args(self):
return ""

def _format_error(self, value):
return self.message.format(input=value)

def __call__(self, value):
if len(value) != len(set(value)):
raise validate.ValidationError(self._format_error(value))
return value


@dataclass
class SizeFilter:
operator: str = field(metadata={"required": True, "validate": validate.OneOf(["gt", "ge", "lt", "le"])})
value: float = field(metadata={"required": True})
metric: Optional[str] = field(metadata={"required": False}, default=None)
name: Optional[str] = field(default=None)
source: Optional[str] = field(default=None)
default: Optional[float] = field(default=0)


@dataclass
Expand All @@ -39,42 +58,28 @@ class NumBoolEqualityFilter:
metric: Optional[str] = field(metadata={"required": False}, default=None)
name: Optional[str] = field(default=None)
source: Optional[str] = field(default=None)
default: Optional[Union[float, bool]] = field(default=0)


@dataclass
class InclusionFilter:
value: list = field(metadata={"required": True})
value: list = field(metadata={"required": True, "validate": [Unique]})
operator: str = field(metadata={"required": True, "validate": validate.OneOf(["in", "not in"])})
metric: Optional[str] = field(metadata={"required": False}, default=None)
name: Optional[str] = field(default=None)
source: Optional[str] = field(default=None)
default: Optional[List[float]] = field(default=None)


@dataclass
class RangeFilter:
class Unique(validate.Validator):
message = "{input} is Not unique"

def __init__(self, error=None):
self.error = error

def _repr_args(self):
return ""

def _format_error(self, value):
return self.message.format(input=value)

def __call__(self, value):
if len(value) != len(set(value)):
raise validate.ValidationError(self._format_error(value))
return value

value: List[float] = field(metadata={
"required": True,
"validate": [validate.Length(min=2, max=2), Unique]})
operator: str = field(metadata={"required": True, "validate": validate.OneOf(["within", "not within"])})
metric: Optional[str] = field(metadata={"required": False}, default=None)
name: Optional[str] = field(default=None)
default: Optional[float] = field(default=0)


@dataclass
Expand Down
95 changes: 52 additions & 43 deletions Mikado/_transcripts/transcript_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ def __init__(self, *args,
self.__derived_children = set()
self.__original_source = None
self.__external_scores = Namespace(default=(0, False))
self.__internal_orf_transcripts = []
self.__cds_not_maximal = None

# Starting settings for everything else
Expand Down Expand Up @@ -889,39 +888,51 @@ def configuration(self):
return self.__configuration

@property
def frames(self):
def frames(self) -> dict:
"""This property will return a dictionary with three keys - the three possible frames, 0, 1 and 2 - and within
each, a set of the positions that are in that frame. If the transcript does not have """
self.finalize()
frames = {0: set(), 1: set(), 2: set()}
for orf in self.internal_orfs:
if self.strand == "-":
exons = sorted([(_[1][0], _[1][1], _[2]) for _ in orf
if _[0] == "CDS"], key=operator.itemgetter(0, 1), reverse=True)
for start, end, phase in exons:
frame = ((3 - phase) % 3 - 1) % 3
for pos in range(end, start - 1, -1):
frame = abs((frame + 1) % 3)
frames[frame].add(pos)
else:
exons = sorted([(_[1][0], _[1][1], _[2]) for _ in orf
if _[0] == "CDS"], key=operator.itemgetter(0, 1), reverse=False)
for start, end, phase in exons:
frame = ((3 - phase) % 3 - 1) % 3 # Previous frame before beginning of the feature
for pos in range(start, end + 1):
frame = abs((frame + 1) % 3)
frames[frame].add(pos)
return frames

@functools.lru_cache()
def calculate_frames(internal_orfs, strand) -> dict:
frames = {0: set(), 1: set(), 2: set()}
for orf in internal_orfs:
if strand == "-":
exons = sorted([(_[1][0], _[1][1], _[2]) for _ in orf
if _[0] == "CDS"], key=operator.itemgetter(0, 1), reverse=True)
for start, end, phase in exons:
frame = ((3 - phase) % 3 - 1) % 3
for pos in range(end, start - 1, -1):
frame = abs((frame + 1) % 3)
frames[frame].add(pos)
else:
exons = sorted([(_[1][0], _[1][1], _[2]) for _ in orf
if _[0] == "CDS"], key=operator.itemgetter(0, 1), reverse=False)
for start, end, phase in exons:
frame = ((3 - phase) % 3 - 1) % 3 # Previous frame before beginning of the feature
for pos in range(start, end + 1):
frame = abs((frame + 1) % 3)
frames[frame].add(pos)
for frame in frames:
frames[frame] = frozenset(frames[frame])
return frames

return calculate_frames(tuple(tuple(_) for _ in self.internal_orfs), self.strand)

@property
def framed_codons(self):
def framed_codons(self) -> List:
"""Return the list of codons as calculated by self.frames."""

codons = list(zip(*[sorted(self.frames[0]), sorted(self.frames[1]), sorted(self.frames[2])]))
if self.strand == "-":
codons = list(reversed(codons))
@functools.lru_cache()
def calculate_codons(frames, strand) -> list:
# Reconvert to dictionary. We had to turn into a tuple of items for hashing.
frames = dict(frames)
codons = list(zip(*[sorted(frames[0]), sorted(frames[1]), sorted(frames[2])]))
if strand == "-":
codons = list(reversed(codons))
return codons

return codons
return calculate_codons(tuple(self.frames.items()), self.strand)

@property
def _selected_orf_transcript(self):
Expand All @@ -939,26 +950,28 @@ def _internal_orfs_transcripts(self):
Note: this will exclude the UTR part, even when the transcript only has one ORF."""

self.finalize()
if not self.is_coding:
return []
elif len(self.__internal_orf_transcripts) == len(self.internal_orfs):
return self.__internal_orf_transcripts
else:
for num, orf in enumerate(self.internal_orfs, start=1):
@functools.lru_cache()
def calculate_orf_transcripts(internal_orfs, chrom, strand, tid):
orf_transcripts = []
for num, orf in enumerate(internal_orfs, start=1):
torf = TranscriptBase()
torf.chrom, torf.strand = self.chrom, self.strand
torf.derives_from = self.id
torf.id = "{}.orf{}".format(self.id, num)
torf.chrom, torf.strand = chrom, strand
torf.derives_from = tid
torf.id = "{}.orf{}".format(tid, num)
__exons, __phases = [], []
for segment in [_ for _ in orf if _[0] == "CDS"]:
__exons.append(segment[1])
__phases.append(segment[2])
torf.add_exons(__exons, features="exon", phases=None)
torf.add_exons(__exons, features="CDS", phases=__phases)
torf.finalize()
self.__internal_orf_transcripts.append(torf)
orf_transcripts.append(torf)
return orf_transcripts

return self.__internal_orf_transcripts
orf_transcripts = calculate_orf_transcripts(tuple(tuple(_) for _ in self.internal_orfs),
self.chrom, self.strand, self.id)

return orf_transcripts

def as_bed12(self) -> BED12:

Expand Down Expand Up @@ -1131,7 +1144,6 @@ def unfinalize(self):
return

self.internal_orfs = []
self.__internal_orf_transcripts = []
self.combined_utr = []
self._cdna_length = None
self.finalized = False
Expand Down Expand Up @@ -1741,10 +1753,7 @@ def original_source(self, value):
@property
def gene(self):

if "gene_id" not in self.attributes:
self.attributes["gene_id"] = self.parent[0]

return self.attributes["gene_id"]
return self.attributes.get("gene_id", self.parent[0])

@property
def location(self):
Expand Down Expand Up @@ -1774,7 +1783,7 @@ def score(self, score):
if not isinstance(score, (float, int)):
try:
score = float(score)
except:
except (ValueError, TypeError):
raise ValueError(
"Invalid value for score: {0}, type {1}".format(score, type(score)))
self.__score = score
Expand Down
21 changes: 20 additions & 1 deletion Mikado/loci/abstractlocus.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,10 @@ def evaluate(param: Union[str, int, bool, float],
comparison = (param not in conf.value)
elif comparison_operator == "within":
start, end = sorted([conf.value[0], conf.value[1]])
try:
float(param)
except (TypeError, ValueError):
raise ValueError(param)
comparison = (start <= float(param) <= end)
elif comparison_operator == "not within":
start, end = sorted([conf.value[0], conf.value[1]])
Expand Down Expand Up @@ -1079,6 +1083,20 @@ def _create_metrics_row(self, tid: str, metrics: dict, transcript: Transcript) -
value = "NA"
row[key] = value

# Add the attributes
for param_dict in [self.configuration.scoring.scoring, self.configuration.scoring.cds_requirements.parameters,
self.configuration.scoring.as_requirements.parameters,
self.configuration.scoring.not_fragmentary.parameters,
self.configuration.scoring.requirements.parameters]:
for attr_name, attr_metric in [(key, metric) for key, metric in param_dict.items()
if key.startswith("attributes")]:
value = transcript.attributes.get(attr_name.replace("attributes.", ""), attr_metric.default)
if isinstance(value, float):
value = round(value, 2)
elif value is None or value == "":
value = "NA"
row[attr_name] = value

return row

def print_metrics(self):
Expand Down Expand Up @@ -1283,7 +1301,8 @@ def _check_not_passing(self, previous_not_passing=(), section_name="requirements
for key in section.parameters:
if "attributes" in key:
key_parts = key.split('.')
value = self.transcripts[tid].attributes[key_parts[1]]
default = section.parameters[key].default
value = self.transcripts[tid].attributes.get(key_parts[1], default)
else:
if section.parameters[key].name is not None:
name = section.parameters[key].name
Expand Down
17 changes: 14 additions & 3 deletions Mikado/loci/locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import operator
from collections import defaultdict
import pysam

from ..scales.resultstorer import ResultStorer
from ..transcripts.pad import pad_transcript
from ..transcripts.transcript import Transcript
from .abstractlocus import Abstractlocus
Expand Down Expand Up @@ -582,7 +584,8 @@ def _check_as_requirements(self, transcript: Transcript, is_reference=False) ->
name = section.parameters[key].name
if 'attributes' in name:
name_parts = name.split('.')
value = transcript.attributes[name_parts[1]]
default = section.parameters[key].default
value = transcript.attributes.get(name_parts[1], default)
else:
value = operator.attrgetter(name)(transcript)
if "external" in key:
Expand Down Expand Up @@ -615,7 +618,11 @@ def is_putative_fragment(self):
evaluated = dict()
for key, params in self.configuration.scoring.not_fragmentary.parameters.items():
name = params.name
value = operator.attrgetter(name)(self.primary_transcript)
if name.startswith("attributes"):
default = params.default
value = self.primary_transcript.attributes.get(name.replace("attributes.", ""), default)
else:
value = operator.attrgetter(name)(self.primary_transcript)
if "external" in key:
value = value[0]
try:
Expand Down Expand Up @@ -644,7 +651,7 @@ def is_putative_fragment(self):
assert self.id == current_id
return fragment

def other_is_fragment(self, other):
def other_is_fragment(self, other) -> (bool, Union[None, ResultStorer]):
"""
If the 'other' locus is marked as a potential fragment (see 'is_putative_fragment'), then this function
will check whether the other locus is within the distance and with the correct comparison class code to be
Expand All @@ -661,6 +668,10 @@ def other_is_fragment(self, other):
self.logger.debug("Self-comparisons are not allowed!")
return False, None

if other.is_putative_fragment() is False:
self.logger.debug(f"{other.id} cannot be a fragment according to the scoring file definition.")
return False, None

if any(other.transcripts[tid].is_reference is True for tid in other.transcripts.keys()):
self.logger.debug("Locus %s has a reference transcript, hence it will not be discarded", other.id)
return False, None
Expand Down
13 changes: 13 additions & 0 deletions Mikado/picking/picker.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,18 @@ def __get_output_files(self):
requested_external.update({param for param in self.configuration.scoring.scoring.keys()
if param.startswith("external")})

attribute_metrics = set()
attribute_metrics.update({param for param in self.configuration.scoring.requirements.parameters.keys()
if param.startswith("attributes")})
attribute_metrics.update({param for param in self.configuration.scoring.not_fragmentary.parameters.keys()
if param.startswith("attributes")})
attribute_metrics.update({param for param in self.configuration.scoring.cds_requirements.parameters.keys()
if param.startswith("attributes")})
attribute_metrics.update({param for param in self.configuration.scoring.cds_requirements.parameters.keys()
if param.startswith("attributes")})
attribute_metrics.update({param for param in self.configuration.scoring.scoring.keys()
if param.startswith("attributes")})

# Check that the external scores are all present. If they are not, raise a warning.
if requested_external - set(available_external_metrics):
self.logger.error(
Expand All @@ -486,6 +498,7 @@ def __get_output_files(self):
metrics.extend(available_external_metrics)
else:
metrics.extend(requested_external)
metrics.extend(attribute_metrics)
metrics = Superlocus.available_metrics[:5] + sorted(metrics)
session.close()
engine.dispose()
Expand Down
2 changes: 1 addition & 1 deletion Mikado/tests/locus_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2271,7 +2271,7 @@ def test_only_CDS_overlap(self):
with self.subTest(min_overlap=min_overlap):
cds_overlap = 0
for frame in range(3):
cds_overlap += len(set.intersection(
cds_overlap += len(frozenset.intersection(
self.t1.frames[frame], t2.frames[frame]
))

Expand Down
Loading

0 comments on commit c1c365c

Please sign in to comment.