Skip to content

Commit

Permalink
SiteRelations for chem2
Browse files Browse the repository at this point in the history
  • Loading branch information
johnsekar committed Apr 27, 2018
1 parent e31113c commit 818e5b0
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 30 deletions.
45 changes: 45 additions & 0 deletions tests/test_chem2.py
@@ -0,0 +1,45 @@
"""
:Author: John Sekar <johnarul.sekar@gmail.com>
:Date: 2018-04-27
:Copyright: 2017, Karr Lab
:License: MIT
"""

from wc_rules import chem2,utils
import unittest


class A(chem2.Molecule):pass
class X(chem2.Site):pass
class B(chem2.Molecule):pass
class Y(chem2.Site):pass

class TestBase(unittest.TestCase):
def test_site_relations(self):

A1 = A().set_id('A1').add_sites(
X().set_id('X1'),
X().set_id('X2')
)

B1 = B().set_id('B1').add_sites(
Y().set_id('Y1'),
Y().set_id('Y2'),
)

bnd1 = chem2.Bond()
with self.assertRaises(utils.AddError):
bnd1.add_sources(A1.sites.get_one(id='X1'))

bnd1.add_targets(
A1.sites.get_one(id='X1'),
B1.sites.get_one(id='Y1'),
)

with self.assertRaises(utils.AddError):
bnd1.add_targets(B1.sites.get_one(id='Y2'))

with self.assertRaises(utils.AddError):
bnd2 = chem2.Bond().add_targets(A1.sites.get_one(id='X1'))

return
78 changes: 54 additions & 24 deletions wc_rules/chem2.py
Expand Up @@ -8,44 +8,74 @@
from wc_rules import base,entity,utils


class Molecule(entity.Entity):pass
class Molecule(entity.Entity):
def add_sites(self,*args):
for arg in args:
if arg._verify_site_molecule_compatibility(molecule=self):
self.sites.append(arg)
else:
raise utils.AddError('Site incompatible with molecule')
return self

class Site(entity.Entity):
molecule = core.ManyToOneAttribute(Molecule,related_name='sites')

def _verify_site_molecule_compatibility(self,molecule=None):
return True

def _get_number_of_relations(self, source_or_target ='target', relation_type=None):
if source_or_target == 'target':
if self.site_relations_targets is not None:
existing = utils.filter_by_type(list(self.site_relations_targets),[relation_type])
return len(existing)
if source_or_target == 'source':
if self.site_relations_sources is not None:
existing = utils.filter_by_type(list(self.site_relations_sources),[relation_type])
return len(existing)
return 0

class SiteRelation(entity.Entity):
sources = core.OneToManyAttribute(Site,related_name='site_relations')
targets = core.OneToManyAttribute(Site,related_name='site_relations_targets')
sources = core.ManyToManyAttribute(Site,related_name='site_relations_sources')
targets = core.ManyToManyAttribute(Site,related_name='site_relations_targets')
n_max_sources = None
n_max_targets = None
n_max_relations_for_a_source = None
n_max_relations_for_a_target = None

def add_sources(self,*args):
def verify_add_sources(self,*args):
if self.n_max_sources is not None:
if len(self.sources)+len(args) > self.n_max_sources:
raise utils.AddError('Number of source sites allowed for this relation will be exceeded.')
for arg in args:
self.sources.append(arg)
return self
raise utils.AddError('Number of source sites allowed for '+str(type(self))+ ' relation must not exceed ' + str(self.n_max_sources)+'.')
if self.n_max_relations_for_a_source is not None:
for arg in args:
n_existing = arg._get_number_of_relations(source_or_target = 'source',relation_type=type(self))
if n_existing + 1 > self.n_max_relations_for_a_source:
raise utils.AddError('Source site already has the allowed maximum of '+str(type(self))+ ' relations: ' + str(self.n_max_relations_for_a_source)+'.')
return True

def add_targets(self,*args):
def verify_add_targets(self,*args):
if self.n_max_targets is not None:
if len(self.targets)+len(args) > self.n_max_targets:
raise utils.AddError('Number of target sites allowed for this relation will be exceeded.')
for arg in args:
self.targets.append(arg)
return self
raise utils.AddError('Number of target sites allowed for '+str(type(self))+ ' relation must not exceed ' + str(self.n_max_targets)+'.')
if self.n_max_relations_for_a_target is not None:
for arg in args:
n_existing = arg._get_number_of_relations(source_or_target = 'target',relation_type=type(self))
if n_existing + 1 > self.n_max_relations_for_a_target:
raise utils.AddError('Target site already has the allowed maximum of '+str(type(self))+ ' relations: ' + str(self.n_max_relations_for_a_target)+'.')
return True


class UndirectedSiteRelation(SiteRelation):
n_max_targets = 0
def add_sites(self,*args):
self.add_sources(*args)
def add_sources(self,*args):
if self.verify_add_sources(*args):
self.sources.extend(args)
return self

class Bond(UndirectedSiteRelation):
n_max_sources = 2

class Overlap(UndirectedSiteRelation):
n_max_sources = 2
def add_targets(self,*args):
if self.verify_add_targets(*args):
self.targets.extend(args)
return self

class DirectedSiteRelation(SiteRelation):pass
class Bond(SiteRelation):
n_max_sources = 0
n_max_targets = 2
n_max_relations_for_a_source = 0
n_max_relations_for_a_target = 1
6 changes: 0 additions & 6 deletions wc_rules/seq.py
Expand Up @@ -135,9 +135,3 @@ def add_feature(self,*args):
arg._verify_feature()
self.features.append(arg)
return self

class LeftOverlap(chem2.DirectedSiteRelation):
n_max_sources = 1

class RightOverlap(chem2.DirectedSiteRelation):
n_max_sources = 1
8 changes: 8 additions & 0 deletions wc_rules/utils.py
Expand Up @@ -26,6 +26,14 @@ def listify(value):
return [value]
return value

def filter_by_type(init_list,type_list):
final_list = []
for x in init_list:
for t in type_list:
if isinstance(x,t):
final_list.append(x)
return final_list

###### Error ######
class GenericError(Exception):

Expand Down

0 comments on commit 818e5b0

Please sign in to comment.