In [1]:
import torch
from torch import nn
import numpy as np
import sys
sys.path.append('..')
from dataset import ptb

def to_one_hot(t_train, vocab_size):
    # 使用NumPy的eye函数创建一个独热编码矩阵
    one_hot_matrix = np.eye(vocab_size)[t_train]
    return torch.tensor(one_hot_matrix)

In [2]:
# 设定超参数
batch_size = 20
wordvec_size = 100
hidden_size = 100  # RNN的隐藏状态向量的元素个数
time_size = 35  # RNN的展开大小
lr = 20.0
max_epoch = 1
max_grad = 0.25

In [None]:
corpus, word_to_id, id_to_word = ptb.load_data('train')
corpus_test, _, _ = ptb.load_data('test')
vocab_size = len(word_to_id)
xs = corpus[:-1]
ts = corpus[1:]
data_size = len(xs)

In [None]:
max_iters = data_size // (batch_size * time_size)
jump = (data_size - 1) // batch_size
offsets = [i * jump for i in range(batch_size)]
time_idx = 0

# Normal LSTM

In [None]:
class lstm_torch(nn.Module):
    
    def __init__(self,vocab_size, wordvec_size, hidden_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size,wordvec_size)
        nn.init.xavier_normal_(self.embed.weight)
        
        self.lstm = nn.LSTM(wordvec_size, hidden_size, batch_first = True)
        nn.init.xavier_normal_(self.lstm.weight_hh_l0)
        nn.init.xavier_normal_(self.lstm.weight_ih_l0)
        
        self.linear = nn.Linear(hidden_size, vocab_size)
        nn.init.xavier_normal_(self.linear.weight)
    
    def forward(self, seq, h_0 ,c_0 ):
        N , T  = seq.size()
        word_embed = self.embed(seq);
        out , (h_0 , c_0) = self.lstm(word_embed, (h_0 , c_0))
        out = out.reshape(N*T,-1);
        out= self.linear(out)
        
        return out , (h_0 , c_0)

In [None]:
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value = max_grad)

In [None]:
model = lstm_torch(vocab_size,wordvec_size,hidden_size)
h_0 = torch.zeros((1,batch_size,hidden_size))
c_0 = torch.zeros((1,batch_size,hidden_size))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [None]:
time_idx=0
ppl_list = []
total_loss = 0
loss_count = 0

In [None]:
for iter in range(max_iters):
    batch_x = np.empty((batch_size, time_size), dtype='i')
    batch_t = np.empty((batch_size, time_size), dtype='i')
    for t in range(time_size):
        for i , offset in enumerate(offsets):
            batch_x[i,t] = xs[( time_idx + offset) % data_size]
            batch_t[i,t] = ts[( time_idx + offset) % data_size]
        time_idx +=1
        
    
    batch_x_tensor = torch.tensor(batch_x)
    batch_t_tensor = torch.tensor(batch_t)
    batch_t_tensor = to_one_hot(batch_t_tensor, vocab_size)
    out , (h_0 , c_0) = model(batch_x_tensor, h_0, c_0)
    batch_t_tensor = batch_t_tensor.reshape(batch_size*time_size,-1)
    
    loss = criterion(out, batch_t_tensor) #要保证第二个维度是 vocab_size
    h_0 = h_0.detach()
    c_0 = c_0.detach()
    
    model.zero_grad()
    loss.backward(retain_graph=True)
    torch.nn.utils.clip_grad_value_(model.parameters(), clip_value = max_grad)
    
    optimizer.step()
    total_loss += loss
    loss_count += 1
    
    if iter % 20 ==0:
        ppl = torch.exp(total_loss / loss_count)
        print('epoch {} | iter: {}/1327 | pp: {} | loss:{}'.format(1,iter+1,ppl,loss.item()))
        ppl_list.append(float(ppl))
        total_loss, loss_count = 0, 0



In [None]:
for epoch in range(4):
    for iter in range(max_iters):
        batch_x = np.empty((batch_size, time_size), dtype='i')
        batch_t = np.empty((batch_size, time_size), dtype='i')
        for t in range(time_size):
            for i , offset in enumerate(offsets):
                batch_x[i,t] = xs[( time_idx + offset) % data_size]
                batch_t[i,t] = ts[( time_idx + offset) % data_size]
            time_idx +=1
        
    
        batch_x_tensor = torch.tensor(batch_x)
        batch_t_tensor = torch.tensor(batch_t)
        batch_t_tensor = to_one_hot(batch_t_tensor, vocab_size)
        out , (h_0 , c_0) = model(batch_x_tensor, h_0, c_0)
        batch_t_tensor = batch_t_tensor.reshape(batch_size*time_size,-1)
    
        loss = criterion(out, batch_t_tensor) #要保证第二个维度是 vocab_size
        
    
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        #torch.nn.utils.clip_grad_value_(model.parameters(), clip_value = max_grad)
    
        optimizer.step()
        h_0 = h_0.detach()
        c_0 = c_0.detach()  #BPTT截断反向传播
        total_loss += loss
        loss_count += 1
    
        if iter % 20 ==0:
            ppl = torch.exp(total_loss / loss_count)
            print('epoch {} | iter: {}/1327 | pp: {} | loss:{}'.format(epoch+1,iter+1,ppl,loss.item()))
            ppl_list.append(float(ppl))
            total_loss, loss_count = 0, 0

In [None]:
import numpy as np

In [None]:
a = np.array([1,2,3,4,2])
b = np.array([3,5,7,4])
np.intersect1d(a,b)

In [None]:
np.setdiff1d(a,b)

# Better LSTM

In [3]:
class better_lstm_torch(nn.Module):
    
    def __init__(self,vocab_size, wordvec_size, hidden_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size,wordvec_size)
        nn.init.xavier_normal_(self.embed.weight)
        
        self.lstm1 = nn.LSTM(wordvec_size, hidden_size, batch_first = True)
        nn.init.xavier_normal_(self.lstm1.weight_hh_l0)
        nn.init.xavier_normal_(self.lstm1.weight_ih_l0)
        
        self.dropout = nn.Dropout(0.5)
        
        self.lstm2 = nn.LSTM(hidden_size, hidden_size, batch_first = True)
        nn.init.xavier_normal_(self.lstm2.weight_hh_l0)
        nn.init.xavier_normal_(self.lstm2.weight_ih_l0)
        
        
        self.linear = nn.Linear(hidden_size, vocab_size)

        for para1 , para2 in zip(self.embed.parameters(), self.linear.parameters()):
            para2.data = para1.data
    
    
    def forward(self, seq, h_00 ,c_00 , h_01, c_01 ):
        N , T  = seq.size()
        out = self.embed(seq);
        out , (h_00 , c_00) = self.lstm1(out, (h_00 , c_00))
        out = self.dropout(out)
        
        out , (h_01 , c_01) = self.lstm2(out, (h_01 , c_01))
        out = out.reshape(N*T,-1);
        out= self.linear(out)
        
        return out , (h_00 , c_00) , (h_01 , c_01)

In [12]:
model = better_lstm_torch(14,10,10)

In [None]:
xs = torch.randint()