<a href="https://colab.research.google.com/github/IonZhao/Frequency_GLM/blob/main/Frequency_GLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
import random
from collections import defaultdict

class CharNGramLanguageModel:
  def __init__(self, n, dataset):
    self.n = n
    # Our dataset, should be a list of text
    self.dataset = dataset
    # Our ngram model, a muti-dimension dictionary with default value 0 similar to dataframe
    self.model = defaultdict(lambda: defaultdict(int))
    # Train our model in initialization
    self.train()

  def train(self):
    for text in self.dataset:
      #Test preprocessing
      text = text.lower() + ('<')
      #From 0 to len-n
      for i in range(len(text) - self.n):
        # Get the ngram and next character
        ngram = text[i:i+self.n]
        next_char = text[i+self.n]
        self.model[ngram][next_char] += 1

        # Extension of the table， include character frequence and total count
        self.model[ngram]['total'] += 1
        self.model[0][next_char] += 1
        self.model[0]['total'] += 1

    # Convert counts to P
    for ngram in self.model:
      for char in self.model[ngram]:
        if char != 'total':
          self.model[ngram][char] = self.model[ngram][char] / self.model[ngram]['total']

      # Remove the total column to indicate correct prob
      del self.model[ngram]['total']

  def generate_character(self, prompt):
    # Ensure prompt is of length n
    prompt = prompt[-self.n:]

    if prompt in self.model:
      # prompt = "cat", items={"c":0.3,"a":0.6....}
      chars, probs = zip(*self.model[prompt].items())
      # Choose a character from probality
      next_char = random.choices(chars, probs)[0]
    # if prompt doesn't exist
    else:
      # random choice based on character frequencies
      if self.model[0]:
        chars, probs = zip(*self.model[0].items())
        next_char = random.choices(chars, probs)[0]
      # if model[0] is empty, no data.
      else:
        next_char = random.choice('abcdefghijklmnopqrstuvwxyz')  # or any default char
    return next_char

  def generate(self, prompt, max_length=100):
    result = prompt
    for x in range(max_length):
      next_char = self.generate_character(result)
      # End of sentence symbol
      if next_char == "<":
          break
      result += next_char
    #print(self.model)
    return result


In [7]:
def load_dataset(filename):
  with open(filename, 'r') as file:
    # Read each line and strip extra whitespace
    return [line.strip() for line in file.readlines()]

def main():
  # Load the dataset from the file
  dataset_filename = '/content/sample_data/dataset.txt'
  dataset = load_dataset(dataset_filename)

  # Create the model and train
  n = 3  # Example value for n
  model = CharNGramLanguageModel(n, dataset)

  # Get user input
  prompt = input("Enter a prompt: ")

  # Generate
  generated_text = model.generate(prompt)
  print("Generated text:", generated_text)

if __name__ == "__main__":
  main()

Enter a prompt: Hello
Generated text: Hello  t out you never cared at it
