## HMM

In [1]:
import torch
import numpy as np

class HMM(torch.nn.Module):
  """
  Hidden Markov Model with discrete observations.
  """
  def __init__(self, M, N):
    super(HMM, self).__init__()
    self.M = M # number of possible observations
    self.N = N # number of states

    # A
    self.transition_model = TransitionModel(self.N)

    # b(x_t)
    self.emission_model = EmissionModel(self.N,self.M)

    # pi
    self.unnormalized_state_priors = torch.nn.Parameter(torch.randn(self.N))

    # use the GPU
    self.is_cuda = torch.cuda.is_available()
    if self.is_cuda: self.cuda()

class TransitionModel(torch.nn.Module):
  def __init__(self, N):
    super(TransitionModel, self).__init__()
    self.N = N
    self.unnormalized_transition_matrix = torch.nn.Parameter(torch.randn(N,N))

class EmissionModel(torch.nn.Module):
  def __init__(self, N, M):
    super(EmissionModel, self).__init__()
    self.N = N
    self.M = M
    self.unnormalized_emission_matrix = torch.nn.Parameter(torch.randn(N,M))


In [2]:
def sample(self, T=10):
  state_priors = torch.nn.functional.softmax(self.unnormalized_state_priors, dim=0)
  transition_matrix = torch.nn.functional.softmax(self.transition_model.unnormalized_transition_matrix, dim=0)
  emission_matrix = torch.nn.functional.softmax(self.emission_model.unnormalized_emission_matrix, dim=1)

  # sample initial state
  z_t = torch.distributions.categorical.Categorical(state_priors).sample().item()
  z = []; x = []
  z.append(z_t)
  for t in range(0,T):
    # sample emission
    x_t = torch.distributions.categorical.Categorical(emission_matrix[z_t]).sample().item()
    x.append(x_t)

    # sample transition
    z_t = torch.distributions.categorical.Categorical(transition_matrix[:,z_t]).sample().item()
    if t < T-1: z.append(z_t)

  return x, z

# Add the sampling method to our HMM class
HMM.sample = sample


In [3]:
import string
alphabet = string.ascii_lowercase

def encode(s):
  """
  Convert a string into a list of integers
  """
  x = [alphabet.index(ss) for ss in s]
  return x

def decode(x):
  """
  Convert list of ints to string
  """
  s = "".join([alphabet[xx] for xx in x])
  return s

# Initialize the model
model = HMM(M=len(alphabet), N=2)

# Hard-wiring the parameters!
# Let state 0 = consonant, state 1 = vowel
for p in model.parameters():
    p.requires_grad = False # needed to do lines below
model.unnormalized_state_priors[0] = 0.    # Let's start with a consonant more frequently
model.unnormalized_state_priors[1] = -0.5
print("State priors:", torch.nn.functional.softmax(model.unnormalized_state_priors, dim=0))

# In state 0, only allow consonants; in state 1, only allow vowels
vowel_indices = torch.tensor([alphabet.index(letter) for letter in "aeiou"])
consonant_indices = torch.tensor([alphabet.index(letter) for letter in "bcdfghjklmnpqrstvwxyz"])
model.emission_model.unnormalized_emission_matrix[0, vowel_indices] = -np.inf
model.emission_model.unnormalized_emission_matrix[1, consonant_indices] = -np.inf
print("Emission matrix:", torch.nn.functional.softmax(model.emission_model.unnormalized_emission_matrix, dim=1))

# Only allow vowel -> consonant and consonant -> vowel
model.transition_model.unnormalized_transition_matrix[0,0] = -np.inf  # consonant -> consonant
model.transition_model.unnormalized_transition_matrix[0,1] = 0.       # vowel -> consonant
model.transition_model.unnormalized_transition_matrix[1,0] = 0.       # consonant -> vowel
model.transition_model.unnormalized_transition_matrix[1,1] = -np.inf  # vowel -> vowel
print("Transition matrix:", torch.nn.functional.softmax(model.transition_model.unnormalized_transition_matrix, dim=0))


State priors: tensor([0.6225, 0.3775])
Emission matrix: tensor([[0.0000, 0.1184, 0.0164, 0.0428, 0.0000, 0.0585, 0.0393, 0.0118, 0.0000,
         0.0296, 0.0076, 0.0719, 0.0089, 0.0961, 0.0000, 0.0843, 0.0967, 0.0098,
         0.0163, 0.0377, 0.0000, 0.0815, 0.0080, 0.1007, 0.0109, 0.0527],
        [0.1295, 0.0000, 0.0000, 0.0000, 0.1358, 0.0000, 0.0000, 0.0000, 0.1018,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4401, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.1928, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
Transition matrix: tensor([[0., 1.],
        [1., 0.]])


In [4]:
# Sample some outputs
for _ in range(4):
  sampled_x, sampled_z = model.sample(T=5)
  print("x:", decode(sampled_x))
  print("z:", sampled_z)

x: elicu
z: [1, 0, 1, 0, 1]
x: pocob
z: [0, 1, 0, 1, 0]
x: fuxaq
z: [0, 1, 0, 1, 0]
x: gotox
z: [0, 1, 0, 1, 0]


In [5]:
def HMM_forward(self, x, T):
  """
  x : IntTensor of shape (batch size, T_max)
  T : IntTensor of shape (batch size)

  Compute log p(x) for each example in the batch.
  T = length of each example
  """
  if self.is_cuda:
  	x = x.cuda()
  	T = T.cuda()

  batch_size = x.shape[0]; T_max = x.shape[1]
  log_state_priors = torch.nn.functional.log_softmax(self.unnormalized_state_priors, dim=0)
  log_alpha = torch.zeros(batch_size, T_max, self.N)
  if self.is_cuda: log_alpha = log_alpha.cuda()

  log_alpha[:, 0, :] = self.emission_model(x[:,0]) + log_state_priors
  for t in range(1, T_max):
    log_alpha[:, t, :] = self.emission_model(x[:,t]) + self.transition_model(log_alpha[:, t-1, :])

  # Select the sum for the final timestep (each x may have different length).
  log_sums = log_alpha.logsumexp(dim=2)
  log_probs = torch.gather(log_sums, 1, T.view(-1,1) - 1)
  return log_probs

def emission_model_forward(self, x_t):
  log_emission_matrix = torch.nn.functional.log_softmax(self.unnormalized_emission_matrix, dim=1)
  out = log_emission_matrix[:, x_t].transpose(0,1)
  return out

def transition_model_forward(self, log_alpha):
  """
  log_alpha : Tensor of shape (batch size, N)
  Multiply previous timestep's alphas by transition matrix (in log domain)
  """
  log_transition_matrix = torch.nn.functional.log_softmax(self.unnormalized_transition_matrix, dim=0)

  # Matrix multiplication in the log domain
  out = log_domain_matmul(log_transition_matrix, log_alpha.transpose(0,1)).transpose(0,1)
  return out

def log_domain_matmul(log_A, log_B):
	"""
	log_A : m x n
	log_B : n x p
	output : m x p matrix

	Normally, a matrix multiplication
	computes out_{i,j} = sum_k A_{i,k} x B_{k,j}

	A log domain matrix multiplication
	computes out_{i,j} = logsumexp_k log_A_{i,k} + log_B_{k,j}
	"""
	m = log_A.shape[0]
	n = log_A.shape[1]
	p = log_B.shape[1]

	# log_A_expanded = torch.stack([log_A] * p, dim=2)
	# log_B_expanded = torch.stack([log_B] * m, dim=0)
    # fix for PyTorch > 1.5 by egaznep on Github:
	log_A_expanded = torch.reshape(log_A, (m,n,1))
	log_B_expanded = torch.reshape(log_B, (1,n,p))

	elementwise_sum = log_A_expanded + log_B_expanded
	out = torch.logsumexp(elementwise_sum, dim=1)

	return out

In [6]:
TransitionModel.forward = transition_model_forward
EmissionModel.forward = emission_model_forward
HMM.forward = HMM_forward

In [7]:
x = torch.stack( [torch.tensor(encode("cat"))] )
T = torch.tensor([3])
print(model.forward(x, T))

x = torch.stack( [torch.tensor(encode("aba")), torch.tensor(encode("abb"))] )
T = torch.tensor([3,3])
print(model.forward(x, T))



tensor([[-9.9049]])
tensor([[-7.1966],
        [   -inf]])


In [8]:
def viterbi(self, x, T):
  """
  x : IntTensor of shape (batch size, T_max)
  T : IntTensor of shape (batch size)
  Find argmax_z log p(x|z) for each (x) in the batch.
  """
  if self.is_cuda:
    x = x.cuda()
    T = T.cuda()

  batch_size = x.shape[0]; T_max = x.shape[1]
  log_state_priors = torch.nn.functional.log_softmax(self.unnormalized_state_priors, dim=0)
  log_delta = torch.zeros(batch_size, T_max, self.N).float()
  psi = torch.zeros(batch_size, T_max, self.N).long()
  if self.is_cuda:
    log_delta = log_delta.cuda()
    psi = psi.cuda()

  log_delta[:, 0, :] = self.emission_model(x[:,0]) + log_state_priors
  for t in range(1, T_max):
    max_val, argmax_val = self.transition_model.maxmul(log_delta[:, t-1, :])
    log_delta[:, t, :] = self.emission_model(x[:,t]) + max_val
    psi[:, t, :] = argmax_val

  # Get the log probability of the best path
  log_max = log_delta.max(dim=2)[0]
  best_path_scores = torch.gather(log_max, 1, T.view(-1,1) - 1)

  # This next part is a bit tricky to parallelize across the batch,
  # so we will do it separately for each example.
  z_star = []
  for i in range(0, batch_size):
    z_star_i = [ log_delta[i, T[i] - 1, :].max(dim=0)[1].item() ]
    for t in range(T[i] - 1, 0, -1):
      z_t = psi[i, t, z_star_i[0]].item()
      z_star_i.insert(0, z_t)

    z_star.append(z_star_i)

  return z_star, best_path_scores # return both the best path and its log probability

def transition_model_maxmul(self, log_alpha):
  log_transition_matrix = torch.nn.functional.log_softmax(self.unnormalized_transition_matrix, dim=0)

  out1, out2 = maxmul(log_transition_matrix, log_alpha.transpose(0,1))
  return out1.transpose(0,1), out2.transpose(0,1)

def maxmul(log_A, log_B):
	"""
	log_A : m x n
	log_B : n x p
	output : m x p matrix

	Similar to the log domain matrix multiplication,
	this computes out_{i,j} = max_k log_A_{i,k} + log_B_{k,j}
	"""
	m = log_A.shape[0]
	n = log_A.shape[1]
	p = log_B.shape[1]

	log_A_expanded = torch.stack([log_A] * p, dim=2)
	log_B_expanded = torch.stack([log_B] * m, dim=0)

	elementwise_sum = log_A_expanded + log_B_expanded
	out1,out2 = torch.max(elementwise_sum, dim=1)

	return out1,out2

TransitionModel.maxmul = transition_model_maxmul
HMM.viterbi = viterbi


In [9]:
x = torch.stack( [torch.tensor(encode("aba")), torch.tensor(encode("abb"))] )
T = torch.tensor([3,3])
print(model.viterbi(x, T))

([[1, 0, 1], [1, 0, 0]], tensor([[-7.1966],
        [   -inf]]))


In [10]:
print(model.forward(x, T))
print(model.viterbi(x, T)[1])

tensor([[-7.1966],
        [   -inf]])
tensor([[-7.1966],
        [   -inf]])


In [11]:
x = torch.tensor([1., 2., 3.])
print(x.max(dim=0)[0])
print(x.logsumexp(dim=0))

tensor(3.)
tensor(3.4076)


In [12]:
import torch.utils.data
from collections import Counter
from sklearn.model_selection import train_test_split

class TextDataset(torch.utils.data.Dataset):
  def __init__(self, lines):
    self.lines = lines # list of strings
    collate = Collate() # function for generating a minibatch from strings
    self.loader = torch.utils.data.DataLoader(self, batch_size=1024, num_workers=1, shuffle=True, collate_fn=collate)

  def __len__(self):
    return len(self.lines)

  def __getitem__(self, idx):
    line = self.lines[idx].lstrip(" ").rstrip("\n").rstrip(" ").rstrip("\n")
    return line

class Collate:
  def __init__(self):
    pass

  def __call__(self, batch):
    """
    Returns a minibatch of strings, padded to have the same length.
    """
    x = []
    batch_size = len(batch)
    for index in range(batch_size):
      x_ = batch[index]

      # convert letters to integers
      x.append(encode(x_))

    # pad all sequences with 0 to have same length
    x_lengths = [len(x_) for x_ in x]
    T = max(x_lengths)
    for index in range(batch_size):
      x[index] += [0] * (T - len(x[index]))
      x[index] = torch.tensor(x[index])

    # stack into single tensor
    x = torch.stack(x)
    x_lengths = torch.tensor(x_lengths)
    return (x,x_lengths)

In [13]:
!wget https://raw.githubusercontent.com/lorenlugosch/pytorch_HMM/master/data/train/training.txt
# If wget does not work, put the file in your current directory, or maybe use curl

filename = "training.txt"

with open(filename, "r") as f:
  lines = f.readlines() # each line of lines will have one word

alphabet = list(Counter(("".join(lines))).keys())
train_lines, valid_lines = train_test_split(lines, test_size=0.1, random_state=42)
train_dataset = TextDataset(train_lines)
valid_dataset = TextDataset(valid_lines)

M = len(alphabet)

--2023-11-05 18:03:43--  https://raw.githubusercontent.com/lorenlugosch/pytorch_HMM/master/data/train/training.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2493109 (2.4M) [text/plain]
Saving to: ‘training.txt’


2023-11-05 18:03:43 (31.8 MB/s) - ‘training.txt’ saved [2493109/2493109]



In [14]:
from tqdm import tqdm # for displaying progress bar

class Trainer:
  def __init__(self, model, lr):
    self.model = model
    self.lr = lr
    self.optimizer = torch.optim.Adam(model.parameters(), lr=self.lr, weight_decay=0.00001)

  def train(self, dataset):
    train_loss = 0
    num_samples = 0
    self.model.train()
    print_interval = 50
    for idx, batch in enumerate(tqdm(dataset.loader)):
      x,T = batch
      batch_size = len(x)
      num_samples += batch_size
      log_probs = self.model(x,T)
      loss = -log_probs.mean()
      self.optimizer.zero_grad()
      loss.backward()
      self.optimizer.step()
      train_loss += loss.cpu().data.numpy().item() * batch_size
      if idx % print_interval == 0:
        print("loss:", loss.item())
        for _ in range(5):
          sampled_x, sampled_z = self.model.sample()
          print(decode(sampled_x))
          print(sampled_z)
    train_loss /= num_samples
    return train_loss

  def test(self, dataset):
    test_loss = 0
    num_samples = 0
    self.model.eval()
    print_interval = 50
    for idx, batch in enumerate(dataset.loader):
      x,T = batch
      batch_size = len(x)
      num_samples += batch_size
      log_probs = self.model(x,T)
      loss = -log_probs.mean()
      test_loss += loss.cpu().data.numpy().item() * batch_size
      if idx % print_interval == 0:
        print("loss:", loss.item())
        sampled_x, sampled_z = self.model.sample()
        print(decode(sampled_x))
        print(sampled_z)
    test_loss /= num_samples
    return test_loss

In [15]:
# Initialize model
model = HMM(N=64, M=M)

# Train the model
num_epochs = 10
trainer = Trainer(model, lr=0.01)

for epoch in range(num_epochs):
        print("========= Epoch %d of %d =========" % (epoch+1, num_epochs))
        train_loss = trainer.train(train_dataset)
        valid_loss = trainer.test(valid_dataset)

        print("========= Results: epoch %d of %d =========" % (epoch+1, num_epochs))
        print("train loss: %.2f| valid loss: %.2f\n" % (train_loss, valid_loss) )



  0%|          | 1/208 [00:00<02:50,  1.21it/s]

loss: 38.897884368896484
KUqzE-RDXU
[62, 21, 12, 52, 51, 15, 53, 57, 37, 53]
upfM
cz

z
[5, 33, 13, 23, 28, 32, 43, 46, 2, 12]
AGItnOyFmX
[7, 61, 56, 61, 34, 5, 15, 50, 38, 56]
ArPYQwdKSS
[17, 19, 57, 53, 31, 54, 4, 19, 47, 48]
CrSh-VpKoU
[41, 37, 55, 49, 0, 56, 6, 31, 13, 52]


 25%|██▍       | 51/208 [00:32<01:40,  1.56it/s]

loss: 33.372100830078125
iCKqiyhbnu
[46, 26, 15, 45, 46, 55, 15, 53, 46, 36]
owkWXxNhZp
[56, 23, 31, 28, 41, 48, 48, 54, 30, 51]
wLpDnaVs
p
[33, 12, 5, 36, 55, 46, 26, 36, 31, 33]
lqTdPnukfe
[0, 43, 20, 4, 7, 33, 48, 37, 42, 12]
dtnmtylmec
[2, 50, 3, 59, 21, 55, 35, 59, 42, 34]


 49%|████▊     | 101/208 [01:04<01:07,  1.59it/s]

loss: 29.842239379882812
eMenfmXtip
[36, 9, 36, 6, 2, 59, 55, 22, 5, 5]
IlmfdwqTsy
[49, 49, 37, 42, 13, 5, 22, 42, 37, 56]
crLWzeoUss
[38, 27, 22, 48, 59, 12, 54, 14, 15, 0]
j-am-iYlJK
[9, 30, 35, 59, 47, 30, 24, 32, 43, 35]
zltcCJroPt
[24, 21, 56, 37, 42, 10, 17, 47, 30, 35]


 73%|███████▎  | 151/208 [01:35<00:35,  1.60it/s]

loss: 28.22951889038086
agugonDhst
[30, 57, 20, 3, 41, 34, 56, 35, 12, 40]
vumxrrZpie
[4, 42, 59, 18, 37, 42, 53, 37, 42, 34]
srYselEnrr
[59, 27, 43, 56, 27, 35, 42, 34, 5, 46]
nzpadiWryl
[43, 36, 5, 30, 50, 8, 48, 37, 42, 47]
milfjtnope
[37, 42, 43, 30, 54, 21, 56, 42, 5, 27]


 97%|█████████▋| 201/208 [02:08<00:05,  1.31it/s]

loss: 26.574800491333008
ohondiowle
[19, 35, 12, 55, 22, 21, 46, 30, 23, 58]
ONrlesteri
[11, 32, 18, 21, 46, 55, 22, 54, 21, 46]
ufbixmepbr
[1, 30, 37, 42, 12, 5, 12, 37, 47, 46]
xeuatymnXr
[59, 27, 21, 46, 15, 42, 34, 5, 12, 5]
altidaWisi
[59, 46, 15, 42, 31, 10, 27, 36, 56, 42]


100%|██████████| 208/208 [02:12<00:00,  1.57it/s]


loss: 26.63311767578125
axenlanimo
[30, 5, 36, 14, 15, 30, 56, 46, 5, 36]
train loss: 30.81| valid loss: 26.70



  0%|          | 1/208 [00:00<03:24,  1.01it/s]

loss: 26.969579696655273
Vidondaeil
[35, 46, 5, 36, 14, 5, 30, 37, 42, 47]
lcpenpnemu
[1, 34, 5, 12, 5, 46, 34, 12, 5, 20]
flaseejlli
[24, 21, 30, 19, 36, 55, 3, 47, 21, 46]
daGezyimxd
[5, 46, 5, 36, 57, 43, 12, 5, 12, 5]
rKachsleia
[19, 47, 30, 50, 8, 26, 17, 8, 27, 30]


 25%|██▍       | 51/208 [00:32<01:36,  1.62it/s]

loss: 25.53016471862793
fauralentf
[48, 30, 34, 5, 30, 47, 36, 55, 22, 21]
dejmhBserv
[5, 12, 56, 5, 8, 12, 5, 12, 34, 5]
taTangreio
[50, 8, 48, 30, 56, 22, 8, 26, 43, 0]
sitermamQe
[59, 27, 22, 12, 34, 5, 30, 5, 1, 27]
gatistoslo
[50, 30, 22, 42, 50, 15, 12, 34, 5, 12]


 49%|████▊     | 101/208 [01:04<01:11,  1.50it/s]

loss: 25.344242095947266
sstolomupc
[56, 27, 22, 46, 15, 12, 5, 27, 22, 37]
untMiesjns
[36, 55, 36, 33, 21, 12, 56, 46, 55, 22]
ghiteviari
[33, 8, 46, 15, 12, 5, 46, 30, 22, 21]
iskocaales
[36, 14, 15, 12, 5, 30, 37, 47, 42, 50]
bioncedabl
[33, 21, 12, 34, 15, 12, 5, 30, 37, 47]


 73%|███████▎  | 151/208 [01:34<00:32,  1.77it/s]

loss: 24.80217933654785
teodmianta
[22, 21, 12, 34, 5, 46, 30, 55, 22, 27]
sylinnemil
[24, 42, 47, 36, 55, 22, 12, 5, 46, 15]
nanghocyra
[5, 30, 55, 22, 8, 46, 50, 12, 34, 46]
cenopaifth
[24, 36, 14, 30, 5, 30, 36, 55, 22, 8]
sibynyalel
[20, 27, 5, 46, 15, 42, 30, 22, 12, 47]


 97%|█████████▋| 201/208 [02:06<00:04,  1.61it/s]

loss: 24.50609588623047
ganiscepei
[33, 30, 5, 46, 14, 15, 27, 50, 8, 46]
cftorthice
[50, 30, 22, 12, 56, 50, 8, 46, 15, 46]
phiallybyp
[33, 8, 46, 30, 37, 47, 42, 5, 27, 22]
akodalente
[30, 22, 30, 15, 46, 5, 12, 55, 22, 12]
selerigapo
[24, 42, 47, 12, 56, 46, 15, 46, 50, 12]


100%|██████████| 208/208 [02:10<00:00,  1.60it/s]


loss: 24.808820724487305
enutoidshi
[36, 55, 27, 22, 12, 36, 34, 50, 8, 46]
train loss: 25.37| valid loss: 24.71



  0%|          | 1/208 [00:00<02:39,  1.30it/s]

loss: 25.041990280151367
fame
ashit
[33, 30, 59, 27, 22, 27, 50, 8, 27, 22]
lerdorpali
[5, 12, 34, 5, 12, 56, 50, 30, 47, 46]
giiallirro
[5, 46, 53, 30, 37, 47, 12, 34, 21, 30]
segerperdo
[5, 12, 5, 12, 34, 5, 12, 34, 5, 30]
rsrotelort
[5, 54, 21, 30, 22, 12, 47, 12, 56, 22]


 25%|██▍       | 51/208 [00:31<01:45,  1.49it/s]

loss: 24.773218154907227
intiztitea
[36, 55, 22, 46, 15, 21, 46, 15, 12, 30]
ntaurecisr
[20, 22, 12, 27, 56, 27, 22, 46, 50, 8]
bedetshemi
[24, 36, 5, 27, 56, 50, 8, 12, 5, 46]
blendidmaT
[19, 47, 12, 55, 22, 46, 0, 5, 46, 25]
Cintrateme
[33, 36, 55, 22, 21, 30, 22, 12, 5, 12]


 49%|████▊     | 101/208 [01:02<01:06,  1.61it/s]

loss: 24.423091888427734
migelallip
[5, 46, 5, 12, 47, 30, 37, 47, 42, 50]
blethupiuc
[19, 47, 42, 22, 8, 42, 19, 36, 55, 15]
bontewgurm
[5, 30, 55, 22, 12, 34, 5, 12, 56, 59]
stremiwpor
[20, 22, 21, 12, 5, 46, 14, 5, 30, 34]
masrabinnu
[59, 27, 22, 21, 30, 37, 46, 14, 5, 30]


 73%|███████▎  | 151/208 [01:34<00:33,  1.70it/s]

loss: 24.52507781982422
cubitemafu
[24, 36, 19, 27, 22, 12, 5, 27, 22, 12]
niMifisthi
[5, 46, 15, 46, 15, 46, 55, 22, 8, 46]
rolustoctr
[5, 12, 5, 12, 56, 22, 12, 55, 22, 21]
debhesouto
[5, 12, 5, 8, 12, 56, 30, 55, 22, 12]
stelteetun
[20, 22, 12, 56, 22, 12, 56, 22, 12, 55]


 97%|█████████▋| 201/208 [02:05<00:04,  1.41it/s]

loss: 24.06537628173828
ustoannice
[36, 56, 22, 12, 30, 55, 15, 27, 22, 12]
tromatoffe
[33, 21, 12, 5, 27, 22, 30, 45, 45, 12]
pingerpani
[33, 46, 14, 44, 12, 34, 50, 30, 55, 46]
odylyblall
[30, 15, 42, 47, 42, 37, 47, 27, 37, 47]
cxm
ntelti
[24, 12, 5, 30, 55, 22, 12, 56, 22, 46]


100%|██████████| 208/208 [02:09<00:00,  1.61it/s]


loss: 24.159208297729492
rativyemic
[5, 27, 22, 46, 15, 27, 37, 59, 46, 15]
train loss: 24.41| valid loss: 24.26



  0%|          | 1/208 [00:00<03:00,  1.15it/s]

loss: 23.876821517944336
thpryutedp
[33, 8, 33, 21, 46, 30, 22, 12, 34, 19]
unkinmrede
[36, 55, 22, 46, 14, 17, 21, 12, 5, 12]
unnagotere
[36, 55, 15, 30, 5, 30, 50, 12, 5, 12]
vortiarion
[5, 12, 56, 22, 46, 30, 5, 46, 30, 15]
feodinkina
[24, 25, 0, 5, 46, 55, 22, 46, 15, 27]


 25%|██▍       | 51/208 [00:32<01:39,  1.59it/s]

loss: 24.256649017333984
brikyc
esk
[19, 21, 46, 15, 42, 50, 21, 12, 56, 38]
dadlotshan
[5, 30, 22, 47, 42, 50, 6, 8, 30, 55]
semastobRp
[24, 12, 5, 30, 56, 22, 30, 31, 20, 5]
voyrimaete
[5, 30, 42, 5, 46, 15, 27, 12, 22, 12]
miponstlyt
[5, 46, 50, 12, 56, 20, 26, 47, 42, 50]


 49%|████▊     | 101/208 [01:03<01:06,  1.62it/s]

loss: 23.937942504882812
buteistiak
[19, 27, 22, 12, 46, 20, 22, 46, 30, 15]
sandsassha
[24, 36, 55, 29, 21, 30, 56, 20, 8, 30]
henbelingr
[24, 36, 55, 5, 12, 21, 46, 14, 44, 21]
acjAngefee
[36, 55, 24, 36, 14, 44, 12, 45, 36, 58]
unoamentin
[36, 5, 12, 30, 5, 12, 56, 22, 46, 55]


 73%|███████▎  | 151/208 [01:34<00:35,  1.62it/s]

loss: 23.866498947143555
sablemeriv
[24, 30, 37, 47, 12, 5, 12, 34, 46, 16]
nodylliopa
[5, 30, 34, 42, 37, 47, 46, 30, 5, 30]
urnetoslif
[36, 34, 5, 27, 22, 12, 56, 35, 36, 45]
Oeostideni
[36, 16, 12, 56, 22, 46, 15, 12, 5, 46]
veretsimre
[16, 12, 5, 12, 56, 20, 36, 34, 5, 12]


 97%|█████████▋| 201/208 [02:06<00:04,  1.59it/s]

loss: 23.860408782958984
altestryst
[36, 55, 22, 12, 56, 22, 21, 42, 56, 22]
illacaleme
[36, 37, 47, 27, 22, 30, 5, 12, 5, 12]
hogytuical
[24, 30, 44, 12, 50, 8, 46, 15, 27, 37]
iqkedrinep
[36, 6, 53, 12, 62, 21, 46, 15, 12, 22]
Perficeffa
[24, 12, 34, 19, 46, 15, 12, 45, 45, 36]


100%|██████████| 208/208 [02:10<00:00,  1.60it/s]


loss: 23.88800811767578
uxaivurica
[36, 55, 30, 46, 15, 27, 21, 46, 15, 27]
train loss: 24.05| valid loss: 23.95



  0%|          | 1/208 [00:00<02:48,  1.23it/s]

loss: 23.37877655029297
landiyseme
[24, 36, 55, 29, 46, 12, 21, 12, 5, 12]
entesmulio
[36, 55, 22, 12, 56, 59, 27, 37, 46, 30]
alWpedassy
[30, 23, 42, 50, 12, 5, 30, 56, 20, 42]
stocfumive
[20, 22, 30, 55, 24, 36, 34, 46, 5, 12]
linyllaist
[5, 46, 15, 42, 37, 47, 30, 46, 20, 22]


 25%|██▍       | 51/208 [00:33<01:45,  1.49it/s]

loss: 23.88699722290039
syrssiapin
[24, 42, 56, 20, 22, 46, 30, 5, 46, 14]
phnifarech
[50, 8, 5, 46, 15, 27, 21, 12, 50, 8]
fycyiohgan
[24, 42, 15, 42, 46, 30, 14, 44, 36, 55]
trosanotup
[33, 21, 30, 22, 30, 5, 12, 5, 12, 56]
chelergoco
[50, 8, 42, 47, 12, 34, 5, 30, 56, 0]


 49%|████▊     | 101/208 [01:06<01:23,  1.28it/s]

loss: 24.1594295501709
sphallytho
[20, 50, 8, 27, 37, 47, 42, 50, 8, 30]
henomvblen
[24, 12, 56, 0, 5, 27, 37, 47, 12, 5]
phomingron
[50, 8, 30, 5, 46, 14, 44, 21, 36, 55]
scowehyned
[20, 22, 30, 5, 12, 8, 42, 5, 12, 29]
tozlurpric
[33, 30, 37, 47, 12, 34, 19, 21, 46, 15]


 73%|███████▎  | 151/208 [01:39<00:36,  1.56it/s]

loss: 23.34781837463379
umgrycadec
[36, 34, 19, 21, 42, 50, 30, 5, 12, 55]
riponpesti
[5, 46, 50, 30, 55, 19, 12, 56, 22, 46]
efeisnoyec
[36, 45, 12, 46, 0, 5, 30, 5, 12, 22]
mypocuenti
[24, 42, 50, 30, 50, 8, 12, 55, 22, 46]
partisphis
[33, 30, 56, 22, 46, 20, 50, 8, 46, 20]


 97%|█████████▋| 201/208 [02:10<00:04,  1.58it/s]

loss: 23.738645553588867
Aftedrisum
[36, 45, 22, 12, 29, 21, 46, 20, 0, 5]
gomawsufed
[50, 30, 5, 27, 55, 20, 28, 19, 12, 34]
medanedioc
[59, 36, 5, 30, 5, 12, 5, 46, 30, 50]
trothonera
[33, 21, 30, 22, 8, 30, 5, 12, 21, 30]
nyntechkei
[5, 27, 55, 22, 12, 50, 8, 53, 12, 46]


100%|██████████| 208/208 [02:14<00:00,  1.55it/s]


loss: 23.8298397064209
eutisptomg
[36, 55, 22, 46, 20, 50, 22, 30, 7, 44]
train loss: 23.84| valid loss: 23.80



  0%|          | 1/208 [00:00<02:59,  1.15it/s]

loss: 23.86598777770996
cruspolori
[33, 21, 36, 55, 50, 30, 47, 12, 21, 46]
plallnstle
[19, 47, 27, 37, 47, 12, 56, 22, 47, 12]
medoustico
[5, 12, 5, 30, 54, 20, 22, 46, 15, 12]
armkyxdibl
[36, 34, 59, 15, 42, 34, 29, 46, 37, 47]
amlumalitu
[30, 34, 47, 12, 5, 27, 37, 46, 50, 36]


 25%|██▍       | 51/208 [00:31<01:46,  1.48it/s]

loss: 23.67974281311035
armuecelys
[36, 34, 59, 36, 31, 19, 30, 47, 42, 56]
poustalanc
[5, 30, 54, 20, 22, 30, 37, 27, 55, 15]
irgivuogyr
[36, 34, 5, 46, 16, 12, 30, 44, 42, 34]
hertemapef
[24, 36, 34, 50, 12, 5, 30, 50, 12, 51]
Fenigherse
[24, 12, 5, 46, 6, 8, 12, 56, 20, 12]


 49%|████▊     | 101/208 [01:02<01:04,  1.65it/s]

loss: 23.82386016845703
contogedlo
[24, 36, 55, 22, 0, 5, 12, 23, 47, 30]
padosthamo
[33, 30, 5, 30, 56, 22, 8, 30, 5, 30]
nistesotin
[5, 46, 20, 22, 12, 56, 12, 22, 46, 15]
dergvatseb
[29, 12, 56, 0, 5, 30, 56, 20, 12, 5]
Monatracta
[24, 36, 5, 27, 22, 21, 30, 55, 22, 27]


 73%|███████▎  | 151/208 [01:35<00:36,  1.57it/s]

loss: 23.68906593322754
tintellewi
[24, 36, 55, 22, 12, 37, 47, 12, 34, 46]
peIantogen
[19, 12, 5, 27, 55, 22, 30, 44, 12, 5]
rsanoryger
[56, 20, 30, 5, 30, 34, 42, 44, 12, 34]
motesthlet
[24, 30, 22, 12, 56, 50, 8, 47, 12, 50]
pophfluchi
[19, 30, 50, 8, 57, 47, 42, 50, 8, 46]


 97%|█████████▋| 201/208 [02:07<00:04,  1.51it/s]

loss: 23.41640281677246
antingliol
[36, 55, 22, 46, 14, 44, 47, 46, 30, 47]
wardaestmu
[33, 30, 34, 5, 27, 12, 56, 22, 59, 36]
diqhitotak
[29, 46, 6, 8, 46, 15, 30, 15, 27, 53]
hersauslea
[24, 12, 34, 5, 30, 54, 20, 47, 12, 27]
oringigust
[36, 34, 46, 14, 44, 46, 15, 27, 56, 22]


100%|██████████| 208/208 [02:11<00:00,  1.58it/s]


loss: 24.003253936767578
cogbizetem
[50, 30, 0, 5, 46, 15, 27, 22, 12, 5]
train loss: 23.70| valid loss: 23.69



  0%|          | 1/208 [00:01<03:27,  1.00s/it]

loss: 23.372812271118164
lochideded
[35, 30, 50, 8, 46, 15, 12, 5, 12, 5]
santuqhand
[20, 36, 55, 22, 54, 6, 8, 27, 55, 29]
coneselsto
[24, 30, 5, 12, 56, 12, 55, 20, 22, 30]
ioshucergu
[36, 55, 20, 8, 12, 50, 12, 34, 44, 12]
ulkodapopo
[36, 55, 22, 30, 5, 30, 50, 30, 50, 30]


 25%|██▍       | 51/208 [00:32<01:39,  1.58it/s]

loss: 23.73843002319336
mosiouscit
[24, 12, 56, 46, 30, 54, 20, 22, 46, 50]
heschanMie
[24, 12, 56, 50, 8, 27, 55, 24, 46, 30]
tharoculiv
[33, 8, 30, 5, 30, 50, 41, 47, 42, 5]
wrocetivec
[33, 21, 30, 50, 12, 22, 46, 16, 12, 55]
Barogrisli
[24, 30, 21, 30, 44, 21, 46, 0, 5, 46]


 49%|████▊     | 101/208 [01:04<01:08,  1.57it/s]

loss: 23.490846633911133
uflecsatom
[36, 45, 45, 12, 56, 20, 27, 22, 30, 5]
pratopento
[33, 21, 27, 22, 0, 5, 12, 55, 22, 30]
pritonocss
[33, 35, 46, 22, 30, 55, 30, 15, 56, 20]
lergerpelu
[24, 12, 34, 44, 12, 34, 19, 12, 21, 12]
Propidablu
[33, 21, 30, 5, 46, 15, 27, 37, 47, 12]


 73%|███████▎  | 151/208 [01:36<00:34,  1.63it/s]

loss: 23.571678161621094
feddleovot
[24, 12, 23, 23, 47, 12, 30, 5, 30, 55]
icfarmirtt
[36, 55, 11, 30, 34, 5, 27, 55, 50, 22]
tolitoraly
[33, 30, 5, 42, 50, 30, 34, 27, 37, 42]
Fhioaheadl
[33, 8, 46, 30, 49, 58, 12, 27, 37, 47]
aftlernomi
[36, 45, 22, 47, 12, 34, 5, 30, 5, 46]


 97%|█████████▋| 201/208 [02:08<00:04,  1.45it/s]

loss: 23.48708152770996
ceqlolator
[24, 12, 50, 35, 36, 47, 27, 22, 12, 56]
holiarelul
[24, 30, 37, 46, 27, 21, 12, 37, 41, 47]
autisphess
[36, 55, 22, 46, 20, 50, 8, 12, 56, 22]
fertushend
[24, 12, 56, 50, 54, 20, 8, 12, 55, 24]
Hungvaengr
[24, 36, 14, 44, 5, 27, 12, 14, 44, 21]


100%|██████████| 208/208 [02:12<00:00,  1.57it/s]


loss: 23.332876205444336
staurament
[20, 22, 30, 54, 21, 30, 5, 12, 55, 22]
train loss: 23.60| valid loss: 23.60



  0%|          | 1/208 [00:01<03:27,  1.00s/it]

loss: 23.6558837890625
novewmimol
[24, 36, 16, 12, 48, 59, 36, 5, 30, 5]
firessctel
[45, 46, 5, 12, 56, 20, 50, 22, 36, 56]
satetisben
[20, 27, 22, 12, 56, 46, 20, 19, 12, 5]
onysnetica
[36, 5, 42, 20, 43, 12, 22, 46, 15, 27]
lingedonea
[5, 46, 14, 44, 12, 23, 0, 5, 12, 27]


 25%|██▍       | 51/208 [00:33<01:42,  1.54it/s]

loss: 23.625843048095703
satrodicue
[20, 27, 22, 21, 30, 5, 46, 50, 8, 12]
venNunuzec
[24, 12, 55, 24, 36, 55, 36, 16, 12, 50]
untaacgous
[36, 55, 22, 30, 36, 55, 44, 30, 54, 20]
Veacusteys
[24, 12, 30, 50, 54, 20, 22, 12, 42, 20]
Airistoscs
[24, 36, 34, 46, 20, 22, 46, 20, 22, 21]


 49%|████▊     | 101/208 [01:05<01:06,  1.61it/s]

loss: 23.376598358154297
crieganusu
[33, 21, 46, 30, 44, 30, 5, 54, 20, 28]
dicytKhern
[29, 46, 15, 42, 50, 60, 8, 12, 34, 5]
hysenorysl
[61, 42, 5, 12, 5, 30, 34, 46, 20, 35]
tesilcisst
[33, 30, 20, 46, 56, 22, 46, 56, 20, 22]
namemeefic
[24, 30, 5, 12, 34, 12, 34, 45, 46, 15]


 73%|███████▎  | 151/208 [01:37<00:40,  1.40it/s]

loss: 23.60538673400879
miyrstelyl
[24, 46, 12, 56, 20, 22, 12, 47, 42, 47]
cosymmolis
[15, 30, 20, 28, 31, 59, 0, 5, 46, 20]
forefuleda
[24, 30, 34, 12, 11, 41, 47, 12, 23, 0]
picitenepo
[19, 46, 15, 46, 22, 12, 43, 12, 5, 30]
rabemerydr
[24, 30, 19, 12, 5, 12, 34, 42, 62, 21]


 97%|█████████▋| 201/208 [02:08<00:04,  1.64it/s]

loss: 23.57352066040039
remastebof
[24, 36, 5, 30, 56, 22, 12, 5, 30, 29]
vatyclisto
[24, 30, 61, 42, 50, 35, 46, 20, 50, 30]
ragsapuivi
[24, 30, 44, 5, 30, 50, 8, 46, 16, 12]
pronfopele
[19, 21, 30, 5, 29, 30, 50, 12, 47, 12]
sbrionolep
[20, 19, 21, 46, 0, 5, 30, 47, 12, 56]


100%|██████████| 208/208 [02:12<00:00,  1.57it/s]


loss: 23.348670959472656
cepskoAnhl
[24, 12, 56, 20, 22, 30, 36, 5, 8, 5]
train loss: 23.51| valid loss: 23.50



  0%|          | 1/208 [00:00<02:31,  1.37it/s]

loss: 23.08064842224121
Liededadar
[24, 30, 49, 29, 12, 23, 0, 5, 30, 34]
pofferturo
[33, 30, 45, 45, 12, 34, 22, 8, 21, 30]
Pabimilino
[33, 27, 37, 46, 5, 46, 15, 46, 5, 30]
Locercanen
[24, 30, 50, 12, 56, 50, 30, 5, 12, 5]
stonescful
[20, 22, 30, 5, 12, 20, 50, 51, 41, 35]


 25%|██▍       | 51/208 [00:32<01:45,  1.48it/s]

loss: 23.57033348083496
renoorened
[24, 12, 5, 30, 36, 5, 12, 43, 12, 23]
Caomutogia
[33, 30, 36, 31, 41, 22, 0, 5, 46, 27]
qhlerphter
[6, 8, 47, 12, 34, 50, 8, 22, 12, 34]
qhedlacnea
[6, 8, 12, 23, 47, 27, 22, 43, 12, 30]
torobusion
[24, 30, 34, 46, 0, 54, 20, 46, 30, 5]


 49%|████▊     | 101/208 [01:04<01:13,  1.46it/s]

loss: 23.241405487060547
goglerexib
[33, 30, 44, 47, 12, 21, 12, 5, 46, 37]
foonillyla
[11, 30, 0, 5, 46, 37, 47, 42, 35, 30]
larilerall
[35, 30, 56, 46, 47, 12, 34, 27, 37, 47]
drefledian
[33, 21, 12, 4, 47, 12, 5, 46, 27, 14]
tophovanob
[33, 30, 50, 8, 0, 5, 30, 5, 30, 31]


 73%|███████▎  | 151/208 [01:36<00:35,  1.60it/s]

loss: 23.21142578125
merotastro
[24, 12, 34, 30, 50, 30, 20, 22, 21, 30]
orenanicep
[36, 34, 12, 55, 0, 5, 46, 15, 27, 19]
Eokogristh
[24, 30, 35, 30, 44, 21, 46, 20, 22, 21]
rapessumma
[24, 30, 50, 12, 56, 20, 28, 31, 59, 30]
ariculytap
[36, 34, 46, 15, 41, 47, 42, 50, 30, 50]


 97%|█████████▋| 201/208 [02:09<00:04,  1.46it/s]

loss: 23.4835147857666
socphausle
[24, 36, 55, 50, 8, 30, 54, 20, 47, 12]
jamacenimp
[24, 30, 59, 30, 50, 36, 5, 46, 31, 19]
befrorbuld
[19, 12, 57, 21, 30, 34, 40, 41, 37, 29]
Dondaffych
[24, 30, 55, 29, 30, 45, 45, 42, 50, 8]
karoiustuo
[38, 27, 34, 30, 49, 54, 20, 50, 8, 30]


100%|██████████| 208/208 [02:12<00:00,  1.56it/s]


loss: 23.426456451416016
Lonaterust
[24, 30, 5, 27, 22, 12, 21, 54, 20, 22]
train loss: 23.43| valid loss: 23.44



  0%|          | 1/208 [00:00<03:01,  1.14it/s]

loss: 24.027624130249023
merctantfa
[59, 12, 34, 50, 22, 27, 55, 22, 51, 27]
togrocaces
[33, 30, 44, 21, 30, 50, 27, 22, 12, 56]
mesdkwfule
[24, 36, 55, 15, 38, 48, 51, 41, 35, 36]
fraritaria
[19, 21, 30, 34, 46, 22, 30, 34, 46, 27]
unillalyss
[36, 5, 46, 37, 47, 27, 47, 42, 56, 20]


 25%|██▍       | 51/208 [00:31<01:45,  1.49it/s]

loss: 23.121946334838867
mornunphmi
[24, 30, 34, 5, 36, 55, 50, 8, 5, 46]
imessalter
[36, 31, 12, 56, 20, 36, 55, 22, 12, 34]
idnessosis
[36, 23, 43, 12, 56, 22, 30, 56, 46, 20]
rotdistove
[24, 30, 55, 29, 46, 20, 22, 30, 16, 12]
melialalif
[24, 12, 47, 46, 30, 37, 27, 37, 46, 11]


 49%|████▊     | 101/208 [01:03<01:06,  1.61it/s]

loss: 23.139862060546875
blymarsoit
[19, 47, 42, 5, 27, 56, 20, 30, 49, 22]
kecenskedw
[24, 12, 22, 12, 55, 20, 53, 12, 23, 48]
gericnemio
[33, 30, 34, 46, 22, 43, 12, 5, 46, 0]
sqhounsypn
[20, 6, 8, 30, 36, 55, 20, 28, 31, 5]
chicessted
[33, 8, 46, 15, 12, 56, 20, 22, 12, 23]


 73%|███████▎  | 151/208 [01:36<00:36,  1.56it/s]

loss: 23.20691680908203
Moaeutiqur
[24, 32, 30, 36, 55, 22, 46, 6, 8, 21]
populomapu
[33, 36, 50, 41, 35, 30, 59, 30, 19, 36]
ontatrarmi
[36, 55, 15, 27, 22, 21, 30, 34, 5, 46]
sotonibisi
[24, 30, 55, 0, 5, 46, 37, 30, 20, 46]
muntegesse
[24, 36, 55, 22, 30, 44, 12, 56, 20, 12]


 97%|█████████▋| 201/208 [02:08<00:04,  1.60it/s]

loss: 23.188507080078125
rimarciSte
[24, 30, 59, 30, 34, 50, 46, 20, 22, 30]
uutarnogli
[36, 55, 22, 30, 34, 5, 30, 44, 47, 46]
faussneral
[24, 30, 54, 56, 20, 43, 12, 34, 27, 37]
lismaplers
[35, 30, 20, 59, 30, 19, 47, 12, 56, 20]
entinuloni
[36, 55, 22, 46, 5, 41, 35, 30, 5, 46]


100%|██████████| 208/208 [02:12<00:00,  1.57it/s]


loss: 23.3612060546875
nansingymu
[24, 36, 55, 20, 46, 14, 44, 42, 59, 12]
train loss: 23.38| valid loss: 23.40



In [16]:
x = torch.tensor(encode("quack")).unsqueeze(0)
T = torch.tensor([5])
print(model.viterbi(x,T))

x = torch.tensor(encode("quick")).unsqueeze(0)
T = torch.tensor([5])
print(model.viterbi(x,T))

x = torch.tensor(encode("qurck")).unsqueeze(0)
T = torch.tensor([5])
print(model.viterbi(x,T)) # should have lower probability---in English only vowels follow "qu"

x = torch.tensor(encode("qiick")).unsqueeze(0)
T = torch.tensor([5])
print(model.viterbi(x,T)) # should have lower probability---in English only "u" follows "q"

([[6, 8, 30, 50, 38]], tensor([[-16.1840]], grad_fn=<GatherBackward0>))
([[6, 8, 46, 15, 38]], tensor([[-14.7257]], grad_fn=<GatherBackward0>))
([[33, 28, 34, 50, 38]], tensor([[-21.5828]], grad_fn=<GatherBackward0>))
([[6, 8, 46, 15, 38]], tensor([[-21.3148]], grad_fn=<GatherBackward0>))
