<a href="https://colab.research.google.com/github/KasMasVan/Codes-and-Notes-for-Some-Public-Coureses/blob/main/SPL3/chapter3/unigram.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from collections import Counter
from random import choices
class corpus():
  def __init__(self, corpus):
    self.word_list = []
    self.bigram_counter = {}
    for sent in corpus:
      words = sent.split() #the only prepocessing method I use for now.
      self.word_list += words
      #count bigram
      for i in range(len(words) - 1):
        if (words[i], words[i + 1]) not in self.bigram_counter:
          self.bigram_counter[(words[i], words[i + 1])] = 1
        else:
          self.bigram_counter[(words[i], words[i + 1])] += 1
    self.unigram_counter = Counter(self.word_list)

  def count_unigram(self):
    unigram_prob = {}
    denominator = sum(self.unigram_counter.values())
    for key in self.unigram_counter:
      unigram_prob[key] = self.unigram_counter[key] / denominator
    return unigram_prob
  
  def count_bigram(self):
    bigram_prob = {}
    for prefix in self.unigram_counter:
      relative_dict = {}
      for next_word in self.unigram_counter:
        if(prefix, next_word) in self.bigram_counter:
          relative_dict[(prefix, next_word)] = self.bigram_counter[(prefix, next_word)]
      denominator = sum(relative_dict.values())
      for key in relative_dict:
        relative_dict[key] /= denominator
      #merge two dict
      bigram_prob = {**bigram_prob, **relative_dict}
    return bigram_prob

  def generate(self, max_len=10):
    start = '<s>'
    end = '</s>'
    sent = [start]
    while sent[-1] != end and len(sent) < max_len:
      prefix = sent[-1]
      #calculate relative frequency
      relative_dict = {}
      for key in self.unigram_counter:
        if (prefix, key) in self.bigram_counter:
          relative_dict[(prefix, key)] = self.bigram_counter[(prefix, key)]    
      denominator = sum(relative_dict.values())
      for key in relative_dict:
        relative_dict[key] /= denominator
      #generate a sample
      next_bigram = choices(list(relative_dict.keys()), list(relative_dict.values()))
      #update status
      sent.append(next_bigram[0][-1])
    return sent if sent[-1] == end else sent + [end]
  
  def compute_ppl(self, text):
    text = text.split()
    bigram_prob = self.count_bigram()
    ppl = 1
    for i in range(len(text) - 2):
      w1, w2 = text[i], text[i+1]
      ppl *= bigram_prob[(w1, w2)]
    return pow(ppl, -1/(len(text) - 2))      

In [None]:
test_corpus = ['<s> I am Sam </s>',
      '<s> Sam I am </s>',
      '<s> I am Sam </s>',
      '<s> I do not like green eggs and Sam </s>']

my_corpus = corpus(test_corpus) 

In [None]:
my_corpus.count_unigram()

{'</s>': 0.16,
 '<s>': 0.16,
 'I': 0.16,
 'Sam': 0.16,
 'am': 0.12,
 'and': 0.04,
 'do': 0.04,
 'eggs': 0.04,
 'green': 0.04,
 'like': 0.04,
 'not': 0.04}

In [None]:
my_corpus.count_bigram()

{('<s>', 'I'): 0.75,
 ('<s>', 'Sam'): 0.25,
 ('I', 'am'): 0.75,
 ('I', 'do'): 0.25,
 ('Sam', '</s>'): 0.75,
 ('Sam', 'I'): 0.25,
 ('am', '</s>'): 0.3333333333333333,
 ('am', 'Sam'): 0.6666666666666666,
 ('and', 'Sam'): 1.0,
 ('do', 'not'): 1.0,
 ('eggs', 'and'): 1.0,
 ('green', 'eggs'): 1.0,
 ('like', 'green'): 1.0,
 ('not', 'like'): 1.0}

In [None]:
my_corpus.generate()

['<s>', 'Sam', '</s>']

In [None]:
my_corpus.compute_ppl('<s> I am Sam </s>')

1.3867225487012693