In [None]:
from transformers import AutoTokenizer
import pandas as pd
import os

def read_csv_data(data_path):
    source, target = [], []
    for file_name in os.listdir(data_path):
        df = pd.read_csv(f'{data_path}/{file_name}')
        src, tgt = df['text'].values, df['summary'].values
        source.extend(src)
        target.extend(tgt)
    return source, target
    
x_train_data, y_train_data = read_csv_data('news/train')
x_test, y_test = read_csv_data('news/test')

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [None]:
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

class SummaryeDataset(Dataset):
    def __init__(self, x, y, tokenizer):
        self.x = x
        self.y = y
        self.tokenizer = tokenizer

    def __getitem__(self, index):
        return self.x[index], self.y[index]
       
    def __len__(self):
        return len(self.x)
    
    def collate_fn(self, batch):    
        batch_x, batch_y = zip(*batch)
        src = self.tokenizer(batch_x, max_length=256, truncation=True, padding="longest", return_tensors='pt')
        tgt = self.tokenizer(batch_y, max_length=256, truncation=True, padding="longest", return_tensors='pt')
        src = {f'src_{k}':v for k, v in src.items()}
        tgt = {f'tgt_{k}':v for k, v in tgt.items()}

        return {**src, **tgt}

x_train, x_valid, y_train, y_valid = train_test_split(x_train_data, y_train_data, train_size=0.8, random_state=46, shuffle=True) 

trainset = SummaryeDataset(x_train, y_train, tokenizer)
validset = SummaryeDataset(x_valid, y_valid, tokenizer)

train_loader = DataLoader(trainset, batch_size = 32, shuffle = True, num_workers = 0, pin_memory = True, collate_fn=trainset.collate_fn)
valid_loader = DataLoader(validset, batch_size = 32, shuffle = True, num_workers = 0, pin_memory = True, collate_fn=validset.collate_fn)

In [None]:
import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):
    def __init__(self, emb_size, dropout, maxlen=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(maxlen, emb_size)
        position = torch.arange(0, maxlen, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, emb_size, 2).float() * (-torch.log(torch.tensor(10000.0)) / emb_size))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)
    
class Seq2SeqTransformer(nn.Module):
    def __init__(self, vocab_size, emb_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward):
        super(Seq2SeqTransformer, self).__init__()
        self.src_embedding = nn.Embedding(vocab_size, emb_size)
        self.tgt_embedding = nn.Embedding(vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(emb_size, dropout=0.1)

        self.transformer = nn.Transformer(
            d_model=d_model, # 對應的嵌入層維度跟emb_size相同大小
            nhead=nhead,     # Muti-head Attention head數量
            num_encoder_layers=num_encoder_layers, # 要幾個Encoder進行運算
            num_decoder_layers=num_decoder_layers, # 要幾個Decoder進行運算
            dim_feedforward=dim_feedforward,       # Layer Norm輸出維度
            batch_first=True
        )

        # 用於生成最終輸出的線性層
        self.fc = nn.Linear(d_model, vocab_size)
        self.criterion = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

    def forward(self, **kwargs):
        src_ids = kwargs['src_input_ids']
        tgt_ids = kwargs['tgt_input_ids']
        src_emb, tgt_emb = self.embedding_step(src_ids, tgt_ids)

        src_key_padding_mask = (kwargs['src_attention_mask'] == 0)
        tgt_key_padding_mask = (kwargs['tgt_attention_mask'] == 0)

        src_mask = torch.zeros((src_emb.shape[1], src_emb.shape[1]), device=src_ids.device.type).type(torch.bool)
        tgt_mask = self.generate_square_subsequent_mask(tgt_emb)

        # 將嵌入通過transformer模型
        outs = self.transformer(
            src_emb, tgt_emb, 
            src_mask=src_mask, 
            tgt_mask=tgt_mask, 
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask, 
            memory_key_padding_mask=src_key_padding_mask
        )

        logits = self.fc(outs)

        tgt_ids_shifted = tgt_ids[:, 1:].reshape(-1)
        logits = logits[:, :-1].reshape(-1, logits.shape[-1])
        loss = self.criterion(logits, tgt_ids_shifted)

        return loss, logits

    def embedding_step(self, src, tgt):
        src_emb = self.src_embedding(src)
        tgt_emb = self.tgt_embedding(tgt)
        
        return self.positional_encoding(src_emb), self.positional_encoding(tgt_emb)
    
    def generate_square_subsequent_mask(self, tgt_emb):
        sz = tgt_emb.shape[1]
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask.to(tgt_emb.device.type)
    
    def generate(self, max_length=50, cls_token_id=101, sep_token_id=102, **kwargs):
        src_input_ids = kwargs['input_ids']
        src_attention_mask = kwargs['attention_mask']

        # 先嵌入源序列
        src_emb = self.positional_encoding(self.src_embedding(src_input_ids))
        src_key_padding_mask = (src_attention_mask == 0)

        # 初始化目標序列，開始符號 (BOS)
        tgt_input_ids = torch.full((src_input_ids.size(0), 1), cls_token_id, dtype=torch.long).to(src_input_ids.device)
        for _ in range(max_length):
            tgt_emb = self.tgt_embedding(tgt_input_ids)
            tgt_emb = self.positional_encoding(tgt_emb)

            # Transformer 前向傳播
            outs = self.transformer(
                src_emb, tgt_emb, 
                src_key_padding_mask=src_key_padding_mask, 
                memory_key_padding_mask=src_key_padding_mask
            )
            logits = self.fc(outs)
            next_token_logits = logits[:, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(1)
            tgt_input_ids = torch.cat([tgt_input_ids, next_token], dim=1)

            # 停止條件: 如果生成的序列中包含了結束符號 (EOS)
            if next_token.item() == sep_token_id:
                break

        return tgt_input_ids

In [None]:
import torch.optim as optim
from transformers import get_cosine_with_hard_restarts_schedule_with_warmup
from Trainer import Trainer

# 設定模型
model = Seq2SeqTransformer(
    vocab_size=len(tokenizer),
    emb_size=512,
    d_model=512,
    nhead=8,
    num_encoder_layers=6,
    num_decoder_layers=6,
    dim_feedforward=2048
)

# 優化器與排成器
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
        optimizer, 
        num_warmup_steps=len(train_loader), 
        num_training_steps=len(train_loader) * 100, 
        num_cycles=1, 
)

# 訓練模型
trainer = Trainer(
    epochs=100, 
    train_loader=train_loader, 
    valid_loader=valid_loader, 
    model=model, 
    optimizer=[optimizer],
    scheduler=[scheduler]
)
trainer.train(show_loss=True)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model.load_state_dict(torch.load('model.ckpt')).to(device)
model.eval()
idx = 7778
input_data = tokenizer(x_test[idx], max_length=1024, truncation=True, padding="longest", return_tensors='pt').to(device)
generated_ids = model.generate(**input_data, max_len=50)

print('輸入文字:\n', x_test[idx])
print('目標文字:\n', y_test[idx])
print('模型文字:\n', tokenizer.decode(generated_ids[0]))