In [None]:
def drizzle_to_wavelength(beams, wcs=None, ra=0., dec=0., wave=1.e4, size=5,
                          pixscale=0.1, pixfrac=0.6, kernel='square',
                          direct_extension='REF', fcontam=0.2, ds9=None):
    """Drizzle a cutout at a specific wavelength from a list of `~grizli.model.BeamCutout` objects
    Parameters
    ----------
    beams : list of `~.model.BeamCutout` objects.
    wcs : `~astropy.wcs.WCS` or None
        Pre-determined WCS.  If not specified, generate one based on ``ra``,
        ``dec``, ``pixscale`` and ``pixscale``.
    ra, dec, wave : float
        Sky coordinates and central wavelength
    size : float
        Size of the output thumbnail, in arcsec
    pixscale : float
        Pixel scale of the output thumbnail, in arcsec
    pixfrac : float
        Drizzle PIXFRAC (for ``kernel`` = 'point')
    kernel : str, ('square' or 'point')
        Drizzle kernel to use
    direct_extension : str, ('SCI' or 'REF')
        Extension of ``self.direct.data`` do drizzle for the thumbnail
    fcontam: float
        Factor by which to scale the contamination arrays and add to the
        pixel variances.
    ds9 : `~grizli.ds9.DS9`, optional
        Display each step of the drizzling to an open DS9 window
    Returns
    -------
    hdu : `~astropy.io.fits.HDUList`
        FITS HDUList with the drizzled thumbnail, line and continuum
        cutouts.
    """

    from drizzlepac import adrizzle
    adrizzle.log.setLevel('ERROR')
    drizzler = adrizzle.do_driz
    dfillval = 0

    # Nothing to do
    if len(beams) == 0:
        return False

    # Get output header and WCS
    if wcs is None:
        header, output_wcs = utils.make_wcsheader(ra=ra, dec=dec, size=size, pixscale=pixscale, get_hdu=False)
    else:
        output_wcs = wcs.copy()
        if not hasattr(output_wcs, 'pscale'):
            output_wcs.pscale = utils.get_wcs_pscale(output_wcs)

        header = utils.to_header(output_wcs, relax=True)

    if not hasattr(output_wcs, '_naxis1'):
        output_wcs._naxis1, output_wcs._naxis2 = output_wcs._naxis

    # Initialize data
    sh = (header['NAXIS2'], header['NAXIS1'])

    outsci = np.zeros(sh, dtype=np.float32)
    outwht = np.zeros(sh, dtype=np.float32)
    outctx = np.zeros(sh, dtype=np.int32)

    coutsci = np.zeros(sh, dtype=np.float32)
    coutwht = np.zeros(sh, dtype=np.float32)
    coutctx = np.zeros(sh, dtype=np.int32)

    xoutsci = np.zeros(sh, dtype=np.float32)
    xoutwht = np.zeros(sh, dtype=np.float32)
    xoutctx = np.zeros(sh, dtype=np.int32)

    all_direct_filters = []
    for beam in beams:
        if direct_extension == 'REF':
            if beam.direct['REF'] is None:
                filt_i = beam.direct.ref_filter
                direct_extension = 'SCI'
            else:
                filt_i = beam.direct.filter

        all_direct_filters.append(filt_i)

    direct_filters = np.unique(all_direct_filters)

    doutsci, doutwht, doutctx = {}, {}, {}
    for f in direct_filters:
        doutsci[f] = np.zeros(sh, dtype=np.float32)
        doutwht[f] = np.zeros(sh, dtype=np.float32)
        doutctx[f] = np.zeros(sh, dtype=np.int32)

    # Loop through beams and run drizzle
    for i, beam in enumerate(beams):
        # Get specific wavelength WCS for each beam
        beam_header, beam_wcs = beam.get_wavelength_wcs(wave)

        if not hasattr(beam_wcs, 'pixel_shape'):
            beam_wcs.pixel_shape = beam_wcs._naxis1, beam_wcs._naxis2

        if not hasattr(beam_wcs, '_naxis1'):
            beam_wcs._naxis1, beam_wcs._naxis2 = beam_wcs._naxis

        # Make sure CRPIX set correctly for the SIP header
        for j in [0, 1]:
            # if beam_wcs.sip is not None:
            #     beam_wcs.sip.crpix[j] = beam_wcs.wcs.crpix[j]
            if beam.direct.wcs.sip is not None:
                beam.direct.wcs.sip.crpix[j] = beam.direct.wcs.wcs.crpix[j]

            for wcs_ext in [beam_wcs.sip]:
                if wcs_ext is not None:
                    wcs_ext.crpix[j] = beam_wcs.wcs.crpix[j]

        # ACS requires additional wcs attributes
        ACS_CRPIX = [4096/2, 2048/2]
        dx_crpix = beam_wcs.wcs.crpix[0] - ACS_CRPIX[0]
        dy_crpix = beam_wcs.wcs.crpix[1] - ACS_CRPIX[1]
        for wcs_ext in [beam_wcs.cpdis1, beam_wcs.cpdis2, beam_wcs.det2im1, beam_wcs.det2im2]:
            if wcs_ext is not None:
                wcs_ext.crval[0] += dx_crpix
                wcs_ext.crval[1] += dy_crpix

        beam_data = beam.grism.data['SCI'] - beam.contam
        if hasattr(beam, 'background'):
            beam_data -= beam.background

        if hasattr(beam, 'extra_lines'):
            beam_data -= beam.extra_lines

        beam_continuum = beam.beam.model*1
        if hasattr(beam.beam, 'pscale_array'):
            beam_continuum *= beam.beam.pscale_array

        # Downweight contamination
        if fcontam > 0:
            # wht = 1/beam.ivar + (fcontam*beam.contam)**2
            # wht = np.cast[np.float32](1/wht)
            # wht[~np.isfinite(wht)] = 0.

            contam_weight = np.exp(-(fcontam*np.abs(beam.contam)*np.sqrt(beam.ivar)))
            wht = beam.ivar*contam_weight
            wht[~np.isfinite(wht)] = 0.

        else:
            wht = beam.ivar*1

        # Convert to f_lambda integrated line fluxes:
        # (Inverse of the aXe sensitivity) x (size of pixel in \AA)
        sens = np.interp(wave, beam.beam.lam, beam.beam.sensitivity,
                         left=0, right=0)

        dlam = np.interp(wave, beam.beam.lam[1:], np.diff(beam.beam.lam))
        # 1e-17 erg/s/cm2 #, scaling closer to e-/s
        sens *= 1.e-17
        sens *= 1./dlam

        if sens == 0:
            continue
        else:
            wht *= sens**2
            beam_data /= sens
            beam_continuum /= sens

        # Go drizzle

        # Contamination-cleaned
        drizzler(beam_data, beam_wcs, wht, output_wcs,
                         outsci, outwht, outctx, 1., 'cps', 1,
                         wcslin_pscale=beam.grism.wcs.pscale, uniqid=1,
                         pixfrac=pixfrac, kernel=kernel, fillval=dfillval)

        # Continuum
        drizzler(beam_continuum, beam_wcs, wht, output_wcs,
                         coutsci, coutwht, coutctx, 1., 'cps', 1,
                         wcslin_pscale=beam.grism.wcs.pscale, uniqid=1,
                         pixfrac=pixfrac, kernel=kernel, fillval=dfillval)

        # Contamination
        drizzler(beam.contam, beam_wcs, wht, output_wcs,
                         xoutsci, xoutwht, xoutctx, 1., 'cps', 1,
                         wcslin_pscale=beam.grism.wcs.pscale, uniqid=1,
                         pixfrac=pixfrac, kernel=kernel, fillval=dfillval)

        # Direct thumbnail
        filt_i = all_direct_filters[i]

        if direct_extension == 'REF':
            thumb = beam.direct['REF']
            thumb_wht = np.cast[np.float32]((thumb != 0)*1)
        else:
            thumb = beam.direct[direct_extension]  # /beam.direct.photflam
            thumb_wht = 1./(beam.direct.data['ERR']/beam.direct.photflam)**2
            thumb_wht[~np.isfinite(thumb_wht)] = 0

        if not hasattr(beam.direct.wcs, 'pixel_shape'):
            beam.direct.wcs.pixel_shape = (beam.direct.wcs._naxis1,
                                           beam.direct.wcs._naxis2)

        if not hasattr(beam.direct.wcs, '_naxis1'):
            beam.direct.wcs._naxis1, beam.direct.wcs._naxis2 = beam.direct.wcs._naxis

        drizzler(thumb, beam.direct.wcs, thumb_wht, output_wcs,
                         doutsci[filt_i], doutwht[filt_i], doutctx[filt_i],
                         1., 'cps', 1,
                         wcslin_pscale=beam.direct.wcs.pscale, uniqid=1,
                         pixfrac=pixfrac, kernel=kernel, fillval=dfillval)

        # Show in ds9
        if ds9 is not None:
            ds9.view((outsci-coutsci), header=header)

    # Scaling of drizzled outputs
    outwht *= (beams[0].grism.wcs.pscale/output_wcs.pscale)**4
    coutwht *= (beams[0].grism.wcs.pscale/output_wcs.pscale)**4
    xoutwht *= (beams[0].grism.wcs.pscale/output_wcs.pscale)**4

    for filt_i in all_direct_filters:
        doutwht[filt_i] *= (beams[0].direct.wcs.pscale/output_wcs.pscale)**4
"""
    # Make output FITS products
    p = pyfits.PrimaryHDU()
    p.header['ID'] = (beams[0].id, 'Object ID')
    p.header['RA'] = (ra, 'Central R.A.')
    p.header['DEC'] = (dec, 'Central Decl.')
    p.header['PIXFRAC'] = (pixfrac, 'Drizzle PIXFRAC')
    p.header['DRIZKRNL'] = (kernel, 'Drizzle kernel')

    p.header['NINPUT'] = (len(beams), 'Number of drizzled beams')
    for i, beam in enumerate(beams):
        p.header['FILE{0:04d}'.format(i+1)] = (beam.grism.parent_file,
                                             'Parent filename')
        p.header['GRIS{0:04d}'.format(i+1)] = (beam.grism.filter,
                                             'Beam grism element')
        p.header['PA{0:04d}'.format(i+1)] = (beam.get_dispersion_PA(),
                                             'PA of dispersion axis')

    h = header.copy()
    h['ID'] = (beam.id, 'Object ID')
    h['PIXFRAC'] = (pixfrac, 'Drizzle PIXFRAC')
    h['DRIZKRNL'] = (kernel, 'Drizzle kernel')

    p.header['NDFILT'] = len(direct_filters), 'Number of direct image filters'
    for i, filt_i in enumerate(direct_filters):
        p.header['DFILT{0:02d}'.format(i+1)] = filt_i
        p.header['NFILT{0:02d}'.format(i+1)] = all_direct_filters.count(filt_i), 'Number of beams with this direct filter'

    HDUL = [p]
    for i, filt_i in enumerate(direct_filters):
        h['FILTER'] = (filt_i, 'Direct image filter')

        thumb_sci = pyfits.ImageHDU(data=doutsci[filt_i], header=h,
                                    name='DSCI')
        thumb_wht = pyfits.ImageHDU(data=doutwht[filt_i], header=h,
                                    name='DWHT')

        thumb_sci.header['EXTVER'] = filt_i
        thumb_wht.header['EXTVER'] = filt_i

        HDUL += [thumb_sci, thumb_wht]

    #thumb_seg = pyfits.ImageHDU(data=seg_slice, header=h, name='DSEG')

    h['FILTER'] = (beam.grism.filter, 'Grism filter')
    h['WAVELEN'] = (wave, 'Central wavelength')

    grism_sci = pyfits.ImageHDU(data=outsci-coutsci, header=h, name='LINE')
    grism_cont = pyfits.ImageHDU(data=coutsci, header=h, name='CONTINUUM')
    grism_contam = pyfits.ImageHDU(data=xoutsci, header=h, name='CONTAM')
    grism_wht = pyfits.ImageHDU(data=outwht, header=h, name='LINEWHT')

    #HDUL = [p, thumb_sci, thumb_wht, grism_sci, grism_cont, grism_contam, grism_wht]
    HDUL += [grism_sci, grism_cont, grism_contam, grism_wht]

    return pyfits.HDUList(HDUL)"""