In [1]:
%matplotlib inline
import numpy as np
import glob
from os.path import join
from astropy.io import fits
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.ticker as ticker

## Crosstalk SQL Database

In [2]:
import sqlalchemy as sql

from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, backref, sessionmaker

from contextlib import contextmanager

In [19]:
Base = declarative_base()
Session = sessionmaker()

@contextmanager
def db_session(database):
    """Create a session bound to the given database.
    
    Args:
        database (str): Database filepath.
    """

    try:
        engine = sql.create_engine(database, echo=False)
        Base.metadata.create_all(engine)
        Session.configure(bind=engine)
        session = Session()
        yield session
        session.commit()
    except Exception as e:
        session.rollback()
        raise e
    finally:
        session.close()
        
class Error(Exception):
    """Base class for other exceptions."""
    pass

class QueryError(Error):
    """Raised when a SQLite query returns an empty list."""
    pass
        
class Result(Base):
    
    __tablename__ = 'result'
    
    ## Columns
    id = sql.Column(sql.Integer, primary_key=True)
    aggressor_id = sql.Column(sql.Integer, sql.ForeignKey('segment.id'))
    aggressor_signal = sql.Column(sql.Float)
    coefficient = sql.Column(sql.Float)
    error = sql.Column(sql.Float)
    method = sql.Column(sql.String)
    victim_id = sql.Column(sql.Integer, sql.ForeignKey('segment.id'))
    victim_name = sql.Column(sql.String)

    ## Relationships
    aggressor = relationship("Segment", back_populates="results", foreign_keys=[aggressor_id])
    victim = relationship("Segment", foreign_keys=[victim_id])
        
class Segment(Base):
    
    __tablename__ = 'segment'
    
    ## Columns
    id = sql.Column(sql.Integer, primary_key=True)
    name = sql.Column(sql.String)
    amplifier_number = sql.Column(sql.Integer)
    sensor_id = sql.Column(sql.Integer, sql.ForeignKey('sensor.id'))
    
    ## Relationships
    results = relationship("Result", back_populates="aggressor", foreign_keys=[Result.aggressor_id])
    sensor = relationship("Sensor", back_populates="segments")
    
    def add_to_db(self, session):
        """Add Segment to database."""
        
        segment = session.query(Segment).filter_by(amplifier_number=self.amplifier_number).first()
    
        if segment is None:
            session.add(self)
        else:
            raise QueryError('Segment already exists in database.')

class Sensor(Base):
    
    __tablename__ = 'sensor'
    
    ## Columns
    id = sql.Column(sql.Integer, primary_key=True)
    name = sql.Column(sql.String)
    manufacturer = sql.Column(sql.String)
    namps = sql.Column(sql.Integer)
    
    ## Relationships
    segments = relationship("Segment", back_populates="sensor")
    
    @classmethod
    def from_db(cls, session, names=[], manufacturer=None):
        
        query = session.query(cls)
        
        if not isinstance(names, list):
            names = [names]
        
        ## Filter on names
        if len(names) == 1:
            query = query.filter(cls.name == names[0])
        elif len(names) > 1:
            query = query.filter(cls.name.in_(names))
            
        ## Filter on manufacturer
        if manufacturer in ['e2v', 'itl']:
            query = query.filter(cls.manufacturer == manufacturer)
            
        ## Return single entry or list of entries
        sensors = query.all()
        if len(sensors) == 0:
            raise QueryError('No sensors found.')
        if len(sensors) == 1:
            return sensors[0]
        else:
            return sensors
            
    def add_to_db(self, session):
        """Add Sensor to database."""
        
        sensor = session.query(Sensor).filter_by(name=self.name).first()
    
        if sensor is None:
            session.add(self)
        else:
            raise QueryError('Sensor already exists in database.')

In [20]:
## Add a sensor
sensor_id = 'Davis_test'
manufacturer = 'e2v'
namps = 16

names = ['C17', 'C16', 'C15', 'C14', 'C13', 'C12', 'C11', 'C10',
         'C00', 'C01', 'C02', 'C03', 'C04', 'C05', 'C06', 'C07']

## Check if sensor added already
with db_session('sqlite:///test.db') as f:
    
    try:
        sensor = Sensor(name=sensor_id, manufacturer=manufacturer, namps=namps)
        sensor.segments = [Segment(name=names[i], amplifier_number=i+1) for i in range(16)]
        sensor.add_to_db(f)
    except QueryError:
        sensor = Sensor.from_db(f, names=sensor_id)
        
    print(sensor)
    print(sensor.segments)

<__main__.Sensor object at 0x7f5eb3c345f8>
[<__main__.Segment object at 0x7f5eb3ba3b38>, <__main__.Segment object at 0x7f5eb3ba3a90>, <__main__.Segment object at 0x7f5eb3ba3a58>, <__main__.Segment object at 0x7f5eb3ba3c88>, <__main__.Segment object at 0x7f5eb3ba3da0>, <__main__.Segment object at 0x7f5eb3ba3e80>, <__main__.Segment object at 0x7f5eb3ba3f28>, <__main__.Segment object at 0x7f5eb3ba3f98>, <__main__.Segment object at 0x7f5eb3ba3cf8>, <__main__.Segment object at 0x7f5eb3c3ad68>, <__main__.Segment object at 0x7f5eb3c3af28>, <__main__.Segment object at 0x7f5eb3c3a320>, <__main__.Segment object at 0x7f5eb3c3ac88>, <__main__.Segment object at 0x7f5eb3c3ac18>, <__main__.Segment object at 0x7f5eb3c3ab38>, <__main__.Segment object at 0x7f5eb3c3aac8>]
