In [56]:
class Stalk:
    def __init__(self,stalk):
        self.stalk = stalk
        if len(self.stalk) == 0: self.stalk = '$'

    def __repr__(self): return self.stalk

    def __eq__(self,other): return self[0] == other[0]

    def __hash__(self): return hash(self[0])

    def __getitem__(self,index): return self.stalk[index]

    def __len__(self):
        if self.stalk == '$' or self.stalk == '^': return 0
        return len(self.stalk)
    
    def __str__(self):
        if self.stalk =='$': return ''
        else: return self.stalk
    
    def common_substring(self,other):
        i = 0
        substr = ''
        if type(other) == str: other = Stalk(other)
        while i < min(len(self),len(other)) and self[i] == other[i]:
            substr += self[i]
            i += 1
        return Stalk(substr), Stalk(self[i:]), Stalk(other[i:])

In [57]:
class Leaf:
    def __init__(self,left,right=''):
        if len(left) == 0:
            self.left = left
            self.right = 1
        else:
            self.left = left
            self.right = Leaf(right)
    
    def __repr__(self): return str(self.right)

    def __eq__(self,other): return self[0] == other[0]

    def __hash__(self): return hash(self.left)

    def __getitem__(self,index): return self.left[index]

    def __len__(self):
        if self.left == '$': return 0
        return len(self.left)#+len(self.right)

In [58]:
class Sequence:
    def __init__(self,seq=None,seen=None):
        self.extensions = {}
        if seq is None: 
            self.seq = ''
            self.seen = tuple()
        else: 
            self.seq = seq
            if seen is None: self.seen = (seq,)
            else: self.seen = seen

    def update(self,items):
        seen = list(self.seen) 
        items = list(items)
        if len(seen) and len(items):
            if items[0] in seen and seen[-1] in items and items[:items.find(seen[-1])] == seen[seen.find(items[0]):]:
                self.seen = tuple(seen + items[items.find(seen[-1]):])
                return
            elif seen[0] in items and items[-1] in seen and items[items.find(seen[0]):] == seen[:seen.find(items[-1])]:
                self.seen = tuple(items[:items.find(seen[0])] + seen )
                return
        self.seen = tuple(seen + items)

    def add(self,item): 
        if item not in self.seen: self.seen = tuple(list(self.seen) + [item]) 

    def __add__(self,other):
        self.seq += other.seq
        self.update(other.seen)
        for a in set(self.extensions).intersection(self.seen):
            self.extensions.pop(a)
        for a in set(other.extensions).intersection(self.seen):
            other.extensions.pop(a)
        self.extensions.update(other.extensions)
        return self
    
    def rollback(self,sequence):
        self.seen = list(self.seen)
        self.seen.pop()
        self.seen = tuple(self.seen)
        self.seq = self.seq[:self.seq.rfind(self.seen[-1])+ len(self.seen[-1])]
        i = 0
        while i <= min(len(sequence.seq),len(self.seq)) and self.seq[:i+1] in sequence.seq: i+=1
        sequence.seq = sequence.seq[:-i]
        sequence = sequence+self

In [59]:
class Branch:
    def __init__(self):
        self.b = {}
        self.s = {}

    def __repr__(self): return repr(self.b)

    def __str__(self):
        s = '' 
        for i in range(len(list(self.b.values()))-1):s+=str(list(self.b.values())[i])+'\n'
        return s+str(list(self.b.values())[-1])
    
    def __getitem__(self,index):
        if type(index) == str: return self.b[Stalk(index)]
        return self.b[index]
    
    def __is_shallow__(self):
        for a in self.b.values():
            if type(a) == Branch: return False
        return True
    
    def __traverse__(self,context):
        b = self[context[0]]
        s = self.s[context[0]]
        context = context[len(s[0]):]
        while len(context) > 0 and len(b) > 1:
            s = b.s[context[0]]
            b = b[context[0]]
            context = context[len(s[0]):]
        return b
    
    def __setitem__(self,index,value):
        if type(index) == str: self.b[Stalk(index)] = value
        else: self.b[index] = value

    def __contains__(self,other): 
        if type(other) == str: return Stalk(other) in self.b
        return other in self.b

    def __len__(self): return len(self.b)

    def pop(self,index): return self.b.pop(index)

    def pop_copy(self,index):
        b = self.b.copy()
        b.pop(index)
        s = self.s.copy()
        s.pop(index)
        br = Branch()
        br.b = b
        br.s = s
        return br

    def add(self,stalk,reads):
        if stalk in self:
            if not len(stalk):
                self[stalk].right+=1
                return
            if type(self[stalk]) == Leaf:
                branch = Branch()
                l1 = self.pop(stalk)
                stalk_ = list(self.s.pop(stalk))
                stalk_[0],l1.left,l2 = stalk_[0].common_substring(stalk)
                branch.add(l1.left,stalk_[1].copy())
                branch.add(l2,reads)
                stalk_[1].update(reads)
                stalk_ = tuple(stalk_)
                self[stalk_[0]] = branch
                self.s[stalk_[0]] = stalk_
            else:
                stalk_ = list(self.s.pop(stalk))
                branch = self.pop(stalk)
                stalk_[0],bstalk,stalk = stalk_[0].common_substring(stalk)
                if len(bstalk):
                    br = Branch()
                    br[bstalk] = branch 
                    br.s[bstalk] = (bstalk,stalk_[1].copy())
                    br.add(stalk,reads)
                    self[stalk_[0]] = br
                else: 
                    branch.add(stalk,reads)
                stalk_[1].update(reads)
                stalk_ = tuple(stalk_)
                if not len(bstalk): self[stalk_[0]] = branch
                self.s[stalk_[0]] = stalk_
        else:
            if type(stalk) == str: stalk = Stalk(stalk)
            self.s[stalk] = (stalk,reads)
            self[stalk] = Leaf(stalk)

    def unpack(self,exclude,context,t=None):
        if t is None: t = ''
        extensions = {}
        for s in self.s:
            if type(self[s]) is Branch:
                for b in self[s].b.values():
                    if type(b) is Branch: 
                        extensions.update(b.unpack(exclude,context,t=s))
                        continue
                    if len(self[s].s[b.left][1].intersection(exclude)): continue
                    extensions[list(self[s].s[b.left][1])[0]] = context, Sequence(list(self[s].s[b.left][1])[0]), Sequence(str(t) + str(s.stalk)+ str(b.left),(list(self[s].s[b.left][1])[0],))
                    t = ''
            else:
                if len(self.s[s][1].intersection(exclude)): continue
                extensions[list(self.s[s][1])[0]] = context, Sequence(list(self.s[s][1])[0]), Sequence(s.stalk,(list(self.s[s][1])[0],))
                t = ''
        return extensions

In [60]:
class Sequitur:
    def __init__(self,reads):
        self.branch = Branch()
        self.reads = reads
        for read in reads: 
            for i in range(len(read)): self.branch.add(Stalk(read[i:]),{read})
        self.sequence = Sequence()
    
    def extend(self,sequence,prefix):
        if len(prefix.seq) == 0: prefix = sequence
        extensions = list(sequence.extensions.values())
        for extension in extensions:
            if sequence.seq.endswith(extension[1].seq[:extension[1].seq.rfind(extension[0])+len(extension[0])]):
                self.sequitur(extension[1],prefix+extension[2])
                if prefix.seq not in self.sequence.seq: self.sequence = prefix
                return
        prefix.rollback(sequence)

    def sequitur(self,sequence,prefix=None):
        if prefix is None: prefix = Sequence()
        self.sequence = sequence
        i = 1
        context = self.sequence.seq[-i:]
        options = set()
        while len(set(self.reads).difference(self.sequence.seen)):
            branch = self.branch.__traverse__(context)
            if branch.__is_shallow__() or i == len(self.sequence.seq):
                stalks = branch.s.copy()
                stalks.pop('$')
                [options.update(s[1]) for s in stalks.values()]
                if len(options.difference(set(self.sequence.seen).union(prefix.seen))):
                    for s in stalks.values():
                        for x in s[1].difference(set(self.sequence.seen).union(prefix.seen)):
                            self.sequence.extensions.update({x:(context,Sequence(x),Sequence(s[0].stalk,(x,)))})
                else:
                    i -= 1
                    context = self.sequence.seq[-i:]
                    branch = self.branch.__traverse__(context).pop_copy('$')
                    self.sequence.extensions.update(branch.unpack(set(self.sequence.seen).union(prefix.seen),context))
                if not len(self.sequence.extensions): break
                self.extend(self.sequence,prefix)
                if len(prefix.seen) and prefix.seen[0] in self.sequence.seen:
                    if prefix == self.sequence: return
                    prefix = self.sequence
                i = 0
            i += 1
            context = self.sequence.seq[-i:]

    def assemble(self,start=None):
        pass

In [61]:
reads = ['betty_bought_butter_th','tter_the_butter_was_','he_butter_was_bitter_','as_bitter_betty_bought','tty_bought_better_butter_t','r_butter_to_make_the_','ke_the_bitter_butter_better']
sequitur = Sequitur(reads)

In [62]:
sequitur.sequitur(Sequence('betty_bought_butter_th'))
sequitur.sequence.seq

'betty_bought_butter_the_butter_was_bitter_betty_bought_better_butter_to_make_the_bitter_butter_better'

In [63]:
sequitur.sequitur(Sequence('tter_the_butter_was_'))
sequitur.sequence.seq

'betty_bought_butter_the_butter_was_bitter_betty_bought_better_butter_to_make_the_bitter_butter_better'

In [145]:
sequitur.sequitur(Sequence('he_butter_was_bitter_'))
sequitur.sequence.seq

True