In [1]:
import os
import math
import torch
import pandas as pd
import torch.nn as nn
import matplotlib.pyplot as plt 

from tqdm import tqdm
from torch import Tensor
from collections import Counter
from torch.nn import Transformer
from torchtext.vocab import vocab
from torch.nn.utils.rnn import pad_sequence
from torchtext.data.utils import get_tokenizer
from torch.utils.data import Dataset, DataLoader

In [2]:
def load_data(path):
    x_train, y_train, x_valid, y_valid= [], [], [], []
    for types in os.listdir(path):
        classes_path = f'{path}/{types}'
        for classes in os.listdir(classes_path):
            file_path = f'{classes_path}/{classes}' 
            print(file_path)

            df = pd.read_csv(file_path).values
            input_text, summary = df[:,1], df[:,0]
            if types == 'Train':
                x_train.extend(input_text)
                y_train.extend(summary)
            
            else:
                x_valid.extend(input_text)
                y_valid.extend(summary)
    return  x_train, y_train, x_valid, y_valid

x_train, y_train, x_valid, y_valid = load_data('SummaryData') # ../SummaryData  ..\SummaryData 

SummaryData/Train/cl_news_summary_more.csv
SummaryData/Train/cl_news_summary_train.csv
SummaryData/Train/cl_train_news_summary_more.csv
SummaryData/Valid/cl_news_summary.csv
SummaryData/Valid/cl_news_summary_valid.csv
SummaryData/Valid/cl_valid_news_summary_more.csv


In [3]:
def get_vocab(inputs, tokenizer, train_len, special = ('<PAD>', '<SOS>','<EOS>','<UNK>')):
    counter = Counter()


    new_inputs = []
    for sentence in inputs:
        tokens = tokenizer(sentence)
        counter.update(tokens)
        new_inputs.append(tokens)

    token_vocab = vocab(counter, min_freq=5, specials=special)

    return token_vocab, new_inputs[:train_len], new_inputs[train_len:]

all_input = x_train + x_valid
all_target = y_train + y_valid
tokenizer = get_tokenizer('basic_english')

input_vocab, x_train, x_valid= get_vocab(all_input, tokenizer, len(x_train))
traget_vocab, y_train, y_valid= get_vocab(all_target, tokenizer, len(y_train))

input_vocab.set_default_index(input_vocab.get_stoi()['<UNK>'])
traget_vocab.set_default_index(traget_vocab.get_stoi()['<UNK>'])

INPUT_DIM =  len(input_vocab)
OUTPUT_DIN = len(traget_vocab)

SOS_IDX = input_vocab.get_stoi()['<SOS>']
EOS_IDX = input_vocab.get_stoi()['<EOS>']
PAD_IDX = input_vocab.get_stoi()['<PAD>']

In [4]:
def token2num(inputs, targets):
    encoder_input, decoder_input = [], []
    for i in range(len(inputs)):
        encoder_in = input_vocab.lookup_indices(inputs[i])[:499] + [EOS_IDX]
        decoder_in = traget_vocab.lookup_indices(targets[i])[:499] + [EOS_IDX]

        encoder_input.append(torch.tensor(encoder_in))
        decoder_input.append(torch.tensor(decoder_in))
    return encoder_input, decoder_input

x_train, y_train= token2num(x_train, y_train)
x_valid, y_valid= token2num(x_valid, y_valid)

In [5]:
class SummaryeDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
          
    def __getitem__(self, index):
        return self.x[index], self.y[index]
       
    def __len__(self):
        return len(self.x)
    
trainset = SummaryeDataset(x_train, y_train)
validset = SummaryeDataset(x_valid, y_valid)

In [6]:
def collate_fn(batch):    
    (x, y) = zip(*batch)
    
    pad_data = pad_sequence(x + y, padding_value=PAD_IDX, batch_first=True)
    src, tgt = torch.split(pad_data, split_size_or_sections=[len(x), len(y)], dim=0)

    return src.permute(1, 0) , tgt.permute(1, 0)


train_loader = DataLoader(trainset, batch_size = 32, shuffle = True, num_workers = 0, pin_memory = True, collate_fn = collate_fn)
valid_loader = DataLoader(validset, batch_size = 32, shuffle = True, num_workers = 0, pin_memory = True, collate_fn = collate_fn)

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 500):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])
    

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
    

class Seq2SeqTransformer(nn.Module):
    def __init__(self, num_encoder_layers: int,  # Encoder數量
                 num_decoder_layers: int,        # Decoder數量
                 emb_size: int,                  
                 nhead: int,                     
                 src_vocab_size: int,            
                 tgt_vocab_size: int,            
                 dim_feedforward: int = 512,     
                 dropout: float = 0.1            
            ):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,         # Encoder輸入
                trg: Tensor,         # Decoder輸入
                src_mask: Tensor,    
                tgt_mask: Tensor,    
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

In [8]:
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3

transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, INPUT_DIM, OUTPUT_DIN, FFN_HID_DIM)

model = transformer.to(device)

criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)



In [9]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=device).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [10]:
def train(epoch):
    train_loss = 0
    train_pbar = tqdm(train_loader, position=0, leave=True) 

    model.train()
    for input_datas in train_pbar: 
        
        src, tgt = [i.to(device) for i in input_datas]
        tgt_input = tgt[:-1, :]
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
        optimizer.zero_grad()
        tgt_out = tgt[1:, :]
        loss = criterion(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()
        optimizer.step()

        train_pbar.set_description(f'Train Epoch {epoch}')  
        train_pbar.set_postfix({'loss':f'{loss:.3f}'}) 

        train_loss += loss.item()

    return train_loss/len(train_loader)

def valid(epoch):
    valid_loss = 0
    valid_pbar = tqdm(valid_loader, position=0, leave=True) 

    model.eval()
    with torch.no_grad(): 
        for input_datas in valid_pbar: 
            
            src, tgt = [i.to(device) for i in input_datas]

            tgt_input = tgt[:-1, :]
            src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

            logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

            tgt_out = tgt[1:, :]
            loss = criterion(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))

            valid_pbar.set_description(f'Valid Epoch {epoch}')  
            valid_pbar.set_postfix({'loss':f'{loss:.3f}'}) 
    
            valid_loss += loss.item()
    
        return valid_loss/len(valid_loader)
    
def show_training_loss(loss_record):
    train_loss, valid_loss = [i for i in loss_record.values()]
    
    plt.plot(train_loss)
    plt.plot(valid_loss)
    plt.title('Result')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['train', 'valid'], loc='upper left')
    plt.show()

In [12]:
epochs = 10                           
early_stopping = 10                      
stop_cnt = 0                             
model_path = 'model.ckpt'                
show_loss = True                         
best_loss = float('inf')                 
loss_record = {'train':[], 'valid':[]}   

for epoch in range(epochs):   
    train_loss = train(epoch)
    valid_loss = valid(epoch)
    
    loss_record['train'].append(train_loss)
    loss_record['valid'].append(valid_loss)
    
    # 儲存最佳的模型權重
    if valid_loss < best_loss:
        best_loss = valid_loss
        torch.save(model.state_dict(), 'e' + model_path)
        print(f'Saving Model With Loss {best_loss:.5f}')
        stop_cnt = 0
    else:
        stop_cnt+=1
    
    # Early stopping
    if stop_cnt == early_stopping:
        output = "Model can't improve, stop training"
        print('-' * (len(output)+2))
        print(f'|{output}|')
        print('-' * (len(output)+2))
        break

    print(f'Train Loss: {train_loss:.5f}' , end='| ')
    print(f'Valid Loss: {valid_loss:.5f}' , end='| ')
    print(f'Best Loss: {best_loss:.5f}', end='\n\n')


if show_loss:
    show_training_loss(loss_record)

Train Epoch 0:   9%|████                                         | 521/5804 [13:14:33<134:16:56, 91.50s/it, loss=7.973]


KeyboardInterrupt: 