In [None]:
# %title Rebinning with deadtime correction and caching: fetch_timepix_frame
# Pixelman FITS header from LANL:asterix camera
# SIMPLE  =                    T / file does conform to FITS standard             
# BITPIX  =                   16 / number of bits per data pixel                  
# NAXIS   =                    2 / number of data axes                           
# NAXIS1  =                  512 / length of data axis 1                          
# NAXIS2  =                  512 / length of data axis 2                          
# EXTEND  =                    T / FITS dataset may contain extensions            
# COMMENT   FITS (Flexible Image Transport System) format is defined in 'Astronomy
# COMMENT   and Astrophysics', volume 376, page 359; bibcode: 2001A&A...376..359H 
# TOF     =   0.0399830399999987 / Ttime of flight from the external trigger      
# TIMEBIN =           1.536E-005 / Time width of this image                       
# N_COUNTS=               756805 / Total counts in this image                     
# N_TRIGS =                10000 / Number of triggers acquired                    
# TEST1   = 'One_Continuous_Word_here' / Comment1                                 
# TEST2   = '14.567  '           / Comment                                        
#
# Note: imagemagick convert to tiff doesnt preserve the special tags:
#    TOF, TIMEBIN, N_COUNTS, N_TRIGS

def rebin_timepix_dataset(path, bins=None, outpath=None):
    print(f"Rebinning {str(path)}...")
    dtype = torch.float32
    # Note: if there is an exception during the following loop
    # (e.g., from an interrupt), then the TimepixWriter context
    # manager exit will delete the outpath that was being built.
    # This is not a complete abort. If the caller created the
    # leading path elements (e.g., SiGrating/Open) then the interrupt
    # during the first file will still leave the created path on disk.
    with (
        TimepixReader(path, dtype=dtype) as tpx, 
        TimepixWriter(outpath) as out
    ):
        # Move to the first bin
        index = 0
        start, end = (index, index+1) if bins is None else bins[index]
        
        # Cycle through the frames
        for k, header, data in tpx:
            # Accumulate probability and correct the frame
            gdata = util.to_cuda(torch.as_tensor(data))
            Δp = gdata/header['N_TRIGS']
            if k == 0:
                corrected = gdata
                p = Δp
            else:
                corrected = gdata / (1 - p)
                p += Δp
                del gdata, Δp

            # Accumulate bin
            if k == start:
                gbinned = corrected
                tof = header['TOF']
                timebin = header['TIMEBIN']
            elif start < k < end:
                gbinned += corrected
                timebin += header['TIMEBIN']
                del corrected

            if k+1 == end:                    
                # Save bin
                tags = {'TOF': tof, 'TIMEBIN': timebin, 'N_TRIGS': header['N_TRIGS']}
                binned = gbinned.cpu()
                out.write(index, binned, tags)

                # Accumulate SummedImg
                if index == 0:
                    summed_image = binned
                    summed_tags = tags
                    #spectra = []
                else:
                    summed_image += binned
                    summed_tags['TIMEBIN'] += tags['TIMEBIN']
                #spectra.append((tags['TOF'], binned.sum()))
                
                # Move to next bin
                index += 1
                if bins is None:
                    start, end = index, index+1
                elif index < len(bins):
                    start, end = bins[index]
                else:
                    # No more bins. No need to read the remaining frames.
                    break
                start, end = (index, index+1) if bins is None else bins[index]

        # TODO: Write {ShutterCount,ShutterTimes,Spectra,Status}.txt ?
        # Save summed image
        out.write('sum', summed_image, summed_tags)

    del p, summed_image, summed_tags

class TimepixWriter:
    def __init__(self, path):
        from zipfile import ZipFile, ZIP_DEFLATED
        path = Path(path).expanduser()
        # TODO: Preserve timestamps on original data?
        zf = None
        if path.suffix == '.zip':
            zf = ZipFile(path, 'w', compression=ZIP_DEFLATED, compresslevel=9)
            handle = lambda f: zf.open(f, mode='w')
        else:
            raise NotImplementedError(f"Only write timepix to .zip, not {str(path)}")
            
        self._path = path
        self._zf = zf
        self._handle = handle
        self._basename = path.stem
 
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, exc_tb):
        self.close()
        # TODO: maybe the unlinking belongs in the rebin code rather than the context manager?
        # If we died with any sort of exception, then remove the 'in progress' zip file.
        if exc_type is not None:
            self._path.unlink()
        return False

    def close(self):
        if self._zf is not None:
            self._zf.close()
            self._zf = None

    def write(self, index, data, tags):
        from astropy.io import fits
        n_counts = float(data.sum())
        data = data.numpy()
        # TODO: Check data format before writing to'>f4'?
        data = np.asarray(data, dtype='>f4')
        hdu = fits.PrimaryHDU(data)
        #for k, v in tags.items(): hdu.header[k] = v   # tags don't have labels
        hdu.header['TOF'] = (tags['TOF'], "Time of flight from the external trigger")
        hdu.header['TIMEBIN'] = (tags['TIMEBIN'], "Time width of this image")
        hdu.header['N_COUNTS'] = (n_counts, "Total counts in this image")
        hdu.header['N_TRIGS'] = (tags['N_TRIGS'], "Number of triggers acquired")
        hdul = fits.HDUList([hdu])
        
        label = 'SummedImg' if index == 'sum' else f'{index:05d}'
        filename = f"{self._basename}_{label}.fits"
        with self._handle(filename) as fd:
            hdul.writeto(fd)            

class TimepixReader:
    def __init__(self, path, dtype=None):
        from zipfile import ZipFile
        from fnmatch import fnmatch

        path = Path(path).expanduser()
        if path.suffix == '.zip' and not path.is_dir():
            # TODO: We should close the zip file in case of loader exception.
            zf = ZipFile(path)
            files = [f for f in sorted(zf.namelist()) if fnmatch(f, '*.fits')]
            handle = lambda f: zf.open(f)
        else:
            if path.is_dir():
                pattern = '*.fits'
                parent = path
            elif path.suffix == '.fits':
                pattern = path.name.rsplit('_', 1)[0] + '_*.fits'
                parent = path.parent
            else:
                pattern = path.name + '*.fits'
                parent = path.parent
            zf = None
            files = sorted(parent.glob(pattern))
            handle = lambda f: fits.open(f)
        if not files:
            raise RuntimeError(f"No dataset matches {path}")

        if 'SummedImg' in files[-1]:
            summed = files[-1]
            files = files[:-1]
        else:
            warnings.warn(f"No summed image in {path}")
            summed = None

        # Find frame files from directory listing
        #indexed = [(index, f) for f in files if index := _timepix_index(f) is not None]
        indexed = [(index, f) for f in files for index in [_timepix_index(f)] if index is not None]
        # Make sure files are in index order with no frames missing
        for k, f in enumerate(files):
            if k != _timepix_index(f):
                raise RuntimeError(f"Missing time slice {k} in dataset {path}")

        self._zf = zf
        self._handle = handle
        self._files = files
        self._summed = summed
        self._dtype = None if dtype is None else util.numpy_native(util.numpy_dtype(dtype))
        
    @property
    def num_slices(self):
        return len(self._files)
 
    def __enter__(self):
        return self

    def __exit__(self, *exc):
        self.close()
        return False
    
    def close(self):
        if self._zf is not None:
            self._zf.close()
            self._zf = None

    def __iter__(self):
        for k in range(self.num_slices):
            header, data = self.read(k)
            yield k, header, data
        
    def read(self, index):
        from astropy.io import fits
        filename = self._summed if index in ('SummedImg', 'sum') else self._files[index]
        h = fits.open(self._handle(filename))
        header, data = h[0].header, h[0].data.T
        h.close()
        dtype = util.numpy_native(data.dtype) if self._dtype is None else self._dtype
        return header, np.asarray(data, dtype=dtype)

def _timepix_index(name):
    """Convert filename to ToF index"""
    tail = name.rsplit('_', 1)[1]
    strval = tail.split('.', 1)[0]
    try:
        return int(strval)
    except ValueError:
        return None

def fetch_timepix_frame(path, t, dark_rate=None, missing_ok=False):
    path = Path(path)
    if path.is_absolute():
        try:
            path = path.relative_to(DATAPATH)
        except ValueError:
            raise RuntimeError(f"Can't cache full path {path}")

    rawpath = DATAPATH / path
    binpath = BINNEDPATH / path
    if not binpath.exists() and not rawpath.exists():
        # Missing data
        if not missing_ok:
            raise ValueError(f"No data in {str(rawpath)}")
        data = torch.zeros((512,512), dtype=torch.float32)
        data.header = dotted(TOF=0., TIMEBIN=1e-5, N_TRIGS=0, BACKGROUND=0, path=path, t=t)
        return data
        
    if not binpath.exists():
        binpath.parent.mkdir(parents=True, exist_ok=True)
        rebin_timepix_dataset(rawpath, bins=BINS.index, outpath=binpath)

    if dark_rate is None:
        dark_rate = (DARK_RATE if 'DARK_RATE' in globals() else 0.)
    with TimepixReader(binpath, torch.float32) as fd:
        header, data = fd.read(t)
        tof, timebin, n_trigs = header['TOF'], header['TIMEBIN'], header['N_TRIGS']
        data = torch.as_tensor(data.T)
        duration = timebin*n_trigs
        background = dark_rate*duration
        data.header = dotted(TOF=tof, TIMEBIN=timebin, N_TRIGS=n_trigs, BACKGROUND=background, path=path, t=t)
        return data