## 说明

一个 RNN Transducer的示例，实现一个类似于语言翻译的序列到序列的任务；

输入为一段文本序列X，输出为另一个文本序列Y；

相当于是一个元音补全的任务：

例如： X: Hll,Wrld --> Y: Hello,World

In [50]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

import numpy
import matplotlib.pyplot as plt
from tqdm import tqdm
import unidecode
import string

In [51]:
characters = "AEIOUaeiou"
Embedding_dim = 1024
Predictor_dim = 1024
Joiner_dim = 1024


with open("war_and_peace.txt", "r") as f:
    lines = f.readlines()

len_txt = len(lines)

len_txt

62015

## 定义数据集

In [52]:
class TxtDataset(Dataset):

    def __init__(self, lines):
        super().__init__()
        self.lines = lines

    def encode_string(self, line):
        return [string.printable.find(x) + 1 for x in line] # 因为0表示blank，所以要加1

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, index):
        line = self.lines[index].replace("\n", "")
        line = unidecode.unidecode(line)  # 去除特殊字符
        x = ''.join([x for x in line if x not in characters])
        y = line

        x = self.encode_string(x)
        y = self.encode_string(y)

        x = torch.tensor(x, dtype=torch.long)
        y = torch.tensor(y, dtype=torch.long)

        T = torch.full(size=(1,), fill_value=len(x), dtype=torch.long)  # 时间序列的长度
        U = torch.full(size=(1,), fill_value=len(y), dtype=torch.long)  # 目标序列的长度

        return x, y, T, U
    

class CustomCollate:
    '''
    通过定义一个自定义的 Collate 类来解决 DataLoader 中不同长度数据无法堆叠的问题。
    '''

    def __call__(self, batch):
        # 拆包batch的数据
        x_list, y_list, T_list, U_list = zip(*batch)

        # 对x和y进行填充， 填充为0，0视作为blank
        x_padded = pad_sequence(x_list, batch_first=True, padding_value=0)
        y_padded = pad_sequence(y_list, batch_first=True, padding_value=0)

        # 将T和U转换为张量
        T = torch.stack(T_list)
        U = torch.stack(U_list)

        return x_padded, y_padded, T, U


In [53]:
string.printable

'0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c'

In [54]:
def decode_string(lst):
    return "".join(string.printable[x - 1] for x in lst)

train_dataset = TxtDataset(lines[: round(0.9 * len_txt)])
x, y, T, U = train_dataset[0]

print(decode_string(x))
print(decode_string(y))
print(T, U)

"Wll, Prnc, s Gn nd Lcc r nw jst fmly stts f th
"Well, Prince, so Genoa and Lucca are now just family estates of the
tensor([47]) tensor([68])


In [55]:
batch_size = 128
collate  = CustomCollate()
train_dataset = TxtDataset(lines[: round(0.9 * len_txt)])
valid_dataset = TxtDataset(lines[round(0.9 * len_txt): ])
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=0, collate_fn=collate)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=0, collate_fn=collate)

In [56]:
valid_iter = iter(valid_dataloader)
x, y, T, U = next(valid_iter)
x.shape, y.shape, T.shape, U.shape

(torch.Size([128, 54]),
 torch.Size([128, 72]),
 torch.Size([128, 1]),
 torch.Size([128, 1]))

## 定义模型

In [57]:
class Encoder(nn.Module):
    '''
    对应声学模型
    '''
    def __init__(self, num_chars = len(string.printable) + 1):
        super().__init__()
        self.num_chars = num_chars
        self.embd = nn.Embedding(num_chars, Embedding_dim)
        self.rnn = nn.LSTM(input_size=Embedding_dim, hidden_size=Embedding_dim, num_layers=3, bidirectional=True)
        self.lc = nn.Linear(Embedding_dim * 2, Joiner_dim)

    def forward(self, x):
        # x input: [batch, time_step]
        x = self.embd(x)   #[batch, time_step, embdding_dim]
        x, _ = self.rnn(x)  #[batch, time_step, embdding_dim * 2]
        x = self.lc(x)  #[batch, time_step, Joiner_dim]
        return x
    
    
class PredNet(nn.Module):

    def __init__(self, num_chars = len(string.printable) + 1):
        super().__init__()
        self.embd = nn.Embedding(num_chars, Predictor_dim)
        self.rnn = nn.GRUCell(input_size=Predictor_dim, hidden_size=Predictor_dim)
        self.lc = nn.Linear(Predictor_dim, Joiner_dim)

        self.initial_state = nn.Parameter(torch.randn(Predictor_dim))  # torch.randn 生成标准正态分布
        self.start_symbol = 0 # 0表示blank

    def forward_one_step(self, input, previous_state):
        embedding = self.embd(input) # embedding.shape: [batch， Predicotr_dim]
        state = self.rnn(embedding, previous_state) # state.shape: [batch, Predictor_dim]
        out = self.lc(state)  # out.shape: [batch, Jointer_dim]
        return out, state

    def forward(self, y):
        # y: [batch, target_num_words]
        batch_size = y.shape[0]
        U = y.shape[1]
        outs = []
        state = torch.stack([self.initial_state] * batch_size).to(y.device)  # state.shape : [batch, pred_dim]
        for u in range(U + 1):
            if u == 0:
                decoder_input = torch.tensor([self.start_symbol] * batch_size, device=y.device)  # decoder_input: [batch], 就一个维度
            else:
                decoder_input = y[:, u-1]  # decoder_input: [batch], 就一个维度

            out, state = self.forward_one_step(decoder_input, state)
            outs.append(out)

        out = torch.stack(outs, dim=1)  # out.shape [batch, U + 1, Jointer_dim]
        return out
    

class Jointer(nn.Module):

    def __init__(self, num_chars = len(string.printable) + 1):
        super().__init__()
        self.lc = nn.Linear(Joiner_dim, num_chars)

    def forward(self, encoder_out, pred_out):
        # encoder_out: [batch, time_step,  1, num_output]
        # pred_out: [batch, 1, target_words_len + 1, num_output]

        out = encoder_out + pred_out
        # out = torch.cat((encoder_out, pred_out), dim=-1)
        out = F.relu(out)
        out = self.lc(out)
        return F.log_softmax(out, dim=-1) # [batch, time_step, target_words_len + 1, num_output]
    

class Transducer(nn.Module):

    def __init__(self, num_input, num_output):
        super().__init__()
        self.encoder = Encoder(num_input)
        self.pred_net = PredNet(num_output)
        self.jointer = Jointer(num_output)

        

    def forward(self, x, y, T, U):
        '''
        T: 时间序列的长度
        U: 目标序列的长度
        '''
        encoder_out = self.encoder(x)  # [batch, time_step, Joiner_dim]
        pred_out = self.pred_net(y)    # [batch, target_words_len + 1, Joiner_dim]
        joiner_out = self.jointer(encoder_out.unsqueeze(2), pred_out.unsqueeze(1)) 
        return joiner_out # [batch, time_step, target_words_len + 1, num_output]
    
    #  Transducer.comput_single_alignment_prob = comput_single_alignment_prob
    def greedy_search(self, x, T):
        y_batch = []
        B = len(x)
        encoder_out = self.encoder.forward(x)
        U_max = 200
        for b in range(B):
            t = 0; u = 0; y = [self.pred_net.start_symbol]; 
            predictor_state = self.pred_net.initial_state.unsqueeze(0)
            while t < T[b] and u < U_max:
                predictor_input = torch.tensor([ y[-1] ], device = x.device)
                g_u, predictor_state = self.pred_net.forward_one_step(predictor_input, predictor_state)
                f_t = encoder_out[b, t]
                h_t_u = self.jointer.forward(f_t, g_u)
                argmax = h_t_u.max(-1)[1].item()
                if argmax == 0:
                    t += 1
                else: # argmax == a label
                    u += 1
                    y.append(argmax)
            y_batch.append(y[1:]) # remove start symbol
        return y_batch


In [58]:
rnnt = Transducer(len(string.printable) + 1, len(string.printable) + 1)
device = "cuda:1" if torch.cuda.is_available() else "cpu"
print(device)
rnnt(x, y, T, U).shape

cuda:1


torch.Size([128, 54, 73, 101])

In [62]:
class RNNTLoss(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, logits, y, T, U):
        '''
        logits: [batch, time_step, target_words_len, num_outputs] , 预测分布
        y: target_token: [batch, target_words_len]
        T: time_step: 输入序列长度，T
        U: target_words_len: 目标序列长度 U
        '''

        batch_size, T_max, U_max, vocab_size = logits.shape
        log_alpha = torch.zeros(batch_size, T_max, U_max, device=logits.device)
        # log_alpha[t, u]：表示在时刻 t 和目标位置 u 的对数概率。该值的大小表示从输入序列的开始到当前位置的路径的概率。
        for t in range(T_max):
            for u in range(U_max):
                if u == 0:
                    if t == 0:
                        log_alpha[:, t, u] = 0.
                    else: # t > 0
                        log_alpha[:, t, u] = log_alpha[:, t-1, u] + logits[:, t-1, 0, 0] #logits[:, t-1, 0, 0]： 在t-1时刻，在预测第0个单词的时候，预测是blank的概率
                else: # u > 0
                    if t == 0:
                        # torch.gather(input, dim, index) 是 PyTorch 中的一个函数，用于根据给定的 index 张量从 input 张量中按指定维度（dim）选取数据。
                        # log_alpha[:, t, u-1]: [batch, vocab_size]
                        # y[:, u-1].view(-1, 1)): [batch, 1]
                        # log_alpha[:, t, u-1]: [batch]
                        log_alpha[:, t, u] = log_alpha[:, t, u-1] + torch.gather(logits[:, t, u-1], dim=1, index=y[:, u-1].view(-1, 1)).reshape(-1)
                    else: # t > 0
                        log_alpha[:, t, u] = torch.logsumexp(torch.stack([
                            log_alpha[:, t-1, u] + logits[:, t-1, u, 0],
                            log_alpha[:, t, u-1] + torch.gather(logits[:, t, u-1], dim=1, index=y[:, u-1].view(-1, 1)).reshape(-1)
                        ]), dim=0) # 这一段写的太妙了，然后为了logsumexp可以正确使用，我们需要对两部分转移过来的数组进行stack，拼接成一个之后，指定dim就可以正确处理了。这个dim的逻辑就和sum的axis逻辑一样

            log_probs = []
            for b in range(batch_size):
                log_prob = log_alpha[b, T[b]-1, U[b]] + logits[b, T[b]-1, U[b], 0]  # 虽然padding之后所有长度都一样了，但是我们自己知道，我们需要的长度其实是T和U
                log_probs.append(log_prob)
            log_probs = torch.stack(log_probs)

            return -log_probs.mean()


In [None]:
def train(model, optimizer, loss_fn, epoch, dataloader):
    model.train()
    loss_mean = 0
    with tqdm(dataloader) as pbar:
        for batch_index, (x, y, T, U) in enumerate(pbar):
            x, y = x.to(device), y.to(device)
            T, U = T.to(device), U.to(device)

            optimizer.zero_grad()
            output_log_softmax = model(x, y, T, U)  # [seq_len, batch, label]

            # target: [batch, num_words_len]
            loss = loss_fn(output_log_softmax, y, T, U)

            loss.backward()
            optimizer.step()

            loss = loss.item()

            if batch_index == 0:
                loss_mean = loss

            loss_mean = 0.1 * loss + 0.9 * loss_mean

            pbar.set_description(f'Epoch {epoch} Loss: {loss_mean:.4f}')

            if batch_index % 30 == 0:
                model.eval()
                guesses = model.greedy_search(x, T)
                model.train()
                print("\n")
                for b in range(2):
                    print("input:", decode_string(x[b,:T[b]]))
                    print("guess:", decode_string(guesses[b]))
                    print("truth:", decode_string(y[b,:U[b]]))
                    print("")


def valid(model, optimizer, epoch, dataloader):
    model.eval()
    with tqdm(dataloader) as pbar, torch.no_grad():
        loss_sum = 0
        acc_sum = 0
        for batch_index, (data, target, input_lengths, target_lengths) in enumerate(pbar):
            data, target = data.cuda(), target.cuda()

            output = model(data)
            output_log_softmax = F.log_softmax(output, dim=-1)
            loss = F.ctc_loss(output_log_softmax, target, input_lengths, target_lengths)

            loss = loss.item()
            acc = calc_acc(target, output)

            loss_sum += loss
            acc_sum += acc

            loss_mean = loss_sum / (batch_index + 1)
            acc_mean = acc_sum / (batch_index + 1)

            pbar.set_description(f'Test : {epoch} Loss: {loss_mean:.4f} Acc: {acc_mean:.4f}')

In [65]:
optimizer = torch.optim.Adam(rnnt.parameters(), 0.0003)
loss_fn = RNNTLoss()
epochs = 15

rnnt.to(device)
for epoch in range(1, epochs + 1):
    train(rnnt, optimizer, loss_fn, epoch, train_dataloader)

Epoch 1 Loss: -0.0000:   0%|          | 1/437 [00:01<07:39,  1.05s/it]



input: "Wll, Prnc, s Gn nd Lcc r nw jst fmly stts f th
guess: 
truth: "Well, Prince, so Genoa and Lucca are now just family estates of the

input: Bnprts. Bt  wrn y, f y dn't tll m tht ths mns wr,
guess: 
truth: Buonapartes. But I warn you, if you don't tell me that this means war,



Epoch 1 Loss: 0.0000:  23%|██▎       | 101/437 [00:35<03:03,  1.83it/s]



input: th Hfkrgsrth nd bth mprrs tk prt. t tht cncl, cntrry
guess: 
truth: the Hofkriegsrath and both Emperors took part. At that council, contrary

input: t th vws f th ld gnrls Ktzv nd Prnc Schwrtznbrg, t
guess: 
truth: to the views of the old generals Kutuzov and Prince Schwartzenberg, it



Epoch 1 Loss: 0.0000:  30%|███       | 133/437 [00:47<01:47,  2.82it/s]


KeyboardInterrupt: 

In [1]:

"""
一个RNN Transducer的示例，实现一个类似语言翻译的序列到序列任务；
输入为一段文本序列X，输出为另一个文本序列Y；
将Y序列中的元音字符去除即为X。例如：
 X: Hll, Wrld --> Y：Hello, World
"""

import torch
import string
import numpy as np
import itertools
from collections import Counter
from tqdm import tqdm
import unidecode

NULL_INDEX = 0

encoder_dim   = 1024
predictor_dim = 1024
joiner_dim    = 1024

class Encoder(torch.nn.Module):
    def __init__(self, num_inputs):
        super(Encoder, self).__init__()
        self.embed = torch.nn.Embedding(num_inputs, encoder_dim)
        self.rnn = torch.nn.GRU(input_size=encoder_dim, hidden_size=encoder_dim, num_layers=3, batch_first=True,bidirectional=True, dropout=0.1)
        self.linear = torch.nn.Linear(encoder_dim*2, joiner_dim)

    def forward(self, x):
        out = x
        out = self.embed(out)
        out = self.rnn(out)[0]
        out = self.linear(out)
        return out
    

class Predictor(torch.nn.Module):
    def __init__(self, num_outputs):
        super(Predictor, self).__init__()
        self.embed = torch.nn.Embedding(num_outputs, predictor_dim)
        self.rnn = torch.nn.GRUCell(input_size=predictor_dim, hidden_size=predictor_dim)
        self.linear = torch.nn.Linear(predictor_dim, joiner_dim)

        self.initial_state = torch.nn.Parameter(torch.randn(predictor_dim))
        self.start_symbol = NULL_INDEX #原始论文中，使用0向量，这里采用使用null index


    def forward_one_step(self, input, previous_state):
        embedding = self.embed(input)
        state = self.rnn.forward(embedding, previous_state)
        out = self.linear(state)
        return out, state
    
    def forward(self, y):
        batch_size = y.shape[0]
        U = y.shape[1]
        outs = []
        state = torch.stack([self.initial_state] * batch_size).to(y.device)
        for u in range(U+1):
            if u == 0:
                decoder_input = torch.tensor([self.start_symbol]*batch_size, device=y.device)
            else:
                decoder_input = y[:,u-1]
            out, state = self.forward_one_step(decoder_input, state)
            outs.append(out)
        out = torch.stack(outs, dim=1)
        return out
    

class Joiner(torch.nn.Module):
    def __init__(self, num_outputs):
        super(Joiner, self).__init__()
        self.linear = torch.nn.Linear(joiner_dim, num_outputs)

    def forward(self, encoder_out, predictor_out):
        out = encoder_out + predictor_out
        out = torch.nn.functional.relu(out)
        out = self.linear(out)
        return out
    

class Transducer(torch.nn.Module):
    def __init__(self, num_inputs, num_outputs):
        super(Transducer, self).__init__()
        self.encoder = Encoder(num_inputs)
        self.predictor = Predictor(num_outputs)
        self.joiner = Joiner(num_outputs)

        if torch.cuda.is_available(): self.device = "cuda:0"
        else: self.device = "cpu"
        self.to(self.device)

    def comput_forward_prob(self, joiner_out, T, U, y):
        """
        joiner_out: tensor of shape (B, T_max, U_max+1, #labels)
        T: list of input lengths
        U: list of output lengths
        y: label tensor (B, U_max+1)
        """

        B = joiner_out.shape[0]
        T_max = joiner_out.shape[1]
        U_max = joiner_out.shape[2] - 1
        log_alpha = torch.zeros(B, T_max, U_max+1, device=model.device)
        for t in range(T_max):
            for u in range(U_max+1):
                if u == 0:
                    if t == 0:
                        log_alpha[:,t,u] = 0.
                    else: # t > 0
                        log_alpha[:, t, u] = log_alpha[:, t-1, u] + joiner_out[:, t-1, 0, NULL_INDEX]

                else: # u > 0
                    if t == 0:
                        log_alpha[:, t, u] = log_alpha[:, t, u-1] + torch.gather(joiner_out[:,t, u-1], dim=1, index=y[:,u-1].view(-1,1)).reshape(-1)
                    else: # t > 0
                        log_alpha[:, t, u] = torch.logsumexp(torch.stack([
                            log_alpha[:, t-1, u] + joiner_out[:, t-1, u, NULL_INDEX],
                            log_alpha[:, t, u-1] + torch.gather(joiner_out[:,t, u-1], dim=1, index=y[:,u-1].view(-1,1) ).reshape(-1)
                        ]), dim=0)

        log_probs = []
        for b in range(B):
            log_prob = log_alpha[b, T[b]-1, U[b]] + joiner_out[b, T[b]-1, U[b], NULL_INDEX]
            log_probs.append(log_prob)
        log_probs = torch.stack(log_probs)
        return log_probs

    def compute_loss(self, x, y, T, U):
        encoder_out = self.encoder.forward(x)
        predictor_out = self.predictor.forward(y)
        joiner_out = self.joiner.forward(encoder_out.unsqueeze(2), predictor_out.unsqueeze(1)).log_softmax(3)
        loss = -self.comput_forward_prob(joiner_out, T, U, y).mean()
        return loss

#    Transducer.comput_single_alignment_prob = comput_single_alignment_prob
    def greedy_search(self, x, T):
        y_batch = []
        B = len(x)
        encoder_out = self.encoder.forward(x)
        U_max = 200
        for b in range(B):
            t = 0; u = 0; y = [self.predictor.start_symbol]; 
            predictor_state = self.predictor.initial_state.unsqueeze(0)
            while t < T[b] and u < U_max:
                predictor_input = torch.tensor([ y[-1] ], device = x.device)
                g_u, predictor_state = self.predictor.forward_one_step(predictor_input, predictor_state)
                f_t = encoder_out[b, t]
                h_t_u = self.joiner.forward(f_t, g_u)
                argmax = h_t_u.max(-1)[1].item()
                if argmax == NULL_INDEX:
                    t += 1
                else: # argmax == a label
                    u += 1
                    y.append(argmax)
            y_batch.append(y[1:]) # remove start symbol
        return y_batch

class Collate:
    def __call__(self, batch):
        """
        batch: list of tuples (input string, output string)
        Returns a minibatch of strings, encoded as labels and padded to have the same length.
        """
        x = []; y = []
        batch_size = len(batch)
        for index in range(batch_size):
            x_, y_ = batch[index]
            x.append(encode_string(x_))
            y.append(encode_string(y_))

        # pad all sequences to have same length
        T = [len(x_) for x_ in x]
        U = [len(y_) for y_ in y]
        T_max = max(T)
        U_max = max(U)

        for index in range(batch_size):
            x[index] += [NULL_INDEX] * (T_max - len(x[index]))
            x[index] = torch.tensor(x[index])
            y[index] += [NULL_INDEX] * (U_max - len(y[index]))
            y[index] = torch.tensor(y[index])

        # stack into single tensor
        x = torch.stack(x)
        y = torch.stack(y)
        T = torch.tensor(T)
        U = torch.tensor(U)

        return (x,y,T,U)
    

class TextDataset(torch.utils.data.Dataset):
    def __init__(self, lines, batch_size):
        lines = list(filter(("\n").__ne__, lines))
        self.lines = lines
        collate = Collate()
        self.loader = torch.utils.data.DataLoader(self, batch_size=batch_size,num_workers=0,collate_fn=collate)

    def __len__(self):
        return len(self.lines)
    
    def __getitem__(self, idx):
        line = self.lines[idx].replace("\n", "")
        line = unidecode.unidecode(line)                        # 去除特殊字符 
        x = "".join(c for c in line if c not in "AEIOUaeiou")   # 去除元音字符
        y = line
        return (x,y)
    
def encode_string(s):
    for c in s:
        if c not in string.printable:
            print(s)
    return [string.printable.index(c) + 1 for c in s]

def decode_labels(l):
    return "".join([string.printable[c-1]  for c in l])

class Trainer:
  def __init__(self, model, lr):
    self.model = model
    self.lr = lr
    self.optimizer = torch.optim.Adam(model.parameters(), lr=self.lr)
  
  def train(self, dataset, print_interval = 2):
    train_loss = 0
    num_samples = 0
    self.model.train()
    pbar = tqdm(dataset.loader)
    for idx, batch in enumerate(pbar):
      x,y,T,U = batch
      x = x.to(self.model.device); y = y.to(self.model.device)
      batch_size = len(x)
      num_samples += batch_size
      loss = self.model.compute_loss(x,y,T,U)
      self.optimizer.zero_grad()
      pbar.set_description("%.2f" % loss.item())
      loss.backward()
      self.optimizer.step()
      train_loss += loss.item() * batch_size
      if idx % print_interval == 0:
        self.model.eval()
        guesses = self.model.greedy_search(x,T)
        self.model.train()
        print("\n")
        for b in range(2):
          print("input:", decode_labels(x[b,:T[b]]))
          print("guess:", decode_labels(guesses[b]))
          print("truth:", decode_labels(y[b,:U[b]]))
          print("")
    train_loss /= num_samples
    return train_loss

  def test(self, dataset, print_interval=1):
    test_loss = 0
    num_samples = 0
    self.model.eval()
    pbar = tqdm(dataset.loader)
    with torch.no_grad():
        for idx, batch in enumerate(pbar):
          x,y,T,U = batch
          x = x.to(self.model.device); y = y.to(self.model.device)
          batch_size = len(x)
          num_samples += batch_size
          loss = self.model.compute_loss(x,y,T,U)
          pbar.set_description("%.2f" % loss.item())
          test_loss += loss.item() * batch_size
          if idx % print_interval == 0:
            print("\n")
            print("input:", decode_labels(x[0,:T[0]]))
            print("guess:", decode_labels(self.model.greedy_search(x,T)[0]))
            print("truth:", decode_labels(y[0,:U[0]]))
            print("")
    test_loss /= num_samples
    return test_loss
    


In [2]:
with open("war_and_peace.txt", "r") as f:
    lines = f.readlines()

end = round(0.9 * len(lines))
train_lines = lines[:end]
test_lines = lines[end:]
train_set = TextDataset(train_lines, batch_size=128)
test_set = TextDataset(test_lines, batch_size=128)
train_set.__getitem__(0)

num_chars = len(string.printable)
model = Transducer(num_inputs=num_chars+1, num_outputs=num_chars+1)
print(model.device)
trainer = Trainer(model=model, lr=0.0003)

num_epochs = 1
train_losses = []
test_losses = []

for epoch in range(num_epochs):
    train_loss = trainer.train(train_set, print_interval=100)
    test_loss = trainer.test(test_set, print_interval=100)
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    print("Epoch %d: train loss = %f, test loss = %f" %(epoch, train_loss, test_loss)) 

cuda:0


407.60:   0%|          | 1/355 [00:15<1:32:31, 15.68s/it]



input: "Wll, Prnc, s Gn nd Lcc r nw jst fmly stts f th
guess: 
truth: "Well, Prince, so Genoa and Lucca are now just family estates of the

input: Bnprts. Bt  wrn y, f y dn't tll m tht ths mns wr,
guess: 
truth: Buonapartes. But I warn you, if you don't tell me that this means war,



29.38:  28%|██▊       | 101/355 [23:25<54:25, 12.86s/it]  



input: f Rstv n th clb prch.
guess: of Rost in the cl perch.
truth: of Rostov in the club porch.

input: "nd d y fl qt clm?" Rstv skd.
guess: "ane do you fel quit coul?" Rost soke.
truth: "And do you feel quite calm?" Rostov asked.



20.28:  57%|█████▋    | 201/355 [44:28<32:47, 12.78s/it]



input: lck hng dwn n th mddl f hs brd frhd. Hs plmp wht nck
guess: lack hing down in th midle of his bred ferhed. His plem what neck
truth: lock hung down in the middle of his broad forehead. His plump white neck

input: std t shrply bv th blck cllr f hs nfrm, nd h smlld
guess: sted to sherly beve th bl caller of his niform, ane he smaled
truth: stood out sharply above the black collar of his uniform, and he smelled



15.54:  85%|████████▍ | 301/355 [1:05:48<12:07, 13.47s/it]



input: cnsdrtns.
guess: considertions.
truth: considerations.

input: n th thrd f Sptmbr Prr wk lt. Hs hd ws chng, th
guess: in the there of Sptember Pierre wek lat. His had was ching, the
truth: On the third of September Pierre awoke late. His head was aching, the



13.69: 100%|██████████| 355/355 [1:17:05<00:00, 13.03s/it]
13.47:   0%|          | 0/41 [00:00<?, ?it/s]



input: "Bt thy dn't ndrstnd r tlk t ll," sd th dncr wth 


13.47:   2%|▏         | 1/41 [00:01<01:01,  1.53s/it]

guess: "But thy don' unerstened or tak it al," said the dancer with a
truth: "But they don't understand our talk at all," said the dancer with a



15.11: 100%|██████████| 41/41 [00:12<00:00,  3.27it/s]

Epoch 0: train loss = 39.060191, test loss = 14.164064





In [3]:
torch.save(model, 'rnnt.pth')