In [1]:
import copy
import math
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np
from random import shuffle
from heapq import heappop, heappush, heapify
from random import randint
import time
from IPython.display import HTML
from PIL import Image, ImageDraw
from IPython.display import Image as Img
from sys import float_info
EPS = float_info.epsilon
%matplotlib notebook
%load_ext autoreload
%autoreload 2

### Search Node Representation

In [148]:
class Node:
    def __init__(self, coords, g = 0, h = 0, F = None, parent = None):
        self.coords = coords # tuple of coords
        self.g = g
        self.h = h
        if F is None:
            self.F = self.g + h
        else:
            self.F = F        
        self.parent = parent
        self.k = 0
    
    
    def __eq__(self, other):
        #return (self.i == other.i) and (self.j == other.j) and (self.g == other.g)
        return self.coords == other.coords
    
    def __lt__(self, other):
        return self.F < other.F or ((self.F == other.F) and (self.h < other.h)) \
               or ((self.F == other.F) and (self.h == other.h) and (self.k > other.k))
    
    def __hash__(self):
        # return hash((self.i, self.j, self.g)) 
        return hash(self.coords)

In [52]:
import heapq
import itertools

REMOVED = Node([-1,-1], -1)


class Open:
    def __init__(self):
        self.elements = []
        self.count = itertools.count()
        self.entry_finder = {}
    
    def __iter__(self):
        for entry in self.elements:
            if entry[2] is not REMOVED:
                  yield entry[2]
                    
    def __len__(self):
        return len(self.elements)

    def isEmpty(self):
        return (len(self.elements) == 0)               
    
    def AddNode(self, node : Node, *args):
        if (node.coords, node.g) in self.entry_finder:
            if self.entry_finder[(node.coords, node.g)][2].F < node.F:
                return
            t = self.entry_finder.pop((node.coords, node.g))
            t[2] = REMOVED
        entry = [node.F, next(self.count), node]
        self.entry_finder[(node.coords, node.g)] = entry
        heapq.heappush(self.elements, entry)

    def GetBestNode(self, *args):
        while self.elements:
            entry = heapq.heappop(self.elements)
            if entry[2] is not REMOVED:
                del self.entry_finder[(entry[2].coords, entry[2].g)]
                return entry[2]

In [53]:
class Closed:
    def __init__(self):
        self.elements = set()

    def __iter__(self):
        return iter(self.elements)
    

    def __len__(self):
        return len(self.elements)

    def AddNode(self, item : Node):
        self.elements.add(item)

    def WasExpanded(self, item : Node):
        return (item in self.elements)

In [2]:
import utils

pam250 = utils.load_pam250_matrix()

In [5]:
utils.get_alignment('ACGH','CFG', pam250, True);

[[ 0.  8. 16. 24.]
 [ 8.  2. 10. 15.]
 [16. -4.  4. 12.]
 [24.  4.  1. -1.]
 [32. 12.  6.  3.]]
[[' ' 'l' 'l' 'l']
 ['u' 'd' 'l' 'd']
 ['u' 'd' 'l' 'l']
 ['u' 'u' 'd' 'd']
 ['u' 'u' 'd' 'd']]


In [23]:
a = 'abc'
a[::-1]

'cba'

In [7]:
utils.get_alignment('HGCA', 'GFC', pam250, True)

[[ 0.  8. 16. 24.]
 [ 8.  2. 10. 18.]
 [16.  3.  7. 13.]
 [24. 11.  7. -5.]
 [32. 19. 14.  3.]]
[[' ' 'l' 'l' 'l']
 ['u' 'd' 'd' 'l']
 ['u' 'd' 'd' 'd']
 ['u' 'u' 'd' 'd']
 ['u' 'u' 'd' 'u']]


(['H', 'G', 'C', 'A'], ['G', 'F', 'C', '_'])

# Heuristic function

In [17]:
a = (1, 2, 3)
b = (1, 2, 3)
for x in zip(a, b):
    print(x)

(1, 1)
(2, 2)
(3, 3)


In [27]:
distances2d = utils.get_distances2d(['ACGH', 'CFG', 'EAC'], pam250)

In [58]:
# Похоже, что считает правильно
def hfunc(coords, distances2d):
    h = 0
    for i, coord1 in enumerate(coords):
        for j, coord2 in enumerate(coords):
            if i <= j:
                continue
            h += distances2d[(i, j)][-coord1 - 1, -coord2 - 1]
    return h
            
hfunc((3, 3, 3), distances2d)            

16.0

In [260]:
get_neighbors((0, 0, 0, 0), seqs)

array([[0, 0, 0, 1],
       [0, 0, 1, 0],
       [0, 0, 1, 1],
       [0, 1, 0, 0],
       [0, 1, 0, 1],
       [0, 1, 1, 0],
       [0, 1, 1, 1],
       [1, 0, 0, 0],
       [1, 0, 0, 1],
       [1, 0, 1, 0],
       [1, 0, 1, 1],
       [1, 1, 0, 0],
       [1, 1, 0, 1],
       [1, 1, 1, 0],
       [1, 1, 1, 1]])

In [213]:
def get_neighbors(coords, seqs):
    num_coords = len(coords)
    t = np.arange(1, 2 ** num_coords, dtype=np.uint8)
    t = np.unpackbits(t).reshape(-1, 8)
    max_coords = list(map(len, seqs))
    t = coords + t[:, -num_coords:]
    return t[np.all(t <= max_coords, axis=1)]

In [113]:
def compute_cost(coords1, coords2, seqs, pam250):
    d = 0
    for i, _ in enumerate(coords2):
        for j, _ in enumerate(coords2):
            if i <= j:
                continue
            if coords1[i] == coords2[i] and coords1[j] == coords2[j]:
                continue
            if coords1[i] == coords2[i] or coords1[j] == coords2[j]:
                d += 8
            else:
                d += pam250[seqs[i][coords1[i]], seqs[j][coords1[j]]]
    return d

In [215]:
def astar(seqs):
    OPEN = Open()
    CLOSED = Closed()
    distances2d = utils.get_distances2d(seqs, pam250)
    
    num_seqs = len(seqs)
    start_coords = tuple(0 for i in range(num_seqs))
    end_coords = tuple(len(seq) for seq in seqs)

    
    
    start_node = Node(start_coords, 0, hfunc(start_coords, distances2d))
    end_node = Node(end_coords)
    OPEN.AddNode(start_node)
    
    while not OPEN.isEmpty():
        s = OPEN.GetBestNode(CLOSED)
        CLOSED.AddNode(s)
        if s == end_node:
            return (True, s, CLOSED, OPEN)
        neighbors = get_neighbors(s.coords, seqs)
        for neighbor in neighbors:
            neighbor = tuple(neighbor)
            if not CLOSED.WasExpanded(Node(neighbor)):
                node = Node(neighbor, s.g + compute_cost(s.coords, neighbor, seqs, pam250),
                            hfunc(neighbor, distances2d))
                node.parent = s
                OPEN.AddNode(node)
        

In [314]:
seqs = ['ACGHT', 'GACGH', 'T']
_, s, _, _ = astar(seqs)

In [315]:
def make_alignment(goal, seqs):
    seq_lengths = list(map(len, seqs))
    max_length = np.max(seq_lengths)
    alignment = np.empty((len(seqs), max_length + 1), dtype=str)
    current = goal
    c1 = goal.coords
    t = 0
    while current.parent:
        c2 = current.parent.coords
        for i in range(len(seqs)):
            if c1[i] != c2[i]:
                alignment[i][-t - 1] = seqs[i][seq_lengths[i] - 1]
                seq_lengths[i] -= 1
            else:
                alignment[i][-t - 1] = '_'
        current = current.parent
        c1 = current.coords
        t += 1
    return alignment    

In [317]:
t = make_alignment(s, seqs)

In [280]:
make_alignment(s, seqs)

array([['_', 'A', 'C', 'G', 'H', 'T'],
       ['G', 'A', 'C', 'G', 'H', '_'],
       ['_', 'T', '_', '_', '_', '_']], dtype='<U1')

In [307]:
seqs = ['ACGH', 'CFG']
_, s, _, _ = astar(seqs)
t = make_alignment(s, seqs)

In [337]:
t1 = np.array([['_', 'A', 'C', 'G', 'H', 'T'],
               ['G', 'A', 'C', 'G', 'H', '_'],
               ['_', 'T', '_', '_', '_', '_']])
t2 = np.array([['_', 'A', 'C', 'G', 'H', 'T'],
               ['G', 'A', 'C', 'G', 'H', '_'],
               ['_', '_', '_', '_', '_', 'T']])

In [338]:
calc_alignment_score(t1, pam250), calc_alignment_score(t2, pam250)

(53, 68)

In [332]:
pam250

{('A', 'A'): -2,
 ('A', 'R'): 2,
 ('A', 'N'): 0,
 ('A', 'D'): 0,
 ('A', 'C'): 2,
 ('A', 'Q'): 0,
 ('A', 'E'): 0,
 ('A', 'G'): -1,
 ('A', 'H'): 1,
 ('A', 'I'): 1,
 ('A', 'L'): 2,
 ('A', 'K'): 1,
 ('A', 'M'): 1,
 ('A', 'F'): 3,
 ('A', 'P'): -1,
 ('A', 'S'): -1,
 ('A', 'T'): -1,
 ('A', 'W'): 6,
 ('A', 'Y'): 3,
 ('A', 'V'): 0,
 ('A', 'B'): 0,
 ('A', 'J'): 1,
 ('A', 'Z'): 0,
 ('A', 'X'): 1,
 ('A', '*'): 8,
 ('R', 'A'): 2,
 ('R', 'R'): -6,
 ('R', 'N'): 0,
 ('R', 'D'): 1,
 ('R', 'C'): 4,
 ('R', 'Q'): -1,
 ('R', 'E'): 1,
 ('R', 'G'): 3,
 ('R', 'H'): -2,
 ('R', 'I'): 2,
 ('R', 'L'): 3,
 ('R', 'K'): -3,
 ('R', 'M'): 0,
 ('R', 'F'): 4,
 ('R', 'P'): 0,
 ('R', 'S'): 0,
 ('R', 'T'): 1,
 ('R', 'W'): -2,
 ('R', 'Y'): 4,
 ('R', 'V'): 2,
 ('R', 'B'): 1,
 ('R', 'J'): 3,
 ('R', 'Z'): 0,
 ('R', 'X'): 1,
 ('R', '*'): 8,
 ('N', 'A'): 0,
 ('N', 'R'): 0,
 ('N', 'N'): -2,
 ('N', 'D'): -2,
 ('N', 'C'): 4,
 ('N', 'Q'): -1,
 ('N', 'E'): -1,
 ('N', 'G'): 0,
 ('N', 'H'): -2,
 ('N', 'I'): 2,
 ('N', 'L'): 3,
 ('N', 'K