From ed4bf44ed58d802465e18a6c0122b270897f1fed Mon Sep 17 00:00:00 2001 From: mfrasca Date: Wed, 27 Dec 2017 10:15:53 -0500 Subject: [PATCH] trying to refactor notes classes (they're all similar). --- bauble/db.py | 61 ++++++++++++ bauble/plugins/garden/accession.py | 127 +------------------------ bauble/plugins/garden/plant.py | 98 +++++++------------ bauble/plugins/plants/family.py | 21 ++-- bauble/plugins/plants/genus.py | 22 ++--- bauble/plugins/plants/species_model.py | 82 ++++++---------- 6 files changed, 143 insertions(+), 268 deletions(-) diff --git a/bauble/db.py b/bauble/db.py index b86404d4d..43aa8f609 100644 --- a/bauble/db.py +++ b/bauble/db.py @@ -481,6 +481,67 @@ def verify_connection(engine, show_error_dialogs=False): return True +def make_note_class(name, compute_serializable_fields, as_dict=None, retrieve=None): + class_name = name + 'Note' + table_name = name.lower() + '_note' + + def is_defined(self): + return bool(self.user and self.category and self.note) + + def retrieve_or_create(cls, session, keys, + create=True, update=True): + """return database object corresponding to keys + """ + result = super(globals()[class_name], cls).retrieve_or_create(session, keys, create, update) + category = keys.get('category', '') + if (create and (category.startswith('[') and category.endswith(']') or + category.startswith('<') and category.endswith('>'))): + result = cls(**keys) + session.add(result) + return result + + def retrieve_default(cls, session, keys): + q = session.query(cls) + if name.lower() in keys: + q = q.join(globals()[name]).filter( + globals()[name].code == keys[name.lower()]) + if 'date' in keys: + q = q.filter(cls.date == keys['date']) + if 'category' in keys: + q = q.filter(cls.category == keys['category']) + try: + return q.one() + except: + return None + + def as_dict_default(self): + result = db.Serializable.as_dict(self) + result[name.lower()] = getattr(self, name.lower()).code + return result + + as_dict = as_dict or as_dict_default + retrieve = retrieve or retrieve_default + + result = type(class_name, (Base, Serializable), + {'__tablename__': table_name, + '__mapper_args__': {'order_by': table_name + '.date'}, + + 'date': sa.Column(types.Date, default=sa.func.now()), + 'user': sa.Column(sa.Unicode(64)), + 'category': sa.Column(sa.Unicode(32)), + 'note': sa.Column(sa.UnicodeText, nullable=False), + name.lower() + '_id': sa.Column(sa.Integer, sa.ForeignKey(name.lower() + '.id'), nullable=False), + name.lower(): sa.orm.relation(name, uselist=False, backref=sa.orm.backref( + 'notes', cascade='all, delete-orphan')), + 'retrieve': classmethod(retrieve), + 'retrieve_or_create': classmethod(retrieve_or_create), + 'compute_serializable_fields': classmethod(compute_serializable_fields), + 'is_defined': is_defined, + 'as_dict': as_dict, + }) + return result + + class WithNotes: key_pattern = re.compile(r'{[^:]+:(.*)}') diff --git a/bauble/plugins/garden/accession.py b/bauble/plugins/garden/accession.py index 7b80f9fc5..e65d07bc5 100755 --- a/bauble/plugins/garden/accession.py +++ b/bauble/plugins/garden/accession.py @@ -455,63 +455,6 @@ def after_update(self, mapper, conn, instance): } -def make_note_class(name, compute_serializable_fields, ): - class_name = name + 'Note' - table_name = name.lower() + '_note' - - def is_defined(self): - return bool(self.user and self.category and self.note) - - def retrieve_or_create(cls, session, keys, - create=True, update=True): - """return database object corresponding to keys - """ - result = super(globals()[class_name], cls).retrieve_or_create(session, keys, create, update) - category = keys.get('category', '') - if (create and (category.startswith('[') and category.endswith(']') or - category.startswith('<') and category.endswith('>'))): - result = cls(**keys) - session.add(result) - return result - - def retrieve(cls, session, keys): - q = session.query(cls) - if name.lower() in keys: - q = q.join(globals()[name]).filter( - globals()[name].code == keys[name.lower()]) - if 'date' in keys: - q = q.filter(cls.date == keys['date']) - if 'category' in keys: - q = q.filter(cls.category == keys['category']) - try: - return q.one() - except: - return None - - def as_dict(self): - result = db.Serializable.as_dict(self) - result[name.lower()] = getattr(self, name.lower()).code - return result - - result = type(class_name, (db.Base, db.Serializable), - {'__tablename__': table_name, - '__mapper_args__': {'order_by': table_name + '.date'}, - - 'date': Column(types.Date, default=func.now()), - 'user': Column(Unicode(64)), - 'category': Column(Unicode(32)), - 'note': Column(UnicodeText, nullable=False), - name.lower() + '_id': Column(Integer, ForeignKey(name.lower() + '.id'), nullable=False), - name.lower(): relation(name, uselist=False, backref=backref( - 'notes', cascade='all, delete-orphan')), - 'retrieve': classmethod(retrieve), - 'retrieve_or_create': classmethod(retrieve_or_create), - 'compute_serializable_fields': classmethod(compute_serializable_fields), - 'is_defined': is_defined, - 'as_dict': as_dict, - }) - return result - def compute_serializable_fields(cls, session, keys): result = {'accession': None} @@ -526,75 +469,7 @@ def compute_serializable_fields(cls, session, keys): return result -AccessionNote = make_note_class('Accession', compute_serializable_fields) - -if False: - class AccessionNote(db.Base, db.Serializable): - """ - Notes for the accession table - """ - __tablename__ = 'accession_note' - __mapper_args__ = {'order_by': 'accession_note.date'} - - date = Column(types.Date, default=func.now()) - user = Column(Unicode(64)) - category = Column(Unicode(32)) - note = Column(UnicodeText, nullable=False) - accession_id = Column(Integer, ForeignKey('accession.id'), nullable=False) - accession = relation( - 'Accession', uselist=False, - backref=backref('notes', cascade='all, delete-orphan')) - - def is_defined(self): - return bool(self.user and self.category and self.note) - - def as_dict(self): - result = db.Serializable.as_dict(self) - result['accession'] = self.accession.code - return result - - @classmethod - def retrieve_or_create(cls, session, keys, - create=True, update=True): - """return database object corresponding to keys - """ - result = super(AccessionNote, cls).retrieve_or_create(session, keys, create, update) - category = keys.get('category', '') - if (create and (category.startswith('[') and category.endswith(']') or - category.startswith('<') and category.endswith('>'))): - result = cls(**keys) - session.add(result) - return result - - @classmethod - def retrieve(cls, session, keys): - q = session.query(cls) - if 'accession' in keys: - q = q.join(Accession).filter( - Accession.code == keys['accession']) - if 'date' in keys: - q = q.filter(cls.date == keys['date']) - if 'category' in keys: - q = q.filter(cls.category == keys['category']) - try: - return q.one() - except: - return None - - @classmethod - def compute_serializable_fields(cls, session, keys): - result = {'accession': None} - - acc_keys = {} - acc_keys.update(keys) - acc_keys['code'] = keys['accession'] - accession = Accession.retrieve_or_create( - session, acc_keys, create=( - 'taxon' in acc_keys and 'rank' in acc_keys)) - - result['accession'] = accession - - return result +AccessionNote = db.make_note_class('Accession', compute_serializable_fields) class Accession(db.Base, db.Serializable, db.WithNotes): diff --git a/bauble/plugins/garden/plant.py b/bauble/plugins/garden/plant.py index 554db935f..072ccd9ac 100755 --- a/bauble/plugins/garden/plant.py +++ b/bauble/plugins/garden/plant.py @@ -30,7 +30,7 @@ import logging logger = logging.getLogger(__name__) -#logger.setLevel(logging.DEBUG) +logger.setLevel(logging.DEBUG) import gtk @@ -203,73 +203,46 @@ def search(self, text, session): return [] -# TODO: what would happen if the PlantRemove.plant_id and PlantNote.plant_id -# were out of synch.... how could we avoid these sort of cycles -class PlantNote(db.Base, db.Serializable): - __tablename__ = 'plant_note' - __mapper_args__ = {'order_by': 'plant_note.date'} - - date = Column(types.Date, default=func.now()) - user = Column(Unicode(64)) - category = Column(Unicode(32)) - note = Column(UnicodeText, nullable=False) - plant_id = Column(Integer, ForeignKey('plant.id'), nullable=False) - plant = relation('Plant', uselist=False, - backref=backref('notes', cascade='all, delete-orphan')) - - def as_dict(self): - result = db.Serializable.as_dict(self) - result['plant'] = (self.plant.accession.code + - Plant.get_delimiter() + self.plant.code) - return result - - @classmethod - def retrieve_or_create(cls, session, keys, - create=True, update=True): - """return database object corresponding to keys - """ - result = super(PlantNote, cls).retrieve_or_create(session, keys, create, update) - category = keys.get('category', '') - if (create and (category.startswith('[') and category.endswith(']') or - category.startswith('<') and category.endswith('>'))): - result = cls(**keys) - session.add(result) - return result - - @classmethod - def retrieve(cls, session, keys): - q = session.query(cls) - if 'plant' in keys: - acc_code, plant_code = keys['plant'].rsplit( - Plant.get_delimiter(), 1) - q = q.join( - Plant).filter(Plant.code == unicode(plant_code)).join( - Accession).filter(Accession.code == unicode(acc_code)) - if 'date' in keys: - q = q.filter(cls.date == keys['date']) - if 'category' in keys: - q = q.filter(cls.category == keys['category']) - try: - return q.one() - except: - return None - - @classmethod - def compute_serializable_fields(cls, session, keys): - 'plant is given as text, should be object' - result = {'plant': None} +def as_dict(self): + result = db.Serializable.as_dict(self) + result['plant'] = (self.plant.accession.code + + Plant.get_delimiter() + self.plant.code) + return result +def retrieve(cls, session, keys): + q = session.query(cls) + if 'plant' in keys: acc_code, plant_code = keys['plant'].rsplit( Plant.get_delimiter(), 1) - logger.debug("acc-plant: %s-%s" % (acc_code, plant_code)) - q = session.query(Plant).filter( - Plant.code == unicode(plant_code)).join( + q = q.join( + Plant).filter(Plant.code == unicode(plant_code)).join( Accession).filter(Accession.code == unicode(acc_code)) - plant = q.one() + if 'date' in keys: + q = q.filter(cls.date == keys['date']) + if 'category' in keys: + q = q.filter(cls.category == keys['category']) + try: + return q.one() + except: + return None - result['plant'] = plant +def compute_serializable_fields(cls, session, keys): + 'plant is given as text, should be object' + result = {'plant': None} - return result + acc_code, plant_code = keys['plant'].rsplit( + Plant.get_delimiter(), 1) + logger.debug("acc-plant: %s-%s" % (acc_code, plant_code)) + q = session.query(Plant).filter( + Plant.code == unicode(plant_code)).join( + Accession).filter(Accession.code == unicode(acc_code)) + plant = q.one() + + result['plant'] = plant + + return result + +PlantNote = db.make_note_class('Plant', compute_serializable_fields, as_dict, retrieve) # TODO: some of these reasons are specific to UBC and could probably be culled. @@ -838,6 +811,7 @@ def on_quantity_changed(self, entry, *args): abs(self._original_quantity-self.model.quantity) else: self.change.quantity = self.model.quantity + self.refresh_view() def on_plant_code_entry_changed(self, entry, *args): """ diff --git a/bauble/plugins/plants/family.py b/bauble/plugins/plants/family.py index a35e170e0..a51060766 100755 --- a/bauble/plugins/plants/family.py +++ b/bauble/plugins/plants/family.py @@ -283,19 +283,16 @@ def top_level_count(self): Familia = Family -class FamilyNote(db.Base): - """ - Notes for the family table - """ - __tablename__ = 'family_note' +def compute_serializable_fields(cls, session, keys): + result = {'family': None} - date = Column(types.Date, default=func.now()) - user = Column(Unicode(64)) - category = Column(Unicode(32)) - note = Column(UnicodeText, nullable=False) - family_id = Column(Integer, ForeignKey('family.id'), nullable=False) - family = relation('Family', uselist=False, - backref=backref('notes', cascade='all, delete-orphan')) + family_dict = {'epithet': keys['family']} + result['family'] = Family.retrieve_or_create( + session, family_keys, create=False) + + return result + +FamilyNote = db.make_note_class('Family', compute_serializable_fields) class FamilySynonym(db.Base): diff --git a/bauble/plugins/plants/genus.py b/bauble/plugins/plants/genus.py index 35bcca555..5a2afaf47 100755 --- a/bauble/plugins/plants/genus.py +++ b/bauble/plugins/plants/genus.py @@ -351,20 +351,16 @@ def top_level_count(self): if a.source and a.source.source_detail])} -class GenusNote(db.Base): - """ - Notes for the genus table - """ - __tablename__ = 'genus_note' - __mapper_args__ = {'order_by': 'genus_note.date'} +def compute_serializable_fields(cls, session, keys): + result = {'genus': None} - date = Column(types.Date, default=func.now()) - user = Column(Unicode(64)) - category = Column(Unicode(32)) - note = Column(UnicodeText, nullable=False) - genus_id = Column(Integer, ForeignKey('genus.id'), nullable=False) - genus = relation('Genus', uselist=False, - backref=backref('notes', cascade='all, delete-orphan')) + genus_dict = {'epithet': keys['genus']} + result['genus'] = Genus.retrieve_or_create( + session, genus_keys, create=False) + + return result + +GenusNote = db.make_note_class('Genus', compute_serializable_fields) class GenusSynonym(db.Base): diff --git a/bauble/plugins/plants/species_model.py b/bauble/plugins/plants/species_model.py index 7024a5d40..e49ff1690 100755 --- a/bauble/plugins/plants/species_model.py +++ b/bauble/plugins/plants/species_model.py @@ -634,61 +634,33 @@ def top_level_count(self): if a.source and a.source.source_detail])} -class SpeciesNote(db.Base, db.Serializable): - """ - Notes for the species table - """ - __tablename__ = 'species_note' - __mapper_args__ = {'order_by': 'species_note.date'} - - date = Column(types.Date, default=func.now()) - user = Column(Unicode(64)) - category = Column(Unicode(32)) - note = Column(UnicodeText, nullable=False) - species_id = Column(Integer, ForeignKey('species.id'), nullable=False) - species = relation('Species', uselist=False, - backref=backref('notes', cascade='all, delete-orphan')) - - def as_dict(self): - result = db.Serializable.as_dict(self) - result['species'] = self.species.str(self.species, remove_zws=True) - return result - - @classmethod - def compute_serializable_fields(cls, session, keys): - logger.debug('compute_serializable_fields(session, %s)' % keys) - result = {} - genus_name, epithet = keys['species'].split(' ', 1) - sp_dict = {'ht-epithet': genus_name, - 'epithet': epithet} - result['species'] = Species.retrieve_or_create( - session, sp_dict, create=False) - return result - - @classmethod - def retrieve_or_create(cls, session, keys, - create=True, update=True): - """return database object corresponding to keys - """ - result = super(SpeciesNote, cls).retrieve_or_create(session, keys, create, update) - category = keys.get('category', '') - if (create and (category.startswith('[') and category.endswith(']') or - category.startswith('<') and category.endswith('>'))): - result = cls(**keys) - session.add(result) - return result - - @classmethod - def retrieve(cls, session, keys): - from genus import Genus - genus, epithet = keys['species'].split(' ', 1) - try: - return session.query(cls).filter( - cls.category == keys['category']).join(Species).filter( - Species.sp == epithet).join(Genus).filter( - Genus.genus == genus).one() - except: - return None +def as_dict(self): + result = db.Serializable.as_dict(self) + result['species'] = self.species.str(self.species, remove_zws=True) + return result + +def compute_serializable_fields(cls, session, keys): + logger.debug('compute_serializable_fields(session, %s)' % keys) + result = {} + genus_name, epithet = keys['species'].split(' ', 1) + sp_dict = {'ht-epithet': genus_name, + 'epithet': epithet} + result['species'] = Species.retrieve_or_create( + session, sp_dict, create=False) + return result + +def retrieve(cls, session, keys): + from genus import Genus + genus, epithet = keys['species'].split(' ', 1) + try: + return session.query(cls).filter( + cls.category == keys['category']).join(Species).filter( + Species.sp == epithet).join(Genus).filter( + Genus.genus == genus).one() + except: + return None + +SpeciesNote = db.make_note_class('Species', compute_serializable_fields, as_dict, retrieve) class SpeciesSynonym(db.Base):