In [1]:
import numpy as np
from pylab import *

%matplotlib notebook
%load_ext autoreload
%autoreload 2

In [2]:
# trace is a list of (width, offset), one for each channel
nchan = 8
trace1 = [(0, np.array([1])) for w in range(nchan)]
trace2 = [(0, np.array([1])) for w in range(nchan//2)] + [(0, np.array([1,1])) for w in range(nchan//2)]
trace3 = [(1, np.array([1])) for w in range(nchan)]
trace4 = [(1, np.array([1])) for w in range(nchan//2)] + [(1, np.array([1,2,1])) for w in range(nchan//2)]
trace5 = [(0, np.array([1,1])) for w in range(nchan)]
trace6 = [(0, np.ones(w+1)) for w in range(nchan)]


traces = [trace1, trace2, trace3, trace4]
traces = [trace2, trace6]

In [3]:
trace2

[(0, array([1])),
 (0, array([1])),
 (0, array([1])),
 (0, array([1])),
 (0, array([1, 1])),
 (0, array([1, 1])),
 (0, array([1, 1])),
 (0, array([1, 1]))]

In [4]:
def trace2image(trace):
    '''
    Produces and image given a trace
    Trace is a list of (offset, kernel) tuples - length nchannels
    offset if positive for backwards in time.
    Offset for channel i is between channel i and channel i+1
    The value for time t is the convolution of the previous N samples (N=size of kernel)
    '''
    nc = len(trace)
    kernels = [t[1] for t in trace]
    offsets = [t[0] for t in trace]
    last_off = offsets[-1]
    total_offset = sum(offsets)
    kernel_lengths = map(len, kernels)
    largest_kernel = max(kernel_lengths)
    # TODO: Fix nt - this isn't right. We need to loop through the trace and work out what hte smallest start index would be
    curr_start = 0
    smallest_start = 0
    for offset, kernel in trace:
        endidx = curr_start
        startidx = endidx - len(kernel)
        if startidx < smallest_start:
            smallest_start = startidx
        
        curr_start -= offset    
    
    nt = -smallest_start
    #print('trace2image nc=', nc, 'total_offset', total_offset, 'largest_kernel', largest_kernel, 'nt=', nt)
    mask = np.zeros((nc,nt))
    curr_start = nt # cumulative offset - right hand side of kernel
    for ichan, (offset, kernel) in enumerate(trace):       
        endidx = curr_start
        startidx = endidx - len(kernel)
        #print(ichan, offset, kernel, curr_start, startidx, endidx)
        mask[ichan, startidx:endidx] = kernel
        curr_start -= offset
        
   #assert coffset - len(kernel) == 0, f'Coffset should be 0 it was {coffset}'
        
        
    return mask
    

In [5]:
trace2image(trace6)

array([[0., 0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0., 0., 1., 1.],
       [0., 0., 0., 0., 0., 1., 1., 1.],
       [0., 0., 0., 0., 1., 1., 1., 1.],
       [0., 0., 0., 1., 1., 1., 1., 1.],
       [0., 0., 1., 1., 1., 1., 1., 1.],
       [0., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1.]])

In [6]:
fig, ax = subplots(1,len(traces), squeeze=False)
for myax, t in zip(ax.flatten(), traces):
    img = trace2image(t)
    print(t, img.shape)
    myax.imshow(img, origin='lower')


<IPython.core.display.Javascript object>

[(0, array([1])), (0, array([1])), (0, array([1])), (0, array([1])), (0, array([1, 1])), (0, array([1, 1])), (0, array([1, 1])), (0, array([1, 1]))] (8, 2)
[(0, array([1.])), (0, array([1., 1.])), (0, array([1., 1., 1.])), (0, array([1., 1., 1., 1.])), (0, array([1., 1., 1., 1., 1.])), (0, array([1., 1., 1., 1., 1., 1.])), (0, array([1., 1., 1., 1., 1., 1., 1.])), (0, array([1., 1., 1., 1., 1., 1., 1., 1.]))] (8, 8)


In [7]:
from collections import namedtuple
Product = namedtuple('Product', ['pid_upper','pid_lower','offset', 'desc'])

class EndProduct:
    def __init__(self, trace, desc):
        self.desc = desc
        self.top_offset, self.kernel = trace
    
    @property
    def offset(self):
        return self.top_offset
        
    def __str__(self):
        s = f'EndProduct kernel={self.kernel}'
        return s
    
    def __eq__(self, other):
        '''
        Need this otherwise we keep adding the same kernel at the bottom level
        '''    
        iseq = np.array_equal(self.kernel, other.kernel)
        return iseq
    
    __repr__ = __str__

    

class EsamTree:
    def __init__(self, nchan, ichan):
        self._products = [] # list containing Products
        self.nchan = nchan
        self.ichan = ichan
        assert self.nchan == 1 or self.nchan % 2 == 0
        
        if nchan == 1:
            self.upper = None
            self.lower = None
        else:
            # normally we'd do FdmtDescriptor here, butif we make a new __class_ if we have a subclass
            # then it'll create the right class
            self.upper = self.__class__(nchan // 2, 2*ichan+1)
            self.lower = self.__class__(nchan // 2, 2*ichan)
        
    @property
    def ndm(self):
        return len(self._products)
    
    @property
    def nprod(self):
        return self.ndm
    
    @property
    def total_products(self):
        '''
        Returns total number of products in hierarchy
        '''
        if self.nchan == 1:
            return self.nprod
        else:
            return self.upper.total_products + self.lower.total_products + self.nprod
        
    def descriptor_tree(self, tree=None, level=0):
        '''
        Returns a list. Each element contains another list. That list contains all descriptors for that iteration
        '''
        if tree is None:
            tree = []
        if self.ichan == 0:
            tree.append([])
            
        tree[level].append(self)
        if self.nchan != 1:
            self.lower.descriptor_tree(tree, level+1)
            self.upper.descriptor_tree(tree, level+1)
        
        return tree              
    
    def get_all_pids(self, all_pids):
        list_idx = int(np.log2(self.nchan))
        if self.nchan == 1:
            all_pids[list_idx].extend(self._products)
        else:
            all_pids[list_idx].extend(self._products)
            self.upper.get_all_pids(all_pids)
            self.lower.get_all_pids(all_pids)
    
    def get_trace_pid(self, trace, doprint=False) -> int:
        '''
        Returns the product ID for the given trace
        trace:list of nchan values. Each value is tuple(width, offset)'
        '''
        assert len(trace) == self.nchan, f'Unexpedcted trace length in {self}. Was {len(trace)} expected {self.nchan}'
        
        n2 = self.nchan // 2
        if self.nchan == 1:
            prod = EndProduct(trace[0], self)
            total_offset = prod.offset
        else:
            pid_lower, offset_lower = self.lower.get_trace_pid(trace[:n2])
            pid_upper, offset_upper = self.upper.get_trace_pid(trace[n2:])
            prod = Product(pid_upper, pid_lower, offset_lower, self)
            total_offset = offset_lower + offset_upper
        
        added = False
        if prod not in self._products:
            self._products.append(prod)
            added = True
            
        pid = self._products.index(prod)
        if doprint:
            print(f'{self} of trace {trace}={prod}=PID{pid} added?={added} total_offset={total_offset}')  
        return pid, total_offset
           
    def __str__(self):
        s = f'Nchan={self.nchan} chan={self.ichan} ndm={self.ndm}'
        return s
    
    def __call__(self, din, doprint=False):
        assert din.shape[0] == self.nchan
        nt = din.shape[1]
        dout = np.zeros((self.nprod, nt)) # NT here is a bit tricky
        
        if self.nchan == 1:
            assert din.shape[0] == 1, f'Expected 1 channel. Got {din.shape}'
            for iprod, prod in enumerate(self._products):
                kernel = prod.kernel
                
                # this worked, but it would be good to spell it out more completely
                istart = 0
                dout[iprod, :] = np.convolve(din[0], kernel, mode='full')[istart:nt+istart]
                # n = len(kernel)                
                # TODO: Implement history
                # This is equivalent, but it's horrendously slow
                #for endidx in range(n, nt):                    
                #    # only look backwards in time
                #    startidx = endidx - n
                #    assert startidx >= 0
                #    dout[iprod, endidx] = np.dot(din[0, startidx:endidx], kernel)

        else:
            nf2 = self.nchan // 2 
            lower = self.lower(din[:nf2,...], doprint)
            upper = self.upper(din[nf2:,...], doprint)
            for iprod, prod in enumerate(self._products):

                off = prod.offset          
                

                if doprint:
                    print(f'{self.ichan}/{self.nchan} setup {off}', 'lower', lower[prod.pid_lower,:].shape, 'upper', upper[prod.pid_upper,:].shape, 'slice', slice(0, nt-off))
                    
                                    
                dout[iprod, :] = lower[prod.pid_lower, :]
                if off < nt: # if hte offset is larger than NT then ... I don't know what to do?!
                    dout[iprod, off:] += upper[prod.pid_upper, :nt-off]
                
                
                
                if doprint:
                    print(f'{self.ichan}/{self.nchan} upper {off}', upper[prod.pid_upper,:])
                    print(f'{self.ichan}/{self.nchan} lower {off}', lower[prod.pid_lower,:])
                    print(f'{self.ichan}/{self.nchan} dout  {off}', dout[iprod,:])
                
                
            
        return dout                        
    
    __repr__ = __str__
    
    
d = EsamTree(nchan,0)
d

Nchan=8 chan=0 ndm=0

In [8]:
[d.get_trace_pid(t) for t in traces]

[(0, 0), (1, 0)]

In [9]:
d.total_products

29

In [10]:
fig, ax = subplots(1,2)
img3 = trace2image(trace2)
pad = np.zeros((nchan,4))
img3 = np.hstack([pad,img3,pad])
dout3 = d(img3, doprint=True)
ax[0].imshow(img3, origin='lower')
ax[1].plot(dout3[0,])
print(img3.sum(), dout3.max())

<IPython.core.display.Javascript object>

0/2 setup 0 lower (10,) upper (10,) slice slice(0, 10, None)
0/2 upper 0 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
0/2 lower 0 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
0/2 dout  0 [0. 0. 0. 0. 0. 2. 0. 0. 0. 0.]
0/2 setup 0 lower (10,) upper (10,) slice slice(0, 10, None)
0/2 upper 0 [0. 0. 0. 0. 0. 1. 1. 0. 0. 0.]
0/2 lower 0 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
0/2 dout  0 [0. 0. 0. 0. 0. 2. 1. 0. 0. 0.]
1/2 setup 0 lower (10,) upper (10,) slice slice(0, 10, None)
1/2 upper 0 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
1/2 lower 0 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
1/2 dout  0 [0. 0. 0. 0. 0. 2. 0. 0. 0. 0.]
1/2 setup 0 lower (10,) upper (10,) slice slice(0, 10, None)
1/2 upper 0 [0. 0. 0. 0. 0. 1. 1. 1. 1. 0.]
1/2 lower 0 [0. 0. 0. 0. 0. 1. 1. 1. 0. 0.]
1/2 dout  0 [0. 0. 0. 0. 0. 2. 2. 2. 1. 0.]
0/4 setup 0 lower (10,) upper (10,) slice slice(0, 10, None)
0/4 upper 0 [0. 0. 0. 0. 0. 2. 0. 0. 0. 0.]
0/4 lower 0 [0. 0. 0. 0. 0. 2. 0. 0. 0. 0.]
0/4 dout  0 [0. 0. 0. 0. 0. 4. 0. 0. 0. 0.]
0/4 setup 0 lower (10,) upper (10,)

In [11]:
din = np.ones((d.nchan, 16))
din[:] = 0
din[:,8] = 1
dout = d(din, doprint=True)
print(dout.shape)
dout

0/2 setup 0 lower (16,) upper (16,) slice slice(0, 16, None)
0/2 upper 0 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
0/2 lower 0 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
0/2 dout  0 [0. 0. 0. 0. 0. 0. 0. 0. 2. 0. 0. 0. 0. 0. 0. 0.]
0/2 setup 0 lower (16,) upper (16,) slice slice(0, 16, None)
0/2 upper 0 [0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0.]
0/2 lower 0 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
0/2 dout  0 [0. 0. 0. 0. 0. 0. 0. 0. 2. 1. 0. 0. 0. 0. 0. 0.]
1/2 setup 0 lower (16,) upper (16,) slice slice(0, 16, None)
1/2 upper 0 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
1/2 lower 0 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
1/2 dout  0 [0. 0. 0. 0. 0. 0. 0. 0. 2. 0. 0. 0. 0. 0. 0. 0.]
1/2 setup 0 lower (16,) upper (16,) slice slice(0, 16, None)
1/2 upper 0 [0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0.]
1/2 lower 0 [0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0.]
1/2 dout  0 [0. 0. 0. 0. 0. 0. 0. 0. 2. 2. 2. 1. 0. 0. 0. 0.]
0/4 setup 0 

array([[0., 0., 0., 0., 0., 0., 0., 0., 8., 4., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 8., 7., 6., 5., 4., 3., 2., 1.]])

In [12]:
figure()
imshow(dout, origin='lower')

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x115440f10>

In [13]:
# unit tests
e1 = EndProduct((0,np.array([1,2,3])), None)
e2 = EndProduct((0,np.array([1,2,5])), None)
e3 = EndProduct((0,np.array([1])), None)

#e1 == e2
assert e1 in [e2,EndProduct((0,np.array([1,2,3])), None)] 
assert e3 not in [e1,e2] 
assert e3 not in [EndProduct((0, np.array([1,1])), None)]

In [14]:
mtree = d.descriptor_tree()
print('tree depth', len(mtree))
for itree, t in enumerate(mtree):
    print('tree level', itree, 'nnodes=', len(t), t)

tree depth 4
tree level 0 nnodes= 1 [Nchan=8 chan=0 ndm=2]
tree level 1 nnodes= 2 [Nchan=4 chan=0 ndm=2, Nchan=4 chan=1 ndm=2]
tree level 2 nnodes= 4 [Nchan=2 chan=0 ndm=2, Nchan=2 chan=1 ndm=2, Nchan=2 chan=2 ndm=2, Nchan=2 chan=3 ndm=2]
tree level 3 nnodes= 8 [Nchan=1 chan=0 ndm=1, Nchan=1 chan=1 ndm=2, Nchan=1 chan=2 ndm=2, Nchan=1 chan=3 ndm=2, Nchan=1 chan=4 ndm=2, Nchan=1 chan=5 ndm=2, Nchan=1 chan=6 ndm=2, Nchan=1 chan=7 ndm=2]


In [15]:
figure()
t = np.zeros(16)
t[8] = 1
mode = 'full'
plot(np.convolve(t, np.array([1]), mode))
plot(np.convolve(t, np.array([1,1]), mode)[1:])
plot(np.convolve(t, np.array([1,1,1]), mode)[2:])

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x1154944c0>]

In [16]:
d.lower.nchan

4

In [17]:
d.lower.lower.nchan

2

In [18]:
d.upper.upper.upper.ichan

7

In [19]:
len(trace1)

8

In [20]:
d.get_trace_pid(trace1)

(2, 0)

In [21]:
d.get_trace_pid(trace1)

(2, 0)

In [22]:
d._products

[Product(pid_upper=0, pid_lower=0, offset=0, desc=Nchan=8 chan=0 ndm=3),
 Product(pid_upper=1, pid_lower=1, offset=0, desc=Nchan=8 chan=0 ndm=3),
 Product(pid_upper=2, pid_lower=0, offset=0, desc=Nchan=8 chan=0 ndm=3)]

In [23]:
d.total_products

37

In [24]:
trace1

[(0, array([1])),
 (0, array([1])),
 (0, array([1])),
 (0, array([1])),
 (0, array([1])),
 (0, array([1])),
 (0, array([1])),
 (0, array([1]))]

In [25]:
trace2

[(0, array([1])),
 (0, array([1])),
 (0, array([1])),
 (0, array([1])),
 (0, array([1, 1])),
 (0, array([1, 1])),
 (0, array([1, 1])),
 (0, array([1, 1]))]

In [26]:
d.get_trace_pid(trace2)

(0, 0)

In [27]:
d.total_products

37

In [28]:
trace4

[(1, array([1])),
 (1, array([1])),
 (1, array([1])),
 (1, array([1])),
 (1, array([1, 2, 1])),
 (1, array([1, 2, 1])),
 (1, array([1, 2, 1])),
 (1, array([1, 2, 1]))]

In [29]:
d.get_trace_pid(trace3)

(3, 8)

In [30]:
d.get_trace_pid(trace4)

(4, 8)

In [31]:
d._products

[Product(pid_upper=0, pid_lower=0, offset=0, desc=Nchan=8 chan=0 ndm=5),
 Product(pid_upper=1, pid_lower=1, offset=0, desc=Nchan=8 chan=0 ndm=5),
 Product(pid_upper=2, pid_lower=0, offset=0, desc=Nchan=8 chan=0 ndm=5),
 Product(pid_upper=3, pid_lower=2, offset=4, desc=Nchan=8 chan=0 ndm=5),
 Product(pid_upper=4, pid_lower=2, offset=4, desc=Nchan=8 chan=0 ndm=5)]

In [32]:
d.total_products

52

In [33]:
all_pids = [[] for i in range(int(np.log2(nchan))+1)]
d.get_all_pids(all_pids)
all_pids[1]

[Product(pid_upper=0, pid_lower=0, offset=0, desc=Nchan=2 chan=3 ndm=5),
 Product(pid_upper=1, pid_lower=1, offset=0, desc=Nchan=2 chan=3 ndm=5),
 Product(pid_upper=2, pid_lower=2, offset=0, desc=Nchan=2 chan=3 ndm=5),
 Product(pid_upper=2, pid_lower=2, offset=1, desc=Nchan=2 chan=3 ndm=5),
 Product(pid_upper=3, pid_lower=3, offset=1, desc=Nchan=2 chan=3 ndm=5),
 Product(pid_upper=0, pid_lower=0, offset=0, desc=Nchan=2 chan=2 ndm=5),
 Product(pid_upper=1, pid_lower=1, offset=0, desc=Nchan=2 chan=2 ndm=5),
 Product(pid_upper=2, pid_lower=2, offset=0, desc=Nchan=2 chan=2 ndm=5),
 Product(pid_upper=2, pid_lower=2, offset=1, desc=Nchan=2 chan=2 ndm=5),
 Product(pid_upper=3, pid_lower=3, offset=1, desc=Nchan=2 chan=2 ndm=5),
 Product(pid_upper=0, pid_lower=0, offset=0, desc=Nchan=2 chan=1 ndm=3),
 Product(pid_upper=1, pid_lower=1, offset=0, desc=Nchan=2 chan=1 ndm=3),
 Product(pid_upper=0, pid_lower=0, offset=1, desc=Nchan=2 chan=1 ndm=3),
 Product(pid_upper=0, pid_lower=0, offset=0, desc=N

In [34]:
d.total_products

52

In [35]:
d.descriptor_tree()

[[Nchan=8 chan=0 ndm=5],
 [Nchan=4 chan=0 ndm=3, Nchan=4 chan=1 ndm=5],
 [Nchan=2 chan=0 ndm=3,
  Nchan=2 chan=1 ndm=3,
  Nchan=2 chan=2 ndm=5,
  Nchan=2 chan=3 ndm=5],
 [Nchan=1 chan=0 ndm=1,
  Nchan=1 chan=1 ndm=2,
  Nchan=1 chan=2 ndm=2,
  Nchan=1 chan=3 ndm=2,
  Nchan=1 chan=4 ndm=4,
  Nchan=1 chan=5 ndm=4,
  Nchan=1 chan=6 ndm=4,
  Nchan=1 chan=7 ndm=4]]

In [36]:
def fdmt_iteration(iterno, input_buffer, output_buffer, fdmt_tables):

    for output_channel_index in fdmt_tables.channels(iterno):
        for odm, id1, id2, offset, inchan1, inchan2 in fdmt_tables.dmlist(iterno, output_channel_index):
            output_buffer[odm, output_channel_index, :] = input_buffer[id1, inchan1, :] \
                                                        + input_buffer[id2, inchan2, offset:]
                
                
def fdmt_function(input_buffer, output_buffer, fdmt_tables):
    buffers = [np.zeros(fdmt_tables.output_buffer(i)) for i in range(fdmt_tables.niter)]
    
    buffers[0] = input_buffer # allocate the input
    
    for iterno in fdmt_tables:
        fdmt_iteration(iterno, buffers[i], buffers[i+1], fdmt_tables)
        
        
    return buffers[-1]



# ESAM - Efficient Summation of Arbirary Masks
This whole thing is completely general - we can have a full 1D convolution with arbitrary weights per channel and it remains general (in fact it's simpler to understand and code up)



In [37]:
from craft import fdmt

In [38]:
fdmt.__file__

'/Users/ban115/bolton/craft/craft-github/src/craft/fdmt.py'

In [39]:
nchan = 256
thefdmt = fdmt.Fdmt(735, 1, nchan, 1024, 256, do_correction=True)
thefdmt_nocorr = fdmt.Fdmt(735, 1, nchan, 1024, 256, do_correction=False)

In [40]:
thefdmt_nocorr.trace_dm(1)

[(0, 1, 0),
 (1, 0, 1),
 (2, 0, 1),
 (3, 0, 1),
 (4, 0, 1),
 (5, 0, 1),
 (6, 0, 1),
 (7, 0, 1),
 (8, 0, 1),
 (9, 0, 1),
 (10, 0, 1),
 (11, 0, 1),
 (12, 0, 1),
 (13, 0, 1),
 (14, 0, 1),
 (15, 0, 1),
 (16, 0, 1),
 (17, 0, 1),
 (18, 0, 1),
 (19, 0, 1),
 (20, 0, 1),
 (21, 0, 1),
 (22, 0, 1),
 (23, 0, 1),
 (24, 0, 1),
 (25, 0, 1),
 (26, 0, 1),
 (27, 0, 1),
 (28, 0, 1),
 (29, 0, 1),
 (30, 0, 1),
 (31, 0, 1),
 (32, 0, 1),
 (33, 0, 1),
 (34, 0, 1),
 (35, 0, 1),
 (36, 0, 1),
 (37, 0, 1),
 (38, 0, 1),
 (39, 0, 1),
 (40, 0, 1),
 (41, 0, 1),
 (42, 0, 1),
 (43, 0, 1),
 (44, 0, 1),
 (45, 0, 1),
 (46, 0, 1),
 (47, 0, 1),
 (48, 0, 1),
 (49, 0, 1),
 (50, 0, 1),
 (51, 0, 1),
 (52, 0, 1),
 (53, 0, 1),
 (54, 0, 1),
 (55, 0, 1),
 (56, 0, 1),
 (57, 0, 1),
 (58, 0, 1),
 (59, 0, 1),
 (60, 0, 1),
 (61, 0, 1),
 (62, 0, 1),
 (63, 0, 1),
 (64, 0, 1),
 (65, 0, 1),
 (66, 0, 1),
 (67, 0, 1),
 (68, 0, 1),
 (69, 0, 1),
 (70, 0, 1),
 (71, 0, 1),
 (72, 0, 1),
 (73, 0, 1),
 (74, 0, 1),
 (75, 0, 1),
 (76, 0, 1),
 (77, 0, 

In [41]:
fdmt.__file__

'/Users/ban115/bolton/craft/craft-github/src/craft/fdmt.py'

In [42]:
idm = 128
fig, ax = subplots(1,2, sharex=True, sharey=True)
t = thefdmt.add_frb_track(idm)
t_nocorr = thefdmt_nocorr.add_frb_track(idm)
offsets = idm*cff(thefdmt.freqs, thefdmt.freqs[0], thefdmt.freqs[-1], thefdmt.freqs[0])

ax[0].imshow(t, aspect='auto', origin='lower', interpolation='none')
ax[0].plot(offsets[-1] - offsets, chans)
xnc, nt = t.shape
times = np.arange(nt)
chans = np.arange(nc)

ax[0].set_title(f'Corrections=True ')
ax[1].set_title(f'Corrections=False ')
ax[1].imshow(t_nocorr, aspect='auto', origin='lower', interpolation='none')
ax[1].plot(offsets[-1] - offsets, chans)
fig.suptitle(f'idm={idm}')


<IPython.core.display.Javascript object>

NameError: name 'cff' is not defined

In [43]:
thefdmt.hist_state_shape

[array([256,   8, 256]),
 array([128,  14, 270]),
 array([ 64,  26, 282]),
 array([ 32,  50, 306]),
 array([ 16,  98, 354]),
 array([  8, 188, 444]),
 array([  4, 352, 608]),
 array([  2, 628, 884]),
 array([   1, 1024, 1280])]

In [44]:
thefdmt.hist_nf_data[0][0]

(735.0,
 737.0,
 735.5,
 736.5,
 14,
 [(0, 0, 0, 0, (0, 0, 0), (0, 0, 0), (1, 0, 0)),
  (0, 0, 1, 0, (0, 1, 1), (0, 0, 1), (1, 0, 0)),
  (1, 1, 2, 0, (0, 2, 2), (0, 1, 2), (1, 0, 0)),
  (1, 1, 2, 1, (0, 3, 2), (0, 1, 2), (1, 1, 0)),
  (1, 1, 3, 1, (0, 4, 3), (0, 1, 3), (1, 1, 0)),
  (1, 1, 4, 1, (0, 5, 4), (0, 1, 4), (1, 1, 0)),
  (2, 2, 5, 1, (0, 6, 5), (0, 2, 5), (1, 1, 0)),
  (2, 2, 5, 2, (0, 7, 5), (0, 2, 5), (1, 2, 0)),
  (2, 2, 6, 2, (0, 8, 6), (0, 2, 6), (1, 2, 0)),
  (2, 2, 7, 2, (0, 9, 7), (0, 2, 7), (1, 2, 0)),
  (3, 3, 8, 2, (0, 10, 8), (0, 3, 8), (1, 2, 0)),
  (3, 3, 8, 3, (0, 11, 8), (0, 3, 8), (1, 3, 0)),
  (3, 3, 9, 3, (0, 12, 9), (0, 3, 9), (1, 3, 0)),
  (3, 3, 10, 3, (0, 13, 10), (0, 3, 10), (1, 3, 0))])

In [45]:
from craft.fdmt import cff
cff(735,736,735,735+256)*1024

6.180499177588411

In [46]:
def dmmillis(dm,f1,f2):
    return 4.15*dm*(f1**-2 - f2**-2)

dmmillis(300,.735,0.735+.256)

1036.879100802759

In [47]:
dmmillis(300,.735,0.735+.001)

6.2582328415724

In [48]:
def image2trace(image):
    '''
    See vivek's idea for an improvement
    '''
    nc = image.shape[0]
    trace = []
    last_start_idx = None
    last_end_idx = None
    for c in range(nc):
        x = image[c,:]
        kernel_idxs = np.where(x != 0)[0]
        startidx = kernel_idxs[0]
        endidx = kernel_idxs[-1]
        krnl = x[startidx:endidx+1]
        if c > 0:
            newoff = last_end_idx - endidx
            trace[-1][0] = newoff            
        else:
            newoff = None
            
        
        trace.append([0, krnl]) # fill offset = 0 to start with. it'll be overwritten next iteration
        #print(c, 'start=',startidx, 'end=',endidx, krnl, kernel_idxs, x, 'newoff', newoff)
        last_start_idx = startidx    
        last_end_idx = endidx
        
    return trace

        
image2trace(t)
#mask_to_trace(t[:,::-1])

[[1, array([1.], dtype=float32)],
 [1, array([1.], dtype=float32)],
 [1, array([1.], dtype=float32)],
 [0, array([1.], dtype=float32)],
 [1, array([1.], dtype=float32)],
 [1, array([1.], dtype=float32)],
 [1, array([1.], dtype=float32)],
 [0, array([1.], dtype=float32)],
 [1, array([1.], dtype=float32)],
 [1, array([1.], dtype=float32)],
 [1, array([1.], dtype=float32)],
 [0, array([1.], dtype=float32)],
 [1, array([1.], dtype=float32)],
 [1, array([1.], dtype=float32)],
 [1, array([1.], dtype=float32)],
 [0, array([1.], dtype=float32)],
 [1, array([1.], dtype=float32)],
 [0, array([1.], dtype=float32)],
 [1, array([1.], dtype=float32)],
 [1, array([1.], dtype=float32)],
 [1, array([1.], dtype=float32)],
 [0, array([1.], dtype=float32)],
 [1, array([1.], dtype=float32)],
 [1, array([1.], dtype=float32)],
 [1, array([1.], dtype=float32)],
 [0, array([1.], dtype=float32)],
 [1, array([1.], dtype=float32)],
 [1, array([1.], dtype=float32)],
 [1, array([1.], dtype=float32)],
 [0, array([1.

In [49]:
def trace2image(trace):
    '''
    Produces and image given a trace
    Trace is a list of (offset, kernel) tuples - length nchannels
    offset if positive for backwards in time.
    Offset for channel i is between channel i and channel i+1
    Kernel for channel 0 ends at the final time
    '''
    nc = len(trace)
    kernels = [t[1] for t in trace]
    offsets = [t[0] for t in trace]
    last_off = offsets[-1]
    total_offset = sum(offsets)
    kernel_lengths = map(len, kernels)
    largest_kernel = max(kernel_lengths)
    nt = largest_kernel + total_offset
    print('trace2image nc=', nc, 'total_offset', total_offset, 'largest_kernel', largest_kernel, 'nt=', nt)
    mask = np.zeros((nc,nt))
    curr_endidx = nt# cumulative offset - right hand side of kernel
    for ichan, (offset, kernel) in enumerate(trace):       
        endidx = curr_endidx
        startidx = endidx - len(kernel) 
        print(ichan, 'start=', startidx, 'end=', endidx, offset, kernel, 'curr', curr_endidx)
        mask[ichan, startidx:endidx] = kernel
        curr_endidx -= offset
        
   #assert coffset - len(kernel) == 0, f'Coffset should be 0 it was {coffset}'
        
        
    return mask

fig, ax = subplots(1,3)
t2 = trace2image(image2trace(t))
ax[0].imshow(t, origin='lower')
ax[1].imshow(t2, origin='lower')
ax[2].imshow(t2 - t, origin='lower')

<IPython.core.display.Javascript object>

trace2image nc= 256 total_offset 128 largest_kernel 1 nt= 129
0 start= 128 end= 129 1 [1.] curr 129
1 start= 127 end= 128 1 [1.] curr 128
2 start= 126 end= 127 1 [1.] curr 127
3 start= 125 end= 126 0 [1.] curr 126
4 start= 125 end= 126 1 [1.] curr 126
5 start= 124 end= 125 1 [1.] curr 125
6 start= 123 end= 124 1 [1.] curr 124
7 start= 122 end= 123 0 [1.] curr 123
8 start= 122 end= 123 1 [1.] curr 123
9 start= 121 end= 122 1 [1.] curr 122
10 start= 120 end= 121 1 [1.] curr 121
11 start= 119 end= 120 0 [1.] curr 120
12 start= 119 end= 120 1 [1.] curr 120
13 start= 118 end= 119 1 [1.] curr 119
14 start= 117 end= 118 1 [1.] curr 118
15 start= 116 end= 117 0 [1.] curr 117
16 start= 116 end= 117 1 [1.] curr 117
17 start= 115 end= 116 0 [1.] curr 116
18 start= 115 end= 116 1 [1.] curr 116
19 start= 114 end= 115 1 [1.] curr 115
20 start= 113 end= 114 1 [1.] curr 114
21 start= 112 end= 113 0 [1.] curr 113
22 start= 112 end= 113 1 [1.] curr 113
23 start= 111 end= 112 1 [1.] curr 112
24 start= 11

<matplotlib.image.AxesImage at 0x124704100>

In [50]:
d = EsamTree(nchan,0)
for idm in range(0, 64):
    t = thefdmt.add_frb_track(idm)
    d.get_trace_pid(image2trace(t))


In [51]:
print(len(d._products))

64


In [52]:
din = thefdmt.add_frb_track(8)
pad = np.zeros((nchan, 32))
din = np.hstack([pad, din, pad])
din.shape

(256, 73)

In [53]:
dout = d(din)

In [54]:
dout.shape

(64, 73)

In [55]:
figure()
print(sum(din), dout.max())
imshow(dout)

<IPython.core.display.Javascript object>

256.0 256.0


<matplotlib.image.AxesImage at 0x1247f91c0>

In [57]:
dms = [1,2,4,8,16,24, 32, 48, 63]
fig, ax = subplots(len(dms), 2)
plt.subplots_adjust(hspace=0.5)
fig.set_size_inches((8,16))


for iax, idm in enumerate(dms):
    din = thefdmt.add_frb_track(idm)
    pad = np.zeros((nchan, 32))
    din = np.hstack([pad, din, pad])
    dout = d(din)
    din.shape
    print(idm, sum(din), dout.max())
    ax[iax,0].imshow(din, aspect='auto', origin='lower')
    ax[iax,1].imshow(dout, aspect='auto', origin='lower')
    ax[iax,0].set_title(f'idm={idm} Sum={din.sum()}')
    idx = np.unravel_index(dout.argmax(), dout.shape)
    ax[iax,1].set_title(f'max={dout.max()} idx={idx}')


<IPython.core.display.Javascript object>

1 256.0 256.0
2 256.0 256.0
4 256.0 256.0
8 256.0 256.0
16 256.0 256.0
24 256.0 256.0
32 256.0 256.0
48 256.0 256.0
63 256.0 256.0


# OK, now play the whole game

In [60]:
nchan = 256
ndm = 1024
tree = EsamTree(nchan,0)
thefdmt = fdmt.Fdmt(850e6, 1e6, nchan, ndm, 256)
idms = np.arange(ndm)
pids = []
nproducts = []
for idm in idms[::2]:
    t = thefdmt.add_frb_track(idm)
    pid = tree.get_trace_pid(image2trace(t))
    nproducts.append(tree.total_products)
    pids.append(pid)
    

In [61]:
# Now verify it
maxvals = []
maxt = []
maxprod = []
insums = []

for idm in idms:
    t = thefdmt.add_frb_track(idm)
    insums.append(t.sum())   

    pad = np.zeros((nchan, 32))
    img = np.hstack([pad,t,pad])
    dout = tree(img, doprint=False)
    maxidx = dout.argmax()
    prodmax, tmax = np.unravel_index(maxidx, dout.shape)
    maxt.append(tmax)
    maxprod.append(prodmax)
    maxvals.append(dout.max())
    
insums = np.array(insums)
maxvals = np.array(maxvals)

In [62]:
figure()
plot(idms, maxprod)
xlabel('Input IDM')
ylabel('best product')

<IPython.core.display.Javascript object>

Text(0, 0.5, 'best product')

In [63]:
figure()
plot(idms, maxvals / insums, '-x')
xlabel('IDM')
ylabel('recovery ratio')

<IPython.core.display.Javascript object>

Text(0, 0.5, 'recovery ratio')

In [64]:
figure()
plot(nproducts)
xlabel('IDM')
ylabel('Total number of products')

<IPython.core.display.Javascript object>

Text(0, 0.5, 'Total number of products')

In [65]:
figure()
plot(idms, maxt)
xlabel('IDM')
ylabel('sample number of best product')

<IPython.core.display.Javascript object>

Text(0, 0.5, 'sample number of best product')

In [66]:
tree.descriptor_tree()[2]

[Nchan=64 chan=0 ndm=328,
 Nchan=64 chan=1 ndm=271,
 Nchan=64 chan=2 ndm=223,
 Nchan=64 chan=3 ndm=185]

In [67]:
cff(700,702,700,704) 

0.5021367463555407

In [68]:
+ cff(702,704,700,704)

0.49786325364445927

In [69]:
cff(700,702,700,704) + cff(702,704,700,704)

1.0

In [70]:
weights = np.zeros((4,2))
# time spent in the cell with boundaries channel 700-701, t=1-2
w1 =  cff(700,701,700,704)
weights[0,1] = w1
w1

0.2516056089898913

In [71]:
# Total time spent in channel 701-702
cff(701,702,700,704)

0.2505311373656494

In [72]:
# time spent in the cell with boundaries channel 701-702, t=1-2
w2=0.5 - w1
weights[1,1] = w2
w2

0.24839439101010868

In [73]:
w3 = cff(701,702,700,704) - w2
weights[1,0] = w3
w3

0.0021367463555407307

In [74]:
w4 = cff(702,703,700,704) 
weights[2,0] = w4
w4

0.2494627750276134

In [75]:
w5 = cff(703,704,700,704) 
weights[3,0] = w5
w5

0.2484004786168459

In [76]:
np.flipud(weights/weights.max())

array([[0.98726129, 0.        ],
       [0.99148336, 0.        ],
       [0.00849244, 0.9872371 ],
       [0.        , 1.        ]])

In [77]:
figure()
imshow(weights, origin='lower')

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x124b27730>

In [78]:
cff(700,700+99,700,700+256)

0.5011396387777056

In [79]:
w1 + w2 + w3

0.5021367463555407

In [80]:
cff(702,700,704,700)

0.5021367463555407

In [83]:
nbl = 435
nchan = 288
nt = 1
%timeit np.exp(1j*np.zeros((nbl,nchan,nt)))

1.32 ms ± 31.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
