In [5]:
import math
import torch
import torch.nn as nn
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:

def make_dict(path):
    with open(path, 'r', encoding='utf-8') as fr:
        text = fr.readlines()
    text = [n.strip() for n in text]
    text = ' '.join(text).split(" ")
    text = list(set(text))
    word2number_dict = {}
    number2word_dict = {}
    for i,word in enumerate(text,4):
        word2number_dict[word] = i
        number2word_dict[i] = word
    word2number_dict["<pad>"] = 0
    number2word_dict[0] = "<pad>"
    word2number_dict["<unk_word>"] = 1
    number2word_dict[1] = "<unk_word>"
    word2number_dict["<sos>"] = 2
    number2word_dict[2] = "<sos>"
    word2number_dict["<eos>"] = 3
    number2word_dict[3] = "<eos>"

    return word2number_dict, number2word_dict

def make_batch(path, word2number_dict, batch_size, n_step):
    def word2number(n):
        try:
            return word2number_dict[n]
        except:
            return 1   #<unk_word>

    all_input_batch = []
    all_target_batch = []

    with open(path, 'r', encoding='utf-8') as fr:
        text = fr.readlines()
    text = [line.strip() for line in text]
 
    input_batch = []
    target_batch = []

    for sen in text:
     
        wordlist = sen.split()
        wordlist.insert(0, "<sos>")
        wordlist.append("<eos>")
  
        if(len(wordlist)<n_step):continue
  
        for i in range(len(wordlist)):
            if(i+n_step>=len(wordlist)):break
            input= []
            for j in range(n_step):
                input.append(word2number(wordlist[i+j]))
            target =word2number(wordlist[i+n_step])
   
            input_batch.append(input)
            target_batch.append(target)
            
    for i in range(0,len(input_batch),batch_size):
        if i+batch_size>len(input_batch):
            break
        all_input_batch.append(input_batch[i:i+batch_size])
        all_target_batch.append(target_batch[i:i+batch_size])
  
    return all_input_batch, all_target_batch


In [7]:
class RNNLM(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super(RNNLM, self).__init__()
        self.embed_dim = embed_dim
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.RNN(embed_dim, hidden_dim)
        self.fc1 = nn.Linear(hidden_dim, vocab_size)
        
    def forward(self, X,hidden):
        X = self.embed(X) # X: [batch_size, n_step , embed_dim]
        X = X.transpose(0,1) # X: [n_step, batch_size, embed_dim]
        output , hidden = self.rnn(X,hidden)
        # output: [n_step, batch_size, hidden_dim]
        # hidden: [1, batch_size, hidden_dim]
        output = self.fc1(output[-1]) # 取最后一个时间步的输出 # output: [batch_size, hidden_dim]
        return output

In [8]:
n_step = 5 # number of steps, n-1 in paper
n_hidden = 128 # number of hidden size, h in paper
m = 16 # embedding size, m in paper
batch_size = 512 #batch size

train_path = './data/train.txt' # the path of train dataset
valid_path = './data/valid.txt'
test_psth = './data/test.txt'

word2number_dict, number2word_dict = make_dict(train_path) #use the make_dict function to make the dict
print("The size of the dictionary is:", len(word2number_dict))

n_class = len(word2number_dict)  #n_class (= dict size)

# prepare training set
all_input_batch, all_target_batch = make_batch(train_path, word2number_dict, batch_size, n_step)  # make the batch
print("The number of the train batch is:", len(all_input_batch))

all_input_batch = torch.LongTensor(all_input_batch).to(device)   #list to tensor
all_target_batch = torch.LongTensor(all_target_batch).to(device)

# prepare validation set
all_valid_batch, all_valid_target = make_batch(valid_path, word2number_dict, 128, n_step)
all_valid_batch = torch.LongTensor(all_valid_batch).to(device)  # list to tensor
all_valid_target = torch.LongTensor(all_valid_target).to(device)

print("\nTrain###############")
learn_rate = 0.001
all_epoch = 100 #the all epoch for training
save_checkpoint_epoch = 10 # save a checkpoint per save_checkpoint_epoch epochs
model = RNNLM(vocab_size=n_class, embed_dim=m, hidden_dim=n_hidden)
model.to(device)
print(model)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learn_rate)

# Training
batch_number = len(all_input_batch)

for epoch in range(all_epoch):
    count_batch = 0
    
    
    for input_batch, target_batch in zip(all_input_batch, all_target_batch):
        
        hidden = torch.zeros(1, batch_size, n_hidden)
        
        output = model(input_batch , hidden)
        # output: [batch_size, n_class]
        loss = criterion(output, target_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        ppl = math.exp(loss.item())
        
        if (count_batch + 1) % 50 == 0:
            print('Epoch:', '%04d' % (epoch + 1), 'Batch:', '%02d' % (count_batch + 1), f'/{batch_number}',
                    'loss =', '{:.6f}'.format(loss), 'ppl =', '{:.6}'.format(ppl))

        count_batch += 1
    
    total_loss = 0
    # valid after training one epoch      
    total_valid = len(all_valid_target)*128
    count_loss = 0
    
    hidden = torch.zeros(1, 128, n_hidden)
    for input_batch, target_batch in zip(all_valid_batch, all_valid_target):
        with torch.no_grad():
            loss = criterion(model(input_batch,hidden), target_batch)
            total_loss += loss.item()
            count_loss += 1
        
    print(f'Valid {total_valid} samples after epoch:', '%04d' % (epoch + 1), 'loss =',
            '{:.6f}'.format(total_loss / count_loss),
            'ppl =', '{:.6}'.format(math.exp(total_loss / count_loss)))

    if (epoch+1) % save_checkpoint_epoch == 0:
        torch.save(model, f'models/rnnlm_model_epoch{epoch+1}.ckpt')


The size of the dictionary is: 7615
The number of the train batch is: 150

Train###############
RNNLM(
  (embed): Embedding(7615, 16)
  (rnn): RNN(16, 128)
  (fc1): Linear(in_features=128, out_features=7615, bias=True)
)
Epoch: 0001 Batch: 50 /150 loss = 6.466872 ppl = 643.468
Epoch: 0001 Batch: 100 /150 loss = 6.741197 ppl = 846.573
Epoch: 0001 Batch: 150 /150 loss = 6.724743 ppl = 832.758
Valid 5504 samples after epoch: 0001 loss = 6.497840 ppl = 663.706
Epoch: 0002 Batch: 50 /150 loss = 6.217069 ppl = 501.232
Epoch: 0002 Batch: 100 /150 loss = 6.531851 ppl = 686.668
Epoch: 0002 Batch: 150 /150 loss = 6.523173 ppl = 680.735
Valid 5504 samples after epoch: 0002 loss = 6.395941 ppl = 599.407
Epoch: 0003 Batch: 50 /150 loss = 6.070329 ppl = 432.823
Epoch: 0003 Batch: 100 /150 loss = 6.372880 ppl = 585.743
Epoch: 0003 Batch: 150 /150 loss = 6.359722 ppl = 578.086
Valid 5504 samples after epoch: 0003 loss = 6.301984 ppl = 545.654
Epoch: 0004 Batch: 50 /150 loss = 5.920478 ppl = 372.59
Epo

KeyboardInterrupt: 