In [1]:
import math
import torch
import torch.nn as nn
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")


In [2]:
def make_batch(train_path, word2number_dict, batch_size, n_step):
    
    all_input_batch = []
    all_target_batch = []

    with open(train_path, 'r', encoding='utf-8') as fr:
        text = fr.readlines()
    text = [line.strip() for line in text]
 
    input_batch = []
    target_batch = []

    for sen in text:
     
        wordlist = sen.split()
  
        if(len(wordlist)<n_step):continue
  
        for i,word in enumerate(wordlist):
            if(i+n_step>=len(wordlist)):break
            input= []
            for j in range(n_step):
                input.append(word2number_dict[wordlist[i+j]])
            target =word2number_dict[wordlist[i+n_step]]
   
            input_batch.append(input)
            target_batch.append(target)
            
    for i in range(len(input_batch)):
        if i+batch_size>len(input_batch):
            break
        all_input_batch.append(input_batch[i:i+batch_size])
        all_target_batch.append(target_batch[i:i+batch_size])
  
    return all_input_batch, all_target_batch

In [3]:

def make_batch_valid_test(train_path, word2number_dict, batch_size, n_step):
    def word2number(n):
        try:
            return word2number_dict[n]
        except:
            return 1   #<unk_word>

    all_input_batch = []
    all_target_batch = []

    with open(train_path, 'r', encoding='utf-8') as fr:
        text = fr.readlines()
    text = [line.strip() for line in text]
 
    input_batch = []
    target_batch = []

    for sen in text:
     
        wordlist = sen.split()
  
        if(len(wordlist)<n_step):continue
  
        for i in range(len(wordlist)):
            if(i+n_step>=len(wordlist)):break
            input= []
            for j in range(n_step):
                input.append(word2number(wordlist[i+j]))
            target =word2number(wordlist[i+n_step])
   
            input_batch.append(input)
            target_batch.append(target)
            
    for i in range(len(input_batch)):
        if i+batch_size>len(input_batch):
            break
        all_input_batch.append(input_batch[i:i+batch_size])
        all_target_batch.append(target_batch[i:i+batch_size])
  
    return all_input_batch, all_target_batch



In [4]:

def make_dict(train_path):
    with open(train_path, 'r', encoding='utf-8') as fr:
        text = fr.readlines()
    text = [n.strip() for n in text]
    text = ' '.join(text).split(" ")
    text = list(set(text))
    word2number_dict = {}
    number2word_dict = {}
    for i,word in enumerate(text,2):
        word2number_dict[word] = i
        number2word_dict[i] = word
    word2number_dict["<pad>"] = 0
    number2word_dict[0] = "<pad>"
    word2number_dict["<unk_word>"] = 1
    number2word_dict[1] = "<unk_word>"

    return word2number_dict, number2word_dict


In [5]:

# Model
class NNLM(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim,n_step):
        super(NNLM, self).__init__()
        self.embed_dim = embed_dim
        self.n_step = n_step
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.fc1 = nn.Linear(embed_dim*n_step, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, vocab_size)
        self.fc3 = nn.Linear(embed_dim*n_step, vocab_size)
        self.tanh = nn.Tanh()

    def forward(self, X):
        X = self.embed(X) # X: [batch_size, n_step] -> [batch_size, n_step, embed_dim]
        X = X.view(-1, self.embed_dim*self.n_step) # X: [batch_size, n_step, embed_dim] -> [batch_size, n_step*embed_dim]
        tanh = self.tanh(self.fc1(X))# tanh: [batch_size, hidden_dim]
        output = self.fc2(tanh)+self.fc3(X) # output: [batch_size, vocab_size]
        return output
    
# n_step = 2 # number of steps, n-1 in paper
# n_hidden = 2 # number of hidden size, h in paper
# m = 2 # embedding size, m in paper



# learn_rate = 0.001
# all_epoch = 200 #the all epoch for training
# save_checkpoint_epoch = 10 # save a checkpoint per save_checkpoint_epoch epochs
# model = NNLM(vocab_size=100, embed_dim=m, hidden_dim=n_hidden,n_step=n_step)
# print(model)


In [6]:

def train():
    n_step = 5 # number of steps, n-1 in paper
    n_hidden = 128 # number of hidden size, h in paper
    m = 16 # embedding size, m in paper
 
    
 
    learn_rate = 0.001
    all_epoch = 100 #the all epoch for training
    save_checkpoint_epoch = 10 # save a checkpoint per save_checkpoint_epoch epochs
    model = NNLM(vocab_size=n_class, embed_dim=m, hidden_dim=n_hidden,n_step=n_step)
    model.to(device)
    print(model)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learn_rate)

    # Training
    batch_number = len(all_input_batch)
    
    for epoch in range(all_epoch):
        count_batch = 0
        for input_batch, target_batch in zip(all_input_batch, all_target_batch):
            
            model.zero_grad()
            output = model(input_batch)
            # output: [batch_size, n_class]
            loss = criterion(output, target_batch)
            loss.backward()
            optimizer.step()
            
            ppl = math.exp(loss.item())
            
            if (count_batch + 1) % 50 == 0:
                print('Epoch:', '%04d' % (epoch + 1), 'Batch:', '%02d' % (count_batch + 1), f'/{batch_number}',
                      'loss =', '{:.6f}'.format(loss), 'ppl =', '{:.6}'.format(ppl))

            count_batch += 1
        
        total_loss = 0
        # valid after training one epoch      
        total_valid = len(all_valid_target)*128
        count_loss = 0
        for input_batch, target_batch in zip(all_valid_batch, all_valid_target):
            with torch.no_grad():
                loss = criterion(model(input_batch), target_batch)
                total_loss += loss.item()
                count_loss += 1
          
        print(f'Valid {total_valid} samples after epoch:', '%04d' % (epoch + 1), 'loss =',
                '{:.6f}'.format(total_loss / count_loss),
                'ppl =', '{:.6}'.format(math.exp(total_loss / count_loss)))

        if (epoch+1) % save_checkpoint_epoch == 0:
            torch.save(model, f'models/nnlm_model_epoch{epoch+1}.ckpt')


In [7]:
if __name__ == '__main__':
    
    batch_size = 512 #batch size
    n_step = 5
    train_path = './data/train.txt' # the path of train dataset
    valid_path = './data/valid.txt'
    test_psth = './data/test.txt'

    word2number_dict, number2word_dict = make_dict(train_path) #use the make_dict function to make the dict
    print("The size of the dictionary is:", len(word2number_dict))

    n_class = len(word2number_dict)  #n_class (= dict size)

    # prepare training set
    all_input_batch, all_target_batch = make_batch(train_path, word2number_dict, batch_size, n_step)  # make the batch
    print("The number of the train batch is:", len(all_input_batch))
 
    all_input_batch = torch.LongTensor(all_input_batch).to(device)   #list to tensor
    all_target_batch = torch.LongTensor(all_target_batch).to(device)

    # prepare validation set
    all_valid_batch, all_valid_target = make_batch_valid_test(valid_path, word2number_dict, 128, n_step)
    all_valid_batch = torch.LongTensor(all_valid_batch).to(device)  # list to tensor
    all_valid_target = torch.LongTensor(all_valid_target).to(device)

    print("\nTrain###############")
    train()

    # print("\nTest###############")
    # all_test_batch, all_test_target = make_batch_valid_test(test_psth, word2number_dict, 128, n_step)
    # all_test_batch = torch.LongTensor(all_test_batch)  # list to tensor
    # all_test_target = torch.LongTensor(all_test_target)
    # select_model_path = "./models/nnlm_model_epoch10.ckpt"
    # test(select_model_path)

The size of the dictionary is: 7613
The number of the train batch is: 68439

Train###############
NNLM(
  (embed): Embedding(7613, 16)
  (fc1): Linear(in_features=80, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=7613, bias=True)
  (fc3): Linear(in_features=80, out_features=7613, bias=True)
  (tanh): Tanh()
)


  from .autonotebook import tqdm as notebook_tqdm


Epoch: 0001 Batch: 50 /68439 loss = 2.727856 ppl = 15.3
Epoch: 0001 Batch: 100 /68439 loss = 0.831484 ppl = 2.29673
Epoch: 0001 Batch: 150 /68439 loss = 0.538542 ppl = 1.71351
Epoch: 0001 Batch: 200 /68439 loss = 0.442267 ppl = 1.55623
Epoch: 0001 Batch: 250 /68439 loss = 0.347810 ppl = 1.41596
Epoch: 0001 Batch: 300 /68439 loss = 0.335809 ppl = 1.39907
Epoch: 0001 Batch: 350 /68439 loss = 0.315843 ppl = 1.37141
Epoch: 0001 Batch: 400 /68439 loss = 0.265594 ppl = 1.30421
Epoch: 0001 Batch: 450 /68439 loss = 0.192043 ppl = 1.21172
Epoch: 0001 Batch: 500 /68439 loss = 0.262571 ppl = 1.30027
Epoch: 0001 Batch: 550 /68439 loss = 0.284580 ppl = 1.3292
Epoch: 0001 Batch: 600 /68439 loss = 0.205516 ppl = 1.22816
Epoch: 0001 Batch: 650 /68439 loss = 0.229173 ppl = 1.25756
Epoch: 0001 Batch: 700 /68439 loss = 0.211390 ppl = 1.23539
Epoch: 0001 Batch: 750 /68439 loss = 0.186549 ppl = 1.20508
Epoch: 0001 Batch: 800 /68439 loss = 0.204186 ppl = 1.22653
Epoch: 0001 Batch: 850 /68439 loss = 0.251151

KeyboardInterrupt: 