In [None]:
class LinearSpecial():
    def __init__(self, xc, yc, xi):
        self.xc = xc
        self.yc = yc
        self.dx = xc[1] - xc[0]
        self.dy = yc[1] - yc[0]
        self.xi = xi
        self.Nx = len(xc)
        self.Ny = len(yc)
        self.fold = None
        
    def __call__(self, X, t):
        # Calculate indices for lower left corner in cell
        i = np.floor((X[0] - self.xc[0]) / self.dx).astype(np.int32)
        j = np.floor((X[1] - self.yc[0]) / self.dy).astype(np.int32)
        
        # If outside the domain, stop
        if (i >= Nx - 2) or (j >= Ny - 2) or (i < 0) or (j < 0):
            raise IndexError
        
        # Use the lower left corner as reference, calculate
        # the rotation of the other vectors, and rotate by 180
        # degrees if required (due to orientational discontinuity)
        subxi = self.xi[:,i:i+2, j:j+2].copy()
        dotp = np.sum(subxi[:,0,0].reshape(2,1,1) * subxi, axis = 0)
        subxi[:, dotp < 0] =  -subxi[:, dotp < 0]
        
        # Linear interpolation
        Wx0 = (self.xc[i+1] - X[0]) / self.dx
        Wx1 = 1 - Wx0
        Wy0 = (self.yc[j+1] - X[1]) / self.dy
        Wy1 = 1 - Wy0

        V = Wy0*(Wx0*subxi[:,0,0] + Wx1*subxi[:,1,0]) + Wy1*(Wx0*subxi[:,0,1] + Wx1*subxi[:,1,1])
                
        # Check orientation against previous vector
        if self.fold is None:
            return V
        else:
            # If dot product is negative, flip direction
            return V * np.sign(np.dot(V, self.fold))
            
def half_strainline(x0, Tmax, h, f, xc, yc, lambda2, ABtrue, pm, max_notAB = 0.3, t=0):
    # Re-initialise the f-function
    f.fold = None

    Nt = int((Tmax-t) / h)
    xs = np.zeros((2, Nt))
    xs[:,0] = x0
    dx = xc[1] - xc[0]
    dy = yc[1] - yc[0]
    
    # Buffer zone outside domain
    xbuf = 0.01
    ybuf = 0.005
    
    # Parameters of strainline
    length = 0.0
    notABlength = 0.0
    mulambda = 0.0

    for n in range(1, Nt):
        f.fold = f(xs[:,n-1], t)
        try:
            xs[:,n] = rk4(xs[:,n-1], t, pm*h, f)
        except IndexError:
            break
        if xs[0,n] < (0.0-xbuf) or (2.0+xbuf) < xs[0,n] or xs[1,n] < (0.0-ybuf) or (1.0+ybuf) < xs[1,n]:
            break
        if notABlength > max_notAB:
            break

        # increment length
        dl = np.sqrt(np.sum((xs[:,n] - xs[:,n-1])**2))
        length += dl
        # calculate closest grid point
        i = np.floor(((xs[0,n]+dx/2) - xc[0]) / dx).astype(np.int32)
        j = np.floor(((xs[1,n]+dy/2) - yc[0]) / dy).astype(np.int32)
        # Use this to look up lambda2, and add to running total
        mulambda += lambda2[i,j] * dl
        # Check if A and B are satisfied
        if ABtrue[i, j]:
            notABlength = 0
        else:
            notABlength += dl

    if length > 0:
        mulambda = mulambda / length
    else:
        mulambda = 0.0
    return xs[:,:n], length, mulambda

def strainline(x0, Tmax, h, f, xc, yc, lambda2, ABtrue, max_notAB = 0.3, t=0):
    line1, length1, mulambda1 = half_strainline(x0, Tmax, h, f, xc, yc, lambda2, ABtrue, pm = +1, max_notAB = max_notAB, t=t)
    line2, length2, mulambda2 = half_strainline(x0, Tmax, h, f, xc, yc, lambda2, ABtrue, pm = -1, max_notAB = max_notAB, t=t)
    length = length1 + length2
    if length > 0:
        mulambda = (length1*mulambda1 + length2*mulambda2) / length
    else:
        mulambda = 0.0
    N1  = line1.shape[1]
    N2  = line2.shape[1]
    line = np.zeros((2, N1+N2-1))
    line[:,:N1] = line1[:,::-1]
    line[:,N1:] = line2[:,1:]
    return line, length, mulambda