<a href="https://colab.research.google.com/github/Anabel-l/Predict-Tense/blob/main/Predict_Tense.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#This cell sets up the model.
import re
import json
import sys
import collections
import os
import torch
import torch.nn as nn
from datetime import datetime


# single-direction RNN, optionally tied embeddings
class Emb_RNNLM(nn.Module):
    def __init__(self, params, use_LSTM=True):
        super(Emb_RNNLM, self).__init__()
        self.vocab_size = params['vocab_size']
        self.d_emb = params['d_emb']
        self.n_layers = params['num_layers']
        self.d_hid = params['d_hid']
        self.embeddings = nn.Embedding(self.vocab_size, self.d_emb)
        self.use_LSTM = use_LSTM
        if use_LSTM:
            print('Using LSTM model')
            self.i2R = nn.LSTM(self.d_emb, self.d_hid, batch_first=True, num_layers=self.n_layers) #input to recurrent layer, default nonlinearity is tanh
        else:
            # input to recurrent layer, default nonlinearity is tanh
            self.i2R = nn.RNN(
                self.d_emb, self.d_hid, batch_first=True, num_layers = self.n_layers
            )
        # recurrent to output layer
        self.R2o = nn.Linear(self.d_hid, self.vocab_size)

    def forward(self, train_datum):
        embs = torch.unsqueeze(self.embeddings(train_datum), 0)
        if self.use_LSTM:
            output, (hidden, context) = self.i2R(embs)
        else:
            output, hidden = self.i2R(embs)
        return self.R2o(output) #The second and third returned values are not used for training but for probing the model to see what it is encoding.


In [None]:
import re
import json
import sys
import collections
import os
import random
import torch
import torch.nn as nn
import numpy as np
import random

verbose = False
sentences = collections.defaultdict(lambda: [])
sentences_test = collections.defaultdict(lambda: [])
models = {}
book_title = 'phonemized_forms'
save_path = 'drive/MyDrive/Colab Notebooks/' + book_title + '_lstm_checkpoint.pth'
words = ['<s>', '<e>']
words_test = ['<s>', '<e>']
reg_words = []
reg_test_words = []
if os.path.isfile('drive/MyDrive/Colab Notebooks/'+book_title):
    print('Processing file', book_title)
    with open('drive/MyDrive/Colab Notebooks/'+book_title, 'r') as f0:
        for ir, line in enumerate(f0.readlines()):
            sentence_buffer = ['<s>']
            if ir % 100 == 0:
                print('Processed', ir, 'lines.')
            line = line.rstrip()
            if len(line) < 1:
                continue
            #if re.search(r'^[A-Z][A-Z][A-Z]', line):
            #    continue
            if line == 'eres-orth past-orth r 175 irregulars out of 2171 forms about 14%':
                continue
            rand_gen = np.random.randint(low=0, high=10)
            lal = line.split()
            for c in lal[3]:
              sentence_buffer.append(c) #Putting characters of pres tense into a "sentence buffer"
              if rand_gen >= 8:
                  if lal[3] not in words_test:
                    words_test.append(c)
              else:
                  if lal[3] not in words:
                    words.append(c) #80% of the characters of the present tense
                    #are going into words and the rest into words test

            i = 0
            for c in lal[4]:
              if  i == 0 and rand_gen >= 8:
                words_test.append('<m>')
                sentence_buffer.append('<m>')
                #At the start of going through past tense, if this is a test word, put <m> next into
                #sentence buffer
              sentence_buffer.append(c)
              #We then append characters to sent buf. But what if this is not a test item? Then there is no <m>
              #to separated present and past
              if rand_gen >= 8:
                  if lal[4] not in words_test:
                    words_test.append(c)
                    #We continue to add characters to test "words" if not seen yet
              else:
                  if lal[4] not in words:
                    words.append(c) #Same for non test words
              i = i + 1
            if rand_gen >= 8:
              reg_test_words.append(lal[2] == 'reg')
            else:
              reg_words.append(lal[2] == 'reg')
              #This seems to be appending the Boolean True or False to the "words"


            if verbose: print(sentence_buffer)
            #seperate into testing & training data
            if rand_gen >= 8:
              sentences_test[book_title].append(sentence_buffer + ['<e>'])
              #print("Test", sentence_buffer)
            else:
              sentences[book_title].append(sentence_buffer + ['<e>'])
              #print("Train", sentence_buffer)
else:
  print("Could not open file.")

print('sbt', sentences[book_title])
print('stbt', sentences_test[book_title])
print('words test', words_test)
print('words', words)
print('reg words', reg_words)
print('reg test words', reg_test_words)
#shuffle lists
random.shuffle(sentences[book_title])
random.shuffle(sentences_test[book_title])

wd2ix = {}
total_words = len(words)
print('total words', total_words)
for i, word in enumerate(words):
    #print('i', i)
    wd2ix[word] = i
    if verbose: print(word)
with open('drive/MyDrive/Colab Notebooks/' + book_title + 'wd2ix.json', 'w') as f3:
    json.dump(wd2ix, f3)
sentences_as_indices = [torch.LongTensor([wd2ix[w] for w in sent])
    for sent in sentences[book_title]
  ]
#training_data = torch.stack(sentences_as_indices, 0)

params = {'vocab_size': total_words, 'd_emb': 128, 'num_layers': 1, 'd_hid': 128, 'lr': 0.0003, 'epochs': 3}

models[book_title] = Emb_RNNLM(params)

Processing file phonemized_forms
Processed 0 lines.
Processed 100 lines.
Processed 200 lines.
Processed 300 lines.
Processed 400 lines.
Processed 500 lines.
Processed 600 lines.
Processed 700 lines.
Processed 800 lines.
Processed 900 lines.
Processed 1000 lines.
Processed 1100 lines.
Processed 1200 lines.
Processed 1300 lines.
Processed 1400 lines.
Processed 1500 lines.
Processed 1600 lines.
Processed 1700 lines.
Processed 1800 lines.
Processed 1900 lines.
Processed 2000 lines.
Processed 2100 lines.
sbt [['<s>', 'k', 'w', 'ɪ', 't', 'k', 'w', 'ɪ', 'ɾ', 'ᵻ', 'd', '<e>'], ['<s>', 't', 'ɹ', 'e', 'ɪ', 'd', 't', 'ɹ', 'e', 'ɪ', 'd', 'ᵻ', 'd', '<e>'], ['<s>', 'w', 'ɪ', 'ð', 's', 't', 'æ', 'n', 'd', 'w', 'ɪ', 'ð', 's', 't', 'ʊ', 'd', '<e>'], ['<s>', 'k', 'ɑ', 'm', 'p', 'a', 'ʊ', 'n', 'd', 'k', 'ɑ', 'm', 'p', 'a', 'ʊ', 'n', 'd', 'ᵻ', 'd', '<e>'], ['<s>', 'b', 'ʌ', 'd', 'ʒ', 'b', 'ʌ', 'd', 'ʒ', 'd', '<e>'], ['<s>', 's', 'e', 'ɪ', 'l', 's', 'e', 'ɪ', 'l', 'd', '<e>'], ['<s>', 'ɐ', 's', 'ɛ', 'n', '

In [None]:
print('cai', sentences_as_indices[0])

cai tensor([    0, 21330, 21162, 21331, 21316, 21330, 21162, 21331, 21316, 21331,
            1])


In [None]:
#This cell trains with the model and can be skipped if training has already been done and there is a checkpoint.
#Now train

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimiser = torch.optim.Adam(models[book_title].parameters(), lr=params['lr'])


for epoch in range(params['epochs']):
    ep_loss = 0
    random.shuffle(sentences_as_indices)
    for j, train_datum in enumerate(sentences_as_indices):
        if len(train_datum) < 4:
            continue
        #print('td', train_datum)
        preds = models[book_title](train_datum)
        #print('ps',preds.size())
        preds = preds[:, :-1, :].contiguous().view(-1, params['vocab_size'])
        #preds = preds[:, :-1, :]
        targets = torch.unsqueeze(train_datum, 0)
        targets = targets[:, 1:].contiguous().view(-1)

        #print(preds.size(), targets.size())
        loss = criterion(preds, targets)
        #print('loss', loss.detach())
        if torch.isnan(loss):
            print(train_datum)
        loss.backward()
        optimiser.step()
        optimiser.zero_grad()
        ep_loss += loss.detach()
        if j > 0 and j % 1000 == 0:
            print('processed', j, 'training examples')
    print('Saving checkpoint')
    torch.save({'net_state_dict': models[book_title].state_dict(),  'optimiser_state_dict': optimiser.state_dict()}, save_path)
    print('epoch', epoch, 'epoch loss', ep_loss / len(sentences_as_indices))



KeyboardInterrupt: ignored

In [None]:
import random
import numpy as np
import torch
import torch.nn as nn
import os
softmax = nn.Softmax(dim=-1)

book_title = 'phonemized_forms'
if os.path.exists(save_path):
    checkpoint = torch.load(save_path)
    print('Loading checkpoint')
    models[book_title].load_state_dict(checkpoint['net_state_dict'])
    models[book_title].eval()
else:
    print('No checkpoint found')
    exit()

choose_randomly = True

with torch.no_grad():
    accumulated_words = ['<s>']
    indices_of_accumulated_words = torch.LongTensor([wd2ix[w] for w in accumulated_words])
    next_wd = models[book_title](indices_of_accumulated_words)
    irreg_correct = 0
    irreg_total = 0
    correct = 0
    total = 0

    pred_wds = []
    real_wds = []
    current_wd = input("Enter seed word (enter <s> for random):")

    real_accumulated = ['<s>']
    for it, wd in enumerate(sentences_test[book_title]):
      state = 0
      accumulated_words = []
      real_accumulated = []
      for ch in wd:
        #jump to the test test data if it's the past tense
        if ch == '<m>':
          state = 1
        if ch == '<s>': continue
        #else add the character to the list
        if state == 0:
          accumulated_words.append(ch)
        if state == 1:
          if ch == '<m>': continue
          if ch == '<e>': continue
          real_accumulated.append(ch)

      real_wds.append(''.join(real_accumulated))
      original_wd = ''.join(accumulated_words)
      chosen_wd = ''
      pred_wd = []
      while not chosen_wd == '<e>':
        if not chosen_wd == '': pred_wd.append(chosen_wd)
        #print('ac1', accumulated_words)
        indices_of_accumulated_words = torch.LongTensor([wd2ix[w] for w in accumulated_words])
        next_wd = models[book_title](indices_of_accumulated_words)

        next_wd_as_array = next_wd[0,-1,:].numpy()
        best_index = np.argmax(next_wd_as_array)
        chosen_wd = words[best_index]
        accumulated_words.append(chosen_wd)
      pred_wd = original_wd + ''.join(pred_wd)
      pred_wds.append(pred_wd)

      corr = ''.join(real_accumulated)
      print(original_wd, ": ", pred_wd, end = ' ')

      if(corr == pred_wd):
        if not reg_test_words[it]: irreg_correct = irreg_correct + 1
        else: correct = correct + 1
        print("correctly predicted")
      else:
        print("incorrectly predicted for", corr)
      if reg_test_words[it]: total = total + 1
      else: irreg_total = irreg_total + 1

      #print(''.join(accumulated_words))
    print("Accuracy for regular: ", correct / total)
    print("Accuracy for irregular: ", irreg_correct / irreg_total)
    print('rw', real_wds)
    print('ac', accumulated_words)

Loading checkpoint
ɐvɛndʒ :  ɐvɛndʒd correctly predicted
kəlæbɚɹeɪt :  kəlæbɚɹeɪt incorrectly predicted for kəlæbɚɹeɪɾᵻd
ɛŋkloʊz :  ɛŋkloʊzd correctly predicted
slʌɡ :  slʌɡd correctly predicted
stɔl :  stɔld correctly predicted
ɐmaʊnt :  ɐmaʊntᵻd correctly predicted
ɐɡlɑmɚɹeɪt :  ɐɡlɑmɚɹeɪt incorrectly predicted for ɐɡlɑmɚɹeɪɾᵻd
sɪp :  sɪpt correctly predicted
spɛɹ :  spɛɹd correctly predicted
bɹid :  bɹid incorrectly predicted for bɹɛd
ʃʌn :  ʃʌnd correctly predicted
swɛl :  swɛld correctly predicted
aɪɾəmaɪz :  aɪɾəmaɪzd correctly predicted
ɹᵻsid :  ɹᵻsid incorrectly predicted for ɹᵻsidᵻd
skɹæmbəl :  skɹæmbəld correctly predicted
pleɪɡ :  pleɪɡɹᵻd incorrectly predicted for pleɪɡd
poʊstpoʊn :  poʊstpoʊnd correctly predicted
steɪbɪlaɪz :  steɪbɪlaɪzd correctly predicted
duplᵻkeɪt :  duplᵻkeɪt incorrectly predicted for duplᵻkeɪɾᵻd
ɹaɪz :  ɹaɪzd incorrectly predicted for ɹoʊz
pɹɑmpt :  pɹɑmpt incorrectly predicted for pɹɑmptᵻd
lɛvi :  lɛvid correctly predicted
dɪstɔɹt :  dɪstɔɹtɹᵻd inco

The irregular accuracy tends lower than the regular accuracy which makes sense in this case, given the irregularity of the forms. However, in this particular run, the accuracy was pretty close for both.