In [1]:
# -*- coding: utf-8 -*-

class Dtw(object):
    
    def __init__(self, seq1, seq2,
                 patterns = [(-1,-1), (-1,0), (0,-1)], 
                 weights = [{(0,0):2}, {(0,0):1}, {(0,0):1}], 
                 band_r=0.5):
        self._seq1 = seq1
        self._seq2 = seq2
        self._r = min(10, band_r*max(len(seq1), len(seq2)))
        assert len(patterns) == len(weights)
        self._patterns = patterns
        self._weights = weights
        self._map = {(-1, -1): 0}
    
    def get_distance(self, idx1, idx2):
        if idx1<0 or idx2<0 or idx1>=len(self._seq1) or idx2>=len(self._seq2):
            return 0
        return abs(self._seq1[idx1] - self._seq2[idx2])

    def calculate_path(self, idx1, idx2, pattern, weight):
        g = self.calculate(idx1+pattern[0], idx2+pattern[1])
        sum_d = 0
        for coor_offset, w in weight.items():
            i1, i2 = map(sum, zip((idx1, idx2), coor_offset))
            sum_d += self.get_distance(i1, i2)
        return g + sum_d

    def calculate(self, idx1, idx2):
        if (idx1, idx2) in self._map:
            return self._map[(idx1, idx2)]
        if idx1 < 0 or idx2 < 0 or abs(idx1-idx2) > self._r:
            return float('inf')
        min_prev_dp = float('inf')
        for i in range(len(self._patterns)):
            min_prev_dp = min(min_prev_dp, self.calculate_path(idx1, idx2, self._patterns[i], self._weights[i]))
        self._map[(idx1, idx2)] = min_prev_dp
        return self._map[(idx1, idx2)]
    
    @property
    def dtw_matrix_dict(self):
        return self._map

    def print_dtw_matrix(self):
        print('      '+' '.join(["{:^7d}".format(i) for i in range(len(self._seq2))]))
        for i in range(len(self._seq1)):
            str = "{:^4d}: ".format(i)
            for j in range(len(self._seq2)):
                if (i,j) not in self._map:
                    str += "{:^7s} ".format('-')
                else:
                    str += "{:^7.3f} ".format(self._map[(i,j)])
            print (str)
    
    def get_dtw(self):
        g = self.calculate(len(self._seq1)-1, len(self._seq2)-1)
        N = len(self._seq1) + len(self._seq2)
        return g/N

# Main

### Different patterns

In [14]:
PATTERNS_1 = [(0,-1), (-1,-1), (-1,0)]
WEIGHTS_SYM_1 = [{(0,0):1}, {(0,0):2}, {(0,0):1}] 
WEIGHTS_ASYM_1 = [{}, {(0,0):1}, {(0,0):1}] 

PATTERNS_2 = [(-1,-3), (-1,-2), (-1,-1), (-2,-1), (-3,-1)]
WEIGHTS_SYM_2 = [{(0,-2):2, (0,-1):1, (0,0):1}, \
                 {(0,-1):2, (0,0):1}, \
                 {(0,0):2}, \
                 {(-1,0):2, (0,0):1}, \
                 {(-2,0):2, (-1,1):1, (0,0):1}] 
WEIGHTS_ASYM_2 = [{(0,-2):1, (0,-1):1, (0,0):1/3}, \
                  {(0,-1):1, (0,0):1/2}, \
                  {(0,0):1}, \
                  {(-1,0):1, (0,0):1}, \
                  {(-2,0):1, (-1,1):1, (0,0):1}] 

### Initiation

In [10]:
import numpy as np
seq1 = [1, 1, 2, 9]*2
seq2 = [0, 1, 1, 2]*2

### Z-Normalization

In [11]:
seq1 = (np.array(seq1)-np.mean(seq1))/np.std(seq1)
seq2 = (np.array(seq2)-np.mean(seq2))/np.std(seq2)

### Calculate DTW

#### Symmetric Pattern 1
g(i, j) = min( g(i,j-1)+d(i,j),  g(i-1,j-1)+2d(i,j),  g(i-1,j)+d(i,j) )

In [5]:
d = Dtw(seq1, seq2, PATTERNS_1, WEIGHTS_SYM_1)
d.get_dtw()

0.26160228246317752

In [6]:
d.print_dtw_matrix()

         0       1       2       3       4       5       6       7   
 0  :  0.742   1.414   2.087   4.174   4.915     -       -       -    
 1  :  1.483   1.414   2.087   4.174   4.915   5.588     -       -    
 2  :  2.524   1.788   1.788   3.576   4.616   4.990   5.364     -    
 3  :  5.657   3.507   3.507   2.093   5.226   6.335   6.709   5.669  
 4  :  6.398   4.180   4.180   4.180   2.834   3.507   4.180   6.267  
 5  :    -     4.852   4.852   6.267   3.576   3.507   4.180   6.267  
 6  :    -       -     5.226   6.640   4.616   3.881   3.881   5.669  
 7  :    -       -       -     5.531   7.750   5.600   5.600   4.186  


#### Asymmetric Pattern 2
g(i, j) = min( g(i,j-1),  g(i-1,j-1)+d(i,j),  g(i-1,j)+d(i,j) )

In [7]:
d = Dtw(seq1, seq2, PATTERNS_1, WEIGHTS_ASYM_1)
d.get_dtw()

0.25730038274945188

In [8]:
d.print_dtw_matrix()

         0       1       2       3       4       5       6       7   
 0  :  0.742   0.742   0.742   0.742   0.742     -       -       -    
 1  :  1.483   1.414   1.414   1.414   1.414   1.414     -       -    
 2  :  2.524   1.788   1.788   1.788   1.788   1.788   1.788     -    
 3  :  5.657   3.507   3.507   2.093   2.093   2.093   2.093   2.093  
 4  :  6.398   4.180   4.180   4.180   2.834   2.766   2.766   2.766  
 5  :    -     4.852   4.852   4.852   3.576   3.438   3.438   3.438  
 6  :    -       -     5.226   5.226   4.616   3.812   3.812   3.812  
 7  :    -       -       -     5.531   5.531   5.531   5.531   4.117  


#### Symmetric Pattern 2

In [9]:
d = Dtw(seq1, seq2, PATTERNS_2, WEIGHTS_SYM_2)
d.get_dtw()

0.26160228246317752

In [10]:
d.print_dtw_matrix()

         0       1       2       3       4       5       6       7   
 0  :  0.742   1.414   2.087    inf     inf      -       -       -    
 1  :  1.483   1.414   2.087   4.174   4.915   5.588     -       -    
 2  :  2.455   1.788   1.788   3.576   4.616   5.289     -       -    
 3  :   inf    3.507   3.507   2.093   5.226   6.335     -       -    
 4  :   inf    4.249   2.766   4.180   2.834   3.507   4.180     -    
 5  :    -     5.519   4.852   4.852   3.576   3.507   4.180     -    
 6  :    -       -       -       -     4.548   3.881   3.881     -    
 7  :    -       -       -       -       -       -       -     4.186  


#### Asymmetric Pattern 2

In [11]:
d = Dtw(seq1, seq2, PATTERNS_2, WEIGHTS_ASYM_2)
d.get_dtw()

0.26160228246317752

In [12]:
d.print_dtw_matrix()

         0       1       2       3       4       5       6       7   
 0  :  0.742   1.414   2.087    inf     inf      -       -       -    
 1  :  1.483   1.414   2.087   4.174   4.915   5.588     -       -    
 2  :  2.455   1.788   1.788   3.576   4.616   5.289     -       -    
 3  :   inf    3.507   3.507   2.093   5.226   6.335     -       -    
 4  :   inf    4.249   2.766   4.180   2.834   3.507   4.180     -    
 5  :    -     5.519   4.852   4.852   3.576   3.507   4.180     -    
 6  :    -       -       -       -     4.548   3.881   3.881     -    
 7  :    -       -       -       -       -       -       -     4.186  


# Implementation2

In [35]:
# -*- coding: utf-8 -*-

class Dtw(object):
    
    def __init__(self, seq1, seq2,
                 patterns = [(-1,-1), (-1,0), (0,-1)], 
                 weights = [{(0,0):2}, {(0,0):1}, {(0,0):1}], 
                 band_r=5):
        self._seq1 = seq1
        self._seq2 = seq2
        self.len_seq1 = len(seq1)
        self.len_seq2 = len(seq2)
        self.len_pattern = len(patterns)
        self.sum_w = [sum(ws.values()) for ws in weights]
        self._r = band_r
        assert len(patterns) == len(weights)
        self._patterns = patterns
        self._weights = weights
    
    def get_distance(self, i1, i2):
        return abs(self._seq1[i1] - self._seq2[i2])

    def calculate(self):
        g = list([float('inf')]*self.len_seq2 for i in range(self.len_seq1))
        cost = list([0]*self.len_seq2 for i in range(self.len_seq1))

        g[0][0] = 2*self.get_distance(0, 0)
        for i in range(self.len_seq1):
            for j in range(max(0,i-self._r), min(i+self._r+1, self.len_seq2)):
                for pat_i in range(self.len_pattern):
                    coor = (i+self._patterns[pat_i][0], j+self._patterns[pat_i][1])
                    if coor[0]<0 or coor[1]<0:
                        continue
                    dist = 0
                    for w_coor_offset, d_w in self._weights[pat_i].items():
                        w_coor = (i+w_coor_offset[0], j+w_coor_offset[1])
                        dist += self.get_distance(w_coor[0], w_coor[1])
                    this_val = g[coor[0]][coor[1]] + dist
                    this_cost = cost[coor[0]][coor[1]] + self.sum_w[pat_i]
                    if this_val < g[i][j]:
                        g[i][j] = this_val
                        cost[i][j] = this_cost
        return g[self.len_seq1-1][self.len_seq2-1]/cost[self.len_seq1-1][self.len_seq2-1], g, cost
    
    def print_table(self, tb):
        print('      '+' '.join(["{:^7d}".format(i) for i in range(self.len_seq2)]))
        for i in range(self.len_seq1):
            str = "{:^4d}: ".format(i)
            for j in range(self.len_seq2):
                str += "{:^7.3f} ".format(tb[i][j])
            print (str)

    def print_g_matrix(self):
        _, tb, _ = self.calculate()
        self.print_table(tb)

    def print_cost_matrix(self):
        _, _, tb = self.calculate()
        self.print_table(tb)
        
    def get_dtw(self):
        ans, _, _ = self.calculate()
        return ans

In [48]:
d = Dtw(seq1, seq2, PATTERNS_1, WEIGHTS_ASYM_1)
d.get_dtw()

0.69404687184108327

In [49]:
d.print_g_matrix()

         0       1       2       3       4       5       6       7   
 0  :  1.483   1.483   1.483   1.483   1.483   1.483    inf     inf   
 1  :  2.225   2.156   2.156   2.156   2.156   2.156   2.156    inf   
 2  :  3.265   2.529   2.529   2.529   2.529   2.529   2.529   2.529  
 3  :  6.398   4.249   4.249   2.834   2.834   2.834   2.834   2.834  
 4  :  7.140   4.921   4.921   4.921   3.576   3.507   3.507   3.507  
 5  :  7.881   5.594   5.594   5.594   4.317   4.180   4.180   4.180  
 6  :   inf    5.968   5.968   5.968   5.358   4.553   4.553   4.553  
 7  :   inf     inf    7.687   6.273   6.273   6.273   6.273   4.858  


In [50]:
d.print_cost_matrix()

         0       1       2       3       4       5       6       7   
 0  :  0.000   0.000   0.000   0.000   0.000   0.000   0.000   0.000  
 1  :  1.000   1.000   1.000   1.000   1.000   1.000   1.000   0.000  
 2  :  2.000   2.000   2.000   2.000   2.000   2.000   2.000   2.000  
 3  :  3.000   3.000   3.000   3.000   3.000   3.000   3.000   3.000  
 4  :  4.000   4.000   4.000   4.000   4.000   4.000   4.000   4.000  
 5  :  5.000   5.000   5.000   5.000   5.000   5.000   5.000   5.000  
 6  :  0.000   6.000   6.000   6.000   6.000   6.000   6.000   6.000  
 7  :  0.000   0.000   7.000   7.000   7.000   7.000   7.000   7.000  
