# Edit Distance Algorithm

One of the ways to measure "alignment" or "similarity" of two sequences in statistical methods.

In [7]:
def edit_distance(seq1: str, seq2: str, cost=[1, 1, 1]):
    """
    cost = [insertion cost, deletion cost, substitution cost]
    """
    m, n = len(seq1), len(seq2)
    
    # Create a grid with dimensions (m+1) x (n+1)
    grid = [[0 for _ in range(n + 1)] for _ in range(m + 1)]
    
    # Initialize the first row and column with costs
    for i in range(m + 1):
        grid[i][0] = i * cost[1]  # Deletion cost for the first column
    for j in range(n + 1):
        grid[0][j] = j * cost[0]  # Insertion cost for the first row
    
    # Fill the grid
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if seq1[i - 1] == seq2[j - 1]:  # Match, no cost
                grid[i][j] = grid[i - 1][j - 1]
            else:
                grid[i][j] = min(
                    grid[i][j - 1] + cost[0],  # Insertion
                    grid[i - 1][j] + cost[1],  # Deletion
                    grid[i - 1][j - 1] + cost[2]  # Substitution
                )
    
    # Return the edit distance
    return grid[m][n]

In [27]:
def run_edit_distance_tests():
    def test_edit_distance():
        # Test cases: (seq1, seq2, cost, expected_output)
        test_cases = [
            ("kitten", "sitting", [1, 1, 1], 3),  # Standard Levenshtein distance
            ("kitten", "sitting", [1, 1, 2], 5),  # Higher insertion and deletion costs
            ("flaw", "lawn", [2, 1, 1], 3),      # Substitution cost is higher
            ("", "abc", [2, 1, 1], 6),           # Insertions with higher cost
            ("abc", "", [1, 2, 1], 6),           # Deletions with higher cost
            ("abc", "adc", [1, 1, 2], 2),        # Substitution cost dominates
            ("abc", "xyz", [1, 1, 2], 6),        # Three substitutions with custom cost
            ("", "", [1, 1, 1], 0),              # Both strings empty
            ("intention", "execution", [1, 1, 1], 5),  # Complex case with equal costs
        ]

        for i, (seq1, seq2, cost, expected) in enumerate(test_cases):
            result = edit_distance(seq1, seq2, cost)
            if result == expected:
                print(f"Test case {i + 1} passed!")
            else:
                print(f"Test case {i + 1} failed: seq1='{seq1}', seq2='{seq2}', cost={cost}, expected={expected}, got={result}")

    test_edit_distance()

In [28]:
run_edit_distance_tests()

Test case 1 passed!
Test case 2 passed!
Test case 3 passed!
Test case 4 passed!
Test case 5 passed!
Test case 6 passed!
Test case 7 passed!
Test case 8 passed!
Test case 9 passed!
