In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
import sys
import glob
import datetime

In [0]:
sys.path.append('drive/My Drive/Colab Notebooks')

In [0]:
from astropy.io import fits
from astropy.wcs import WCS
import cupy  as cp
import numpy as np

In [0]:
import matplotlib.pyplot as plt
from astropy.visualization import ZScaleInterval, ImageNormalize

In [0]:
from eclair import FitsContainer, ImAlign, reduction, fixpix, imcombine

In [0]:
dark = 'dark_%s%d.fits'
flat = 'flat_%s.fits'
bpmask = 'bpmask_%s.fits'

outname = 'combine_%s.fits'
interp = 'spline3'
tolerance = 128
maskrange = 2

imshape = 1020

comb_param = dict(combine='mean',width=3.0,iter=3,overwrite=True)

In [0]:
xy0_array = np.array([[450,450],[450,650],[650,450],[650,650]])
xmin, xmax, ymin, ymax = 52, 1072, 2, 1022

In [0]:
imalign = ImAlign(x_len=imshape, y_len=imshape, interp=interp)

In [0]:
cd /content/drive/My\ Drive/Colab\ Notebooks/fits

In [0]:
all_list = glob.glob('MTAT*.fits')
all_list.sort()

filter_list = dict(G=[], R=[], I=[])
for f in all_list:
  filter_ = f[4] # 'G' or 'R' or 'I'
  filter_list[filter_].append(f)

In [0]:
class Reduction(FitsContainer):
  def __init__(self, filter, comb_flg=True):
    super().__init__(filter_list[filter])
    self.filter = filter
    self.slice = (slice(ymin,ymax), slice(xmin,xmax))
    self.flg = comb_flg
    self.idx = 0

  def main(self):
    self.reduction()
    if self.flg:
      self.getshift()
      basefits = self.list[self.idx]
      
      indices = []
      kwargs = dict(reject=True, baseidx=self.idx, tolerance=tolerance, selected=indices)
      self.data = imalign(self.data, self.shift, **kwargs)
      
      self.list = [self.list[i] for i in indices]
      self.shift = np.array([self.shift[i,:] for i in indices])
      x_off, y_off = np.ceil(self.shift.max(axis=0))
      
      head = self.edithead(basefits, x_off=x_off, y_off=y_off)
      
      imcombine(self.list, self.data, outname%self.filter, header=head, **comb_param)
    else:
      outlist = ['r_'+f for f in self.list]
      for f in self.list:
        self.header[f] = self.edithead(f)
      super().write(outlist, overwrite=True)
    
  def getshift(self):
    self.wcs = {f: WCS(f) for f in self.list}
    
    self.subgetshift()
    self.idx = np.linalg.norm(self.shift-np.median(self.shift,axis=0),axis=1).argmin()
    self.subgetshift()
  
  def subgetshift(self):
    ad0 = self.wcs[self.list[self.idx]].wcs_pix2world(xy0_array,1)
    self.shift = np.empty([len(self.list), 2], dtype='f4')
    for i, f in enumerate(self.list):
      xy_array = self.wcs[f].wcs_world2pix(ad0,1)
      self.shift[i,:] = (xy0_array-xy_array).mean(axis=0)

  def reduction(self):
    super().load()
    
    exptime = self.header[self.list[0]]['EXPTIME']
    cpbias = cp.empty([len(self.list), 1, 1], dtype='f4')
    for i, f in enumerate(self.list):
      cpbias[i,:,:] = self.header[f]['PEDLEVEL']
      
    npdark = fits.getdata(dark%(self.filter,exptime)).astype('f4')
    npflat = fits.getdata(flat%self.filter).astype('f4')
    npmask = fits.getdata(bpmask%self.filter).astype('f4')
    
    cpdark = cp.array(npdark)
    cpflat = cp.array(npflat+(npflat==0.0).astype('f4'))
    cpmask = cp.sign(cp.array(npmask))
    
    self.data = reduction(self.data, cpbias, cpdark, cpflat)
    self.data = fixpix(self.data, cpmask, range=maskrange)
    
  def edithead(self, f, x_off=0, y_off=0):
    now_ut = datetime.datetime.utcnow().strftime('%Y/%m/%dT%H:%M:%S')
    
    head = self.header[f]
    head['BZERO'] = 0
    head['BSCALE'] = 1
    head['DATE'] = (now_ut, 'Date FITS file was generated')
    try:
      head['CRPIX1'] -= x_off
      head['CRPIX2'] -= y_off
    except KeyError:
      pass
    
    return head

In [0]:
for f in ('G', 'R', 'I'):
  Reduction(f).main()

In [0]:
ls combine_*.fits

In [0]:
data = fits.getdata('combine_G.fits')
vrange = ZScaleInterval().get_limits(data)
data = data.clip(*vrange)
ylim, xlim = data.shape

plt.figure()
plt.imshow(data)
plt.xlim(0,xlim)
plt.ylim(0,ylim)
plt.show()