In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random 
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
import jieba
from torchtext.data import TabularDataset,Field,BucketIterator

In [None]:
def tokenize(text):
    return list(jieba.cut(text))

TEXT = Field(sequential=True,tokenize=tokenize,lower = True,pad_token='<pad>',batch_first=True)
fields = [('trg',TEXT),('src',TEXT)]

train_data,valid_data = TabularDataset.splits(
    path = 'C:\\Users\\Alfred\\Desktop\\rss\\train_kit\\data\\',
    train = 'train.tsv',
    validation = 'valiation.tsv',
    format = 'tsv',
    fields = fields,
)
TEXT.build_vocab(train_data,min_freq = 2)
print(train_data)
train_iter,valid_iter = BucketIterator.splits(
    (train_data,valid_data),
    batch_size = 8,
    shuffle = True,
    sort_key=lambda x: len(x.src),
    device = -1
)

In [3]:
id2vocab = TEXT.vocab.itos
vocab2id = TEXT.vocab.stoi
PAD_IDX =vocab2id[TEXT.pad_token]
SOS_IDX = vocab2id[TEXT.init_token]
EOS_IDX = vocab2id[TEXT.eos_token]

In [4]:
class Encoder(nn.Module):
    def __init__(self,input_dim,emb_dim,enc_hid_dim,dec_hid_dim,dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_dim,emb_dim)
        self.rnn = nn.GRU(emb_dim,enc_hid_dim,bidirectional=True,batch_first =True)
        self.fc = nn.Linear(enc_hid_dim *2,dec_hid_dim)
        self.dropout = nn.Dropout(dropout)
        
        #embed->rnn->full_connected(dropout)
        
    def forward(self,src):
        embedded = self.dropout(self.embedding(src))
        outputs,hidden = self.rnn(embedded)
        hidden =torch.tanh(self.fc(torch.cat((hidden[-2,:,:],hidden[-1,:,:]),dim=1)))
        return outputs,hidden

In [5]:
class Attention(nn.Module):
    def __init__(self,enc_hid_dim,dec_hid_dim):
        super().__init__()
        self.attention = nn.Linear((enc_hid_dim*2)+dec_hid_dim,dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim,1,bias=False)
    def forward(self,hidden,encoder_outputs):
        src_len = encoder_outputs.shape[1]
        hidden = hidden.unsqueeze(1).repeat(1,src_len,1)
        energy = torch.tanh(self.attention(torch.cat((hidden,encoder_outputs),dim = 2)))
        attention  = self.v(energy).squeeze(2)
        return F.softmax(attention,dim=1)

In [6]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
        super().__init__()
        self.output_dim = output_dim
        self.attention = attention       
        self.embedding = nn.Embedding(output_dim, emb_dim)        
        self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim, batch_first=True)        
        self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)      
        self.dropout = nn.Dropout(dropout)
    def forward(self, inputs, hidden, encoder_outputs):                    
        inputs = inputs.unsqueeze(1)   
        embedded = self.dropout(self.embedding(inputs))     
        a = self.attention(hidden, encoder_outputs)                    
        a = a.unsqueeze(1)        
        weighted = torch.bmm(a, encoder_outputs)
      
        rnn_input = torch.cat((embedded, weighted), dim = 2)              
        output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))             
        embedded = embedded.squeeze(1)
        output = output.squeeze(1)
        weighted = weighted.squeeze(1)        
        prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))             
        return prediction, hidden.squeeze(0)

In [7]:
class Seq2Seq(nn.Module):
    def __init__(self,encoder,decoder,device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
    def forward(self, src, trg, teacher_forcing_ratio = 0.5):
        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.output_dim   
        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)
        encoder_outputs, hidden = self.encoder(src)
        inputs = trg[:,0]
        
        for t in range(1, trg_len):
            output, hidden = self.decoder(inputs, hidden, encoder_outputs) 
            outputs[:,t,:] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            inputs = trg[:,t] if teacher_force else top1
        return outputs

In [8]:
device = torch.device("cpu")
INPUT_DIM = len(id2vocab)
OUTPUT_DIM = len(id2vocab)
ENC_EMB_DIM = 128
DEC_EMB_DIM = 128
ENC_HID_DIM = 256
DEC_HID_DIM = 256
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
N_EPOCHS = 1
CLIP = 1
attention = Attention(ENC_HID_DIM,ENC_HID_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attention)
model  =Seq2Seq(enc,dec,device)

In [9]:
def init_weights(m):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)    
model.apply(init_weights)

optimizer = optim.Adam(model.parameters(),lr = 5e-5)
criterion = nn.CrossEntropyLoss(ignore_index = PAD_IDX)

In [None]:

torch.cuda.empty_cache()
loss_vals = []
loss_vals_eval = []

device = torch.device('cpu')
model = model.to(device)
criterion = criterion.to(device)

for epoch in range(N_EPOCHS):
    model.train()
    epoch_loss = []
    dp = tqdm(valid_iter)
    
    for i ,batch in enumerate(dp):
            
        try:
            src = batch.src.to(device)
            trg = batch.trg.to(device)
            model.zero_grad()
            output = model(src, trg)     
            output_dim = output.shape[-1]       
            output = output[:, 1:, :].reshape(-1, output_dim)
            trg = trg[:, 1:].reshape(-1)  
            loss = criterion(output, trg)    
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)
            epoch_loss.append(loss.item())
            optimizer.step()
        except StopIteration:
            print("End of epoch")
            torch.cuda.empty_cache()
            valid_iter.init_epoch()
    loss_vals.append(np.mean(epoch_loss))
    #节省eval步骤
    print(f'Epoch: {epoch+1}, Loss: {np.mean(epoch_loss)}')
torch.save(model.state_dict(), 'model.pt')
