In [0]:
# -*- coding: utf-8 -*-

from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import torch.autograd as autograd
import numpy as np 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [0]:
SOS_token = 0
EOS_token = 1


class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS"}
        self.n_words = 2  # Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1


In [0]:
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters


def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

In [0]:
def readLangs(lang1, lang2, reverse=False):
    print("Reading lines...")

    # Read the file and split into lines
    lines = open('%s_%s.txt' % (lang1, lang2), encoding='utf-8').\
        read().strip().split('\n')

    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]

    # Reverse pairs, make Lang instances
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else:
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)

    return input_lang, output_lang, pairs

In [0]:
MAX_LENGTH = 10

eng_prefixes = (
    "i am ", "i m ",
    "he is", "he s ",
    "she is", "she s ",
    "you are", "you re ",
    "we are", "we re ",
    "they are", "they re "
)


def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH and \
        len(p[1].split(' ')) < MAX_LENGTH and \
        p[1].startswith(eng_prefixes)


def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

In [8]:
def prepareData(lang1, lang2, reverse=False):
    input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
    print("Read %s sentence pairs" % len(pairs))
    pairs = filterPairs(pairs)
    print("Trimmed to %s sentence pairs" % len(pairs))
    print("Counting words...")
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
    print("Counted words:")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)
    return input_lang, output_lang, pairs


input_lang, output_lang, pairs = prepareData('eng', 'gar', True)
print(random.choice(pairs))

Reading lines...
Read 169813 sentence pairs
Trimmed to 9404 sentence pairs
Counting words...
Counted words:
gar 4434
eng 2872
['er ist hier nicht mehr willkommen .', 'he is no longer welcome here .']


In [0]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        output = embedded
        output, hidden = self.gru(output, hidden)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [0]:
class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):
        super(AttnDecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_p = dropout_p
        self.max_length = max_length

        self.embedding = nn.Embedding(self.output_size, self.hidden_size)
        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
        self.dropout = nn.Dropout(self.dropout_p)
        self.gru = nn.GRU(self.hidden_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input).view(1, 1, -1)
        embedded = self.dropout(embedded)
        
        attn_weights = F.softmax(
            self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)

        attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                                 encoder_outputs.unsqueeze(0))

        output = torch.cat((embedded[0], attn_applied[0]), 1)
        output = self.attn_combine(output).unsqueeze(0)

        output = F.relu(output)
        output, hidden = self.gru(output, hidden)

        output = F.log_softmax(self.out(output[0]), dim=1)
        return output, hidden, attn_weights

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [0]:
def indexesFromSentence(lang, sentence):
    return [lang.word2index[word] for word in sentence.split(' ')]


def tensorFromSentence(lang, sentence):
    indexes = indexesFromSentence(lang, sentence)
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)


def tensorsFromPair(pair):
    input_tensor = tensorFromSentence(input_lang, pair[0])
    target_tensor = tensorFromSentence(output_lang, pair[1])
    return (input_tensor, target_tensor)

In [0]:
hidden_size = 256
attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device)
encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device)

In [13]:
encoder1.load_state_dict(torch.load('encoder.dict'))
attn_decoder1.load_state_dict(torch.load('decoder.dict'))

<All keys matched successfully>

In [0]:
n_layers = 20
block_dim = 256
gp_lambda = 10
latent_dim = 256
interval = 1000
batch_size = 1
n_critic = 5

In [0]:
class Block(nn.Module):

    def __init__(self, block_dim):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(block_dim, block_dim),
            nn.ReLU(True),
            nn.Linear(block_dim, block_dim),
        )

    def forward(self, x):
        return self.net(x) + x


class Generator(nn.Module):

    def __init__(self, n_layers, block_dim):
        super().__init__()

        self.net = nn.Sequential(
            *[Block(block_dim) for _ in range(n_layers)]
        )

    def forward(self, x):
        return self.net(x)


class Critic(nn.Module):

    def __init__(self, n_layers, block_dim):
        super().__init__()

        self.net = nn.Sequential(
            *[Block(block_dim) for _ in range(n_layers)]
        )

    def forward(self, x):
        return self.net(x)


In [0]:
generator = Generator(n_layers, block_dim)
critic = Critic(n_layers, block_dim)
critic.to(device)
g_optimizer = optim.Adam(generator.parameters(), lr=1e-4)
c_optimizer = optim.Adam(critic.parameters(), lr=1e-4)

In [0]:
def compute_grad_penalty(critic, real_data, fake_data):
    B = real_data.size(0)
    alpha = torch.FloatTensor(np.random.random((B, 1)))
    alpha = alpha.to(device)
    sample = alpha*real_data + (1-alpha)*fake_data
    sample.requires_grad_(True)
    sample = sample.to(device)
    score = critic(sample)
    outputs = torch.FloatTensor(B, 256).fill_(1.0)
    outputs.requires_grad_(False)
    outputs = outputs.to(device)
 
    grads = autograd.grad(
        outputs=score,
        inputs=sample,
        grad_outputs=outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    #grads = grads.view(B, -1)
    grad_penalty = ((grads.norm(2, dim=1) - 1.) ** 2).mean()

    return grad_penalty

In [0]:
def train(epoch):


    generator = Generator(n_layers, block_dim)
    generator.to(device)
    critic = Critic(n_layers, block_dim)
    critic.to(device)

    g_optimizer = optim.Adam(generator.parameters(), lr=1e-4)
    c_optimizer = optim.Adam(critic.parameters(), lr=1e-4)
    
    encoder1.eval()
    attn_decoder1.eval()
    generator.train()
    critic.train()

    c_train_loss = 0.
    g_train_loss = 0.
    g_batches = 0
    hidden_size = 256
    max_length = 10

    for i in range(len(pairs)):
      pair = pairs[i]
      sentence = pair[0]

      input_tensor = tensorFromSentence(input_lang, sentence)
      input_length = input_tensor.size()[0]

      encoder_hidden = encoder1.initHidden()
      encoder_outputs = torch.zeros(input_length, encoder1.hidden_size, device=device)
        
      decoder_output = torch.zeros(max_length, latent_dim, device=device)

      fc3 = nn.Linear(encoder1.hidden_size, latent_dim)
      fc3.to(device)

      with torch.no_grad():
          for ei in range(input_length):
              encoder_output, encoder_hidden = encoder1(input_tensor[ei], encoder_hidden)
              encoder_outputs[ei] += encoder_output[0, 0]

      encoder_outputs = fc3(encoder_outputs)
      c_optimizer.zero_grad()

      noise = torch.from_numpy(np.random.normal(0, 1, (input_length, latent_dim))).float()
      noise = noise.to(device)

      z_fake = generator(noise)        
      z_fake.to(device)

      real_score = critic(encoder_outputs)
      fake_score = critic(z_fake)
      grad_penalty = compute_grad_penalty(critic, encoder_outputs.data, z_fake.data)
      
      c_loss = -torch.mean(real_score) + torch.mean(fake_score) + gp_lambda * grad_penalty
      c_train_loss += c_loss.item()
      c_loss.backward()
      c_optimizer.step()



      # train generator
      if i % n_critic == 0:
          g_batches += 1
          g_optimizer.zero_grad()
          fake_score = critic(generator(noise))
          g_loss = -torch.mean(fake_score)
          g_train_loss += g_loss.item()
          g_loss.backward()
          g_optimizer.step()

      if interval > 0 and i % interval == 0:
          print('Epoch: {} | Batch: {}/{} ({:.0f}%) | G Loss: {:.6f} | C Loss: {:.6f}'.format(
              epoch, batch_size * i, len(pairs),
                      100. * (batch_size * i) / len(pairs),
              g_loss.item(), c_loss.item()
          ))

    print("End of loop ====>>>>>")
    g_train_loss /= g_batches
    c_train_loss /= len(pairs)
    print('* (Train) Epoch: {} | G Loss: {:.4f} | C Loss: {:.4f}'.format(
        epoch, g_train_loss, c_train_loss
    ))
    return (g_train_loss, c_train_loss)

In [20]:
best_loss = np.inf
epochs = 10
for epoch in range(1, epochs + 1):
    g_loss, c_loss = train(epoch)
    loss = g_loss + c_loss
    if loss < best_loss:
        best_loss = loss
        print('* Saved')
        torch.save(generator.state_dict(), 'generator.th')
        torch.save(critic.state_dict(), 'critic.th')

Epoch: 1 | Batch: 0/9404 (0%) | G Loss: 0.086658 | C Loss: 6440.877930
Epoch: 1 | Batch: 1000/9404 (11%) | G Loss: -0.117293 | C Loss: 0.301929
Epoch: 1 | Batch: 2000/9404 (21%) | G Loss: -0.109819 | C Loss: 1.829215
Epoch: 1 | Batch: 3000/9404 (32%) | G Loss: -0.115603 | C Loss: 0.374288
Epoch: 1 | Batch: 4000/9404 (43%) | G Loss: -0.100903 | C Loss: 0.092565
Epoch: 1 | Batch: 5000/9404 (53%) | G Loss: -0.062586 | C Loss: 0.030611
Epoch: 1 | Batch: 6000/9404 (64%) | G Loss: -0.014667 | C Loss: 0.314916
Epoch: 1 | Batch: 7000/9404 (74%) | G Loss: -0.036270 | C Loss: 0.051012
Epoch: 1 | Batch: 8000/9404 (85%) | G Loss: -0.018941 | C Loss: 0.003692
Epoch: 1 | Batch: 9000/9404 (96%) | G Loss: 0.002555 | C Loss: 0.041357
End of loop ====>>>>>
* (Train) Epoch: 1 | G Loss: -0.0556 | C Loss: 6.3920
* Saved
Epoch: 2 | Batch: 0/9404 (0%) | G Loss: 0.253180 | C Loss: 8267.447266
Epoch: 2 | Batch: 1000/9404 (11%) | G Loss: -0.004784 | C Loss: 0.029291
Epoch: 2 | Batch: 2000/9404 (21%) | G Loss: -