## Prepare for Data

In [1]:
from torchtext.datasets import TranslationDataset

class GEC_DATASET(TranslationDataset):
    @classmethod
    def splits(cls, 
               exts, 
               fields, 
               root, 
               train="train", 
               validation="dev", 
               test="test", 
               **kwargs):
        
        return super(GEC_DATASET, cls).splits(exts=exts, 
                                      fields=fields,
                                      path=root, 
                                      root=root,
                                      train=train, 
                                      validation=validation, 
                                      test=test,
                                      **kwargs)

In [2]:
from torchtext.data import Field
from torchtext.data import BucketIterator

def prepare_data(root):
    src_field = Field(init_token="<sos>", eos_token="<eos>", tokenize=lambda sentence: sentence.split(' '))
    trg_field = Field(init_token="<sos>", eos_token="<eos>", tokenize=lambda sentence: sentence.split(' '))
    
    #? Should be lower?
    
    train_set, valid_set, test_set = GEC_DATASET.splits(exts=('.src', '.trg'), 
                                                        fields=(src_field, trg_field),
                                                        root=root,
                                                        filter_pred=lambda sentence: 
                                                        len(vars(sentence)['src']) < MAX_LEN 
                                                        and len(vars(sentence)['trg']) < MAX_LEN)

    src_field.build_vocab(train_set, min_freq=MIN_FREQ)
    trg_field.build_vocab(train_set, min_freq=MIN_FREQ)

    train_iter, valid_iter, test_iter = BucketIterator.splits(
        datasets=(train_set, valid_set, test_set), 
        batch_size=BATCH_SIZE,
        device=DEVICE)
    
    return src_field, trg_field, train_iter, valid_iter, test_iter, train_set, valid_set, test_set

## Building the Model

### Encoder

In [3]:
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, src_vocab_size):  
        super().__init__()
        
        self.embedding = nn.Embedding(src_vocab_size, ENC_EMB_DIM)
        self.dropout = nn.Dropout(ENC_DROPOUT)  #+ dropout
        
        self.gru = nn.GRU(ENC_EMB_DIM, ENC_HID_DIM, bidirectional=True)
        
        self.fc = nn.Linear(ENC_HID_DIM * 2, DEC_HID_DIM)
        
    def forward(self, inputs):
        # (in)  inputs: [src_len, batch_size]
        # (out) outputs: [src_len, batch_size, enc_hid_dim * 2]
        # (out) hidden: [batch_size, dec_hid_dim]
        
        # (in)  inputs
        # (out) embedded: [src_len, batch_size, enc_emb_dim]
        embedded = self.dropout(
            self.embedding(inputs))
        
        # (in)  embedded
        # (out) outputs: [src_len, batch_size, enc_hid_dim * 2]
        # (out) hiddens: [2, batch_size, enc_hid_dim]
        outputs, hiddens = self.gru(embedded)
        
        # (in)  hiddens
        # (out) hidden: [batch_size, dec_hid_dim]
        hidden = torch.tanh(
            self.fc(
                torch.cat((hiddens[0], 
                           hiddens[1]), dim=1)))
        
        return outputs, hidden

### Attention

In [4]:
class Attn(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.fc = nn.Linear(DEC_HID_DIM + ENC_HID_DIM * 2, ATTN_V_DIM)
        self.v = nn.Parameter(torch.rand(1, ATTN_V_DIM))
        
    def forward(self, decoder_hidden, encoder_outputs):
        # (in)  decoder_hidden: [batch_size, dec_hid_dim]
        # (in)  encoder_outputs: [src_len, batch_size, enc_hid_dim * 2]
        # (out) attn: [batch_size, src_len]
        
        # (in)  decoder_hidden
        # (in)  encoder_outputs
        # (out) energy: [batch_size, src_len, attn_v_dim]
        energy = torch.tanh(
            self.fc(
                torch.cat((
                    decoder_hidden.unsqueeze(1).repeat(1, encoder_outputs.size()[0], 1), 
                    encoder_outputs.permute(1, 0, 2)), dim=2)))
        
        # (in)  v: [1, attn_v_dim]
        # (in)  energy
        # (out) attn: [batch_size, src_len]
        attn = F.softmax(self.v.unsqueeze(0).repeat(energy.size()[0], 1, 1).bmm(energy.permute(0, 2, 1)), dim=2).squeeze(1)  #m
        
        return attn

### Decoder

In [5]:
class Decoder(nn.Module):
    def __init__(self, trg_vocab_size):  
        super().__init__()
        
        self.embedding = nn.Embedding(trg_vocab_size, DEC_EMB_DIM)
        self.dropout = nn.Dropout(DEC_DROPOUT)  #+ dropout
        
        self.attn = Attn()
        
        self.gru = nn.GRU(DEC_EMB_DIM + ENC_HID_DIM * 2, DEC_HID_DIM)
        
        self.fc = nn.Linear(DEC_EMB_DIM + ENC_HID_DIM * 2 + DEC_HID_DIM, trg_vocab_size)
        
    def forward(self, last_output, decoder_hidden, encoder_outputs):
        # (in)  last_output: [batch_size]
        # (in)  decoder_hidden: [batch_size, dec_hid_dim]
        # (in)  encoder_outputs: [src_len, batch_size, enc_hid_dim * 2]
        #!(out) decoder_outputs: [batch_size, trg_vocab_size]
        # (out) decoder_hidden: [batch_size, dec_hid_dim]
        
        # (in)  last_output
        # (out) embedded: [batch_size, dec_emb_dim]
        embedded = self.dropout(
            self.embedding(last_output))
        
        # (in)  decoder_hidden
        # (in)  encoder_outputs
        # (out) attn: [batch_size, src_len]
        attn = self.attn(decoder_hidden, encoder_outputs)
        # (in)  attn
        # (in)  encoder_outputs
        # (out) context: [batch, enc_hid_dim * 2]
        context = attn.unsqueeze(1).bmm(encoder_outputs.permute(1, 0, 2)).squeeze(1)

        # (in)  embedded
        # (in)  context
        # (in)  decoder_hidden
        # (out) outputs: [1, batch_size, dec_hid_dim]
        # (out) decoder_hidden: [1, batch_size, dec_hid_dim]
        outputs, decoder_hidden = self.gru(
            torch.cat((embedded.unsqueeze(0), 
                       context.unsqueeze(0)), dim=2), 
            decoder_hidden.unsqueeze(0))
        
        # (in)  embedded
        # (in)  context
        # (in)  decoder_hidden
        # (out) decoder_outputs: [batch_size, trg_vocab_size]
        decoder_outputs = self.fc(
            torch.cat((embedded, 
                       context, 
                       decoder_hidden.squeeze(0)), dim=1))
        
        return decoder_outputs, decoder_hidden.squeeze(0), attn

### Seq2Seq

In [6]:
import random

class Seq2Seq(nn.Module):
    def __init__(self, src_vocab_size, trg_vocab_size):
        super().__init__()
        
        self.encoder = Encoder(src_vocab_size)
        self.decoder = Decoder(trg_vocab_size)
        
        self.trg_vocab_size = trg_vocab_size
        
    def forward(self, inputs, trgs, teacher_forcing_ratio=0.5):
        # (in)  inputs: [src_len, batch_size]
        # (in)  trgs: [trg_len, batch_size]
        # (out) outputs: [trg_len, batch_size, trg_vocab_size]
        
        # seq len of inputs and trgs may not always be the same
                
        # Encode.
        # (in)  inputs
        # (out) encoder_outputs: [src_len, batch_size, enc_hid_dim * 2]
        # (out) decoder_hidden: [batch_size, dec_hid_dim]
        encoder_outputs, decoder_hidden = self.encoder(inputs)
        
        # Decode.
        trg_len = trgs.size()[0]
        batch_size = trgs.size()[1]
        
        outputs = torch.zeros(trg_len, batch_size, self.trg_vocab_size, device=DEVICE)

        decoder_outputs = trgs[0]
        for t in range(1, trg_len):
            # (in)  decoder_output: [batch_size]
            # (in)  decoder_hidden
            # (in)  encoder_outputs
            #!(out) decoder_outputs: [batch_size, trg_vocab_size]
            # (out) decoder_hidden: [batch_size, dec_hid_dim]
            decoder_outputs, decoder_hidden, _ = self.decoder(decoder_outputs, 
                                                           decoder_hidden, 
                                                           encoder_outputs)
            
            outputs[t] = decoder_outputs
            
            decoder_outputs = decoder_outputs.argmax(dim=1) if teacher_forcing_ratio <= random.random() else trgs[t]
            
        return outputs

In [7]:
def init_weights(m):  #+ init weights
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)

## Training and Validating

In [8]:
def _train(train_iter, model, criterion, optimizer):
    
    train_loss = 0
    
    model.train()
    
    for batch in train_iter:
        
        # Gets data.
        srcs = batch.src
        trgs = batch.trg
        
        # Forward.
        outputs = model(srcs, trgs)
        
        # Loss.
        loss = criterion(outputs[1:].view(-1, outputs.size()[-1]), 
                         trgs[1:].view(-1))  #m
        
        # Backward.
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)  #+
        
        # Updates params
        optimizer.step()
        # Zeros grad.
        optimizer.zero_grad()
        
        train_loss += loss.item()
        
    return train_loss / len(train_iter)

In [9]:
def _evaluate(data_iter, model, criterion):
    
    eval_loss = 0
    
    model.eval()
    
    with torch.no_grad():
        for batch in data_iter:

            # Gets data.
            srcs = batch.src
            trgs = batch.trg

            # Forward.
            outputs = model(srcs, trgs, 0)

            # Loss.
            loss = criterion(outputs[1:].view(-1, outputs.size()[-1]), 
                             trgs[1:].view(-1))  #m

            eval_loss += loss.item()
        
        return eval_loss / len(data_iter)

In [10]:
def time_track(start, end):
    
    elapsed_time = end - start
    
    mins = int(elapsed_time / 60)
    secs = int(elapsed_time % 60)
    
    return f"{mins:>2}mins {secs:>2}secs"

In [11]:
import time
import math
import copy

def train(train_iter, valid_iter, model, criterion, optimizer):
        
    min_valid_loss = float("inf")  #+
    
    for epoch in range(N_EPOCHS):
        
        start = time.time()
        
        train_loss = _train(train_iter, model, criterion, optimizer)
        valid_loss = _evaluate(valid_iter, model, criterion)
    
        end = time.time()
        
        print(f"epoch: {epoch + 1:02}, time: {time_track(start, end)}")
        print(f"train loss: {train_loss:.3f}, train ppl: {math.exp(train_loss):.3f}")
        print(f"valid loss: {valid_loss:.3f}, valid ppl: {math.exp(valid_loss):.3f}")
        
        if valid_loss < min_valid_loss:  #+
            min_valid_loss = valid_loss
            torch.save(model.state_dict(), PT)
            
    print()

## Testing

In [12]:
def test(test_iter, model, criterion):
    model.load_state_dict(torch.load(PT))
    
    test_loss = _evaluate(test_iter, model, criterion)

    print(f"test loss: {test_loss:.3f}, test ppl: {math.exp(test_loss):.3f}")
    print()

## Inference  #+

In [13]:
def correct_sentence(sentence, src_field, trg_field, model):
    
    model.eval()
        
    if isinstance(sentence, str):
        tokens = [token for token in sentence.split(' ')]
    else:
        tokens = sentence

    tokens = [src_field.init_token] + tokens + [src_field.eos_token]
        
    src_indexes = [src_field.vocab.stoi[token] for token in tokens]
    src_tensor = torch.tensor(src_indexes, dtype=torch.long, device=DEVICE).unsqueeze(1)
    
    with torch.no_grad():
        encoder_outputs, hidden = model.encoder(src_tensor)
        
    trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]
    attentions = torch.zeros(MAX_LEN, 1, len(src_indexes), device=DEVICE)
    
    for i in range(MAX_LEN):

        trg_tensor = torch.tensor([trg_indexes[-1]], dtype=torch.long, device=DEVICE)
                
        with torch.no_grad():
            output, hidden, attention = model.decoder(trg_tensor, hidden, encoder_outputs)
            
        pred_token = output.argmax(1).item()
        
        trg_indexes.append(pred_token)
        attentions[i] = attention

        if pred_token == trg_field.vocab.stoi[trg_field.eos_token]:
            break
    
    trg_tokens = [trg_field.vocab.itos[i] for i in trg_indexes]
    
    return trg_tokens[1:], attentions[:len(trg_tokens)-1]

In [14]:
def display_attention(sentence, correction, attention):
    pass

In [15]:
import random

def inference(dataset, src_field, trg_field, model):
    
    for _ in range(10):
        example_idx = random.randint(1, len(dataset))
        
        src = vars(dataset.examples[example_idx])['src']
        trg = vars(dataset.examples[example_idx])['trg']

        print(f"src = {' '.join(src)}")
        print(f"trg = {' '.join(trg)}")
        
        correction, attention = correct_sentence(src, src_field, trg_field, model)
        
        print(f"out = {' '.join(correction[:-1])}")
        
        print('---')
        
#         display_attention(src, correction, attention)
    
    print()

## BLEU #+

In [16]:
from torchtext.data.metrics import bleu_score

def calculate_bleu(data, src_field, trg_field, model):
    
    trgs = []
    outs = []
    
    for datum in data:
        
        src = vars(datum)['src']
        trg = vars(datum)['trg']
        
        out, _ = correct_sentence(src, src_field, trg_field, model)
        
        #cut off <eos> token
        out = out[:-1]
        
        outs.append(out)
        trgs.append([trg])
        
    return bleu_score(outs, trgs)

## Main

In [17]:
import torch
import torch.optim as optim

if __name__ == '__main__':
    DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    DATASET = "../data/fce+lang8(partial) with bpe (train seq num: 302616)"
    
    # Prepares for data.
    ROOT = f"{DATASET}/parallel"
    MAX_LEN = 100
    MIN_FREQ = 1
    BATCH_SIZE = 32

    src_field, trg_field, train_iter, valid_iter, test_iter, train_set, valid_set, test_set = prepare_data(root=ROOT)
    
    # Model.
    ENC_EMB_DIM = 600
    ENC_HID_DIM = 1_000
    ENC_DROPOUT = 0.5

    DEC_EMB_DIM = 600
    DEC_HID_DIM = 1_000
    DEC_DROPOUT = 0.5
    
    ATTN_V_DIM = 1_000  #m
    
    model = Seq2Seq(len(src_field.vocab), len(trg_field.vocab)).to(DEVICE)
    model.apply(init_weights)
    
    # Criterion.
    criterion = nn.CrossEntropyLoss(ignore_index=trg_field.vocab.stoi["<pad>"])
    
    # Optimizer.
    LR = 0.0003
    
    optimizer = optim.Adam(model.parameters(), lr=LR)  #n Decreasing lr makes the change of valid_loss slowlier, thus easier to get the optimal?
    
    # Trains and validates.
    N_EPOCHS = 10
    CLIP = 1  #+
    PT = f"{DATASET}.pt"
    
    train(train_iter, valid_iter, model, criterion, optimizer)
    
    # Tests.
    test(test_iter, model, criterion)

epoch: 01, time: 47mins  3secs
train loss: 2.420, train ppl: 11.251
valid loss: 2.154, valid ppl: 8.623
epoch: 02, time: 47mins  4secs
train loss: 1.293, train ppl: 3.645
valid loss: 1.973, valid ppl: 7.190
epoch: 03, time: 46mins 55secs
train loss: 1.110, train ppl: 3.033
valid loss: 1.980, valid ppl: 7.244
epoch: 04, time: 47mins  2secs
train loss: 0.992, train ppl: 2.698
valid loss: 1.983, valid ppl: 7.262
epoch: 05, time: 47mins  7secs
train loss: 0.903, train ppl: 2.466
valid loss: 2.048, valid ppl: 7.754
epoch: 06, time: 47mins  6secs
train loss: 0.832, train ppl: 2.298
valid loss: 2.098, valid ppl: 8.152
epoch: 07, time: 47mins  2secs
train loss: 0.776, train ppl: 2.173
valid loss: 2.118, valid ppl: 8.317
epoch: 08, time: 46mins 59secs
train loss: 0.731, train ppl: 2.077
valid loss: 2.161, valid ppl: 8.679
epoch: 09, time: 47mins  1secs
train loss: 0.697, train ppl: 2.008
valid loss: 2.234, valid ppl: 9.333
epoch: 10, time: 47mins  7secs
train loss: 0.668, train ppl: 1.950
valid

In [18]:
print("training inference")
inference(train_set, src_field, trg_field, model)

training inference
src = But I 'm here again , as always wi@@ shing to make my English better .
trg = But I 'm here again , as always wi@@ shing to make my English better .
out = But I 'm here again , as always always shing to make my English better .
---
src = re@@ i ha@@ ra@@ kami
trg = re@@ i ha@@ ra@@ kami
out = re@@ i ha@@ ra@@ kami
---
src = Work
trg = Work
out = Work
---
src = there was very beautiful , gra@@ ce@@ ful and anti@@ que architecture in front of me .
trg = There was very beautiful , gra@@ ce@@ ful and anti@@ que piece of architecture in front of me .
out = there was very beautiful , gra@@ ce@@ ful and anti@@ que architecture in front of me .
---
src = After dinner and taking a bath , I always wash the dishes but my daughter wants to play with me .
trg = After dinner and bath , I always wash the dishes but my daughter would want to play with me .
out = After dinner and taking a bath , I always wash the dishes but my daughter wants to play with me .
---
src = But it is

In [19]:
print("validating inference")
inference(valid_set, src_field, trg_field, model)

validating inference
src = Yours sincerely ,
trg = Yours sincerely ,
out = Yours sincerely ,
---
src = We started our journey and we arrived at our purpose place by 6 o'clock in the evening .
trg = We started our journey and we arrived at our destination by 6 o'clock in the evening .
out = We started our journey and we arrived at our purpose place by 6 o'clock in the evening .
---
src = Unfortunately , we were really disappointed about everything in the theatre .
trg = Unfortunately , we were really disappointed about everything in the theatre .
out = Unfortunately , we were really disappointed about everything in the theatre .
---
src = Cli@@ mbing , in contrast , is something I started few months ago , so I have not gained any success in this sport .
trg = Cli@@ mbing , in contrast , is something I started a few months ago , so I have not gained any success in this sport .
out = Cli@@ mbing , in contrast , is something I started few months ago , so I have not gained any success in th

In [20]:
print("testing inference")
inference(test_set, src_field, trg_field, model)

testing inference
src = The P@@ alace Ho@@ tel is located in the centre of the city , so you can go by train , by bus or by car .
trg = The P@@ alace Ho@@ tel is located in the centre of the city , so you can come by train , by bus or by car .
out = The P@@ alace Ho@@ tel is located in the centre of the city , so you can go by train train by by by by by .
---
src = If you want to talk with your friends , parents , partner or to ask something , you can call them without wondering how can be communicate with them .
trg = If you want to talk to your friends , parents , or partner to ask something , you can call them without wondering how you can communicate with them .
out = If you want to talk with your friends , parents , partner or to ask something , you can call them without wondering how can be communicate with them .
---
src = The popular attractions in our town , including the art gallery , museum , aquarium , market square , could be found near our mon@@ u@@ ment of our town , the

In [21]:
bleu = calculate_bleu(test_set, src_field, trg_field, model)
print(f'BLEU score = {bleu*100:.2f}')

BLEU score = 73.52
