# Paraphrase Generation with Deep Reinforcement Learning

Обзор работы по использованию обучения с подкреплением для задачи машинного перевода. Статья по [ссылке](https://www.aclweb.org/anthology/D18-1421.pdf).

## Библиотеки

In [1]:
import io
import math
import time
from tqdm import tqdm
from collections import Counter

import torch
import numpy as np
import scipy.spatial
import torch.nn as nn
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torch.nn import (TransformerEncoder, TransformerDecoder,
                      TransformerEncoderLayer, TransformerDecoderLayer)

import torchtext
from torchtext.vocab import Vocab
from torchtext.data.utils import get_tokenizer
from torchtext.utils import download_from_url, extract_archive

from transformers import AutoModel, AutoTokenizer

In [2]:
# Убирем рандом
_ = torch.manual_seed(0)

In [3]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE

'cuda'

## Базовый пример обучения генерации парафраза
[На основе кода по генерации перевода.](https://pytorch.org/tutorials/beginner/translation_transformer.html)

### Скачиваем данные для парафраза на инглийском.
[Статья 2021 года. ParaSCI: A Large Scientific Paraphrase Dataset for Longer Paraphrase Generation.](https://github.com/dqxiu/ParaSCI)

In [4]:
url_base = 'https://raw.githubusercontent.com/dqxiu/ParaSCI/master/Data/ParaSCI-ACL/'
train_urls = ('train/train.src', 'train/train.tgt')
val_urls = ('val/val.src', 'val/val.tgt')
test_urls = ('test/test.src', 'test/test.tgt')

train_filepaths = [download_from_url(url_base + url) for url in train_urls]
val_filepaths = [download_from_url(url_base + url) for url in val_urls]
test_filepaths = [download_from_url(url_base + url) for url in test_urls]

### Строим словарь

In [5]:
tokenizer = get_tokenizer('spacy', language='en_core_web_sm')

def build_vocab(filepaths, tokenizer):
    counter = Counter()
    for filepath in filepaths:
        with io.open(filepath, encoding="utf8") as f:
            for string_ in f:
                counter.update(tokenizer(string_))
    return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])

vocab = build_vocab(train_filepaths, tokenizer)

In [6]:
len(vocab)

14821

### Предобработка данных

In [7]:
def data_process(filepaths):
    raw_src_iter = iter(io.open(filepaths[0], encoding="utf8"))
    raw_tgt_iter = iter(io.open(filepaths[1], encoding="utf8"))
    data = []
    for (raw_src, raw_tgt) in zip(raw_src_iter, raw_tgt_iter):
        src_tensor_ = torch.tensor(
            [vocab[token] for token in tokenizer(raw_src.rstrip("\n"))],
            dtype=torch.long)
        tgt_tensor_ = torch.tensor(
            [vocab[token] for token in tokenizer(raw_tgt.rstrip("\n"))],
            dtype=torch.long)
        data.append((src_tensor_, tgt_tensor_))
    return data

train_data = data_process(train_filepaths)
val_data = data_process(val_filepaths)
test_data = data_process(test_filepaths)

### Финальные датасеты для обучения

In [8]:
BATCH_SIZE = 16
PAD_IDX = vocab['<pad>']
BOS_IDX = vocab['<bos>']
EOS_IDX = vocab['<eos>']

In [9]:
def generate_batch(data_batch):
    src_batch, tgt_batch = [], []
    for (src_item, tgt_item) in data_batch:
        src_batch.append(torch.cat([torch.tensor([BOS_IDX]), src_item, torch.tensor([EOS_IDX])], dim=0))
        tgt_batch.append(torch.cat([torch.tensor([BOS_IDX]), tgt_item, torch.tensor([EOS_IDX])], dim=0))
    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    return src_batch, tgt_batch

train_iter = DataLoader(train_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)
valid_iter = DataLoader(val_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)
test_iter = DataLoader(test_data, batch_size=BATCH_SIZE,
                       shuffle=True, collate_fn=generate_batch)

### Класс модели SEQ2SEQ transformer

#### Определение самой модели

In [10]:
class Seq2SeqTransformer(nn.Module):
    def __init__(self, num_encoder_layers: int, num_decoder_layers: int,
                 emb_size: int, src_vocab_size: int, tgt_vocab_size: int,
                 dim_feedforward:int = 512, dropout:float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        encoder_layer = TransformerEncoderLayer(d_model=emb_size, nhead=NHEAD,
                                                dim_feedforward=dim_feedforward)
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        decoder_layer = TransformerDecoderLayer(d_model=emb_size, nhead=NHEAD,
                                                dim_feedforward=dim_feedforward)
        self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)

    def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor,
                tgt_mask: Tensor, src_padding_mask: Tensor,
                tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        memory = self.transformer_encoder(src_emb, src_mask, src_padding_mask)
        outs = self.transformer_decoder(tgt_emb, memory, tgt_mask, None,
                                        tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer_encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer_decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

    
class PositionalEncoding(nn.Module):
    def __init__(self, emb_size: int, dropout, maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding +
                            self.pos_embedding[:token_embedding.size(0),:])

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size
    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
    
# Делаем, так чтобы в обучении не было заглядывания на дальнешие слова
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len), device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

#### Определения декодирования

In [11]:
def greedy_decode(model, src, src_mask, max_len, start_symbol, num_samples=1):
    src = src.to(DEVICE)
    src = torch.cat([src]*num_samples, dim=1)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    
    ys = torch.ones(1, num_samples).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len-1):
        memory = memory.to(DEVICE)
        memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(DEVICE).type(torch.bool)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                                    .type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim = 1)
        next_word = next_word.detach()

        ys = torch.cat([ys,
                        next_word.view(1, -1)], dim=0)
    return ys.transpose(0,1)

def sampling_decode(model, src, src_mask, max_len, start_symbol, num_samples=1):
    src = src.to(DEVICE)
    src = torch.cat([src]*num_samples, dim=1)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    
    ys = torch.ones(1, num_samples).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len-1):
        memory = memory.to(DEVICE)
        memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(DEVICE).type(torch.bool)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                                    .type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        next_word = torch.multinomial(torch.nn.functional.softmax(prob, dim=-1), 1)
        next_word = next_word.detach()

        ys = torch.cat([ys,
                        next_word.view(1, -1)], dim=0)
    return ys.transpose(0,1)

def paraphrase(model, 
              srcs, 
              src_vocab, 
              tgt_vocab, 
              src_tokenizer, 
              decoder=greedy_decode, 
              ret_tokens=False, 
              ret_idx=False, 
              max_len_add=10,
              input_idx=False,
              **argv):
    model.eval()
    global_answers = []
    for src in srcs:
        if not input_idx:
            tokens = [BOS_IDX] + [src_vocab.stoi[tok] for tok in src_tokenizer(src)]+ [EOS_IDX]
            src = torch.LongTensor(tokens)
        num_tokens = len(src)
        src = src.reshape(num_tokens, 1)
        
        src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
        tgt_tokens = decoder(model, src, src_mask, max_len=num_tokens + max_len_add, start_symbol=BOS_IDX, **argv)

        answers = []
        for tgt_token in tgt_tokens:
            if not ret_idx:
                reference = []
                for tok in tgt_token:
                    if tok.item() == tgt_vocab['<eos>']:
                        break
                    if tok.item() not in {tgt_vocab['<eos>'], tgt_vocab['<bos>'], tgt_vocab['<pad>']}:
                        reference.append(tgt_vocab.itos[tok])
                answers.append(" ".join(reference).strip())
                if ret_tokens:
                    answers[-1] = answers[-1].split(" ")
            else:
                reference = []
                for tok in tgt_token:
                    if tok.item() == tgt_vocab['<eos>']:
                        break
                    if tok.item() not in {tgt_vocab['<eos>'], tgt_vocab['<bos>'], tgt_vocab['<pad>']}:
                        reference.append(tok.item())
                        
                answers.append(reference)
        global_answers.append(answers)
    return global_answers

In [12]:
def evaluate(model, val_iter):
    model.eval()
    losses = 0
    for idx, (src, tgt) in (enumerate(valid_iter)):
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,
                              src_padding_mask, tgt_padding_mask, src_padding_mask)
        
        tgt_out = tgt[1:,:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()
    return losses / len(val_iter)

## Базовое обучение без RL

### Функции обучения

In [13]:
def train_epoch(model, train_iter, optimizer, loss_fn):
    model.train()
    losses = 0
    for idx, (src, tgt) in enumerate(train_iter):
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, 
                       tgt_input, 
                       src_mask, 
                       tgt_mask,
                       src_padding_mask, 
                       tgt_padding_mask, 
                       src_padding_mask)

        optimizer.zero_grad()

        tgt_out = tgt[1:,:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()
    return losses / len(train_iter)

### Инициализация модели

In [14]:
SRC_VOCAB_SIZE = len(vocab)
TGT_VOCAB_SIZE = len(vocab)
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3
NUM_EPOCHS = 20


transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, 
                                 NUM_DECODER_LAYERS,
                                 EMB_SIZE, SRC_VOCAB_SIZE, 
                                 TGT_VOCAB_SIZE,
                                 FFN_HID_DIM)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(
    transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9
)

### Обучение модели

In [None]:
for epoch in range(1, NUM_EPOCHS+1):
    start_time = time.time()
    train_loss = train_epoch(transformer, train_iter, optimizer, loss_fn)
    end_time = time.time()
    val_loss = evaluate(transformer, valid_iter)
    all_time = time.time()
    print(f"Epoch: {epoch}, "
          f"Train loss: {train_loss:.3f}, "
          f"Val loss: {val_loss:.3f}, "
          f"Epoch time = {(end_time - start_time):.3f}s, "
          f"All time = {(all_time - start_time):.3f}s")

Epoch: 1, Train loss: 4.585, Val loss: 2.956, Epoch time = 46.839s, All time = 48.204s
Epoch: 2, Train loss: 3.446, Val loss: 2.353, Epoch time = 47.250s, All time = 48.622s
Epoch: 3, Train loss: 3.006, Val loss: 2.021, Epoch time = 47.126s, All time = 48.496s
Epoch: 4, Train loss: 2.695, Val loss: 1.789, Epoch time = 47.338s, All time = 48.686s
Epoch: 5, Train loss: 2.447, Val loss: 1.607, Epoch time = 47.526s, All time = 48.876s
Epoch: 6, Train loss: 2.238, Val loss: 1.461, Epoch time = 47.660s, All time = 49.024s
Epoch: 7, Train loss: 2.055, Val loss: 1.390, Epoch time = 47.774s, All time = 49.241s
Epoch: 8, Train loss: 1.897, Val loss: 1.301, Epoch time = 47.755s, All time = 49.102s
Epoch: 9, Train loss: 1.755, Val loss: 1.255, Epoch time = 47.286s, All time = 48.647s
Epoch: 10, Train loss: 1.632, Val loss: 1.191, Epoch time = 47.471s, All time = 48.833s
Epoch: 11, Train loss: 1.525, Val loss: 1.164, Epoch time = 47.395s, All time = 48.770s
Epoch: 12, Train loss: 1.425, Val loss: 1

### Пример работы (greedy search)

In [None]:
%%time
paraphrase(transformer, ["in our work , we focus on supervised domain adaptation ."], 
          vocab, 
          vocab, 
          tokenizer, 
          decoder=greedy_decode, num_samples=5)[0]

### Пример работы (multinominal sampling)

In [None]:
%%time
paraphrase(transformer, ["in our work , we focus on supervised domain adaptation ."], 
          vocab, 
          vocab, 
          tokenizer, 
          decoder=sampling_decode, num_samples=5)[0]

## Обучение с RL
Продолжаем обучение уже используя модель из предыдущего пункта

In [None]:
class Reward(object):
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/LaBSE")
        self.bert = AutoModel.from_pretrained("sentence-transformers/LaBSE").to(DEVICE)
        self.bert.eval()
        
    def score(self, references, candidates):
        assert len(references) == len(candidates)
        with torch.no_grad():
            tokes = reward.tokenizer(
                references, return_tensors='pt', 
                padding=True, max_length=512, truncation=True).to(DEVICE)
            ref_emb = reward.bert(**tokes)[1].cpu().numpy()
            tokes = reward.tokenizer(
                candidates, return_tensors='pt', 
                padding=True, max_length=512, truncation=True).to(DEVICE)
            can_emb = reward.bert(**tokes)[1].cpu().numpy()

        distances = 1-scipy.spatial.distance.cdist(can_emb, 
                                                   ref_emb,
                                                   metric='cosine').diagonal()
    
        return distances.tolist()

In [None]:
reward = Reward()

In [None]:
def train_epoch_with_rl(model, train_iter, optimizer, loss_fn, alpha=0.75):
    model.train()
    losses = 0
    for idx, (src, tgt) in enumerate(train_iter):
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, 
                       tgt_input, 
                       src_mask, 
                       tgt_mask,
                       src_padding_mask, 
                       tgt_padding_mask, 
                       src_padding_mask)
        
######################RL-start##################################
        logits_batch_first = logits.transpose(0,1)
        # получаем предсказания для RL
        toks = torch.multinomial(
            torch.nn.functional.softmax(
                logits_batch_first.reshape(-1, logits_batch_first.shape[-1]), 
                dim=-1), 
            1).reshape(logits_batch_first.shape[:2])


        references = []
        candidates = []
        for real_toks, pred_toks in zip(tgt[1:,:].transpose(0,1), toks):
            reference = []
            for tok in real_toks:
                if tok.item() == vocab['<eos>']:
                    break
                if tok.item() not in {vocab['<eos>'], vocab['<bos>'], vocab['<pad>']}:
                    reference.append(vocab.itos[tok])
            candidate = []
            for tok in pred_toks:
                if tok.item() == vocab['<eos>']:
                    break
                if tok.item() not in {vocab['<eos>'], vocab['<bos>'], vocab['<pad>']}:
                    candidate.append(vocab.itos[tok])
            references.append(' '.join(reference))
            candidates.append(' '.join(candidate))

        reward_tr = torch.tensor(reward.score(references, candidates)).float().to(DEVICE)

        action_proba = torch.gather(torch.nn.functional.log_softmax(logits_batch_first, dim=-1), 
                                    2, 
                                    toks.view(*logits_batch_first.shape[:2], 1)).squeeze(-1)
######################RL-end###################################
        optimizer.zero_grad()

        loss = alpha*loss_fn(logits.reshape(-1, logits.shape[-1]), tgt[1:,:].reshape(-1)) \
               + (1-alpha)*(-1*reward_tr.view(-1,1)*action_proba).mean()
        
        loss.backward()

        optimizer.step()
        losses += loss.item()
    return losses / len(train_iter)

In [None]:
for epoch in range(1, NUM_EPOCHS+1):
    start_time = time.time()
    train_loss = train_epoch_with_rl(transformer, train_iter, optimizer, loss_fn)
    end_time = time.time()
    val_loss = evaluate(transformer, valid_iter)
    all_time = time.time()
    print(f"Epoch: {epoch}, "
          f"Train loss: {train_loss:.3f}, "
          f"Val loss: {val_loss:.3f}, "
          f"Epoch time = {(end_time - start_time):.3f}s, "
          f"All time = {(all_time - start_time):.3f}s")

### Пример работы (greedy search)

In [None]:
%%time
paraphrase(transformer, ["in our work , we focus on supervised domain adaptation ."], 
          vocab, 
          vocab, 
          tokenizer, 
          decoder=greedy_decode, num_samples=5)[0]

### Пример работы (multinominal sampling)

In [None]:
%%time
paraphrase(transformer, ["in our work , we focus on supervised domain adaptation ."], 
          vocab, 
          vocab, 
          tokenizer, 
          decoder=sampling_decode, num_samples=5)[0]