In [181]:
# -*- coding: utf-8 -*-

class Dtw(object):
    
    def __init__(self, seq1, seq2,
                 patterns = [(-1,-1), (-1,0), (0,-1)], 
                 weights = [{(0,0):2}, {(0,0):1}, {(0,0):1}], 
                 band_r=0.5):
        self._seq1 = seq1
        self._seq2 = seq2
        self._r = min(10, band_r*max(len(seq1), len(seq2)))
        assert len(patterns) == len(weights)
        self._patterns = patterns
        self._weights = weights
        self._map = {(0, 0): 2*self.get_distance(0,0)}
    
    def get_distance(self, idx1, idx2):
        if idx1<0 or idx2<0 or idx1>=len(self._seq1) or idx2>=len(self._seq2):
            return 0
        return abs(self._seq1[idx1] - self._seq2[idx2])

    def calculate_path(self, idx1, idx2, pattern, weight):
        g = self.calculate(idx1+pattern[0], idx2+pattern[1])
        sum_d = 0
        for coor_offset, w in weight.items():
            i1, i2 = map(sum, zip((idx1, idx2), coor_offset))
            sum_d += self.get_distance(i1, i2)
        return g + sum_d

    def calculate(self, idx1, idx2):
        if (idx1, idx2) in self._map:
            return self._map[(idx1, idx2)]
        if idx1 < 0 or idx2 < 0 or abs(idx1-idx2) > self._r:
            return float('inf')
        min_prev_dp = float('inf')
        for i in range(len(self._patterns)):
            min_prev_dp = min(min_prev_dp, self.calculate_path(idx1, idx2, self._patterns[i], self._weights[i]))
        self._map[(idx1, idx2)] = min_prev_dp
        return self._map[(idx1, idx2)]
    
    @property
    def dtw_matrix_dict(self):
        return self._map

    def print_dtw_matrix(self):
        print('      '+' '.join(["{:^7d}".format(i) for i in range(len(self._seq2))]))
        for i in range(len(self._seq1)):
            str = "{:^4d}: ".format(i)
            for j in range(len(self._seq2)):
                if (i,j) not in self._map:
                    str += "{:^7s} ".format('-')
                else:
                    str += "{:^7.3f} ".format(self._map[(i,j)])
            print (str)
    
    def get_dtw(self):
        g = self.calculate(len(self._seq1)-1, len(self._seq2)-1)
        N = len(self._seq1) + len(self._seq2)
        return g/N

# Main

### Different patterns

In [182]:
PATTERNS_1 = [(0,-1), (-1,-1), (-1,0)]
WEIGHTS_SYM_1 = [{(0,0):1}, {(0,0):2}, {(0,0):1}] 
WEIGHTS_ASYM_1 = [{}, {(0,0):1}, {(0,0):1}] 

PATTERNS_2 = [(-1,-3), (-1,-2), (-1,-1), (-2,-1), (-3,-1)]
WEIGHTS_SYM_2 = [{(0,-2):2, (0,-1):1, (0,0):1}, \
                 {(0,-1):2, (0,0):1}, \
                 {(0,0):2}, \
                 {(-1,0):2, (0,0):1}, \
                 {(-2,0):2, (-1,1):1, (0,0):1}] 
WEIGHTS_ASYM_2 = [{(0,-2):1, (0,-1):1, (0,0):1/3}, \
                  {(0,-1):1, (0,0):1/2}, \
                  {(0,0):1}, \
                  {(-1,0):1, (0,0):1}, \
                  {(-2,0):1, (-1,1):1, (0,0):1}] 

### Initiation

In [183]:
import numpy as np
seq1 = [1, 1, 2, 9]*2
seq2 = [0, 1, 1, 2]*2

### Z-Normalization

In [184]:
seq1 = (np.array(seq1)-np.mean(seq1))/np.std(seq1)
seq2 = (np.array(seq2)-np.mean(seq2))/np.std(seq2)

### Calculate DTW

#### Symmetric Pattern 1
g(i, j) = min( g(i,j-1)+d(i,j),  g(i-1,j-1)+2d(i,j),  g(i-1,j)+d(i,j) )

In [185]:
d = Dtw(seq1, seq2, PATTERNS_1, WEIGHTS_SYM_1)
d.get_dtw()

0.30794740614419958

In [186]:
d.print_dtw_matrix()

         0       1       2       3       4       5       6       7   
 0  :  1.483   2.156   2.828   4.915   5.657     -       -       -    
 1  :  2.225   2.156   2.828   4.915   5.657   6.330     -       -    
 2  :  3.265   2.529   2.529   4.317   5.358   5.732   6.105     -    
 3  :  6.398   4.249   4.249   2.834   5.968   7.077   7.451   6.410  
 4  :  7.140   4.921   4.921   4.921   3.576   4.249   4.921   7.008  
 5  :    -     5.594   5.594   7.008   4.317   4.249   4.921   7.008  
 6  :    -       -     5.968   7.382   5.358   4.622   4.622   6.410  
 7  :    -       -       -     6.273   8.491   6.341   6.341   4.927  


#### Asymmetric Pattern 2
g(i, j) = min( g(i,j-1),  g(i-1,j-1)+d(i,j),  g(i-1,j)+d(i,j) )

In [187]:
d = Dtw(seq1, seq2, PATTERNS_1, WEIGHTS_ASYM_1)
d.get_dtw()

0.30364550643047394

In [188]:
d.print_dtw_matrix()

         0       1       2       3       4       5       6       7   
 0  :  1.483   1.483   1.483   1.483   1.483     -       -       -    
 1  :  2.225   2.156   2.156   2.156   2.156   2.156     -       -    
 2  :  3.265   2.529   2.529   2.529   2.529   2.529   2.529     -    
 3  :  6.398   4.249   4.249   2.834   2.834   2.834   2.834   2.834  
 4  :  7.140   4.921   4.921   4.921   3.576   3.507   3.507   3.507  
 5  :    -     5.594   5.594   5.594   4.317   4.180   4.180   4.180  
 6  :    -       -     5.968   5.968   5.358   4.553   4.553   4.553  
 7  :    -       -       -     6.273   6.273   6.273   6.273   4.858  


#### Symmetric Pattern 2

In [189]:
d = Dtw(seq1, seq2, PATTERNS_2, WEIGHTS_SYM_2)
d.get_dtw()

0.30794740614419958

In [190]:
d.print_dtw_matrix()

         0       1       2       3       4       5       6       7   
 0  :  1.483    inf     inf     inf     inf      -       -       -    
 1  :   inf    2.156   2.828   4.915    inf     inf      -       -    
 2  :   inf    2.529   2.529   4.317   5.358   6.031     -       -    
 3  :   inf    4.249   4.249   2.834   5.968   7.077     -       -    
 4  :   inf     inf    3.507   4.921   3.576   4.249   4.921     -    
 5  :    -      inf    5.594   5.594   4.317   4.249   4.921     -    
 6  :    -       -       -       -     5.289   4.622   4.622     -    
 7  :    -       -       -       -       -       -       -     4.927  


#### Asymmetric Pattern 2

In [191]:
d = Dtw(seq1, seq2, PATTERNS_2, WEIGHTS_ASYM_2)
d.get_dtw()

0.30794740614419958

In [192]:
d.print_dtw_matrix()

         0       1       2       3       4       5       6       7   
 0  :  1.483    inf     inf     inf     inf      -       -       -    
 1  :   inf    2.156   2.828   4.915    inf     inf      -       -    
 2  :   inf    2.529   2.529   4.317   5.358   6.031     -       -    
 3  :   inf    4.249   4.249   2.834   5.968   7.077     -       -    
 4  :   inf     inf    3.507   4.921   3.576   4.249   4.921     -    
 5  :    -      inf    5.594   5.594   4.317   4.249   4.921     -    
 6  :    -       -       -       -     5.289   4.622   4.622     -    
 7  :    -       -       -       -       -       -       -     4.927  
