In [1]:
%matplotlib inline
import numpy as np
import glob
from os.path import join, basename
from astropy.io import fits
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.ticker as ticker

## Crosstalk SQL Database

In [6]:
import sqlalchemy as sql

from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.orm import relationship, backref, sessionmaker

from contextlib import contextmanager

from lsst.eotest.sensor.MaskedCCD import MaskedCCD
from mixcoatl.crosstalk import make_stamp, crosstalk_fit

In [7]:
Base = declarative_base()
Session = sessionmaker()

@contextmanager
def db_session(database):
    """Create a session bound to the given database.
    
    Args:
        database (str): Database filepath.
    """

    try:
        engine = sql.create_engine(database, echo=False)
        Base.metadata.create_all(engine)
        Session.configure(bind=engine)
        session = Session()
        yield session
        session.commit()
    except Exception as e:
        session.rollback()
        raise e
    finally:
        session.close()
        
class Error(Exception):
    """Base class for other exceptions."""
    pass

class MissingKeyword(Error):
    """Raised when an SQL query is missing a required keyword."""
    pass

class AlreadyExists(Error):
    """Raised when an object already exists in the SQL database."""
    pass
        
class Result(Base):
    
    __tablename__ = 'result'
    
    ## Columns
    id = sql.Column(sql.Integer, primary_key=True)
    aggressor_id = sql.Column(sql.Integer, sql.ForeignKey('segment.id'))
    aggressor_signal = sql.Column(sql.Float)
    coefficient = sql.Column(sql.Float)
    error = sql.Column(sql.Float)
    filename = sql.Column(sql.String)
    method = sql.Column(sql.String)
    victim_id = sql.Column(sql.Integer, sql.ForeignKey('segment.id'))

    ## Relationships
    aggressor = relationship("Segment", back_populates="results", foreign_keys=[aggressor_id])
    victim = relationship("Segment", foreign_keys=[victim_id])
    
    def add_to_db(self, session):
        """Add Result to database."""
        
        session.add(self)
        
class Segment(Base):
    
    __tablename__ = 'segment'
    
    ## Columns
    id = sql.Column(sql.Integer, primary_key=True)
    name = sql.Column(sql.String)
    amplifier_number = sql.Column(sql.Integer)
    sensor_id = sql.Column(sql.Integer, sql.ForeignKey('sensor.id'))
    
    ## Relationships
    results = relationship("Result", back_populates="aggressor", foreign_keys=[Result.aggressor_id])
    sensor = relationship("Sensor", back_populates="segments")
    
    def add_to_db(self, session):
        """Add Segment to database."""
        
        session.add(self)

class Sensor(Base):
    
    __tablename__ = 'sensor'
    
    ## Columns
    id = sql.Column(sql.Integer, primary_key=True)
    name = sql.Column(sql.String)
    manufacturer = sql.Column(sql.String)
    namps = sql.Column(sql.Integer)
    
    ## Relationships
    segments = relationship("Segment", back_populates="sensor")
    
    @classmethod
    def from_db(cls, session, **kwargs):
        
        query = session.query(cls)
        
        if 'name' in kwargs:
            query = query.filter(cls.name == kwargs['name'])
        elif 'id' in kwargs:
            query = query.filter(cls.id == kwargs['id'])
        else:
            raise MissingKeyword('Query requires name or id keyword.')
            
        sensor = query.one()
        return sensor
            
    def add_to_db(self, session):
        """Add Sensor to database."""
        
        sensor = session.query(Sensor).filter_by(name=self.name).first()
    
        if sensor is None:
            session.add(self)
        else:
            raise AlreadyExists('Sensor already exists in database.')

In [16]:
#main_dir = '/project/bootcamp/cslage/e2v_fits_files/satellite/20200902-injectdata'
main_dir = '/project/bootcamp/cslage/e2v_fits_files/satellite/20200831-injectdata'

## Input for script
infiles = sorted(glob.glob(join(main_dir, '*dark_dark*.fits')))
sensor_id = 'Davis_test'
manufacturer = 'e2v'
namps = 16
database = 'test.db'

seg_names = ['C17', 'C16', 'C15', 'C14', 'C13', 'C12', 'C11', 'C10',
             'C00', 'C01', 'C02', 'C03', 'C04', 'C05', 'C06', 'C07']

with db_session('sqlite:///{0}'.format(database)) as f:
    
    ## Check if sensor exists
    try:
        sensor = Sensor.from_db(f, name=sensor_id)
        namps = sensor.namps
    except NoResultFound:
        sensor = Sensor(name=sensor_id, manufacturer=manufacturer, namps=namps)
        sensor.segments = [Segment(name=seg_names[i], amplifier_number=i+1) for i in range(namps)]
        sensor.add_to_db(f)
        f.commit()
        
    for infile in infiles:
        ccd = MaskedCCD(infile)

        for agg in sensor.segments:

            aggarr = ccd.unbiased_and_trimmed_image(agg.amplifier_number).getImage().getArray()

            signal = np.mean(aggarr[:, 307])
            if signal < 1000:
                continue

            ## Find aggressor amp and signal level
            ly = 100
            lx = 50
            y, x = (150, 305)
            aggressor_stamp = make_stamp(aggarr, y, x, ly=ly, lx=lx)

            ## Calculate crosstalk for each victim amp
            for vic in sensor.segments:
                vicarr = ccd.unbiased_and_trimmed_image(vic.amplifier_number).getImage().getArray()
                victim_stamp = make_stamp(vicarr, y, x, ly=ly, lx=lx)
                res = crosstalk_fit(aggressor_stamp, victim_stamp, noise=7.0,
                                        num_iter=1, nsig=7.0)

                result = Result(aggressor_id=agg.id,
                                aggressor_signal=signal,
                                coefficient=res[0],
                                error=res[4],
                                method='MODEL_LSQ',
                                victim_id=vic.id,
                                filename=basename(infile))
                result.add_to_db(f)

In [18]:
with db_session('sqlite:///{0}'.format(database)) as f:
    
    sensor = Sensor.from_db(f, name=sensor_id)
    
    for segment in sensor.segments:
        if segment.amplifier_number == 4:
            print(len(segment.results))
        else:
            continue

2736
