In [1]:
import torch
from torch import nn
import numpy as np
import sys
sys.path.append('..')
from dataset import ptb

def to_one_hot(t_train, vocab_size):
    # 使用NumPy的eye函数创建一个独热编码矩阵
    one_hot_matrix = np.eye(vocab_size)[t_train]
    return torch.tensor(one_hot_matrix)

In [2]:
# 设定超参数
batch_size = 20
wordvec_size = 100
hidden_size = 100  # RNN的隐藏状态向量的元素个数
time_size = 35  # RNN的展开大小
lr = 20.0
max_epoch = 1
max_grad = 0.25

In [3]:
corpus, word_to_id, id_to_word = ptb.load_data('train')
corpus_test, _, _ = ptb.load_data('test')
vocab_size = len(word_to_id)
xs = corpus[:-1]
ts = corpus[1:]
data_size = len(xs)

In [4]:
max_iters = data_size // (batch_size * time_size)
jump = (data_size - 1) // batch_size
offsets = [i * jump for i in range(batch_size)]
time_idx = 0

# Normal LSTM

In [None]:
class lstm_torch(nn.Module):
    
    def __init__(self,vocab_size, wordvec_size, hidden_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size,wordvec_size)
        nn.init.xavier_normal_(self.embed.weight)
        
        self.lstm = nn.LSTM(wordvec_size, hidden_size, batch_first = True)
        nn.init.xavier_normal_(self.lstm.weight_hh_l0)
        nn.init.xavier_normal_(self.lstm.weight_ih_l0)
        
        self.linear = nn.Linear(hidden_size, vocab_size)
        nn.init.xavier_normal_(self.linear.weight)
    
    def forward(self, seq, h_0 ,c_0 ):
        N , T  = seq.size()
        word_embed = self.embed(seq);
        out , (h_0 , c_0) = self.lstm(word_embed, (h_0 , c_0))
        out = out.reshape(N*T,-1);
        out= self.linear(out)
        
        return out , (h_0 , c_0)

In [None]:
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value = max_grad)

In [None]:
model = lstm_torch(vocab_size,wordvec_size,hidden_size)
h_0 = torch.zeros((1,batch_size,hidden_size))
c_0 = torch.zeros((1,batch_size,hidden_size))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [None]:
time_idx=0
ppl_list = []
total_loss = 0
loss_count = 0

In [None]:
for iter in range(max_iters):
    batch_x = np.empty((batch_size, time_size), dtype='i')
    batch_t = np.empty((batch_size, time_size), dtype='i')
    for t in range(time_size):
        for i , offset in enumerate(offsets):
            batch_x[i,t] = xs[( time_idx + offset) % data_size]
            batch_t[i,t] = ts[( time_idx + offset) % data_size]
        time_idx +=1
        
    
    batch_x_tensor = torch.tensor(batch_x)
    batch_t_tensor = torch.tensor(batch_t)
    batch_t_tensor = to_one_hot(batch_t_tensor, vocab_size)
    out , (h_0 , c_0) = model(batch_x_tensor, h_0, c_0)
    batch_t_tensor = batch_t_tensor.reshape(batch_size*time_size,-1)
    
    loss = criterion(out, batch_t_tensor) #要保证第二个维度是 vocab_size
    h_0 = h_0.detach()
    c_0 = c_0.detach()
    
    model.zero_grad()
    loss.backward(retain_graph=True)
    torch.nn.utils.clip_grad_value_(model.parameters(), clip_value = max_grad)
    
    optimizer.step()
    total_loss += loss
    loss_count += 1
    
    if iter % 20 ==0:
        ppl = torch.exp(total_loss / loss_count)
        print('epoch {} | iter: {}/1327 | pp: {} | loss:{}'.format(1,iter+1,ppl,loss.item()))
        ppl_list.append(float(ppl))
        total_loss, loss_count = 0, 0



In [None]:
for epoch in range(4):
    for iter in range(max_iters):
        batch_x = np.empty((batch_size, time_size), dtype='i')
        batch_t = np.empty((batch_size, time_size), dtype='i')
        for t in range(time_size):
            for i , offset in enumerate(offsets):
                batch_x[i,t] = xs[( time_idx + offset) % data_size]
                batch_t[i,t] = ts[( time_idx + offset) % data_size]
            time_idx +=1
        
    
        batch_x_tensor = torch.tensor(batch_x)
        batch_t_tensor = torch.tensor(batch_t)
        batch_t_tensor = to_one_hot(batch_t_tensor, vocab_size)
        out , (h_0 , c_0) = model(batch_x_tensor, h_0, c_0)
        batch_t_tensor = batch_t_tensor.reshape(batch_size*time_size,-1)
    
        loss = criterion(out, batch_t_tensor) #要保证第二个维度是 vocab_size
        
    
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        #torch.nn.utils.clip_grad_value_(model.parameters(), clip_value = max_grad)
    
        optimizer.step()
        h_0 = h_0.detach()
        c_0 = c_0.detach()  #BPTT截断反向传播
        total_loss += loss
        loss_count += 1
    
        if iter % 20 ==0:
            ppl = torch.exp(total_loss / loss_count)
            print('epoch {} | iter: {}/1327 | pp: {} | loss:{}'.format(epoch+1,iter+1,ppl,loss.item()))
            ppl_list.append(float(ppl))
            total_loss, loss_count = 0, 0

In [None]:
import numpy as np

In [None]:
a = np.array([1,2,3,4,2])
b = np.array([3,5,7,4])
np.intersect1d(a,b)

In [None]:
np.setdiff1d(a,b)

# Better LSTM

In [9]:
class better_lstm_torch(nn.Module):
    
    def __init__(self,vocab_size, wordvec_size, hidden_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size,wordvec_size)
        nn.init.xavier_normal_(self.embed.weight)
        
        self.lstm1 = nn.LSTM(wordvec_size, hidden_size, batch_first = True)
        nn.init.xavier_normal_(self.lstm1.weight_hh_l0)
        nn.init.xavier_normal_(self.lstm1.weight_ih_l0)
        
        self.dropout = nn.Dropout(0.5)
        
        self.lstm2 = nn.LSTM(hidden_size, hidden_size, batch_first = True)
        nn.init.xavier_normal_(self.lstm2.weight_hh_l0)
        nn.init.xavier_normal_(self.lstm2.weight_ih_l0)
        
        
        self.linear = nn.Linear(hidden_size, vocab_size)

        for para1 , para2 in zip(self.embed.parameters(), self.linear.parameters()):
            para2.data = para1.data
    
    
    def forward(self, seq, h_00 ,c_00 , h_01, c_01 ):
        N , T  = seq.size()
        out = self.embed(seq);
        out , (h_00 , c_00) = self.lstm1(out, (h_00 , c_00))
        out = self.dropout(out)
        
        out , (h_01 , c_01) = self.lstm2(out, (h_01 , c_01))
        out = out.reshape(N*T,-1);
        out= self.linear(out)
        
        return out , (h_00 , c_00) , (h_01 , c_01)

In [23]:
model = better_lstm_torch(vocab_size, wordvec_size, hidden_size)
h_00,c_00  = (torch.zeros((1,batch_size,hidden_size)),torch.zeros((1,batch_size,hidden_size)))
h_01,c_01  = (torch.zeros((1,batch_size,hidden_size)),torch.zeros((1,batch_size,hidden_size)))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

time_idx=0
ppl_list = []
total_loss = 0
loss_count = 0

In [24]:
for epoch in range(4):
    for iter in range(max_iters):
        batch_x = np.empty((batch_size, time_size), dtype='i')
        batch_t = np.empty((batch_size, time_size), dtype='i')
        for t in range(time_size):
            for i , offset in enumerate(offsets):
                batch_x[i,t] = xs[( time_idx + offset) % data_size]
                batch_t[i,t] = ts[( time_idx + offset) % data_size]
            time_idx +=1
        
    
        batch_x_tensor = torch.tensor(batch_x)
        batch_t_tensor = torch.tensor(batch_t)
        batch_t_tensor = to_one_hot(batch_t_tensor, vocab_size)
        
        out , (h_00 , c_00), (h_01 , c_01) = model(batch_x_tensor, h_00, c_00, h_01, c_01)
        
        
        batch_t_tensor = batch_t_tensor.reshape(batch_size*time_size,-1)
        
        loss = criterion(out, batch_t_tensor) #要保证第二个维度是 vocab_size
        
        
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        #torch.nn.utils.clip_grad_value_(model.parameters(), clip_value = max_grad)
    
        optimizer.step()
        h_00 = h_00.detach()
        c_00 = c_00.detach()  #BPTT截断反向传播
        h_01 = h_01.detach()
        c_01 = c_01.detach()
        
        total_loss += loss
        loss_count += 1
    
        if iter % 20 ==0:
            ppl = torch.exp(total_loss / loss_count)
            print('epoch {} | iter: {}/1327 | pp: {} | loss:{}'.format(epoch+1,iter+1,ppl,loss.item()))
            ppl_list.append(float(ppl))
            total_loss, loss_count = 0, 0

epoch 1 | iter: 1/1327 | pp: 10102.641803065873 | loss:9.220552233287266
epoch 1 | iter: 21/1327 | pp: 1740.611121440743 | loss:7.38252699136734
epoch 1 | iter: 41/1327 | pp: 1253.2697669171362 | loss:7.1632459279469085
epoch 1 | iter: 61/1327 | pp: 1023.3458556406706 | loss:6.902310350622449
epoch 1 | iter: 81/1327 | pp: 874.3320298152369 | loss:6.6685345772334506
epoch 1 | iter: 101/1327 | pp: 823.1902769891416 | loss:6.720573621817998
epoch 1 | iter: 121/1327 | pp: 879.4086244810204 | loss:6.848347262995584
epoch 1 | iter: 141/1327 | pp: 908.572910750185 | loss:6.744910958835057
epoch 1 | iter: 161/1327 | pp: 896.6158432647885 | loss:6.595385100500924
epoch 1 | iter: 181/1327 | pp: 950.2702753312873 | loss:6.867752318041665
epoch 1 | iter: 201/1327 | pp: 823.7810063852896 | loss:6.647412667955671
epoch 1 | iter: 221/1327 | pp: 830.8190241254431 | loss:6.748873678616115
epoch 1 | iter: 241/1327 | pp: 779.428629366344 | loss:6.619546938964299
epoch 1 | iter: 261/1327 | pp: 799.7252962

epoch 2 | iter: 881/1327 | pp: 220.8607635109063 | loss:5.221474100608279
epoch 2 | iter: 901/1327 | pp: 260.35675118662823 | loss:5.567309299672488
epoch 2 | iter: 921/1327 | pp: 240.55274572223217 | loss:5.458218446980214
epoch 2 | iter: 941/1327 | pp: 248.53431445890752 | loss:5.323691996658786
epoch 2 | iter: 961/1327 | pp: 268.7339025487138 | loss:5.483283473978684
epoch 2 | iter: 981/1327 | pp: 250.82279571575233 | loss:5.683613821762597
epoch 2 | iter: 1001/1327 | pp: 215.15641885801716 | loss:5.446313451620434
epoch 2 | iter: 1021/1327 | pp: 247.19550566834096 | loss:5.35348371082102
epoch 2 | iter: 1041/1327 | pp: 237.0117046418032 | loss:5.252388260577593
epoch 2 | iter: 1061/1327 | pp: 223.32390494156274 | loss:5.32181112405158
epoch 2 | iter: 1081/1327 | pp: 196.05302318622827 | loss:5.30019205498322
epoch 2 | iter: 1101/1327 | pp: 221.32310330869927 | loss:5.347302588396727
epoch 2 | iter: 1121/1327 | pp: 261.62889936366344 | loss:5.517262595827717
epoch 2 | iter: 1141/132

epoch 4 | iter: 401/1327 | pp: 193.21032044871646 | loss:5.344774730020269
epoch 4 | iter: 421/1327 | pp: 178.06596081771121 | loss:5.14663705683043
epoch 4 | iter: 441/1327 | pp: 182.91142652563818 | loss:5.147847845662909
epoch 4 | iter: 461/1327 | pp: 191.0874733226447 | loss:5.334453158053969
epoch 4 | iter: 481/1327 | pp: 195.23159416955932 | loss:5.058846635646748
epoch 4 | iter: 501/1327 | pp: 194.51342990058356 | loss:5.091798987026892
epoch 4 | iter: 521/1327 | pp: 198.87414301254793 | loss:4.810411796766566
epoch 4 | iter: 541/1327 | pp: 201.7901616440902 | loss:5.378550995116737
epoch 4 | iter: 561/1327 | pp: 196.4630051848653 | loss:5.110838300638113
epoch 4 | iter: 581/1327 | pp: 173.92679010686572 | loss:5.195570663058814
epoch 4 | iter: 601/1327 | pp: 230.43075576350103 | loss:5.492239521505378
epoch 4 | iter: 621/1327 | pp: 220.21185210508241 | loss:5.408539423644377
epoch 4 | iter: 641/1327 | pp: 202.52194361779613 | loss:5.268616684990702
epoch 4 | iter: 661/1327 | pp