In [2]:
import time

import numpy as np


class STDistance():
    def __init__(self, a: np.ndarray, b: np.ndarray):
        # dist matrix
        # cost matrix
        pass

In [144]:
from scipy.spatial.distance import cdist, euclidean
from typing import Callable, Tuple
import numpy as np


class STDTW:
    def __init__(self, a: np.ndarray, b: np.ndarray, dist: Callable[[np.ndarray, np.ndarray], float]):
        self.a = a
        self.b = b
        self.dist_func = dist
        self.dist_matrix = cdist(a, b, self.dist_func)
        self.cost, self.path, self.C, self.P = self.full()

    def move(self, a_remove: int = 0, b_remove: int = 0, a_add: np.ndarray = None, b_add: np.ndarray = None):
        next_start:Tuple[int,int] = (a_remove, b_remove)
        if next_start not in self.path:
            if a_remove>0:
                self.a = self.a[a_remove:]
            if b_remove>0:
                self.b = self.b[b_remove:]
            if a_add is not None:
                self.a = np.vstack([self.a,a_add])
            if b_add is not None:
                self.b = np.vstack([self.b,b_add])

            self.dist_matrix = cdist(self.a, self.b, self.dist_func)

            self.cost, self.path, self.C, self.P = self.full()

            return self.cost
        else:
            if a_add is not None or b_add is not None:
                self.cost = self.add(a_add, b_add)
            next_start:Tuple[int,int] = (a_remove, b_remove)
            if next_start in self.path:
                if a_remove > 0 or b_remove > 0:
                    self.cost = self.remove(a_remove, b_remove)
            else:
                if a_remove > 0 or b_remove > 0:
                    if a_remove>0:
                        self.a = self.a[a_remove:]
                    if b_remove>0:
                        self.b = self.b[b_remove:]
                    self.dist_matrix = cdist(self.a, self.b, self.dist_func)
                    self.cost, self.path, self.C, self.P = self.full()
            # return and call full in async
            return self.cost

    def remove(self, a_remove, b_remove):
        # hit
        self.cost = self.cost - (self.C[a_remove+1,b_remove+1]-self.dist_matrix[a_remove,b_remove])
        return self.cost

    def add(self, a_add: np.ndarray, b_add: np.ndarray):
        if a_add is not None:
            self.dist_matrix = np.vstack([self.dist_matrix, cdist(a_add, self.b)])
            self.a = np.vstack([self.a, a_add])
        if b_add is not None:
            self.dist_matrix = np.hstack([self.dist_matrix, cdist(self.a, b_add)])
            self.b = np.vstack([self.b, b_add])
        t0 = self.a
        t1 = self.b
        n0 = len(t0)
        n1 = len(t1)
        C = np.zeros((n0 + 1, n1 + 1))
        C[1:, 0] = float('inf')
        C[0, 1:] = float('inf')
        P = np.zeros((n0 + 1, n1 + 1, 2), dtype=int)
        ori0, ori1 = self.C.shape
        C[:ori0, :ori1] = self.C
        P[:ori0, :ori1] = self.P
        for i in np.arange(1,ori0):
            for j in np.arange(ori1,n1+1):
                prev = np.argmin([C[i, j - 1], C[i - 1, j - 1], C[i - 1, j]])
                if prev == 0:
                    prev_i = i
                    prev_j = j - 1
                elif prev == 1:
                    prev_i = i - 1
                    prev_j = j - 1
                else:
                    prev_i = i - 1
                    prev_j = j
                C[i, j] = self.dist_matrix[i - 1, j - 1] + C[prev_i, prev_j]
                P[i, j] = [prev_i, prev_j]
        for i in np.arange(ori0,n1+1):
            for j in np.arange(1,n1+1):
                prev = np.argmin([C[i, j - 1], C[i - 1, j - 1], C[i - 1, j]])
                if prev == 0:
                    prev_i = i
                    prev_j = j - 1
                elif prev == 1:
                    prev_i = i - 1
                    prev_j = j - 1
                else:
                    prev_i = i - 1
                    prev_j = j
                C[i, j] = self.dist_matrix[i - 1, j - 1] + C[prev_i, prev_j]
                P[i, j] = [prev_i, prev_j]
        dtw = C[n0, n1]
        pt = (n0, n1)
        path:list = [(pt[0]-1,pt[1]-1)]
        while pt != (1,1):
            pt = tuple(P[pt[0], pt[1]])
            path.append((pt[0]-1,pt[1]-1))
        # path -= 1
        path.reverse()
        self.cost, self.path, self.C, self.P = dtw, path, C, P
        return dtw

    def full(self):
        t0 = self.a
        t1 = self.b
        n0 = len(t0)
        n1 = len(t1)
        C = np.zeros((n0 + 1, n1 + 1))
        P = np.zeros((n0 + 1, n1 + 1, 2), dtype=int)
        C[1:, 0] = float('inf')
        C[0, 1:] = float('inf')
        for i in np.arange(n0) + 1:
            for j in np.arange(n1) + 1:
                prev = np.argmin([C[i, j - 1], C[i - 1, j - 1], C[i - 1, j]])
                if prev == 0:
                    prev_i = i
                    prev_j = j - 1
                elif prev == 1:
                    prev_i = i - 1
                    prev_j = j - 1
                else:
                    prev_i = i - 1
                    prev_j = j
                C[i, j] = self.dist_matrix[i - 1, j - 1] + C[prev_i, prev_j]
                P[i, j] = [prev_i, prev_j]
        dtw = C[n0, n1]
        pt = (n0, n1)
        path:list = [(pt[0]-1,pt[1]-1)]
        while pt != (1,1):
            pt = tuple(P[pt[0], pt[1]])
            path.append((pt[0]-1,pt[1]-1))
        path.reverse()
        return dtw, path, C, P


# print(STDTW(np.asarray([[1], [2], [4], [5], [1], [2]]), np.asarray([[2], [5], [6], [7], [3], [4]]), euclidean).cost)
# print(STDTW(np.asarray([[1], [2], [4], [5]]), np.asarray([[2], [5], [6], [7]]), euclidean).add(np.asarray([[1], [2]]),
#                                                                                                np.asarray([[3], [4]])))
a=np.asarray([[10],
 [ 8],
 [11],
 [ 0],
 [13],
 [ 0],
 [15]])
b=np.asarray([[10],
 [14],
 [16],
 [ 6],
 [16],
 [12],
 [ 9]])

truth = STDTW(a[2:],b[2:],euclidean)
our = STDTW(a[:5],b[:5],euclidean)
our.move(2,2,a_add=a[5:],b_add=b[5:])
print(truth.cost,our.cost)
print(truth.C)
print(our.C)

30.0 30.0
[[ 0. inf inf inf inf inf]
 [inf  5. 10. 15. 16. 18.]
 [inf 21. 11. 26. 27. 25.]
 [inf 24. 18. 14. 15. 19.]
 [inf 40. 24. 30. 26. 24.]
 [inf 41. 33. 25. 28. 30.]]
[[ 0. inf inf inf inf inf]
 [inf  5. 10. 15. 16. 18.]
 [inf 21. 11. 26. 27. 25.]
 [inf 24. 18. 14. 15. 19.]
 [inf 40. 24. 30. 26. 24.]
 [inf 41. 33. 25. 28. 30.]]


In [129]:
a = np.random.randint(0,20,(7,1))
b = np.random.randint(0,20,(7,1))
%time truth = STDTW(a[2:],b[2:],euclidean).cost
our = STDTW(a[:5],b[:5],euclidean)
%time our.move(2,2,a[5:],b[5:])
assert truth==our.cost, f"{a}\n{b}"

CPU times: user 1.48 ms, sys: 52 µs, total: 1.53 ms
Wall time: 1.37 ms
CPU times: user 1.22 ms, sys: 0 ns, total: 1.22 ms
Wall time: 974 µs


In [148]:
from tqdm.autonotebook import tqdm
import time
from fastdtw import dtw
# a=[[10],
#  [ 8],
#  [11],
#  [ 0],
#  [13],
#  [ 0],
#  [15]]
# b=[[10],
#  [14],
#  [16],
#  [ 6],
#  [16],
#  [12],
#  [ 9]]
loose = 0
for _ in tqdm(range(1000)):
    a = np.random.randint(0,200,(100,1))
    b = np.random.randint(0,200,(100,1))
    t1 = time.perf_counter_ns()
    _ = STDTW(a[:60],b[:60],euclidean)
    truth = STDTW(a[2:],b[2:],euclidean)
    t2 = time.perf_counter_ns()
    our = STDTW(a[:60],b[:60],euclidean)
    o = our.move(2,2,a_add=a[60:],b_add=b[60:])
    t3 = time.perf_counter_ns()
    if (t3-t2>t2-t1):
        loose+=1
        print(t3-t2,t2-t1)
    assert truth.cost==o
print(loose)

  0%|          | 0/1000 [00:00<?, ?it/s]

113446791 104256268
110918937 102869149
106461157 100819411
212827477 103809441
112315832 108655457
127458183 102581240
121888142 104385359
123207313 104722000
104830233 103469173
130621399 110627691
110285711 108438611
131606694 111163705
116174280 108037810
141930731 121760451
142230251 117026987
119617549 115723435
135789618 120023890
114271214 112452088
119552487 113102927
138587619 115919683
141235228 134489473
135804626 113083470
133003239 112445816
132179857 113731795
227965741 177158079
142124403 116898496
204011796 115540993
114665493 114582569
126227459 122999713
134512586 113904178
138108441 112407194
139769261 114353961
143927059 112572083
220399579 114193520
131858716 113149083
164354679 132613720
36
