## Prepare for Data

In [1]:
from torchtext.datasets import TranslationDataset

class GEC_DATASET(TranslationDataset):
    @classmethod
    def splits(cls, 
               exts, 
               fields, 
               root, 
               train="train", 
               validation="dev", 
               test="test", 
               **kwargs):
        
        return super(GEC_DATASET, cls).splits(exts=exts, 
                                      fields=fields,
                                      path=root, 
                                      root=root,
                                      train=train, 
                                      validation=validation, 
                                      test=test,
                                      **kwargs)

In [2]:
from torchtext.data import Field
from torchtext.data import BucketIterator

def prepare_data(root):
    src_field = Field(init_token="<sos>", eos_token="<eos>", tokenize=lambda sentence: sentence.split(' '))
    trg_field = Field(init_token="<sos>", eos_token="<eos>", tokenize=lambda sentence: sentence.split(' '))
    
    #? Should be lower?
    
    train_set, valid_set, test_set = GEC_DATASET.splits(exts=('.src', '.trg'), 
                                                        fields=(src_field, trg_field),
                                                        root=root,
                                                        filter_pred=lambda sentence: 
                                                        len(vars(sentence)['src']) < MAX_LEN 
                                                        and len(vars(sentence)['trg']) < MAX_LEN)

    src_field.build_vocab(train_set, min_freq=MIN_FREQ)
    trg_field.build_vocab(train_set, min_freq=MIN_FREQ)

    train_iter, valid_iter, test_iter = BucketIterator.splits(
        datasets=(train_set, valid_set, test_set), 
        batch_size=BATCH_SIZE,
        device=DEVICE)
    
    return src_field, trg_field, train_iter, valid_iter, test_iter, train_set, valid_set, test_set

## Building the Model

### Encoder

In [3]:
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, src_vocab_size):  
        super().__init__()
        
        self.embedding = nn.Embedding(src_vocab_size, ENC_EMB_DIM)
        self.dropout = nn.Dropout(ENC_DROPOUT)  #+ dropout
        
        self.gru = nn.GRU(ENC_EMB_DIM, ENC_HID_DIM, bidirectional=True)
        
        self.fc = nn.Linear(ENC_HID_DIM * 2, DEC_HID_DIM)
        
    def forward(self, inputs):
        # (in)  inputs: [src_len, batch_size]
        # (out) outputs: [src_len, batch_size, enc_hid_dim * 2]
        # (out) hidden: [batch_size, dec_hid_dim]
        
        # (in)  inputs
        # (out) embedded: [src_len, batch_size, enc_emb_dim]
        embedded = self.dropout(
            self.embedding(inputs))
        
        # (in)  embedded
        # (out) outputs: [src_len, batch_size, enc_hid_dim * 2]
        # (out) hiddens: [2, batch_size, enc_hid_dim]
        outputs, hiddens = self.gru(embedded)
        
        # (in)  hiddens
        # (out) hidden: [batch_size, dec_hid_dim]
        hidden = torch.tanh(
            self.fc(
                torch.cat((hiddens[0], 
                           hiddens[1]), dim=1)))
        
        return outputs, hidden

### Attention

In [4]:
class Attn(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.fc = nn.Linear(DEC_HID_DIM + ENC_HID_DIM * 2, ATTN_V_DIM)
        self.v = nn.Parameter(torch.rand(1, ATTN_V_DIM))
        
    def forward(self, decoder_hidden, encoder_outputs):
        # (in)  decoder_hidden: [batch_size, dec_hid_dim]
        # (in)  encoder_outputs: [src_len, batch_size, enc_hid_dim * 2]
        # (out) attn: [batch_size, src_len]
        
        # (in)  decoder_hidden
        # (in)  encoder_outputs
        # (out) energy: [batch_size, src_len, attn_v_dim]
        energy = torch.tanh(
            self.fc(
                torch.cat((
                    decoder_hidden.unsqueeze(1).repeat(1, encoder_outputs.size()[0], 1), 
                    encoder_outputs.permute(1, 0, 2)), dim=2)))
        
        # (in)  v: [1, attn_v_dim]
        # (in)  energy
        # (out) attn: [batch_size, src_len]
        attn = F.softmax(self.v.unsqueeze(0).repeat(energy.size()[0], 1, 1).bmm(energy.permute(0, 2, 1)), dim=2).squeeze(1)  #m
        
        return attn

### Decoder

In [5]:
class Decoder(nn.Module):
    def __init__(self, trg_vocab_size):  
        super().__init__()
        
        self.embedding = nn.Embedding(trg_vocab_size, DEC_EMB_DIM)
        self.dropout = nn.Dropout(DEC_DROPOUT)  #+ dropout
        
        self.attn = Attn()
        
        self.gru = nn.GRU(DEC_EMB_DIM + ENC_HID_DIM * 2, DEC_HID_DIM)
        
        self.fc = nn.Linear(DEC_EMB_DIM + ENC_HID_DIM * 2 + DEC_HID_DIM, trg_vocab_size)
        
    def forward(self, last_output, decoder_hidden, encoder_outputs):
        # (in)  last_output: [batch_size]
        # (in)  decoder_hidden: [batch_size, dec_hid_dim]
        # (in)  encoder_outputs: [src_len, batch_size, enc_hid_dim * 2]
        #!(out) decoder_outputs: [batch_size, trg_vocab_size]
        # (out) decoder_hidden: [batch_size, dec_hid_dim]
        
        # (in)  last_output
        # (out) embedded: [batch_size, dec_emb_dim]
        embedded = self.dropout(
            self.embedding(last_output))
        
        # (in)  decoder_hidden
        # (in)  encoder_outputs
        # (out) attn: [batch_size, src_len]
        attn = self.attn(decoder_hidden, encoder_outputs)
        # (in)  attn
        # (in)  encoder_outputs
        # (out) context: [batch, enc_hid_dim * 2]
        context = attn.unsqueeze(1).bmm(encoder_outputs.permute(1, 0, 2)).squeeze(1)

        # (in)  embedded
        # (in)  context
        # (in)  decoder_hidden
        # (out) outputs: [1, batch_size, dec_hid_dim]
        # (out) decoder_hidden: [1, batch_size, dec_hid_dim]
        outputs, decoder_hidden = self.gru(
            torch.cat((embedded.unsqueeze(0), 
                       context.unsqueeze(0)), dim=2), 
            decoder_hidden.unsqueeze(0))
        
        # (in)  embedded
        # (in)  context
        # (in)  decoder_hidden
        # (out) decoder_outputs: [batch_size, trg_vocab_size]
        decoder_outputs = self.fc(
            torch.cat((embedded, 
                       context, 
                       decoder_hidden.squeeze(0)), dim=1))
        
        return decoder_outputs, decoder_hidden.squeeze(0), attn

### Seq2Seq

In [6]:
import random

class Seq2Seq(nn.Module):
    def __init__(self, src_vocab_size, trg_vocab_size):
        super().__init__()
        
        self.encoder = Encoder(src_vocab_size)
        self.decoder = Decoder(trg_vocab_size)
        
        self.trg_vocab_size = trg_vocab_size
        
    def forward(self, inputs, trgs, teacher_forcing_ratio=0.5):
        # (in)  inputs: [src_len, batch_size]
        # (in)  trgs: [trg_len, batch_size]
        # (out) outputs: [trg_len, batch_size, trg_vocab_size]
        
        # seq len of inputs and trgs may not always be the same
                
        # Encode.
        # (in)  inputs
        # (out) encoder_outputs: [src_len, batch_size, enc_hid_dim * 2]
        # (out) decoder_hidden: [batch_size, dec_hid_dim]
        encoder_outputs, decoder_hidden = self.encoder(inputs)
        
        # Decode.
        trg_len = trgs.size()[0]
        batch_size = trgs.size()[1]
        
        outputs = torch.zeros(trg_len, batch_size, self.trg_vocab_size, device=DEVICE)

        decoder_outputs = trgs[0]
        for t in range(1, trg_len):
            # (in)  decoder_output: [batch_size]
            # (in)  decoder_hidden
            # (in)  encoder_outputs
            #!(out) decoder_outputs: [batch_size, trg_vocab_size]
            # (out) decoder_hidden: [batch_size, dec_hid_dim]
            decoder_outputs, decoder_hidden, _ = self.decoder(decoder_outputs, 
                                                           decoder_hidden, 
                                                           encoder_outputs)
            
            outputs[t] = decoder_outputs
            
            decoder_outputs = decoder_outputs.argmax(dim=1) if teacher_forcing_ratio <= random.random() else trgs[t]
            
        return outputs

In [7]:
def init_weights(m):  #+ init weights
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)

## Training and Validating

In [8]:
def _train(train_iter, model, criterion, optimizer):
    
    train_loss = 0
    
    model.train()
    
    for batch in train_iter:
        
        # Gets data.
        srcs = batch.src
        trgs = batch.trg
        
        # Forward.
        outputs = model(srcs, trgs)
        
        # Loss.
        loss = criterion(outputs[1:].view(-1, outputs.size()[-1]), 
                         trgs[1:].view(-1))  #m
        
        # Backward.
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)  #+
        
        # Updates params
        optimizer.step()
        # Zeros grad.
        optimizer.zero_grad()
        
        train_loss += loss.item()
        
    return train_loss / len(train_iter)

In [9]:
def _evaluate(data_iter, model, criterion):
    
    eval_loss = 0
    
    model.eval()
    
    with torch.no_grad():
        for batch in data_iter:

            # Gets data.
            srcs = batch.src
            trgs = batch.trg

            # Forward.
            outputs = model(srcs, trgs, 0)

            # Loss.
            loss = criterion(outputs[1:].view(-1, outputs.size()[-1]), 
                             trgs[1:].view(-1))  #m

            eval_loss += loss.item()
        
        return eval_loss / len(data_iter)

In [10]:
def time_track(start, end):
    
    elapsed_time = end - start
    
    mins = int(elapsed_time / 60)
    secs = int(elapsed_time % 60)
    
    return f"{mins:>2}mins {secs:>2}secs"

In [11]:
import time
import math
import copy

def train(train_iter, valid_iter, model, criterion, optimizer):
        
    min_valid_loss = float("inf")  #+
    
    for epoch in range(N_EPOCHS):
        
        start = time.time()
        
        train_loss = _train(train_iter, model, criterion, optimizer)
        valid_loss = _evaluate(valid_iter, model, criterion)
    
        end = time.time()
        
        print(f"epoch: {epoch + 1:02}, time: {time_track(start, end)}")
        print(f"train loss: {train_loss:.3f}, train ppl: {math.exp(train_loss):.3f}")
        print(f"valid loss: {valid_loss:.3f}, valid ppl: {math.exp(valid_loss):.3f}")
        
        if valid_loss < min_valid_loss:  #+
            min_valid_loss = valid_loss
            torch.save(model.state_dict(), PT)
            
    print()

## Testing

In [12]:
def test(test_iter, model, criterion):
    model.load_state_dict(torch.load(PT))
    
    test_loss = _evaluate(test_iter, model, criterion)

    print(f"test loss: {test_loss:.3f}, test ppl: {math.exp(test_loss):.3f}")
    print()

## Inference  #+

In [13]:
def correct_sentence(sentence, src_field, trg_field, model):
    
    model.eval()
        
    if isinstance(sentence, str):
        tokens = [token for token in sentence.split(' ')]
    else:
        tokens = sentence

    tokens = [src_field.init_token] + tokens + [src_field.eos_token]
        
    src_indexes = [src_field.vocab.stoi[token] for token in tokens]
    src_tensor = torch.tensor(src_indexes, dtype=torch.long, device=DEVICE).unsqueeze(1)
    
    with torch.no_grad():
        encoder_outputs, hidden = model.encoder(src_tensor)
        
    trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]
    attentions = torch.zeros(MAX_LEN, 1, len(src_indexes), device=DEVICE)
    
    for i in range(MAX_LEN):

        trg_tensor = torch.tensor([trg_indexes[-1]], dtype=torch.long, device=DEVICE)
                
        with torch.no_grad():
            output, hidden, attention = model.decoder(trg_tensor, hidden, encoder_outputs)
            
        pred_token = output.argmax(1).item()
        
        trg_indexes.append(pred_token)
        attentions[i] = attention

        if pred_token == trg_field.vocab.stoi[trg_field.eos_token]:
            break
    
    trg_tokens = [trg_field.vocab.itos[i] for i in trg_indexes]
    
    return trg_tokens[1:], attentions[:len(trg_tokens)-1]

In [14]:
def display_attention(sentence, correction, attention):
    pass

In [15]:
import random

def inference(dataset, src_field, trg_field, model):
    
    for _ in range(10):
        example_idx = random.randint(1, len(dataset))
        
        src = vars(dataset.examples[example_idx])['src']
        trg = vars(dataset.examples[example_idx])['trg']

        print(f"src = {' '.join(src)}")
        print(f"trg = {' '.join(trg)}")
        
        correction, attention = correct_sentence(src, src_field, trg_field, model)
        
        print(f"out = {' '.join(correction[:-1])}")
        
        print('---')
        
#         display_attention(src, correction, attention)
    
    print()

## BLEU #+

In [16]:
from torchtext.data.metrics import bleu_score

def calculate_bleu(data, src_field, trg_field, model):
    
    trgs = []
    outs = []
    
    for datum in data:
        
        src = vars(datum)['src']
        trg = vars(datum)['trg']
        
        out, _ = correct_sentence(src, src_field, trg_field, model)
        
        #cut off <eos> token
        out = out[:-1]
        
        outs.append(out)
        trgs.append([trg])
        
    return bleu_score(outs, trgs)

## Main

In [17]:
import torch
import torch.optim as optim

if __name__ == '__main__':
    DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    DATASET = "../data/fce with bpe"
    
    # Prepares for data.
    ROOT = f"{DATASET}/parallel"
    MAX_LEN = 200
    MIN_FREQ = 1
    BATCH_SIZE = 16

    src_field, trg_field, train_iter, valid_iter, test_iter, train_set, valid_set, test_set = prepare_data(root=ROOT)
    
    # Model.
    ENC_EMB_DIM = 256
    ENC_HID_DIM = 512
    ENC_DROPOUT = 0.5
    ENC_DROPOUT = 0

    DEC_EMB_DIM = 256
    DEC_HID_DIM = 512
    DEC_DROPOUT = 0.5
    DEC_DROPOUT = 0
    
    ATTN_V_DIM = DEC_HID_DIM  #m
    
    model = Seq2Seq(len(src_field.vocab), len(trg_field.vocab)).to(DEVICE)
    model.apply(init_weights)
    
    # Criterion.
    criterion = nn.CrossEntropyLoss(ignore_index=trg_field.vocab.stoi["<pad>"])
    
    # Optimizer.
    LR = 0.0003
    
    optimizer = optim.Adam(model.parameters(), lr=LR)  #n Decreasing lr makes the change of valid_loss slowlier, thus easier to get the optimal?
    
    # Trains and validates.
    N_EPOCHS = 10
    CLIP = 1  #+
    PT = f"{DATASET}.pt"
    
    train(train_iter, valid_iter, model, criterion, optimizer)
    
    # Tests.
    test(test_iter, model, criterion)

epoch: 01, time:  4mins 20secs
train loss: 5.119, train ppl: 167.152
valid loss: 4.274, valid ppl: 71.784
epoch: 02, time:  4mins 27secs
train loss: 3.303, train ppl: 27.200
valid loss: 3.059, valid ppl: 21.311
epoch: 03, time:  4mins 26secs
train loss: 2.178, train ppl: 8.830
valid loss: 2.618, valid ppl: 13.707
epoch: 04, time:  4mins 23secs
train loss: 1.646, train ppl: 5.189
valid loss: 2.464, valid ppl: 11.756
epoch: 05, time:  4mins 24secs
train loss: 1.323, train ppl: 3.754
valid loss: 2.346, valid ppl: 10.443
epoch: 06, time:  4mins 33secs
train loss: 1.090, train ppl: 2.975
valid loss: 2.287, valid ppl: 9.843
epoch: 07, time:  4mins 24secs
train loss: 0.919, train ppl: 2.506
valid loss: 2.332, valid ppl: 10.299
epoch: 08, time:  4mins 22secs
train loss: 0.783, train ppl: 2.188
valid loss: 2.327, valid ppl: 10.251
epoch: 09, time:  4mins 25secs
train loss: 0.679, train ppl: 1.973
valid loss: 2.372, valid ppl: 10.717
epoch: 10, time:  4mins 25secs
train loss: 0.581, train ppl: 1

In [18]:
print("training inference")
inference(train_set, src_field, trg_field, model)

training inference
src = It is a really great pleasure going in and out from different shops , try on a variety of clothes and shoes , and finally go back home with do@@ zens of bags in your hands .
trg = It is a really great pleasure going in and out of different shops , trying on a variety of clothes and shoes , and finally going back home with do@@ zens of bags in your hands .
out = It is a really great pleasure going in and out from different shops , try on a variety of clothes and shoes and finally finally back home with do@@ zens of bags in your hands .
---
src = Yours sincerely ,
trg = Yours sincerely ,
out = Yours sincerely ,
---
src = Teresa .
trg = Teresa .
out = Teresa .
---
src = In the 21st century , people will all the wearing the same clothes , the same shoes , they will be buying the same accessories .
trg = In the 21st century , people will all the wearing the same clothes , the same shoes , they will be buying the same accessories .
out = In the 21st century , people 

In [19]:
print("validating inference")
inference(valid_set, src_field, trg_field, model)

validating inference
src = Yours sincerely
trg = Yours sincerely
out = Yours sincerely
---
src = So , I hope I will get my money back , it was very disappointing evening out in my life ! I hope you understand me !
trg = So , I hope I will get my money back . It was a very disappointing evening in my life ! I hope you understand me !
out = So , I hope I will get my money back . It was very disappointing evening out out of my life ! I hope you understand me !
---
src = The question is : Do I need some extra money and what kind of clothes should I take for this trip ?
trg = The question is : Do I need some extra money and what kind of clothes should I take for this trip ?
out = The question is : I do I need some extra money and what kind of clothes should I take for this trip ?
---
src = Then , we can go to the show during that period
trg = Then , we can go to the show during that period .
out = Then , we can go to the show during that time ?
---
src = JUST BE Y@@ OUR@@ S@@ EL@@ F
trg = J

In [20]:
print("testing inference")
inference(test_set, src_field, trg_field, model)

testing inference
src = In the past the people did n't have elec@@ tic@@ ity and if they wanted for example to read or to cook something they used to do in the fire .
trg = In the past people did n't have electricity and if they wanted , for example , to read or to cook something they used to light a fire .
out = In the past the people did n't have a e@@ and and they they wanted for example to read or to cook something they used to do in the fire .
---
src = The second one is the la@@ mp , the electricity that is very important in our life .
trg = The second one is the la@@ mp , the electricity that is very important in our life .
out = The second one is the la@@ circumstances , the is that is very important in our life .
---
src = I am writing to answer your question about the information on an interesting building to visit .
trg = I am writing to answer your question about an interesting building to visit .
out = I am writing to answer your question about the information on an intere

In [21]:
bleu = calculate_bleu(test_set, src_field, trg_field, model)
print(f'BLEU score = {bleu*100:.2f}')

BLEU score = 52.12
