In [1]:
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import BertTokenizer
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import random
import re
from nltk import word_tokenize,sent_tokenize
import gc
import torch
from torch import tensor
import torch.nn as nn
import torch.nn.functional as F

dataset = load_dataset("wiki_qa")

print(dataset['train'][0])

{'question_id': 'Q1', 'question': 'how are glacier caves formed?', 'document_title': 'Glacier cave', 'answer': 'A partly submerged glacier cave on Perito Moreno Glacier .', 'label': 0}


In [2]:
query_list = {}
for train_data in dataset['train']:
    if train_data['question_id'] not in query_list:
        query_list[train_data['question_id']] = train_data['question']
for train_data in dataset['validation']:
    if train_data['question_id'] not in query_list:
        query_list[train_data['question_id']] = train_data['question']
    
print(query_list['Q1'])
print(len(query_list.keys()))

how are glacier caves formed?
2118


In [65]:
for q_index, question in query_list.items():
    print(question)
    break
print(len(query_list.items()))

how are glacier caves formed?
2414


In [9]:
def Clean_data(data):
    """Removes all the unnecessary patterns and cleans the data to get a good sentence"""
    repl='' #String for replacement
    
    #removing all open brackets
    data=re.sub('\(', repl, data)
    
    #removing all closed brackets
    data=re.sub('\)', repl, data)
    
    #Removing all the headings in data
    for pattern in set(re.findall("=.*=",data)):
        data=re.sub(pattern, repl, data)
    
    #Removing unknown words in data
    for pattern in set(re.findall("<unk>",data)):
        data=re.sub(pattern,repl,data)
    
    #Removing all the non-alphanumerical characters
    for pattern in set(re.findall(r"[^\w ]", data)):
        repl=''
        if pattern=='-':
            repl=' '
        #Retaining period, apostrophe
        if pattern!='.' and pattern!="\'":
            data=re.sub("\\"+pattern, repl, data)
            
    return data
    

In [66]:
cleaned_data = []
for q_index, question in query_list.items():
    data = Clean_data(question)
    cleaned_data.append(data)
print(cleaned_data[0])
print(len(cleaned_data))

how are glacier caves formed
2414


In [13]:
def create_xy_pairs(questions):
    xy_pairs = []
    for question in questions:
        tokens = question.split()  # 假设使用空格进行简单分词
        for i in range(1, len(tokens)):
            x = " ".join(tokens[:i])
            y = tokens[i]
            xy_pairs.append((x, y))
    return xy_pairs

In [15]:
xy_pairs = create_xy_pairs(cleaned_data)

# 查看构建的一些（X, Y）对
print(xy_pairs[:5])

[('how', 'are'), ('how are', 'glacier'), ('how are glacier', 'caves'), ('how are glacier caves', 'formed'), ('How', 'are')]


In [39]:
def create_vocab(sentences):
    #Word tokenization
    words=set()
    for sent in sentences:
        for word in str.split(sent,' '):
            words.add(word)
    words=list(words)
    
    #Adding empty string in list of words to avoid confusion while padding.
    #Padded zeroes can be interpreted as empty strings.
    words.insert(0,"")
    return words
vocabs = create_vocab(cleaned_data)

In [42]:
def Convert_data(sentences, words, seq_len):
    """Converts text data into numerical form"""
    
    sent_sequences=[]
    for i in range(len(sentences)):
        words_in_sent=str.split(sentences[i],' ')
        for j in range(1,len(words_in_sent)):
            if j<=(seq_len):
                sent_sequences.append(words_in_sent[:j])
            elif j>seq_len and j<len(words_in_sent):
                sent_sequences.append(words_in_sent[j-seq_len:j])
            elif j>len(words_in_sent)-seq_len:
                sent_sequences.append(words_in_sent[j-seq_len:])
                
    #The above code converts the text data into the following sequences
    #[['The', '2013'],
    #['The', '2013', '14'],
    #['The', '2013', '14', 'season'],
    #['The', '2013', '14', 'season', 'was']]
    
    #Splitting into predictors and class_labels
    predictors=[];class_labels=[]
    for i in range(len(sent_sequences)):
        predictors.append(sent_sequences[i][:-1])
        class_labels.append(sent_sequences[i][-1])
    
    #Padding the predictors manually with Empty strings
    pad_predictors=[]
    for i in range(len(predictors)):
        emptypad=['']*(seq_len-len(predictors[i])-1)
        emptypad.extend(predictors[i])
        pad_predictors.append(emptypad)
        
    #The following two chunks of code are useful to convert text into numeric form
    #Dictionary with words as keys and indices as values
    global word_ind
    word_ind=dict()
    for ind,word in enumerate(words):
        word_ind[word]=ind
    
    #Dictionary with indices as keys and words as values
    global ind_word
    ind_word=dict()
    for ind,word in enumerate(words):
        ind_word[ind]=word
        
    #Convert each word into their respective index
    for i in range(len(pad_predictors)):
        for j in range(len(pad_predictors[i])):
            pad_predictors[i][j]=word_ind[pad_predictors[i][j]]
        class_labels[i]=word_ind[class_labels[i]]
        
    #Convert sequences to tensors
    for i in range(len(pad_predictors)):
        pad_predictors[i]=torch.tensor(pad_predictors[i])
    pad_predictors=torch.stack(pad_predictors)
    class_labels=torch.tensor(class_labels)
     
    return pad_predictors, class_labels

In [43]:
class LSTM(nn.Module):
    """Base class for all neural network modules.
       All models should subclass this class"""
    def __init__(self,num_embeddings, embedding_dim, padding_idx, hidden_size, Dropout_p, batch_size):
        super(LSTM,self).__init__()
        self.num_embeddings=num_embeddings
        self.embedding_dim=embedding_dim
        self.padding_idx=padding_idx
        self.hidden_size=hidden_size
        self.dropout=Dropout_p
        self.batch_size=batch_size
        
        #Adding Embedding Layer
        self.Embedding=nn.Embedding(num_embeddings, embedding_dim, padding_idx=0)
        
        #Adding LSTM Layer
        self.lstm=nn.LSTM(embedding_dim, hidden_size, num_layers=1, batch_first=True)
        
        #Adding Dropout Layer
        self.dropout=nn.Dropout(Dropout_p)
        
        #Adding fully connected dense Layer
        self.FC=nn.Linear(hidden_size, num_embeddings)
        
    def init_hidden(self, batch_size):
        """Initializes hiddens state tensors to zeros"""
        
        state_h=torch.zeros(1, batch_size, self.hidden_size)
        state_c=torch.zeros(1, batch_size, self.hidden_size)
        
        return (state_h,state_c)
        
    def forward(self,input_sequence, state_h,state_c):
        
        #Applying embedding layer to input sequence
        Embed_input=self.Embedding(input_sequence)
        
        #Applying LSTM layer
        output,(state_h,state_c)=self.lstm(Embed_input, (state_h,state_c)) 
        
        #Applying fully connected layer
        logits=self.FC(output[:,-1,:])
         
        return logits,(state_h,state_c)
    
    def topk_sampling(self, logits, topk):
        """Applies softmax layer and samples an index using topk"""
        
        #Applying softmax layer to logits
        logits_softmax=F.softmax(logits,dim=1)
        values,indices=torch.topk(logits_softmax[0],k=topk)
        choices=indices.tolist()
        sampling=random.sample(choices,1)
        
        return ind_word[sampling[0]]

In [44]:
def get_batch(pad_predictors, class_labels, batch_size):
    for i in range(0, len(pad_predictors), batch_size):
        if i+batch_size<len(pad_predictors):
            yield pad_predictors[i:i+batch_size], class_labels[i:i+batch_size]

In [61]:
def train_model(pad_predictors, class_labels, n_vocab, embedding_dim, padding_idx, hidden_size, Dropout_p, batch_size, lr):
    """Trains an LSTM Model"""
    #Creates instance of LSTM class
    model=LSTM(n_vocab, embedding_dim, padding_idx, hidden_size, Dropout_p, batch_size)
    
    #Creates instance of CrossEntropLoss class
    criterion=nn.CrossEntropyLoss(ignore_index=0)
    
    #Creates instance of Adam optimizer class
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    num_epochs=100
    for epoch in range(num_epochs):
        state_h, state_c=model.init_hidden(batch_size)
        
        total_loss=0
        for x, y in get_batch(pad_predictors, class_labels, batch_size):
            #sets model in training mode
            model.train()
            
            state_h=state_h.detach()
            state_c=state_c.detach()
            
            logits,(state_h,state_c)=model(x, state_h, state_c)
           
            #compute loss
            loss = criterion(logits, y)
            loss_value = loss.item()
            total_loss+=len(x)*loss_value

            #Sets the gradients of all the optimized tensors to zero
            model.zero_grad()

            #computes dloss/dx and assigns gradient for every parameter
            loss.backward()

            #Clips the gradient norm to avoid exploding gradient problems
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)

            #Performs a single optimization step (parameter update).
            optimizer.step()
            
        total_loss/=len(pad_predictors)
            
        print("Epoch [{}/{}] Loss: {}, perplexity: {}".
              format(epoch+1, num_epochs, total_loss, np.exp(total_loss)))
        
        gen_text=generate(model, init='How are', sent_len=10, topk=3)
        print("Text generated after epoch", epoch,":")
        print("\n",end='')
        print(gen_text)
        print('\n',end='')
        
    return model, total_loss

In [62]:
def generate(model, init, sent_len, topk):
    """Generates sentences from the model"""

    sentence=init
    for k in range(sent_len):
        #sets model in evaluation mode
        model.eval()
        
        #sets the length of sentence to seq_len
        input_indices=[]
        for word in str.split(sentence," "):
            input_indices.append(word_ind[word])
        if len(input_indices)<seq_len-1:
            input_tensor=[0]*(seq_len-len(input_indices)-1)
            input_tensor.extend(input_indices)
        else:
            input_tensor=input_indices[-seq_len+1:]
            
        #Initiates hidden state and cell state tensors to zeros
        state_h, state_c=model.init_hidden(len(input_tensor))
        
        input_tensor=torch.stack([torch.tensor(input_tensor)])
        out,(state_h,state_c)=model(input_tensor.transpose(0,1),state_h, state_c)
        
        #Samples a word from topk words
        word=model.topk_sampling(out, topk)
        
        if word!='' and word!=str.split(sentence,' ')[-1]:
            sentence=sentence+" "+word

    return sentence

In [63]:
def main():
    # train=load_data("../input/wikitext2-data/train.txt")
    # data=train[:]
    # data=Clean_data(data)
    # sentences, words=split_data(data, num_sentences=25000)
        
    pad_predictors, class_labels=Convert_data(cleaned_data, vocabs, seq_len)
    
    print("Number of input sequences :",len(pad_predictors))
    
    model, loss=train_model(pad_predictors, class_labels, n_vocab=len(vocabs), embedding_dim=100,
                padding_idx=0, hidden_size=128, Dropout_p=0.1, batch_size=64, lr=0.001)
    
    generated_sentence=generate(model, init='The', sent_len=10, topk=5)
    
    #save the model
    torch.save(model,"./Wiki_Model.pt")
    
    return loss

In [64]:
if __name__ == "__main__":
    seq_len=5
    loss=main()
    
    print("Loss on train data: ", loss)
    print("Perplexity on train data: ", np.exp(loss))

Number of input sequences : 14341
Epoch [1/100] Loss: 6.114992013378926, perplexity: 452.5924326724992
Text generated after epoch 0 :

How are of a are of the cellular in the name was

Epoch [2/100] Loss: 5.133280217369085, perplexity: 169.57244122360714
Text generated after epoch 1 :

How are in a are the an name the song

Epoch [3/100] Loss: 4.814148920648526, perplexity: 123.2418790447528
Text generated after epoch 2 :

How are in the many in a name people the slugs of

Epoch [4/100] Loss: 4.53378324945571, perplexity: 93.11015452131956
Text generated after epoch 3 :

How are a the much on slugs first is a is

Epoch [5/100] Loss: 4.261363392901648, perplexity: 70.90659112522552
Text generated after epoch 4 :

How are the in much there song the What what liberty song

Epoch [6/100] Loss: 3.9963730440405105, perplexity: 54.40048362786634
Text generated after epoch 5 :

How are a the much the magnetic name does first

Epoch [7/100] Loss: 3.7313323148947264, perplexity: 41.7346748685896