In [2]:
import torch
import numpy as np
class HMM(torch.nn.Module):
  """
  Hidden Markov Model with discrete observations.
  """
  def __init__(self, M, N):
    super(HMM, self).__init__()
    self.M = M # number of possible observations
    self.N = N # number of states
    # A
    self.transition_model = TransitionModel(self.N)# TODO
    # b(x_t)
    self.emission_model = EmissionModel(self.N, self.M) # TODO
    # pi
    self.unnormalized_state_priors = torch.nn.Parameter(torch.randn(self.N)) # TODO
    # use the GPU
    self.is_cuda = torch.cuda.is_available()
    if self.is_cuda: self.cuda()

class TransitionModel(torch.nn.Module):
  def __init__(self, N):
    super(TransitionModel, self).__init__()# TODO
    self.N = N
    self.unnormalized_transition_matrix = torch.nn.Parameter(torch.randn(N,N)) # TODO

class EmissionModel(torch.nn.Module):
  def __init__(self, N, M):
    super(EmissionModel, self).__init__()
    self.N = N
    self.M = M # TODO
    self.unnormalized_emission_matrix = torch.nn.Parameter(torch.randn(N,M)) # TODO

In [3]:
def sample(self, T=10):
  state_priors = torch.nn.functional.softmax(self.unnormalized_state_priors, dim=0)
  transition_matrix = torch.nn.functional.softmax(self.transition_model.unnormalized_transition_matrix, dim=0)
  emission_matrix = torch.nn.functional.softmax(self.emission_model.unnormalized_emission_matrix, dim=1)
  # sample initial state
  z_t = torch.distributions.categorical.Categorical(state_priors).sample().item()
  z = []; x = []
  z.append(z_t)
  for t in range(0,T):
  # sample emission
    x_t = torch.distributions.categorical.Categorical(emission_matrix[z_t]).sample().item() # TODO
    x.append(x_t)

  # sample transition
    z_t = torch.distributions.categorical.Categorical(transition_matrix[:,z_t]).sample().item() # TODO
    if t < T-1: z.append(z_t)

  return x, z
# Add the sampling method to our HMM class
HMM.sample = sample


In [4]:
import string
alphabet = string.ascii_lowercase

def encode(s):
  """
  Convert a string into a list of integers
  """
  x = [alphabet.index(ss) for ss in s]
  return x

def decode(x):
  """
  Convert list of ints to string
  """
  s = "".join([alphabet[xx] for xx in x])
  return s

# Initialize the model
model = HMM(M=len(alphabet), N=2)

# Hard-wiring the parameters!
# Let state 0 = consonant, state 1 = vowel
for p in model.parameters():
    p.requires_grad = False # needed to do lines below
model.unnormalized_state_priors[0] = 0.    # Let's start with a consonant more frequently
model.unnormalized_state_priors[1] = -0.5
print("State priors:", torch.nn.functional.softmax(model.unnormalized_state_priors, dim=0))

# In state 0, only allow consonants; in state 1, only allow vowels
vowel_indices = torch.tensor([alphabet.index(letter) for letter in "aeiou"])
consonant_indices = torch.tensor([alphabet.index(letter) for letter in "bcdfghjklmnpqrstvwxyz"])
model.emission_model.unnormalized_emission_matrix[0, vowel_indices] = -np.inf
model.emission_model.unnormalized_emission_matrix[1, consonant_indices] = -np.inf
print("Emission matrix:", torch.nn.functional.softmax(model.emission_model.unnormalized_emission_matrix, dim=1))

# Only allow vowel -> consonant and consonant -> vowel
model.transition_model.unnormalized_transition_matrix[0,0] = -np.inf  # consonant -> consonant
model.transition_model.unnormalized_transition_matrix[0,1] = 0.       # vowel -> consonant
model.transition_model.unnormalized_transition_matrix[1,0] = 0.       # consonant -> vowel
model.transition_model.unnormalized_transition_matrix[1,1] = -np.inf  # vowel -> vowel
print("Transition matrix:", torch.nn.functional.softmax(model.transition_model.unnormalized_transition_matrix, dim=0))



State priors: tensor([0.6225, 0.3775], device='cuda:0')
Emission matrix: tensor([[0.0000, 0.1355, 0.0329, 0.0752, 0.0000, 0.0882, 0.1094, 0.0136, 0.0000,
         0.0395, 0.0262, 0.0212, 0.1066, 0.0663, 0.0000, 0.0634, 0.0273, 0.0192,
         0.0178, 0.0098, 0.0000, 0.0057, 0.0080, 0.1089, 0.0140, 0.0112],
        [0.0524, 0.0000, 0.0000, 0.0000, 0.0744, 0.0000, 0.0000, 0.0000, 0.0629,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1145, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.6958, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
       device='cuda:0')
Transition matrix: tensor([[0., 1.],
        [1., 0.]], device='cuda:0')


In [5]:
# Sample some outputs
for _ in range(4):
  sampled_x, sampled_z = model.sample(T=5)
  print("x:", decode(sampled_x))
  print("z:", sampled_z)

x: gupux
z: [0, 1, 0, 1, 0]
x: jumuj
z: [0, 1, 0, 1, 0]
x: dequn
z: [0, 1, 0, 1, 0]
x: modun
z: [0, 1, 0, 1, 0]


In [6]:
def HMM_forward(self, x, T):
  """
  x : IntTensor of shape (batch size, T_max)
  T : IntTensor of shape (batch size)
  Compute log p(x) for each example in the batch.
  T = length of each example
  """
  if self.is_cuda:
    x = x.cuda()
    T = T.cuda()

  batch_size = x.shape[0]; T_max = x.shape[1]
  log_state_priors = torch.nn.functional.log_softmax(self.unnormalized_state_priors, dim=0) # TODO
  log_alpha = torch.zeros(batch_size, T_max, self.N)# TODO
  if self.is_cuda: log_alpha = log_alpha.cuda()

  log_alpha[:, 0, :] = self.emission_model(x[:,0]) + log_state_priors
  for t in range(1, T_max):
    log_alpha[:, t, :] = self.emission_model(x[:,t]) + self.transition_model(log_alpha[:, t-1, :])

  # Select the sum for the final timestep (each x may have different length).
  log_sums = log_alpha.logsumexp(dim=2) # TODO # HINT: Use logexpsum
  log_probs = torch.gather(log_sums, 1, T.view(-1,1) - 1)
  return log_probs


In [7]:
def emission_model_forward(self, x_t):
  log_emission_matrix = torch.nn.functional.log_softmax(self.unnormalized_emission_matrix, dim=1) # TODO
  out = log_emission_matrix[:, x_t].transpose(0,1)
  return out

def transition_model_forward(self, log_alpha):
  """
  log_alpha : Tensor of shape (batch size, N)
  Multiply previous timestep's alphas by transition matrix (in log
  domain)
  """
  log_transition_matrix = torch.nn.functional.log_softmax(self.unnormalized_transition_matrix, dim=0) # TODO
  # Matrix multiplication in the log domain
  out = log_domain_matmul(log_transition_matrix, log_alpha.transpose(0,1)).transpose(0,1)
  return out

def log_domain_matmul(log_A, log_B):
  """
  log_A : m x n
  log_B : n x p
  output : m x p matrix
  Normally, a matrix multiplication
  computes out_{i,j} = sum_k A_{i,k} x B_{k,j}
  A log domain matrix multiplication
  computes out_{i,j} = logsumexp_k log_A_{i,k} + log_B_{k,j}
  """
  m = log_A.shape[0]
  n = log_A.shape[1]
  p = log_B.shape[1]
  # log_A_expanded = torch.stack([log_A] * p, dim=2)
  # log_B_expanded = torch.stack([log_B] * m, dim=0)
  # fix for PyTorch > 1.5 by egaznep on Github:
  log_A_expanded = torch.reshape(log_A, (m,n,1)) # TODO
  log_B_expanded = torch.reshape(log_B, (1,n,p)) # TODO
  elementwise_sum = log_A_expanded + log_B_expanded # TODO
  out = torch.logsumexp(elementwise_sum, dim=1) # TODO
  return out

TransitionModel.forward = transition_model_forward
EmissionModel.forward = emission_model_forward
HMM.forward = HMM_forward


In [8]:
x = torch.stack( [torch.tensor(encode("cat"))] )
T = torch.tensor([3])
print(model.forward(x, T))
x = torch.stack( [torch.tensor(encode("aba")),
torch.tensor(encode("abb"))] )
T = torch.tensor([3,3])
print(model.forward(x, T))

tensor([[-11.4655]], device='cuda:0')
tensor([[-8.8687],
        [   -inf]], device='cuda:0')


In [9]:
def viterbi(self, x, T):
  """
  x : IntTensor of shape (batch size, T_max)
  T : IntTensor of shape (batch size)
  Find argmax_z log p(x|z) for each (x) in the batch.
  """
  if self.is_cuda:
    x = x.cuda()
    T = T.cuda()

  batch_size = x.shape[0]; T_max = x.shape[1]
  log_state_priors = torch.nn.functional.log_softmax(self.unnormalized_state_priors, dim=0) # TODO
  log_delta = torch.zeros(batch_size, T_max, self.N).float() # TODO
  psi = torch.zeros(batch_size, T_max, self.N).long() # TODO

  if self.is_cuda:
    log_delta = log_delta.cuda()
    psi = psi.cuda()

  log_delta[:, 0, :] = self.emission_model(x[:,0]) + log_state_priors# TODO: Use emission model and log state priors
  for t in range(1, T_max):
    max_val, argmax_val = self.transition_model.maxmul(log_delta[:, t-1, :])
    log_delta[:, t, :] = self.emission_model(x[:,t]) + max_val
    psi[:, t, :] = argmax_val
  # Get the log probability of the best path

  log_max = log_delta.max(dim=2)[0]
  best_path_scores = torch.gather(log_max, 1, T.view(-1,1) - 1) # TODO
  # This next part is a bit tricky to parallelize across the batch,
  # so we will do it separately for each example.
  z_star = []
  for i in range(0, batch_size):
    z_star_i = [ log_delta[i, T[i] - 1, :].max(dim=0)[1].item()] # TODO

    for t in range(T[i] - 1, 0, -1):
      z_t = psi[i, t, z_star_i[0]].item()
      z_star_i.insert(0, z_t)

    z_star.append(z_star_i)
  return z_star, best_path_scores # return both the best path and its log probability

In [10]:
def transition_model_maxmul(self, log_alpha):

  log_transition_matrix = torch.nn.functional.log_softmax(self.unnormalized_transition_matrix, dim=0) # TODO
  out1, out2 = maxmul(log_transition_matrix, log_alpha.transpose(0,1)) # TODO
  return out1.transpose(0,1), out2.transpose(0,1)

def maxmul(log_A, log_B):
  """
  log_A : m x n
  log_B : n x p
  output : m x p matrix
  Similar to the log domain matrix multiplication,
  this computes out_{i,j} = max_k log_A_{i,k} + log_B_{k,j}
  """
  m = log_A.shape[0]
  n = log_A.shape[1]
  p = log_B.shape[1]
  log_A_expanded = torch.stack([log_A] * p, dim=2) # TODO
  log_B_expanded = torch.stack([log_B] * m, dim=0) # TODO
  elementwise_sum = log_A_expanded + log_B_expanded # TODO
  out1,out2 = torch.max(elementwise_sum, dim=1)
  return out1,out2

TransitionModel.maxmul = transition_model_maxmul
HMM.viterbi = viterbi

In [11]:
x = torch.stack( [torch.tensor(encode("aba")),
torch.tensor(encode("abb"))] )
T = torch.tensor([3,3])
print(model.viterbi(x, T))

([[1, 0, 1], [1, 0, 0]], tensor([[-8.8687],
        [   -inf]], device='cuda:0'))


In [12]:
print(model.forward(x, T))
print(model.viterbi(x, T)[1])

tensor([[-8.8687],
        [   -inf]], device='cuda:0')
tensor([[-8.8687],
        [   -inf]], device='cuda:0')


In [13]:
x = torch.tensor([1., 2., 3.])
print(x.max(dim=0)[0])
print(x.logsumexp(dim=0))

tensor(3.)
tensor(3.4076)


In [14]:
import torch.utils.data
from collections import Counter
from sklearn.model_selection import train_test_split

class TextDataset(torch.utils.data.Dataset):
  def __init__(self, lines):
    self.lines = lines # list of strings
    collate = Collate() # function for generating a minibatch from strings
    self.loader = torch.utils.data.DataLoader(self, batch_size=1024, num_workers=1, shuffle=True, collate_fn=collate)

  def __len__(self):
    return len(self.lines)

  def __getitem__(self, idx):
    line = self.lines[idx].lstrip(" ").rstrip("\n").rstrip(" ").rstrip("\n")
    return line

class Collate:
  def __init__(self):
    pass

  def __call__(self, batch):
    """
    Returns a minibatch of strings, padded to have the same length.
    """
    x = []
    batch_size = len(batch)
    for index in range(batch_size):
      x_ = batch[index]

      # convert letters to integers
      x.append(encode(x_))

    # pad all sequences with 0 to have same length
    x_lengths = [len(x_) for x_ in x]
    T = max(x_lengths)
    for index in range(batch_size):
      x[index] += [0] * (T - len(x[index]))
      x[index] = torch.tensor(x[index])

    # stack into single tensor
    x = torch.stack(x)
    x_lengths = torch.tensor(x_lengths)
    return (x,x_lengths)

In [15]:
!wget https://raw.githubusercontent.com/lorenlugosch/pytorch_HMM/master/data/train/training.txt

--2023-11-05 23:05:24--  https://raw.githubusercontent.com/lorenlugosch/pytorch_HMM/master/data/train/training.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2493109 (2.4M) [text/plain]
Saving to: ‘training.txt’


2023-11-05 23:05:24 (47.3 MB/s) - ‘training.txt’ saved [2493109/2493109]



In [16]:
filename = "training.txt"

with open(filename, "r") as f:
  lines = f.readlines() # each line of lines will have one word

alphabet = list(Counter(("".join(lines))).keys())
train_lines, valid_lines = train_test_split(lines, test_size=0.1, random_state=42)
train_dataset = TextDataset(train_lines)
valid_dataset = TextDataset(valid_lines)

M = len(alphabet)

In [17]:
from tqdm import tqdm # for displaying progress bar

class Trainer:
  def __init__(self, model, lr):
    self.model = model
    self.lr = lr
    self.optimizer = torch.optim.Adam(model.parameters(), lr=self.lr, weight_decay=0.00001)

  def train(self, dataset):
    train_loss = 0
    num_samples = 0
    self.model.train()
    print_interval = 50
    for idx, batch in enumerate(tqdm(dataset.loader)):
      x,T = batch
      batch_size = len(x)
      num_samples += batch_size
      log_probs = self.model(x,T)
      loss = -log_probs.mean()
      self.optimizer.zero_grad()
      loss.backward()
      self.optimizer.step()
      train_loss += loss.cpu().data.numpy().item() * batch_size
      if idx % print_interval == 0:
        print("loss:", loss.item())
        for _ in range(5):
          sampled_x, sampled_z = self.model.sample()
          print(decode(sampled_x))
          print(sampled_z)
    train_loss /= num_samples
    return train_loss

  def test(self, dataset):
    test_loss = 0
    num_samples = 0
    self.model.eval()
    print_interval = 50
    for idx, batch in enumerate(dataset.loader):
      x,T = batch
      batch_size = len(x)
      num_samples += batch_size
      log_probs = self.model(x,T)
      loss = -log_probs.mean()
      test_loss += loss.cpu().data.numpy().item() * batch_size
      if idx % print_interval == 0:
        print("loss:", loss.item())
        sampled_x, sampled_z = self.model.sample()
        print(decode(sampled_x))
        print(sampled_z)
    test_loss /= num_samples
    return test_loss

In [18]:
# Initialize model
model = HMM(N=64, M=M)

# Train the model
num_epochs = 10
trainer = Trainer(model, lr=0.01)

for epoch in range(num_epochs):
        print("========= Epoch %d of %d =========" % (epoch+1, num_epochs))
        train_loss = trainer.train(train_dataset)
        valid_loss = trainer.test(valid_dataset)

        print("========= Results: epoch %d of %d =========" % (epoch+1, num_epochs))
        print("train loss: %.2f| valid loss: %.2f\n" % (train_loss, valid_loss) )



  0%|          | 0/208 [00:00<?, ?it/s]

loss: 37.76902770996094
HFahlsHEWw
[48, 42, 17, 4, 16, 60, 12, 13, 16, 24]


  1%|▏         | 3/208 [00:00<00:22,  8.92it/s]

ZjMujGobKo
[26, 54, 27, 56, 58, 22, 18, 9, 34, 34]
fZvVRKLUoh
[33, 5, 9, 35, 28, 51, 27, 56, 48, 51]
VX-kpAqiPZ
[18, 57, 59, 3, 47, 15, 45, 10, 3, 1]
BNWFJ-AaOe
[25, 37, 60, 45, 6, 49, 15, 31, 29, 15]


 25%|██▌       | 53/208 [00:03<00:10, 15.35it/s]

loss: 33.40502166748047
fDSofenlnw
[50, 35, 18, 33, 40, 10, 1, 6, 23, 19]
utfDLmMbYH
[10, 54, 0, 44, 39, 22, 22, 22, 43, 22]
ge-DVdxhlE
[13, 40, 63, 32, 20, 45, 36, 7, 25, 21]
iPBDiYIeEJ
[23, 32, 45, 44, 10, 30, 41, 6, 42, 7]
FlbPVZZFcQ
[41, 28, 51, 45, 44, 8, 18, 0, 58, 15]


 50%|████▉     | 103/208 [00:06<00:07, 13.17it/s]

loss: 30.182662963867188
mlozrcbAsd
[37, 0, 40, 40, 58, 31, 0, 27, 56, 57]
broAr-anIm
[10, 42, 23, 48, 36, 33, 40, 56, 59, 35]
SwdcydoUic
[17, 14, 53, 56, 47, 0, 44, 55, 51, 58]
TEEulkttdp
[53, 35, 32, 57, 54, 9, 27, 56, 48, 29]
trtegrgzez
[26, 42, 20, 1, 40, 58, 21, 0, 53, 0]


 74%|███████▎  | 153/208 [00:10<00:03, 15.52it/s]

loss: 28.436504364013672
dsdoUgclki
[41, 35, 49, 18, 36, 35, 58, 38, 10, 10]
lBwcntrlKg
[54, 59, 37, 7, 56, 48, 33, 6, 58, 20]
bonlcagyzZ
[50, 23, 18, 0, 7, 53, 0, 47, 57, 59]
fcbisonrtM
[50, 62, 41, 10, 54, 48, 51, 58, 20, 51]
byVeaJlYua
[8, 21, 50, 40, 56, 41, 28, 19, 45, 53]


 98%|█████████▊| 203/208 [00:13<00:00, 15.93it/s]

loss: 26.966365814208984
rksrIiauti
[58, 17, 54, 54, 20, 51, 58, 17, 59, 23]
oDyweltdau
[10, 1, 45, 8, 17, 0, 44, 0, 38, 56]
mmtedeguAc
[26, 5, 10, 1, 15, 51, 31, 3, 62, 58]
ciyeilsniy
[7, 51, 7, 46, 36, 0, 53, 20, 51, 45]
msthvniara
[3, 56, 48, 59, 44, 56, 58, 20, 51, 1]


100%|██████████| 208/208 [00:13<00:00, 15.63it/s]


loss: 26.964330673217773
Nirtidtlky
[60, 10, 6, 20, 51, 58, 36, 0, 45, 45]
train loss: 30.84| valid loss: 26.63



  0%|          | 1/208 [00:00<00:40,  5.16it/s]

loss: 26.349462509155273
siHamoSeyh
[33, 10, 1, 23, 27, 35, 58, 17, 54, 59]
fMsahalrat
[50, 23, 22, 53, 59, 17, 0, 58, 53, 20]
cesmtiXste
[57, 17, 54, 22, 20, 23, 1, 53, 20, 51]
vesgamiong
[58, 17, 54, 20, 44, 58, 51, 23, 56, 8]
balCtirVCU
[55, 44, 0, 45, 29, 32, 48, 24, 13, 18]


 25%|██▍       | 51/208 [00:03<00:12, 12.52it/s]

loss: 25.460973739624023
Lapisspnun
[1, 44, 20, 10, 54, 27, 29, 60, 53, 0]
poracVraly
[33, 40, 22, 53, 20, 23, 18, 36, 0, 45]
uaTilokedb
[3, 56, 34, 36, 0, 44, 58, 17, 58, 36]
mafunIcuiw
[33, 44, 33, 44, 56, 17, 54, 14, 23, 56]
heJoteriZo
[59, 17, 55, 44, 48, 17, 32, 51, 1, 53]


 50%|████▉     | 103/208 [00:06<00:06, 15.50it/s]

loss: 25.167308807373047
chocraoede
[1, 59, 44, 29, 60, 44, 58, 17, 58, 17]
agpelusraY
[26, 42, 21, 36, 0, 10, 54, 22, 53, 56]
ossilensom
[3, 56, 48, 36, 0, 53, 56, 41, 23, 27]
sslaphepae
[54, 48, 6, 23, 29, 60, 17, 33, 17, 35]
crsMvitrvo
[29, 60, 40, 26, 58, 51, 48, 32, 58, 17]


 74%|███████▎  | 153/208 [00:09<00:03, 15.13it/s]

loss: 24.99958610534668
thJrfimagu
[29, 60, 17, 32, 58, 51, 58, 23, 41, 42]
warasmerdy
[26, 42, 33, 10, 54, 22, 17, 32, 58, 36]
otusbatiam
[44, 29, 62, 54, 22, 53, 20, 51, 53, 20]
raierswgin
[33, 44, 27, 17, 32, 22, 53, 8, 10, 54]
jyderphica
[33, 40, 58, 17, 32, 29, 60, 23, 22, 53]


 98%|█████████▊| 203/208 [00:12<00:00, 15.37it/s]

loss: 24.748218536376953
bVsaerteco
[26, 17, 54, 20, 17, 32, 20, 51, 1, 23]
fidezyalis
[33, 10, 58, 17, 58, 36, 53, 20, 51, 54]
unveuinonY
[3, 56, 58, 17, 9, 51, 1, 23, 56, 8]
snfebYMevi
[3, 56, 21, 36, 0, 10, 27, 17, 58, 51]
sttlasener
[54, 54, 48, 60, 17, 33, 17, 58, 17, 32]


100%|██████████| 208/208 [00:13<00:00, 15.76it/s]


loss: 24.681243896484375
elttwecmar
[31, 0, 45, 20, 55, 35, 58, 30, 57, 60]
train loss: 25.19| valid loss: 24.51



  0%|          | 0/208 [00:00<?, ?it/s]

loss: 24.517404556274414
obeshenpor
[23, 0, 17, 54, 59, 17, 32, 33, 44, 32]
erikizouwi
[44, 32, 51, 58, 51, 1, 23, 62, 33, 17]
ucglalinyw
[3, 56, 41, 0, 53, 0, 51, 1, 45, 33]
cremrdisbi
[1, 60, 36, 2, 32, 33, 10, 54, 16, 51]


  1%|▏         | 3/208 [00:00<00:19, 10.26it/s]

nerarmodom
[58, 17, 32, 53, 32, 33, 44, 28, 23, 22]


 25%|██▍       | 51/208 [00:03<00:14, 11.01it/s]

loss: 24.678577423095703
shlalicanb
[54, 59, 6, 44, 58, 51, 1, 53, 56, 41]
phitrormop
[29, 60, 51, 48, 60, 44, 32, 22, 23, 29]
troupormal
[29, 60, 23, 62, 33, 17, 32, 22, 53, 0]
periliLele
[26, 17, 32, 36, 0, 51, 26, 17, 0, 17]
wltriipiaB
[33, 51, 48, 32, 51, 23, 22, 51, 23, 62]


 50%|████▉     | 103/208 [00:06<00:06, 15.42it/s]

loss: 23.975370407104492
edanablres
[17, 32, 53, 56, 53, 16, 0, 6, 10, 22]
glypqustag
[41, 6, 45, 27, 61, 42, 54, 20, 44, 8]
dnRlidante
[35, 58, 53, 0, 51, 58, 53, 56, 48, 17]
amtaprites
[44, 27, 29, 44, 29, 60, 44, 48, 17, 22]
unecoshour
[3, 56, 17, 48, 36, 54, 59, 23, 62, 54]


 74%|███████▎  | 153/208 [00:09<00:03, 15.64it/s]

loss: 24.077823638916016
muparafbli
[26, 42, 29, 44, 32, 44, 36, 18, 0, 51]
reuokessic
[60, 17, 44, 35, 58, 17, 54, 54, 51, 1]
cerinuVtob
[33, 17, 32, 51, 58, 42, 18, 20, 17, 41]
trinotosme
[29, 60, 44, 56, 44, 29, 44, 54, 22, 17]
sprerityjt
[54, 29, 60, 17, 32, 51, 48, 45, 31, 48]


 98%|█████████▊| 203/208 [00:13<00:00, 14.88it/s]

loss: 24.331140518188477
dardeshesi
[58, 53, 32, 41, 10, 54, 59, 10, 54, 51]
sWlentlall
[26, 53, 0, 17, 56, 48, 60, 44, 0, 0]
SuaccistPo
[26, 44, 3, 56, 58, 10, 54, 48, 60, 44]
pylonimthi
[29, 36, 0, 23, 56, 44, 27, 29, 60, 51]
sispscascl
[54, 10, 54, 29, 60, 48, 44, 54, 29, 51]


100%|██████████| 208/208 [00:13<00:00, 15.61it/s]


loss: 24.058752059936523
ludatisodl
[58, 17, 58, 53, 20, 51, 1, 23, 41, 6]
train loss: 24.16| valid loss: 24.01



  0%|          | 0/208 [00:00<?, ?it/s]

loss: 24.063980102539062
cantseoler
[1, 53, 56, 10, 54, 17, 35, 58, 17, 33]


  0%|          | 1/208 [00:00<00:51,  4.03it/s]

cathyyliom
[29, 44, 29, 60, 45, 45, 0, 51, 23, 22]
unersenati
[3, 56, 17, 32, 33, 17, 56, 53, 20, 51]
bielggWanc
[26, 51, 17, 56, 8, 8, 49, 53, 56, 48]

ltisuatea
[19, 45, 20, 51, 48, 42, 53, 20, 51, 53]


 25%|██▌       | 53/208 [00:03<00:09, 16.08it/s]

loss: 23.91845703125
pineshmomi
[29, 44, 56, 10, 54, 59, 2, 23, 27, 51]
congeniisi
[1, 23, 56, 8, 17, 56, 51, 53, 20, 51]
sesstendee
[26, 10, 54, 54, 48, 17, 35, 58, 17, 18]
diraminnur
[41, 44, 6, 44, 22, 23, 56, 48, 42, 33]
acicatodre
[3, 56, 51, 1, 53, 20, 17, 32, 33, 17]


 50%|████▉     | 103/208 [00:06<00:06, 15.41it/s]

loss: 24.146703720092773
stenacogli
[54, 20, 17, 56, 44, 1, 23, 8, 6, 51]
aneshivalo
[3, 56, 10, 54, 59, 51, 58, 53, 0, 53]
omatiparis
[44, 22, 53, 20, 51, 1, 53, 56, 10, 54]
hensariLel
[19, 17, 56, 48, 53, 32, 51, 1, 53, 0]
Mormableal
[26, 44, 32, 22, 53, 43, 0, 51, 53, 0]


 74%|███████▎  | 153/208 [00:09<00:03, 15.20it/s]

loss: 23.95726776123047
unvealeoso
[3, 56, 58, 17, 53, 0, 17, 44, 22, 44]
pilllpioko
[29, 36, 0, 0, 0, 20, 51, 23, 6, 44]
penetichya
[33, 17, 56, 53, 20, 51, 48, 59, 47, 53]
sinsestupe
[26, 23, 56, 48, 44, 54, 48, 42, 33, 17]
bostergrob
[41, 44, 54, 48, 17, 32, 41, 6, 17, 31]


 97%|█████████▋| 201/208 [00:12<00:00, 12.36it/s]

loss: 23.787620544433594
thigiousst
[29, 60, 44, 8, 51, 23, 62, 54, 54, 48]
ungualdeni
[3, 56, 8, 42, 53, 0, 58, 17, 56, 51]
befilticyr
[26, 44, 21, 10, 18, 48, 51, 1, 45, 33]
igenedceti
[44, 8, 17, 58, 17, 32, 1, 53, 20, 51]
dnuurespos
[41, 6, 42, 30, 6, 10, 54, 22, 51, 54]


100%|██████████| 208/208 [00:13<00:00, 15.52it/s]


loss: 23.80683135986328
ogalusesui
[23, 8, 53, 0, 40, 22, 17, 2, 42, 51]
train loss: 23.86| valid loss: 23.81



  0%|          | 0/208 [00:00<?, ?it/s]

loss: 23.948991775512695
densuoneon
[33, 17, 56, 2, 42, 35, 58, 17, 23, 56]


  1%|▏         | 3/208 [00:00<00:21,  9.35it/s]

upproyepha
[3, 27, 29, 60, 23, 52, 17, 29, 60, 53]
criontates
[29, 6, 51, 23, 56, 48, 53, 20, 17, 54]
sualidolli
[41, 42, 53, 0, 51, 58, 17, 0, 0, 51]
frarsometi
[41, 6, 44, 32, 22, 44, 22, 53, 20, 51]


 25%|██▌       | 53/208 [00:03<00:09, 15.89it/s]

loss: 23.704809188842773
mouspolion
[50, 23, 62, 54, 29, 36, 0, 51, 23, 56]
ptytikider
[29, 60, 45, 20, 51, 48, 51, 58, 17, 32]
aneteaeage
[3, 56, 10, 48, 17, 58, 17, 44, 8, 17]
couscechea
[1, 23, 62, 54, 48, 44, 29, 60, 17, 53]
phocraciri
[29, 60, 44, 29, 60, 44, 48, 44, 32, 51]


 50%|████▉     | 103/208 [00:06<00:06, 15.06it/s]

loss: 23.408023834228516
juweraYizm
[55, 42, 33, 17, 32, 53, 20, 51, 23, 22]

ubermovem
[55, 42, 30, 17, 32, 22, 44, 58, 17, 33]
unqurrated
[3, 56, 61, 42, 32, 6, 10, 20, 35, 58]
unycayeryp
[3, 56, 45, 1, 53, 52, 17, 32, 45, 29]
wropramaxa
[26, 6, 44, 29, 60, 44, 22, 53, 56, 53]


 74%|███████▎  | 153/208 [00:09<00:03, 15.16it/s]

loss: 23.74091911315918
Nemioprill
[26, 17, 22, 51, 23, 29, 60, 51, 0, 0]
chedperful
[1, 60, 17, 32, 22, 17, 32, 21, 36, 0]
nontildere
[50, 23, 56, 48, 44, 0, 58, 17, 32, 17]
mouliorami
[26, 44, 36, 0, 51, 23, 6, 44, 22, 51]
Sotrispcan
[26, 44, 48, 6, 10, 54, 29, 57, 53, 56]


 98%|█████████▊| 203/208 [00:12<00:00, 13.92it/s]

loss: 23.572647094726562
prartenkap
[26, 60, 44, 32, 48, 17, 56, 41, 44, 29]
baateralll
[34, 6, 44, 48, 17, 32, 53, 0, 0, 0]
parontiogn
[29, 44, 32, 23, 56, 20, 51, 23, 8, 56]
dochuFrcos
[33, 44, 29, 60, 42, 38, 32, 1, 36, 54]
bitorindlo
[26, 44, 29, 36, 32, 51, 56, 41, 0, 23]


100%|██████████| 208/208 [00:13<00:00, 15.75it/s]


loss: 23.692394256591797
conemitumb
[50, 23, 56, 10, 22, 51, 48, 42, 27, 34]
train loss: 23.70| valid loss: 23.68



  0%|          | 1/208 [00:00<00:39,  5.26it/s]

loss: 23.91607666015625
hykednabov
[19, 45, 57, 17, 32, 58, 53, 43, 36, 0]
mipiproram
[26, 44, 29, 44, 29, 60, 44, 32, 44, 22]
pordtasesi
[33, 17, 32, 58, 12, 53, 56, 10, 22, 10]
sopbestori
[26, 44, 27, 11, 17, 54, 48, 44, 32, 51]
mechotaedi
[33, 10, 48, 59, 44, 29, 44, 35, 41, 51]


 25%|██▌       | 53/208 [00:03<00:09, 15.61it/s]

loss: 23.307058334350586
cokeraulor
[29, 44, 58, 17, 32, 44, 36, 0, 44, 32]
maudenglyi
[33, 44, 35, 41, 17, 56, 8, 0, 45, 10]
intaphimon
[3, 56, 48, 44, 29, 60, 10, 22, 23, 56]
ousmophiqu
[23, 62, 54, 22, 44, 29, 60, 51, 55, 42]
trioticona
[26, 6, 44, 18, 48, 51, 1, 23, 56, 53]


 50%|████▉     | 103/208 [00:06<00:06, 15.89it/s]

loss: 23.625816345214844
rissphedeo
[33, 10, 54, 54, 29, 60, 44, 58, 17, 44]
Oiseperali
[29, 3, 48, 51, 29, 17, 32, 53, 0, 51]
Cuhtefased
[26, 62, 4, 20, 17, 41, 44, 22, 17, 58]
axtioutine
[3, 56, 48, 51, 23, 62, 48, 53, 56, 10]
ancritacap
[3, 56, 48, 6, 10, 48, 44, 29, 44, 29]


 73%|███████▎  | 151/208 [00:09<00:04, 12.24it/s]

loss: 23.473098754882812
otatenzeya
[44, 29, 44, 29, 17, 56, 58, 17, 52, 53]
afordekuma
[35, 41, 36, 32, 41, 10, 48, 40, 22, 44]
unucollyph
[3, 56, 42, 29, 36, 0, 0, 45, 29, 60]
frelopopor
[26, 6, 10, 6, 44, 29, 44, 29, 17, 32]
insatousso
[3, 56, 48, 53, 20, 23, 62, 54, 22, 23]


 98%|█████████▊| 203/208 [00:12<00:00, 15.10it/s]

loss: 23.438913345336914
umnactogit
[3, 27, 22, 53, 18, 48, 44, 8, 51, 48]
ingricrema
[3, 56, 8, 6, 51, 48, 60, 44, 5, 53]
eniomuondr
[44, 58, 51, 23, 27, 51, 23, 56, 41, 6]
tlaetialyc
[29, 6, 44, 18, 48, 51, 53, 0, 45, 46]
plinthecti
[34, 6, 10, 56, 48, 60, 17, 18, 48, 51]


100%|██████████| 208/208 [00:13<00:00, 15.73it/s]


loss: 23.473918914794922
pulenbopro
[26, 36, 0, 23, 56, 33, 44, 29, 60, 44]
train loss: 23.59| valid loss: 23.59



  0%|          | 0/208 [00:00<?, ?it/s]

loss: 23.52134132385254
Kydooershm
[26, 45, 41, 6, 44, 17, 32, 54, 59, 5]
entatherio
[3, 56, 48, 53, 20, 60, 17, 32, 51, 23]
Caniverala
[26, 53, 56, 51, 58, 17, 32, 53, 0, 53]
Crotorstym
[26, 6, 44, 29, 44, 32, 54, 48, 45, 22]


  1%|▏         | 3/208 [00:00<00:20,  9.92it/s]

unclatoion
[3, 56, 29, 6, 53, 20, 44, 51, 23, 56]


 25%|██▌       | 53/208 [00:03<00:10, 15.38it/s]

loss: 23.573421478271484
wilcazdord
[33, 44, 18, 20, 44, 35, 41, 44, 32, 41]
churtintod
[29, 60, 36, 32, 20, 51, 56, 48, 17, 41]
Maotwobroc
[26, 44, 36, 48, 59, 44, 11, 6, 44, 48]
ushaimimtr
[3, 54, 59, 44, 3, 13, 3, 27, 29, 6]
odcwecsebi
[44, 32, 46, 59, 17, 18, 48, 44, 11, 51]


 50%|████▉     | 103/208 [00:06<00:06, 15.24it/s]

loss: 23.43610191345215
mantrinbar
[13, 23, 56, 48, 60, 51, 56, 63, 53, 32]
pardopadud
[29, 44, 32, 41, 44, 29, 44, 33, 10, 41]
Cysscurraa
[19, 45, 54, 54, 48, 36, 32, 6, 44, 35]
vhillattho
[26, 60, 36, 0, 0, 44, 18, 48, 60, 44]
blitayeile
[34, 6, 10, 48, 53, 52, 17, 44, 0, 10]


 74%|███████▎  | 153/208 [00:09<00:03, 13.85it/s]

loss: 23.624982833862305
corercwaly
[26, 36, 32, 22, 51, 1, 47, 53, 0, 45]
suesucucso
[2, 42, 10, 54, 42, 33, 10, 18, 48, 23]
aclyniapoi
[53, 20, 0, 45, 56, 51, 23, 29, 44, 32]
baticceddo
[26, 53, 20, 44, 18, 48, 17, 35, 41, 44]
aspatchiog
[44, 54, 22, 53, 32, 46, 59, 51, 23, 8]


 98%|█████████▊| 203/208 [00:13<00:00, 14.97it/s]

loss: 23.65166473388672
brovalysss
[34, 6, 44, 58, 53, 0, 45, 54, 54, 48]
datheromru
[41, 44, 46, 59, 7, 6, 44, 27, 2, 42]
meccoyantr
[33, 10, 18, 20, 36, 12, 53, 56, 48, 6]
etchterata
[3, 27, 29, 4, 48, 17, 32, 53, 20, 53]
apylisnica
[44, 29, 36, 0, 51, 54, 22, 10, 48, 53]


100%|██████████| 208/208 [00:13<00:00, 15.65it/s]


loss: 23.379825592041016
iddirmousc
[44, 35, 41, 36, 32, 22, 23, 62, 54, 48]
train loss: 23.49| valid loss: 23.49



  0%|          | 0/208 [00:00<?, ?it/s]

loss: 23.208723068237305
oodentivat
[44, 35, 41, 10, 56, 48, 51, 58, 53, 20]
reiliscrac
[33, 44, 36, 0, 10, 54, 48, 6, 44, 18]
cadartosyg
[26, 44, 58, 53, 32, 20, 23, 19, 45, 8]
dedicatisy
[33, 10, 41, 51, 1, 53, 20, 51, 19, 45]


  1%|▏         | 3/208 [00:00<00:20, 10.04it/s]

Mnenamolyg
[26, 6, 10, 56, 44, 13, 36, 0, 45, 8]


 25%|██▌       | 53/208 [00:03<00:10, 14.58it/s]

loss: 23.38774299621582
eacatericw
[17, 53, 1, 53, 20, 17, 32, 44, 46, 59]
rapilllleb
[33, 44, 29, 36, 0, 0, 0, 0, 44, 11]
Rashlerynl
[26, 44, 54, 59, 0, 17, 32, 45, 56, 0]
trainalode
[26, 6, 44, 51, 1, 53, 0, 44, 58, 17]
dococedder
[33, 44, 50, 23, 48, 17, 32, 33, 17, 32]


 49%|████▊     | 101/208 [00:06<00:09, 11.06it/s]

loss: 23.241878509521484
Aablaphomc
[5, 53, 43, 0, 44, 29, 60, 44, 27, 29]
ethisxeleo
[44, 48, 59, 51, 54, 48, 17, 0, 10, 35]
phorinever
[29, 60, 36, 32, 51, 58, 10, 58, 17, 32]
anissuinab
[3, 56, 10, 54, 2, 42, 51, 58, 53, 16]
anialeescy
[3, 56, 10, 53, 56, 10, 44, 54, 48, 45]


 74%|███████▎  | 153/208 [00:10<00:03, 14.81it/s]

loss: 23.327198028564453
Crextasame
[26, 6, 10, 56, 48, 53, 54, 44, 33, 44]
scylenonec
[54, 48, 45, 0, 44, 22, 51, 58, 10, 48]
bosonceabl
[26, 44, 22, 23, 56, 48, 17, 53, 43, 0]
stanguecte
[54, 48, 53, 56, 8, 42, 10, 18, 20, 17]
nontrythiz
[50, 23, 56, 48, 6, 45, 46, 59, 51, 1]


 98%|█████████▊| 203/208 [00:13<00:00, 15.13it/s]

loss: 23.52621841430664
pemtlanoph
[26, 44, 27, 29, 60, 53, 56, 44, 29, 60]
dorytealer
[41, 44, 32, 45, 20, 17, 53, 0, 17, 32]
sinteriqub
[54, 10, 56, 48, 17, 32, 44, 55, 42, 39]
triiderall
[26, 6, 44, 35, 41, 17, 32, 53, 0, 0]
hynatitifo
[19, 45, 56, 53, 20, 51, 20, 51, 21, 44]


100%|██████████| 208/208 [00:13<00:00, 15.08it/s]


loss: 23.43899917602539
baloustici
[26, 53, 0, 23, 62, 54, 48, 51, 1, 51]
train loss: 23.40| valid loss: 23.41



  0%|          | 1/208 [00:00<00:39,  5.20it/s]

loss: 23.411701202392578
undermeste
[3, 56, 41, 17, 32, 22, 10, 54, 20, 17]
amirenderi
[44, 13, 36, 32, 10, 56, 58, 17, 32, 51]
sidrigisth
[26, 10, 41, 6, 10, 8, 10, 54, 48, 59]
taligative
[48, 53, 0, 10, 39, 53, 20, 51, 58, 17]
ssicalonal
[54, 22, 51, 1, 53, 32, 23, 56, 53, 0]


 25%|██▌       | 53/208 [00:03<00:10, 14.85it/s]

loss: 23.181900024414062
refytrasse
[33, 44, 21, 36, 48, 6, 44, 54, 22, 17]
jeesartyst
[33, 10, 44, 22, 53, 32, 20, 45, 54, 48]
dathlerico
[33, 44, 29, 60, 6, 44, 32, 51, 1, 23]
phenwaomme
[29, 60, 17, 56, 15, 38, 44, 27, 22, 10]
Moscylyant
[26, 44, 54, 20, 36, 0, 45, 53, 56, 48]


 50%|████▉     | 103/208 [00:06<00:07, 13.63it/s]

loss: 23.091135025024414
esericalle
[44, 54, 17, 32, 51, 1, 53, 0, 0, 17]
neflalllix
[58, 10, 21, 6, 53, 0, 0, 0, 10, 27]
shramicall
[54, 59, 6, 44, 27, 51, 1, 53, 0, 0]
aadtoadutt
[44, 35, 41, 60, 38, 35, 41, 42, 18, 48]
tectelvend
[26, 17, 18, 48, 17, 0, 58, 17, 56, 41]


 74%|███████▎  | 153/208 [00:09<00:03, 15.63it/s]

loss: 23.137086868286133
mubbraspir
[2, 42, 30, 11, 6, 44, 54, 22, 36, 32]
unvendlone
[3, 56, 33, 10, 56, 41, 6, 44, 58, 17]
aeativuabl
[53, 17, 53, 20, 51, 58, 10, 53, 43, 0]
jawbiterdf
[33, 44, 15, 31, 36, 48, 17, 32, 41, 21]
Eenodisles
[26, 10, 56, 44, 9, 51, 54, 0, 10, 54]


 98%|█████████▊| 203/208 [00:12<00:00, 15.44it/s]

loss: 23.501678466796875
odatizinio
[44, 58, 53, 20, 51, 58, 10, 56, 51, 23]
upinylessy
[3, 27, 51, 56, 7, 0, 10, 54, 48, 45]
Saumpaidri
[26, 44, 3, 27, 34, 38, 35, 41, 6, 10]
Srechorouc
[26, 6, 10, 48, 59, 44, 32, 44, 18, 20]
saficallyc
[33, 44, 21, 51, 1, 53, 0, 0, 45, 46]


100%|██████████| 208/208 [00:13<00:00, 15.72it/s]


loss: 23.398937225341797
asteanolks
[44, 18, 20, 17, 53, 56, 44, 18, 57, 54]
train loss: 23.34| valid loss: 23.37



  0%|          | 0/208 [00:00<?, ?it/s]

loss: 23.38336181640625
intisseztr
[3, 56, 20, 51, 54, 22, 44, 35, 20, 6]
malbanocha
[5, 53, 0, 63, 53, 56, 44, 29, 60, 38]
bilalletog
[33, 10, 0, 53, 0, 0, 10, 48, 36, 8]
impilllati
[44, 27, 34, 36, 0, 0, 0, 53, 20, 51]


  1%|▏         | 3/208 [00:00<00:19, 10.55it/s]

pationsoud
[34, 53, 20, 51, 23, 56, 48, 44, 35, 41]


 25%|██▌       | 53/208 [00:03<00:10, 15.09it/s]

loss: 23.556167602539062
mansiceson
[5, 53, 56, 22, 10, 48, 17, 22, 23, 56]
Pightatida
[26, 10, 24, 4, 48, 53, 20, 51, 1, 53]
bantactant
[33, 3, 56, 48, 44, 18, 48, 53, 56, 48]
gedletioun
[26, 17, 41, 0, 10, 48, 51, 23, 62, 56]
japocriden
[33, 44, 29, 44, 48, 6, 44, 58, 17, 56]


 50%|████▉     | 103/208 [00:06<00:07, 13.70it/s]

loss: 23.477420806884766
Onialifirm
[3, 56, 51, 53, 0, 10, 21, 36, 32, 13]
tlemypophy
[26, 6, 44, 27, 45, 29, 44, 29, 60, 45]
iniatiphon
[3, 56, 51, 53, 20, 51, 29, 60, 23, 56]
hylehautri
[19, 45, 0, 44, 33, 44, 18, 20, 6, 51]
ihinchlysp
[3, 33, 10, 56, 46, 59, 6, 52, 54, 34]


 74%|███████▎  | 153/208 [00:10<00:03, 15.24it/s]

loss: 23.083168029785156
stertistis
[54, 48, 17, 32, 58, 51, 54, 48, 36, 54]
unflidarph
[3, 56, 21, 0, 51, 1, 53, 32, 29, 60]
noninstrap
[50, 23, 56, 10, 56, 54, 48, 6, 44, 5]
Anetticcer
[3, 56, 10, 18, 48, 36, 18, 20, 17, 32]
ceonterlag
[50, 23, 23, 56, 48, 17, 32, 33, 44, 39]


 98%|█████████▊| 203/208 [00:13<00:00, 15.20it/s]

loss: 23.348907470703125
morapcorad
[13, 36, 32, 44, 29, 20, 36, 32, 44, 58]
ralllewami
[33, 44, 0, 0, 0, 10, 33, 44, 27, 51]
sclaticabl
[2, 29, 60, 53, 20, 51, 1, 53, 43, 0]
nopeteomet
[50, 23, 22, 10, 48, 17, 23, 13, 10, 48]
idicoliavo
[44, 33, 10, 48, 36, 0, 10, 44, 58, 36]


100%|██████████| 208/208 [00:13<00:00, 15.51it/s]


loss: 23.250343322753906
nontlylyll
[50, 23, 56, 48, 0, 45, 0, 45, 0, 0]
train loss: 23.30| valid loss: 23.33

