In [None]:
import matplotlib.pyplot as plt

In [None]:
import time
import sys
import struct
import numpy as np

# Class adapted from: https://github.com/SumeetRohilla/readPTU_FLIM.
class PTUreader():
    """ PTUreader() provides the capability to retrieve raw_data from
    a PTU file acquired using available PQ TCSPC module in the year 2019.

    @params str filename: path + filename
    @params bool print_header: True or False
    
    Output: ptu_read_raw_data(), this function reads single-photon data from the input file.
    The output variables contain the followig data:
        sync : number of the sync events that preceeded this detection event
        tcspc : number of the tcspc-bin of the event
        channel : number of the input channel of the event (detector-number)
        special : marker event-type (0: photon; else : virtual photon/line_Startmarker/line_Stopmarker/framer_marker)
    """
    
    # Global constants
    # Define different tag types in header
    tag_type = dict(
    tyEmpty8      = 0xFFFF0008,
    tyBool8       = 0x00000008,
    tyInt8        = 0x10000008,
    tyBitSet64    = 0x11000008,
    tyColor8      = 0x12000008,
    tyFloat8      = 0x20000008,
    tyTDateTime   = 0x21000008,
    tyFloat8Array = 0x2001FFFF,
    tyAnsiString  = 0x4001FFFF,
    tyWideString  = 0x4002FFFF,
    tyBinaryBlob  = 0xFFFFFFFF,
    )
    
    # Dictionary with Record Types format for different TCSPC devices and corresponding T2 or T3 TTTR mode
    rec_type = dict(
        rtMultiHarpNT3   = 0x00010307,  # (SubID = $00 ,RecFmt: $01) (V1), T-Mode: $03 (T3), HW: $07 (MultiHarp150N)
        rtMultiHarpNT2   = 0x00010207,  # (SubID = $00 ,RecFmt: $01) (V1), T-Mode: $02 (T2), HW: $07 (MultiHarp150N)
    )

    def __init__(self, filename, print_header_data = False):
        # Reverse mappins of tag-type and record-type dictionary
        self.tag_type_r = {j: k for k, j in self.tag_type.items()}
        self.rec_type_r = {j: k for k, j in self.rec_type.items()}
        
        self.ptu_name = filename
        self.print_header = print_header_data
        self.ptu_data_string = None
        
        f = open(self.ptu_name, 'rb')
        self.ptu_data_string = f.read()  # ptu_data_string is a string of bytes and reads all file in memory
        f.close()
        
        # Read magic and version of the PTU file
        self.magic = self.ptu_data_string[:8].rstrip(b'\0')
        self.version = self.ptu_data_string[8:16].rstrip(b'\0')
        if self.magic != b'PQTTTR':  # Check if the input file is a valid input file
            raise IOError(f"This file is not a valid PTU file. Magic : {self.magic}")
            exit(0)
        
        self.head = {}
        self._ptu_read_head(self.ptu_data_string)  # Read and print header if set True
        self._ptu_read_raw_data()  # Read and return raw TTTR data
        if self.print_header == True:
            return self._print_ptu_head()
        return None
    
    def _ptu_TDateTime_to_time(self, TDateTime):
        EpochDiff = 25569  # days between 30/12/1899 and 01/01/1970
        SecsInDay = 86400  # number of seconds in a day
        return (TDateTime - EpochDiff) * SecsInDay

    def _ptu_read_tags(self, ptu_data_string, offset):
        # Get the header struct as a tuple
        # Struct fields: 32-char string, int32, uint32, int64
        tag_struct = struct.unpack('32s i I q', ptu_data_string[offset:offset+48])
        offset += 48

        # Get the tag name (first element of the tag_struct)
        tagName = tag_struct[0].rstrip(b'\0').decode()
        keys = ('idx', 'type', 'value')
        tag = {k: v for k, v in zip(keys, tag_struct[1:])}

        # Recover the name of the type from tag_dictionary
        tag['type'] = self.tag_type_r[tag['type']]
        tagStringR='NA'

        # Some tag types need conversion to appropriate data format
        if tag['type'] == 'tyFloat8':
            tag['value'] = np.int64(tag['value']).view('float64')
        elif tag['type'] == 'tyBool8':
            tag['value'] = bool(tag['value'])
        elif tag['type'] == 'tyTDateTime':
            TDateTime = np.uint64(tag['value']).view('float64')
            t = time.gmtime(self._ptu_TDateTime_to_time(TDateTime))
            tag['value'] = time.strftime("%Y-%m-%d %H:%M:%S", t)

        # Some tag types have additional data
        if tag['type'] == 'tyAnsiString':
            try: tag['data'] = ptu_data_string[offset: offset + tag['value']].rstrip(b'\0').decode()
            except: tag['data'] = ptu_data_string[offset: offset + tag['value']].rstrip(b'\0').decode(encoding  = 'utf-8', errors = 'ignore')
            tagStringR = tag['data']
            offset += tag['value']
        elif tag['type'] == 'tyFloat8Array':
            tag['data'] = np.frombuffer(ptu_data_string, dtype='float', count=tag['value']/8)
            offset += tag['value']
        elif tag['type'] == 'tyWideString':
            # WideString default encoding is UTF-16.
            tag['data'] = ptu_data_string[offset: offset + tag['value']*2].decode('utf16')
            tagStringR=tag['data']
            offset += tag['value']
        elif tag['type'] == 'tyBinaryBlob':
            tag['data'] = ptu_data_string[offset: offset + tag['value']]
            offset += tag['value']

        tagValue  = tag['value']
        return tagName, tagValue, offset, tagStringR
    
    def _ptu_read_head(self, ptu_data_string):
        offset         = 16
        FileTagEnd     = 'Header_End' 
        tag_end_offset = ptu_data_string.find(FileTagEnd.encode())
        tagName, tagValue, offset, tagString  = self._ptu_read_tags(ptu_data_string, offset)
        self.head[tagName] = tagValue
        while tagName != FileTagEnd:
                tagName, tagValue, offset, tagString = self._ptu_read_tags(ptu_data_string, offset)
                if tagString=='NA': self.head[tagName] = tagValue
                else: self.head[tagName] = tagString
        # End of Header file and beginning of TTTR data
        self.head[FileTagEnd] = offset

    def _print_ptu_head(self): 
        ''' Print "head" dictionary '''     
        print("{:<30} {:8}".format('Head ID','Value'))
        for keys in self.head:
            val = self.head[keys] 
            print("{:<30} {:<8}".format(keys, val))     
    
    def _ptu_read_raw_data(self):
        ''' This function reads single-photon data from the file 's'
        Returns:
        sync    : number of the sync events that preceeded this detection event
        tcspc   : number of the tcspc-bin of the event
        chan    : number of the input channel of the event (detector-number)
        special : indicator of the event-type (0: photon; else : virtual photon)
        num     : counter of the records that were actually read
        '''
        record_type = self.rec_type_r[self.head['TTResultFormat_TTTRRecType']]
        num_T3records = self.head['TTResult_NumberOfRecords']

        # Read all T3 records in memory
        t3records = np.frombuffer(self.ptu_data_string, dtype='uint32', count=num_T3records, offset= self.head['Header_End'])
        # Clear ptu string data from memory and delete it's existence
        del self.ptu_data_string
        
        # Next is to do T3Records formatting according to Record_type
        if record_type in ['rtMultiHarpNT3']:
            print('TCSPC Hardware: {}'.format(record_type[2:]))
            WRAPAROUND = 1024                                                   # After this sync counter will overflow
            sync       = np.bitwise_and(t3records, 1023)                        # Lowest 10 bits
            tcspc      = np.bitwise_and(np.right_shift(t3records, 10), 32767)   # Next 15 bits, dtime can be obtained from header
            chan       = np.bitwise_and(np.right_shift(t3records, 25), 63)      # Next 8 bits 
            special    = np.bitwise_and(t3records,2147483648)>0                 # Last bit for special markers
            index      = (special*1)*((chan==63)*1)                             # Find overflow locations
            special    = (special*1)*chan
        elif record_type in ['rtMultiHarpNT2']:
            print('TCSPC Hardware: {}'.format(record_type[2:]))
            WRAPAROUND = 33554432                                               # After this sync counter will overflow
            sync       = np.bitwise_and(t3records, 33554431)                    # Lowest 25 bits
            chan       = np.bitwise_and(np.right_shift(t3records, 25), 63)      # Next 6 bits 
            tcspc      = np.bitwise_and(chan, 15)                               
            special    = np.bitwise_and(np.right_shift(t3records, 31), 1)       # Last bit for special markers
            index      = (special*1) * ((chan==63)*1)                           # Find overflow locations
            special    = (special*1)*chan
        else:
            print('Illegal RecordType!')
            exit(0)

        # Fill in the correct sync values for overflow location    
        if record_type in ['rtMultiHarpNT3']:
            sync = sync + (WRAPAROUND*np.cumsum(index*sync)) # For overflow corrections 
        else:
            sync = sync + (WRAPAROUND*np.cumsum(index)) # correction for overflow to sync varibale
        sync     = np.delete(sync, np.where(index == 1), axis = 0)
        tcspc    = np.delete(tcspc, np.where(index == 1), axis = 0)
        chan     = np.delete(chan, np.where(index == 1), axis = 0)
        special  = np.delete(special, np.where(index == 1), axis = 0)
        del index

        # Convert to appropriate data type to save memory
        self.sync    = sync.astype(np.uint64, copy=False)
        self.tcspc   = tcspc.astype(np.uint16, copy=False)
        self.channel = chan.astype(np.uint8,  copy=False)
        self.special = special.astype(np.uint8, copy=False)
        print("Raw Data has been Read!\n")
        return None

## Import file

In [None]:
import os
filename = r'path\filename.ptu'
try: 
    with open(filename): pass
    print('Data file found, you can proceed.')
except IOError:
    print(f'Beware: Data file not found, please check the filename (current value: {filename})')

In [None]:
ptu_file = PTUreader(filename, True)

### TIMESTAMPS

In [None]:
# convert the sync times in second units:
true_sync = ptu_file.sync * ptu_file.head["MeasDesc_GlobalResolution"] * 1e12

# tcspc times in second units:
true_tcspc = np.zeros((ptu_file.tcspc.size,), dtype=np.int64)
for i in range(ptu_file.tcspc.size):
    true_tcspc[i] = ptu_file.tcspc[i] * 10

# get the 'true' timestamp:
timestamp_temp = true_sync + true_tcspc

# ===============================

# get the timestamps packet as int for cython code (small info loss):
timestamp_packet = np.zeros((ptu_file.sync.size,), dtype=np.int64)
for i in range(ptu_file.sync.size):
    timestamp_packet[i] = int(timestamp_temp[i])

# get the channel packet
channel_packet = np.zeros((ptu_file.sync.size,), dtype=np.int8)
for i in range(ptu_file.sync.size):
    channel_packet[i] = ptu_file.channel[i]

timestamp_packet.dtype, channel_packet.dtype

In [None]:
start_channel = 0
stop_channel = 1

In [None]:
fig,ax = plt.subplots(figsize=(4,3))

ax.plot(timestamp_packet[channel_packet==stop_channel] * 1e-6, 
        channel_packet[channel_packet==stop_channel], color='blue', lw=0, marker='+')
ax.plot(timestamp_packet[channel_packet==start_channel] * 1e-6, 
        channel_packet[channel_packet==start_channel], color='red', lw=0, marker='+')

ax.set_xlabel("Time [µs]")
ax.set_ylabel("Channel")
ax.set_xlim(0, None)
pass

## Calculate g2

In [None]:
g2_time_window_s = 10e-3
g2_time_bin_s = 10e-9

time_window_ps = int(g2_time_window_s*1e12)
time_bin_ps = int(g2_time_bin_s*1e12)
number_of_bin = int(time_window_ps/time_bin_ps)

In [None]:
%load_ext cython

In [None]:
%%cython

import numpy as np
import cython
@cython.cdivision(True)
@cython.boundscheck(False) 
@cython.wraparound(False)

def compute_g2(long long[:] timestamp_packet not None, 
               signed char [:] channel_packet not None,
               int start_channel,
               int stop_channel,
               long long time_window_ps,
               long long time_bin_ps):
    
    cdef int number_of_bin, length, length_minus_1, i, j, index
    cdef long long delay, start
    
    number_of_bin = int(time_window_ps/time_bin_ps)
    histogram_before = np.zeros(number_of_bin, dtype=np.int32)
    histogram_after = np.zeros(number_of_bin, dtype=np.int32)
    
    cdef int[:] view_histogram_before = histogram_before
    cdef int[:] view_histogram_after = histogram_after
    
    length = len(timestamp_packet)
    i = 0
    
    while i < length:
        if channel_packet[i] == start_channel:
            start = timestamp_packet[i]
            # Positive times
            j = i+1
            while j < length:
                delay = timestamp_packet[j] - start
                if delay >= time_window_ps:
                    break
                if channel_packet[j] == stop_channel:
                    index = delay // time_bin_ps
                    view_histogram_after[index] += 1
                j += 1
            # Negatives times
            j = i-1
            while j >= 0:
                delay = start - timestamp_packet[j]
                if delay >= time_window_ps:
                    break
                if channel_packet[j] == stop_channel:
                    index = delay // time_bin_ps
                    view_histogram_before[index] += 1
                j -= 1
        i += 1
    return np.concatenate([np.flip(histogram_before), histogram_after])

In [None]:
g2 = compute_g2(timestamp_packet, channel_packet, start_channel, stop_channel, time_window_ps, time_bin_ps)
histogram = np.zeros(2*number_of_bin, dtype=np.int32)
histogram += g2
histogram

## PLOT

In [None]:
x_trace = (np.arange(len(histogram)) - int(len(histogram)/2)) * g2_time_bin_s

### Norm

In [None]:
start_count, stop_count = 0, 0
start_count += (channel_packet == start_channel).sum()
stop_count += (channel_packet == stop_channel).sum()

time_stop_ms = ptu_file.head["MeasDesc_AcquisitionTime"]
time_stop_ps = time_stop_ms * 1e9
time_start_ps = 0

duration = (time_stop_ps - time_start_ps) / 1e12

average_stop = stop_count / duration
normalization_g2 = start_count * average_stop * g2_time_bin_s
histogram_norm = histogram / normalization_g2

print(f"start_count: {start_count}, stop_count: {stop_count}, time_bin: {g2_time_bin_s}s , acq_time: {duration}s, \
norm: {normalization_g2}")

In [None]:
# x_to_plot, y_to_plot = x_trace, histogram_norm

### Rebin

In [None]:
# rebin functions from : https://pypi.org/project/data-analysis-tools/
def rebin(data, rebin_ratio, do_average=False):
    """ Rebin a 1D array the good old way.
    @param 1d numpy array data : The data to rebin
    @param int rebin_ratio: The number of old bin per new bin
    @return 1d numpy array : The array rebinned
    The last values may be dropped if the sizes do not match. """
    
    rebin_ratio = int(rebin_ratio)
    length = (len(data) // rebin_ratio) * rebin_ratio
    data = data[0:length]
    data = data.reshape(length//rebin_ratio, rebin_ratio)
    if do_average :
        data_rebinned = data.mean(axis=1)
    else :
        data_rebinned = data.sum(axis=1)
    return data_rebinned

def decimate(data, decimation_ratio):
    """ Decimate a 1D array . This means some value are dropped, not averaged
    @param 1d numpy array data : The data to decimated
    @param int decimation_ratio: The number of old value per new value
    @return 1d numpy array : The array decimated. """
    
    decimation_ratio = int(decimation_ratio)
    length = (len(data) // decimation_ratio) * decimation_ratio
    data_decimated = data[:length:decimation_ratio]
    return data_decimated

def rebin_xy(x, y, ratio=1, do_average=True):
    """ Helper method to decimate x and rebin y, with do_average True as default. """
    return decimate(x, ratio), rebin(y, ratio, do_average)

In [None]:
x_to_plot, y_to_plot = rebin_xy(x_trace, histogram_norm, ratio=10)

### Figure

In [None]:
fig, ax = plt.subplots(figsize=(8, 4))
label = f'$g^2_{{ ({start_channel}, {stop_channel}) }}$'

ax.plot(x_to_plot*1e6, y_to_plot, label=label, color='dodgerblue', marker='+', lw=0)

ax.set_xlabel('Time [us]', fontsize=18)
ax.set_ylabel('$g^2(\\tau)$', fontsize=18)

# ax.axhline(1, color='black', lw=1, ls='--')
ax.axhline(histogram_norm.mean(), color='red', lw=1, ls='--', label='Mean value')
ax.axvline(0, color='green', lw=1, ls='--')

# ax.set_xlim(-10, 10)
ax.set_ylim(0, 1.5)
ax.legend(loc='lower left', fontsize=12.0)

# fig.dpi = 100
pass