In [1]:
import stmpy
import numpy as np
from scipy.interpolate import interp1d
from scipy.optimize import minimize

class Cancel(object):
    '''
    xdata = geohpone data
    ydata = tip data
    p = [n, o, b]
    n = length of segments
    o = number of overlaping elements
    b = Kaiser window parameter
    dt1 = 1/f
    '''
       
    def __init__(self, xdata, ydata, dt1, zdata=None, optimize=True,
                 p=None,**kwarg):
        self.dt1 = dt1
        if p == None:
            self.p = [len(xdata),0,0]
        else:
            self.p = p
        [self.transf, self.freqs] = self.get_transf(xdata, ydata, self.p, dt1)
        self.xdata = [xdata]
        self.ydata = [ydata]
        self.N = len(self.xdata)
        
    def fit_cancel(self, array, n=2):
        t=np.linspace(0,len(array),len(array))
        fit = np.polyfit(t,array,n)
        corrective = fit[n]*np.ones(np.shape(t))
        for j in np.arange(n):
            corrective += fit[j]*(t**(n-j))
        return array - corrective
    
    # x is the array, n is the length of each segment, o is the number of overlaping elements adjacent segments share
    def chop(self,x,n,o):
        segs = np.zeros(((len(x)-o)//(n-o),n),dtype=np.complex)
        j = 0
        while len(x) >= n:
#             i = 0
#             mini_seg = np.zeros(n)
#             while i < n:
#                 mini_seg[i] = x[i]
#                 i+=1
#             segs[j,:] = mini_seg
            segs[j,:] = self.fit_cancel(x[:n])
            j+=1
            x=x[n-o:]
#             x = x[i-o:]
        return segs

    # Applies a Kaiser Window paramaterized by b to each segment
    def window(self,segs,b):
        seg_num = len(segs[:,0])
        seg_leng = len(segs[0,:])
        for i in range(0,seg_num):
            segs[i,:] = segs[i,:]*np.kaiser((seg_leng),b)
        return segs

    # Fourier Transforms each Row, given time between samples, dt
    def fourier(self,segs,dt):
        seg_num = len(segs[:,0])
        freqs = np.fft.fftfreq(len(segs[0,:]),dt)
        for i in range(0,seg_num):
            segs[i,:] = np.fft.fft(segs[i,:])
        return freqs,segs
        
    def get_transf(self, xdata, ydata, p, dt):
        
        n = int(p[0])
        o = int(p[1])
        b = p[2]
        
        top = self.chop(ydata,n,o)
        top = self.window(top,b)
        [freqs, ftop] = self.fourier(top,dt)
        
        bottom = self.chop(xdata,n,o)
        bottom = self.window(bottom,b)
        [freqs, fbottom] = self.fourier(bottom,dt)
        
        num_segs = len(fbottom[:,0])
        numer = 0
        denomer = 0
        for i in range(0,num_segs):
            numer += ftop[i,:]*np.conjugate(fbottom[i,:])
            denomer += fbottom[i,:]*np.conjugate(fbottom[i,:])
        
        numer = numer/num_segs
        denomer = denomer/num_segs
        
        tf = numer/denomer
        tf_spl = interp1d(freqs, tf)
        return tf_spl,freqs

    def update_transf(self,xdata, ydata, dt, p = None):
        self.xdata = xdata
        self.ydata = ydata
        self.N += len(self.xdata)
        n = len(self.xdata)
        
        if p == None:
            p = self.p
        
        [new_transf,new_freqs] = self.get_transf(xdata,ydata,p,dt)
        if new_freqs[np.argmax(new_freqs)] < self.freqs[np.argmax(self.freqs)]:
            self.freqs = new_freqs

        tf = ((self.N-n)/self.N)*self.transf(self.freqs) + (n/self.N)*new_transf(self.freqs)
        self.transf = interp1d(self.freqs,tf)

    def optimize(self,xdata1, ydata1, xdata2, ydata2, dt):
        ns = np.linspace(1, 20, 20)
        os = np.linspace(1, 20, 20)
        betas = np.linspace(0, 14, 20)
        maybe_error = self.error(ydata2,np.zeros(len(ydata2)))
        nmaybe = 0
        bmaybe = 0
        omaybe = 0
        for n in ns:
            for o in os:
                for beta in betas:
                    [tf1,freqs1] = self.get_transf(xdata1, ydata1, [len(xdata1)//n,o,beta], dt)
                    error = self.error(ydata2, self.create_drive(xdata2,tf = tf1,freqs = freqs1))
                    if maybe_error < error:
                        maybe_error = error
                        nmaybe = n
                        bmaybe = beta
                        omaybe = o
        return [nmaybe,omaybe,bmaybe]
    
    def error(self,x,y):
        return np.sum(np.absolute(np.subtract(x,y)))
    
    def create_drive(self, zdata, dt2 = None, tf = None, freqs = None):
        if dt2 == None:
            dt2 = self.dt1
        fbottom2 = np.fft.fft(zdata)
        freq2 = np.fft.fftfreq(len(zdata),dt2)
        if tf == None:
            freq2[freq2 > self.freqs[np.argmax(self.freqs)]] = self.freqs[np.argmax(self.freqs)]
            freq2[freq2 < self.freqs[np.argmin(self.freqs)]] = self.freqs[np.argmin(self.freqs)]
            driver = np.fft.ifft(self.transf(freq2)*fbottom2)
        else:
            freq2[freq2 > freqs[np.argmax(freqs)]] = freqs[np.argmax(freqs)]
            freq2[freq2 < freqs[np.argmin(freqs)]] = freqs[np.argmin(freqs)]
            driver = np.fft.ifft(tf(freq2)*fbottom2)
        return driver