# Seq2Seq approach to lemmatization task based on this paper - https://www.aclweb.org/anthology/N18-1126.pdf

# Seq2Seq model is conditional GRU from scratch based on this paper https://www.aclweb.org/anthology/E17-3017.pdf

In [24]:
import conllu
with open("data/UD_Russian-SynTagRus/ru_syntagrus-ud-train.conllu", 'r') as f:
    data = f.read()
sentences = conllu.parse(data)

In [25]:
import re

allowed_char_re = re.compile(r'[^a-zA-Zа-яА-Я.,?!ёЁ"]')
def preprocess(text):
    text = allowed_char_re.sub(' ', text)
    text = re.sub(' +', ' ', text)
    return text

In [26]:
is_skip_re = re.compile(r'[а-яА-ЯёЁ]+')
def is_skip(text):
    return not is_skip_re.fullmatch(text)

In [27]:
def prepare_sentence(sentence_tokens):
    text = ''
    cur_pointer = 0
    lemmas = []
    for s in sentence_tokens:
        if not is_skip(s['form']):
            lemmas.append(
                    {
                        "form": s['form'],
                        "lemma": s['lemma'],
                        "start_idx": cur_pointer,
                        "end_idx": cur_pointer + len(s['form'])
                    }
                )
        if s['misc'] is not None and s['misc']['SpaceAfter'] == 'No':
            cur_pointer += len(s['form'])
            text += s['form']
        else:
            cur_pointer += len(s['form']) + 1
            text += s['form'] + ' '

    text = text.strip(' ')
    return text, lemmas

In [28]:
def get_examples_from_converted(text, lemma, window_size):
    
    chars = list(text)
    target_word = lemma['form']
    
    left_border = max(0, lemma['start_idx'] - window_size)
    left_context = chars[left_border:lemma['start_idx']]
    
    right_border = min(len(text), lemma['end_idx'] + window_size)
    right_context = chars[lemma['end_idx']:right_border]
    
    target =  ['<lc>'] + list(target_word) + ['<rc>']
    return target, (left_context, right_context)

In [29]:
def get_dataset(sentences, window_size):
    dataset = []
    for sentence in sentences:
        text, lemmas = prepare_sentence(sentence)
        for lemma in lemmas:
            target, contexts = get_examples_from_converted(text, lemma, window_size)
            dataset.append({
                "input": contexts[0] + target + contexts[1],
                "target": list(lemma["lemma"])
            })
    return dataset

In [30]:
import torch
import torch.nn as  nn
import torch.nn.functional as F

In [31]:
class LemmatizationDataset(torch.utils.data.Dataset):
    
    def __init__(self, sentences, window_size):
        char_dataset = get_dataset(sentences, window_size=window_size)
        self.input_token2idx = self.create_token2idx("input", char_dataset)
        self.target_token2idx = self.create_token2idx("target", char_dataset)
        
        self.dataset = self.tokenize_dataset(char_dataset)
    
    @staticmethod
    def create_token2idx(dest, char_dataset):
        chars = set.union(*map(lambda x: set(x[dest]), char_dataset))
        if dest == "input":
            token2idx = {char: idx for idx, char in enumerate(chars, start=1)}
        elif dest == "target":
            token2idx = {char: idx for idx, char in enumerate(chars, start=3)}
            token2idx["<BOS>"] = 1
            token2idx["<EOS>"] = 2
        token2idx["<PAD>"] = 0
        return token2idx
    
    def tokenize_dataset(self, char_dataset):
        result = []
        for data_point in char_dataset:
            
            cur_inp = []
            for char in data_point["input"]:
                cur_inp.append(self.input_token2idx[char])
                
            cur_targ = []
            for char in data_point["target"]:
                cur_targ.append(self.target_token2idx[char])
            cur_targ = [self.target_token2idx["<BOS>"]] + cur_targ + [self.target_token2idx["<EOS>"]]
            
                
            result.append(
                {
                    "input": cur_inp,
                    "target": cur_targ ,
                }
            )
        return result
            
        
    def __getitem__(self, idx):
        return self.dataset[idx]
    
    def __len__(self):
        return len(self.dataset)

In [32]:
window_size = 25

In [33]:
dataset = LemmatizationDataset(sentences, window_size=window_size)

In [35]:
def collate_fn(batch):
    inputs = []
    targets = []
    
    for data_point in batch:
        inputs.append(torch.LongTensor(data_point["input"]))
        targets.append(torch.LongTensor(data_point["target"]))
    
    inp_tensor =  torch.nn.utils.rnn.pad_sequence(inputs, batch_first=True)
    target_tensor =  torch.nn.utils.rnn.pad_sequence(targets, batch_first=True)
    
    return inp_tensor, target_tensor

In [36]:
from sklearn.model_selection import train_test_split

def get_data_loaders(dataset, batch_size):
    train, valid = train_test_split(dataset)
    train_loader = torch.utils.data.DataLoader(
        train,
        shuffle=True,
        batch_size=batch_size,
        collate_fn=collate_fn
    )
    valid_loader = torch.utils.data.DataLoader(
        valid,
        shuffle=True,
        batch_size=batch_size,
        collate_fn=collate_fn
    )
    return {
        "train": train_loader,
        "valid": valid_loader,
    }

In [37]:
batch_size = 128
loaders = get_data_loaders(dataset, batch_size)

In [38]:
def collate_fn(batch):
    inputs = []
    targets = []
    
    for data_point in batch:
        inputs.append(torch.LongTensor(data_point["input"]))
        targets.append(torch.LongTensor(data_point["target"]))
    
    inp_tensor =  torch.nn.utils.rnn.pad_sequence(inputs, batch_first=True)
    target_tensor =  torch.nn.utils.rnn.pad_sequence(targets, batch_first=True)
    
    return inp_tensor, target_tensor


In [39]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [40]:
def get_mask(target, eos_token=2):
    mask = torch.zeros_like(target).to(device)
    bz, maxlen = target.shape
    for b in range(bz):
        for i in range(maxlen):
            mask[b, i] = 1
            if target[b, i] == eos_token:
                break
    return mask

def calculate_seq_loss(preds, target):
    preds = preds
    target = target[:, 1:]
    assert preds.shape[:2] == target.shape
    mask = get_mask(target)

    losses = []
    for i in range(target.shape[1]):
        loss = F.cross_entropy(preds[:, i], target[:, i], reduction='none')
        losses.append(loss)
    masked_loss = torch.stack(losses, dim=1) * mask
    return masked_loss.sum(axis=1).mean(dim=0)

In [41]:
class DecoderNetwork(nn.Module):
    def __init__(self, device=device, ):
        super().__init__()
        self.device = device
        
    def keep_decoding(self, train, stopped_flag_tensor):
        if train:
            return True
        else:
            return not torch.all(stopped_flag_tensor)
    
    def get_final_inds_tensor(self, hidden):
        bz = hidden.shape[0]
        return (torch.ones([bz], dtype=torch.long) * -1).to(self.device)
    
    def get_stopped_flag_tensor(self, hidden):
        bz = hidden.shape[0]
        return torch.tensor([False] * bz).to(self.device)
    
    def get_bos_tokens(self, hidden):
        bz = hidden.shape[0]
        return self.bos_token * torch.ones([bz], dtype=torch.long).to(self.device)

In [42]:
class Seq2Seq(nn.Module):
    
    def __init__(self, device):
        super().__init__()
        self.device = device
        
        self.input_token2idx = dataset.input_token2idx
        self.input_idx2token = {i: char for char, i in self.input_token2idx.items()}
        
        self.target_token2idx = dataset.target_token2idx
        self.target_idx2token = {i: char for char, i in self.target_token2idx.items()}
        
        self.to(device)
        
    def lemmatize_text(self, text):
        text, lc, rc = self.prepare_for_lemmatization(text)
        inp_list = self.prepare_inference_input(text, lc, rc)
        inp_tensor = torch.LongTensor(inp_list).unsqueeze(0)
        pred, _, _ = self.forward(inp_tensor)
        pred = pred.squeeze()
        length = pred.shape[0]
        inds = []
        
        for i in range(length):
            cur_pred = torch.argmax(pred[i])
            inds.append(cur_pred.item())
        return self.stringify(inds, mode="target")
        
    def stringify(self, inds, mode):
        if mode == "input":
            idx2token = self.input_idx2token
        elif mode == "target":
            idx2token = self.target_idx2token
        else:
            raise ValueError
        chars = []
        for ind in inds:
            chars.append(idx2token[ind])
            if ind == 2:
                break
        return ''.join(chars)
    
    def prepare_inference_input(self, text, lc, rc):
        text = preprocess(text)
        lc = preprocess(lc)
        rc = preprocess(rc)
        
        text_tokenized = [self.input_token2idx[char] for char in text]
        
        lc_tokenized = [self.input_token2idx[char] for char in lc]
        lc_border = [self.input_token2idx["<lc>"]]
        
        rc_tokenized = [self.input_token2idx[char] for char in rc]
        rc_border = [self.input_token2idx["<rc>"]]
        
        return lc_tokenized + lc_border + text_tokenized + rc_border + rc_tokenized
    
    @staticmethod
    def prepare_for_lemmatization(text, sep="|", window_size=25):
        lc, target, rc = text.split(sep)
        if len(lc) > window_size:
            lc = lc[-window_size:]
        if len(rc) > window_size:
            rc = rc[:window_size]
        return target, lc, rc

In [43]:
class EncoderNetwork(nn.Module):
    
    def __init__(self, enc_hid_size, emb_dim, emb_count, p, dec_hid_size, num_layers=2, device=device):
        self.device = device
        super().__init__()
        self.embeddings = nn.Embedding(
            num_embeddings=emb_count,
            embedding_dim=emb_dim,
            padding_idx=0,
        )
        
        self.rnn = nn.GRU(
            input_size=emb_dim,
            hidden_size=enc_hid_size,
            dropout=p,
            bidirectional=True,
            batch_first=True,
            num_layers=num_layers,
        )
        self.dropout = nn.Dropout(p)
        self.fc = nn.Linear(2 * num_layers * enc_hid_size, dec_hid_size)
        
    def get_initial_state(self, inp):
        shape = self.rnn.get_expected_hidden_size(inp, None)
        return torch.zeros(shape).to(self.device)
        
    def forward(self, x):
        x = x.to(device)
        lens = (x != 0).sum(dim=1)
        x = self.embeddings(x)
        packed = torch.nn.utils.rnn.pack_padded_sequence(
            x, 
            lengths=lens, 
            batch_first=True, 
            enforce_sorted=False
        )
        states, last_hidden = self.rnn(packed, self.get_initial_state(x))
        states, lens = torch.nn.utils.rnn.pad_packed_sequence(states, batch_first=True)
        last_hidden = torch.cat([*last_hidden], axis=1)
        
        last_hidden = self.fc(last_hidden)
        last_hidden = torch.tanh(last_hidden)
        last_hidden = self.dropout(last_hidden)
        
        return (states, lens), last_hidden
    
    

In [44]:
class BahdanauAttention(nn.Module):
    
    def __init__(self, encoder_hid_size, decoder_hid_size, attn_units):
        super().__init__()
        self.attn = nn.Linear(2 * encoder_hid_size + decoder_hid_size, attn_units)
        self.V = nn.Linear(attn_units, 1)
        
        self.encoder_states = None
        self.enc_seq_len = None
        
    def get_scores(self, concated_states):
        return self.V(torch.tanh(self.attn(concated_states))).squeeze()
        
    def init_states(self, encoder_states):
        self.encoder_states = encoder_states
        self.encoder_mask = (encoder_states.sum(dim=2) != 0).float()
        self.enc_seq_len = encoder_states.shape[1]
    
    @staticmethod
    def masked_softmax(tensor, mask, dim=1):
        exps = torch.exp(tensor)
        exps = exps * mask
        divider = torch.sum(exps, keepdim=True, dim=dim)
        return exps / divider
        
    def calc_attention(self, dec_hidden):
        
        expanded_dec_hidden = dec_hidden.unsqueeze(1).repeat(1, self.enc_seq_len, 1)
        
        concated_states = torch.cat([self.encoder_states, expanded_dec_hidden], dim=2)
        scores = self.get_scores(concated_states)
        
        weights = self.masked_softmax(scores, self.encoder_mask)
        
        attn_vecs = (self.encoder_states * weights.unsqueeze(2)).sum(dim=1)
        return attn_vecs, weights

In [45]:
class ConditionalDecoder(DecoderNetwork):
    
    def __init__(self, 
                 enc_hid_dim,
                 dec_hid_dim,
                 emb_dim,
                 emb_count, 
                 attn_dim, 
                 maxlen=30,
                 p=0.2, 
                 eos_token=2,
                 bos_token=1,
                 teacher_forcing_rate=0.0,
                 device=device):
        super().__init__(device)
        self.embeddings = nn.Embedding(
            embedding_dim=emb_dim,
            num_embeddings=emb_count,
            padding_idx=0,
        )
        
        self.bos_token = bos_token
        self.eos_token = eos_token
        self.maxlen = maxlen
        
        self.rnn_cell_bottom = nn.GRUCell(emb_dim, dec_hid_dim)
        self.attn_module = BahdanauAttention(enc_hid_dim, dec_hid_dim, attn_dim)
        self.rnn_cell_top = nn.GRUCell(dec_hid_dim, enc_hid_dim * 2)
        
        self.recurrent_fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
        
        self.dropout = nn.Dropout(p)
        
        self.teacher_forcing_rate = teacher_forcing_rate
        self.output_proj = nn.Linear(dec_hid_dim, emb_count)
        
        self.recurrent_fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
        
        self.relu = nn.ReLU()
        
    def process_input(self, inputs, state):
        embedded = self.embeddings(inputs)
        step1 = self.rnn_cell_bottom(embedded, state)
        step1 = self.dropout(step1)
        
        attn_vec, weights = self.attn_module.calc_attention(step1)
        attn_vec = self.dropout(attn_vec)
        
        step2 = self.rnn_cell_top(step1, attn_vec)
        
        step2 = self.relu(step2)
        step2 = self.recurrent_fc(step2)

        return step2, weights
        
    
    def decode(self, hidden, enc_states, true_labels=None, train=False):
        assert (train and true_labels is not None) or (not train and true_labels is None)
        self.attn_module.init_states(enc_states)
        
        inputs = self.get_bos_tokens(hidden)
        
        state, weights = self.process_input(inputs, hidden)
        attn_weights = [weights]
        states = [state]
        
        final_inds_tensor = self.get_final_inds_tensor(hidden)
        stopped_flag_tensor = self.get_stopped_flag_tensor(hidden)
        
        steps = 1
        
        if train:
            maxlen = true_labels.shape[1] - 1
        else:
            maxlen = self.maxlen
        while steps < maxlen and self.keep_decoding(train, stopped_flag_tensor):
            steps += 1
            if train and torch.rand(1) < self.teacher_forcing_rate:
                inputs = true_labels[:, steps]
            else:
                output_proj = self.output_proj(state)
                inputs = torch.argmax(output_proj, dim=1)
            state, weights = self.process_input(inputs, state)
            attn_weights.append(weights)
            states.append(state)
            
            stopped_flag_tensor = stopped_flag_tensor | (inputs == self.eos_token)
            final_inds_tensor = torch.where(
                stopped_flag_tensor & (final_inds_tensor == -1),
                torch.tensor(steps).to(self.device),
                final_inds_tensor,
            )
        final_inds_tensor = torch.where(final_inds_tensor == -1, 
                                        torch.tensor(steps).to(self.device),
                                        final_inds_tensor)
        
        states = torch.stack(states, dim=1)
        attn_weights = torch.stack(attn_weights, dim=1)
        preds = self.output_proj(states)
        return preds, final_inds_tensor, attn_weights
    

In [46]:
class ConditionalSeq2Seq(Seq2Seq):
    
    def __init__(self, enc_params, dec_params, device=device):
        super().__init__(device)
        self.encoder = EncoderNetwork(**enc_params)
        self.decoder = ConditionalDecoder(**dec_params)
        
    def forward(self, x, true_labels=None, train=False):
        x = x.to(device)
        if true_labels is not None:
            true_labels = true_labels.to(device)
        states, last_hidden =  self.encoder(x)
        preds, lens, attn_weights = self.decoder.decode(last_hidden, states[0], true_labels, train)
        return preds, lens, attn_weights

In [47]:
def unpad_seq(batch_inds, eos_token=2):
    result = []
    cur_seq = []
    for point in batch_inds:
        for ind in point:
            if ind != 2:
                cur_seq.append(ind.item())
            else:
                break
        result.append(cur_seq)
        cur_seq = []
    return result


def calc_accuracy(preds, target):
    pred_inds = preds.argmax(dim=2)
    target = target[:, 1:]
    pred_seqs = unpad_seq(pred_inds)
    true_seqs = unpad_seq(target)
    goods = 0
    for y_pred, y_true in zip(pred_seqs, true_seqs):
        if y_pred == y_true:
            goods += 1
    return goods / pred_inds.shape[0]

In [48]:
def validate_model(model):
    model.eval()
    losses = []
    accs = []
    for i, batch in enumerate(loaders['valid']):
        inp, target = batch
        target = target.to(device)
        preds, lens, _ = model(inp, true_labels=target, train=True)
            
        loss = calculate_seq_loss(preds, target)
        losses.append(loss.item())
            
        acc = calc_accuracy(preds, target)
        accs.append(acc)
    return sum(losses) / len(losses), sum(accs) / len(accs)
    

In [49]:
def train_model(model, NUM_EPOCHS=1, clip_value=1, lr=1e-3):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.15)
    for k in range(NUM_EPOCHS):
        print(f"Starting epoch {k}")
        model.train()
        losses = []
        accs = []
        for i, batch in enumerate(loaders['train']):
            optimizer.zero_grad()
            inp, target = batch
            target = target.to(device)
            preds, lens, _ = model(inp, true_labels=target, train=True)
            
            loss = calculate_seq_loss(preds, target)
            losses.append(loss.item())
            
            acc = calc_accuracy(preds, target)
            accs.append(acc)            
            
            loss.backward()            
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
            optimizer.step()
            
            if i % 100 == 0:
                print(f"train loss step {i} = {sum(losses) / len(losses)}")
                print(f"train accs step {i} = {sum(accs) / len(accs)}")
                print()
                losses = []
                accs = []
        valid_loss, valid_acc = validate_model(model)
        print(f"valid loss epoch {k} = {valid_loss}")
        print(f"valid accs epoch {k} = {valid_acc}")
        lr_scheduler.step()
    return model


In [50]:
enc_params = {
    "enc_hid_size": 256, 
    "dec_hid_size": 256,
    "emb_dim": 200,
    "emb_count": len(dataset.input_token2idx),
    "num_layers": 2,
    "p": 0.2, 
}

cond_dec_params = {
    "enc_hid_dim": 256,
    "dec_hid_dim": 256,
    "emb_dim": 200, 
    "emb_count": len(dataset.target_token2idx),
    "attn_dim": 100,
    "p": 0.2,
    "teacher_forcing_rate": 0.4,
}

In [51]:
cond_model = ConditionalSeq2Seq(enc_params, cond_dec_params).to(device)

In [52]:
trained_cond_model = train_model(cond_model, NUM_EPOCHS=4)

Starting epoch 0
train loss step 0 = 29.791912078857422
train accs step 0 = 0.0

train loss step 100 = 17.82345296859741
train accs step 100 = 0.0071875

train loss step 200 = 12.426255202293396
train accs step 200 = 0.09875

train loss step 300 = 5.477972460985184
train accs step 300 = 0.4196875

train loss step 400 = 3.701076601743698
train accs step 400 = 0.55015625

train loss step 500 = 2.769656639099121
train accs step 500 = 0.64515625

train loss step 600 = 2.2979648685455323
train accs step 600 = 0.698984375

train loss step 700 = 1.9486912441253663
train accs step 700 = 0.730703125

train loss step 800 = 1.8189603686332703
train accs step 800 = 0.75828125

train loss step 900 = 1.4886805638670921
train accs step 900 = 0.77859375

train loss step 1000 = 1.8013139227032662
train accs step 1000 = 0.771875

train loss step 1100 = 1.4700160697102547
train accs step 1100 = 0.794140625

train loss step 1200 = 1.4347372949123383
train accs step 1200 = 0.796953125

train loss step 1300

train loss step 2300 = 0.30351193368434903
train accs step 2300 = 0.944375

train loss step 2400 = 0.30785492226481437
train accs step 2400 = 0.944375

train loss step 2500 = 0.3070417229086161
train accs step 2500 = 0.938046875

train loss step 2600 = 0.29593746431171897
train accs step 2600 = 0.939375

train loss step 2700 = 0.31732655052095654
train accs step 2700 = 0.94015625

train loss step 2800 = 0.28621554709970953
train accs step 2800 = 0.9425

train loss step 2900 = 0.2961538009345531
train accs step 2900 = 0.941640625

train loss step 3000 = 0.343190878406167
train accs step 3000 = 0.94234375

train loss step 3100 = 0.28141997385770084
train accs step 3100 = 0.941015625

train loss step 3200 = 0.2930500101670623
train accs step 3200 = 0.94125

train loss step 3300 = 0.3049010809883475
train accs step 3300 = 0.939765625

train loss step 3400 = 0.2898872723430395
train accs step 3400 = 0.942421875

train loss step 3500 = 0.30787438694387675
train accs step 3500 = 0.938828125



In [123]:
teacher_model.lemmatize_text("привет почему ты так |думаешь|?")

'думать<EOS>'

In [53]:
loss, acc = validate_model(teacher_model)

In [61]:
acc

0.9533435519215626

In [62]:
loss

0.2426984647249171