<a href="https://colab.research.google.com/github/JamieBali/MRSCC/blob/main/NGramTextGeneration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# N-Gram Text Generation

N-grams are a way of performing lexical analysis and are a good resource to use when having a machine generate text.

This colab notebook contains a simple n-gram generation system that is fully dynamic.

In [None]:
# imports

import os, random, math, nltk, operator
from nltk import word_tokenize as tokenize

from google.colab import drive
drive.mount('/content/gdrive')

nltk.download('punkt')

In [None]:
# Global Functions

TRAINING_DIR = "/content/gdrive/My Drive/ColabNotebooks/ANLE/Resources/Holmes_Training_Data"

###
#
# This function gathers the data from a given training directory.
#
###
def get_training_testing(training_dir=TRAINING_DIR, split=0.5):
  filenames=os.listdir(training_dir)
  n=len(filenames)
  print("There are {} files in the training directory: {}".format(n,training_dir))
  random.shuffle(filenames)
  index=int(n*split)
  return(filenames[:index],filenames[index:])

In [None]:
# Setup

trainingfiles,heldoutfiles = get_training_testing()

MAX_FILES=25
mylm=language_model(files=trainingfiles[:MAX_FILES])

In [None]:
# Language Model

class language_model():
  def __init__(self,trainingdir=TRAINING_DIR,files=[],n=2):
    self.n = n
    self.training_dir=trainingdir
    self.files=files
    self.previousWords = [""]*self.n
    self.train()
  
  def train(self):
    self.ngram = {}
    self._processfiles()
    self._convert_to_probs()

  def _processline(self,line):
    tokens=["_START"]+tokenize(line)+["_END"]

    # unigrams get built seperately, just because it's easier.

    if modelType == 1:
      for tokenU in tokens:
        token = tokenU.lower()
        self.ngram[token]=self.ngram.get(token,0)+1

    # ngrams (currently trigram. needs to be made dynamic.)

    # for a given depth, open incremental dictionaries from the given tags.

    else: 
      if self.trigramWords[0] == "" and self.trigramWords[1] == "":
        self.trigramWords[0] = self.trigramWords[1]
        self.trigramWords[1] = token
      else:
        if not self.trigramWords[0] in self.trigram:
          self.trigram[self.trigramWords[0]] = {}
        if not self.trigramWords[1] in self.trigram[self.trigramWords[0]]:
          self.trigram[self.trigramWords[0]][self.trigramWords[1]] = {}
        self.trigram[self.trigramWords[0]][self.trigramWords[1]][token] = self.trigram[self.trigramWords[0]][self.trigramWords[1]].get(token, 0) + 1
      self.trigramWords[0] = self.trigramWords[1]
      self.trigramWords[1] = token


  
  def _processfiles(self):
    for afile in self.files:
      print("Processing {}".format(afile))
      try:
        with open(os.path.join(self.training_dir,afile)) as instream:
          for line in instream:
            line=line.rstrip()
            if len(line)>0:
              self._processline(line)
      except UnicodeDecodeError:
        print("UnicodeDecodeError processing {}: ignoring file".format(afile))

  def _convert_to_probs(self):
    self.unigram={k:v/sum(self.unigram.values()) for (k,v) in self.unigram.items()}

  def get_prob(self,token,method="unigram"):
    if method=="unigram":
      return self.unigram.get(token,0)
    if method=="bigram":
      token0 = token[0]
      token1 = token[1]
      return self.bigram[token0][token1]
    if method=="trigram":
      return self.trigram[token[0]][token[1]][token[2]]
    else:
      print("Not implemented: {}".format(method))
    return 0

  def normalise_bigram(self):
   
    # currently trigram, needs to be made dynamic

    for x in self.trigram.items():
      for y in self.trigram[x[0]].items():
        sum = 0
        for z in self.trigram[x[0]][y[0]]:
          sum += self.trigram[x[0]][y[0]][z]
        for z in self.trigram[x[0]][y[0]]:
          self.trigram[x[0]][y[0]][z] = self.trigram[x[0]][y[0]][z]/sum

  def get_top_k_unigrams(self, k):
    return sorted(self.unigram.items(), key=lambda items:items[1], reverse=True)[2:k+2]

  def generate_unigram(self, max_length=10, k=1000, use_prob = False):
    if use_prob:
      wordList = []
      probList = []
      for x in self.unigram.items():
        wordList.append(x[0])
        probList.append(x[1])
      itt = 0
      generated = ""
      while itt < max_length:
        word = random.choices(wordList, probList)[0]
        if word == "." or word == "?" or word == "!":
          itt = max_length
          generated += word
        elif word == "_end" or word == "_start":
          itt = itt
        else:
          generated += word
          generated += " "
          itt += 1
    else:  
      top_k = self.get_top_k_unigrams(k)
      itt = 0
      generated = ""
      while itt < max_length:
        word = random.choice(top_k)[0]

        generated += word
        if word == "." or word == "?" or word == "!":
          itt = max_length + 1
        else:
          generated += " "
          itt += 1
    return generated

  def generate_bigram(self, max_length=100, starting_key="_start"):
    self.last = starting_key
    generated = ""
    itt = 0
    if not starting_key == "_start":
      generated = starting_key + " "
    while itt < max_length:
      _dict = self.bigram[self.last]
      wordList = []
      probList = []
      for x in _dict.items():
        wordList.append(x[0])
        probList.append(x[1])
      word = random.choices(wordList, probList)[0]
      if word == "." or word == "?" or word == "!":
        itt = max_length
        generated += word
      elif word == "_end" or word == "_start":
        itt = itt
      else:
        generated += word
        generated += " "
        itt += 1
      self.last = word
    return generated

  def generate_trigram(self, max_length=100, key1 = "_start", key2 = "the"):
    self.key1 = key1
    self.key2 = key2
    generated = ""
    if self.key1 == "_random":
      self.key1 = random.choice(list(self.trigram.items()))[0]
    if self.key2 == "_random":
      self.key2 = random.choice(list(self.trigram[self.key1].items()))[0]
    if not self.key1 == "_start":
      generated = self.key1 + " "
    generated += self.key2 + " "
    itt = 0
    while itt < max_length:
      _dict = self.trigram[self.key1][self.key2]
      wordList = []
      probList = []
      for x in _dict.items():
        wordList.append(x[0])
        probList.append(x[1])
      word = random.choices(wordList, probList)[0]
      if word == "." or word == "?" or word == "!":
        itt = max_length
        generated += word
      elif word == "_end" or word == "_start":
        itt = itt
      else:
        generated += word
        generated += " "
        itt += 1
      self.key1 = self.key2
      self.key2 = word
    return generated

  