In [1]:
import random
from collections import defaultdict
from typing import List, Tuple, Dict, Optional

#Given a tuple containing a '.', replace any tokens before '.' into None
def replace_before_point(tup: tuple) -> tuple:
    new_tup = []
    passed_point_flag = False
    for p in list(tup):
        if p == '.':
            passed_point_flag = True
            new_tup.append(p)
        else:
            if passed_point_flag:
                new_tup.append(p)
            else:
                new_tup.append(None)       #Anything before the point is turned into None 
    out_tup = tuple(new_tup)
    return out_tup

class MarkovBabbler:
    def __init__(self, order: int = 2, point_blind: bool = True):
        self.n = order
        self.model: Dict[Tuple[str, ...], List[str]] = defaultdict(list)
        self.point_blind = point_blind
    
    def train(self, tokens: List[str]) -> None:
        #Trains the model resulting in a dict of [Tuple(n), List]
        if len(tokens) < self.n + 1:
            raise ValueError(f"Token list must contain at least {self.n + 1} tokens.")
        
        for i in range(len(tokens) - self.n):
            prefix = tuple(tokens[i:i + self.n])            # n-gram (key)
            if self.point_blind and '.' in prefix:          # If model is blinded, checks if possible tuple contains a point
                prefix = replace_before_point(prefix)       # Replace all tokens before '.' with None
            next_token = tokens[i + self.n]                 # the token that follows (value)
            self.model[prefix].append(next_token)

     
    def generate(self, max_tokens: int = 50, seed: Optional[Tuple[str, ...]] = None) -> str:
        print('generating...')
        if not self.model:
            raise ValueError("Model is empty. Did you forget to call train()?")

        if seed:
            if len(seed) != self.n:
                raise ValueError(f"Seed must be a list of {self.n} tokens.")
            current_prefix = tuple(seed)
            if current_prefix not in self.model:
                raise ValueError("Seed not found in model.")
        else:
            current_prefix = random.choice(list(self.model.keys()))
            print(f'No seed given. "{current_prefix}" n-gram chosen to start with.')

        generated = list(current_prefix)

        for _ in range(max_tokens - self.n):
            #Transforms anything before the point into None (if model is blinded)
            if self.point_blind and '.' in current_prefix:
                current_prefix = replace_before_point(current_prefix)

            next_tokens = self.model.get(current_prefix)
        
            if not next_tokens:
                print("Warning: Generation stopped before reaching max tokens.")
                break

            next_token = random.choice(next_tokens)
            generated.append(next_token)
            current_prefix = tuple(generated[-self.n:])

        generated_str = ''
        for w in generated:
            if w in [',','.','?','!',':',';',')','...']:
                generated_str += w
            else:
                if len(generated_str) > 0 and generated_str[-1] in ['¿','¡','(']:
                    generated_str = generated_str + w
                else:
                    generated_str = generated_str + ' ' + w

        return generated_str

['/home/gabosh/projects/corpuslab/corpuslab', '/home/gabosh/anaconda3/lib/python38.zip', '/home/gabosh/anaconda3/lib/python3.8', '/home/gabosh/anaconda3/lib/python3.8/lib-dynload', '', '/home/gabosh/.local/lib/python3.8/site-packages', '/home/gabosh/anaconda3/lib/python3.8/site-packages', '/home/gabosh/anaconda3/lib/python3.8/site-packages/locket-0.2.1-py3.8.egg', '/home/gabosh/anaconda3/lib/python3.8/site-packages/IPython/extensions', '/home/gabosh/.ipython']
