In [1]:
import os

from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_

from onmt.modules import Embeddings
from onmt.encoders import RNNEncoder, TransformerEncoder
from onmt.decoders.decoder import StdRNNDecoder
from onmt.decoders.transformer import TransformerDecoder

from utils import Corpus, batchify, truncate, SOS_IDX

# TODO
1. Adversarial training할 때 discriminator accuracy에 따라 (O)
2. Validation step (O)
3. Transformer 테스트
4. 실제 데이터 Transfer 시켜서 눈으로 확인
5. Beam search

In [2]:
gpu_id = 0
device = torch.device("cuda:{}".format(gpu_id) if gpu_id != -1 else "cpu") 

In [3]:
max_length = 50
vocab_size = 10000
embed_dim = 100

bidirectional = True
num_layers = 1
rnn_size = 300
enc_dropout = 0.0

# specially for transformer
# num_heads = 2048
# ff_size = 2048

num_epoch = 100
batch_size = 32
learning_rate_ae = 1
learning_rate_d = 0.1
grad_clip = 1
disc_layer = '{}-400-200'.format(num_layers*rnn_size)

valid_every = 5
checkpoint_every = 10000

In [4]:
# (Path to textfile, Name, Use4Vocab)
datafiles = [
    ('data/yelp/pos_train.txt', "train0", True),
    ('data/yelp/neg_train.txt', "train1", True),
    ('data/yelp/pos_valid.txt', "valid0", False),
    ('data/yelp/neg_valid.txt', "valid1", False)
]

In [5]:
corpus = Corpus(datafiles,
                maxlen=max_length,
                vocab_size=vocab_size,
                lowercase=True)

Original vocab 9599; Pruned to 9603
Number of sentences dropped from data/yelp/pos_train.txt: 0 out of 267314 total
Number of sentences dropped from data/yelp/neg_train.txt: 0 out of 176787 total
Number of sentences dropped from data/yelp/pos_valid.txt: 0 out of 38205 total
Number of sentences dropped from data/yelp/neg_valid.txt: 0 out of 25278 total


In [6]:
vocab_size = len(corpus.dictionary.word2idx)
print("Vocabulary Size: {}".format(vocab_size))

Vocabulary Size: 9603


In [7]:
train0_data = batchify(corpus.data['train0'], batch_size, shuffle=True)
train1_data = batchify(corpus.data['train1'], batch_size, shuffle=True)

valid0_data = batchify(corpus.data['valid0'], batch_size, shuffle=False)
valid1_data = batchify(corpus.data['valid1'], batch_size, shuffle=False)

KeyboardInterrupt: 

In [None]:
class Net(nn.Module):
    def __init__(self, encoder, decoder0, decoder1):
        super(Net, self).__init__()
        
        self.encoder = encoder
        self.decoder0 = decoder0
        self.decoder1 = decoder1
        self.generator = None
        
        self.choose_decoder = lambda dec_idx: decoder0 if dec_idx == 0 else decoder1
        
    def forward(self, indices, lengths, dec_idx, only_enc=False):
        # Encode
        enc_final, memory_bank = self.encoder(indices, lengths)
        if only_enc:
            return torch.cat([enc_final[0:enc_final.size(0):2], enc_final[1:enc_final.size(0):2]], 2).squeeze(0)
        
        # Decode
        assert dec_idx == 0 or dec_idx == 1
        
#        enc_state = self.choose_decoder(dec_idx).init_decoder_state(indices, memory_bank, enc_final)
        enc_state = enc_final
        decoder_outputs, dec_state, attns = self.choose_decoder(dec_idx)(indices, memory_bank, enc_state, memory_lengths=lengths)
        decoded = self.generator(decoder_outputs)
        
        return decoded #decoder_outputs, attns, dec_state
    
    def generate(self, indices, lengths, dec_idx):
        assert dec_idx == 0 or dec_idx == 1
        batch_size = indices.size(1)
        
        enc_final, memory_bank = self.encoder(indices, lengths)
        
        token = torch.full((1, batch_size), SOS_IDX, dtype=torch.long, device=device).unsqueeze(2)
        dec_state = self.choose_decoder(dec_idx).init_decoder_state(indices, memory_bank, enc_final)
        
        # unroll
        all_indices = []
        for i in range(max_length):
            decoder_outputs, dec_state, attns = self.choose_decoder(dec_idx)(token, memory_bank, dec_state, memory_lengths=lengths)
            output = self.generator(decoder_outputs)
            topv, topi = output.topk(1, dim=2)
            
            all_indices.append(topi.squeeze(0).cpu())

        all_indices = torch.cat(all_indices, dim=1)

        return all_indices

In [None]:
class Discriminator(nn.Module):
    def __init__(self, ninput, noutput, layers, activation=nn.ReLU(), device=torch.device("cpu")):
        super(Discriminator, self).__init__()
        self.ninput = ninput
        self.noutput = noutput

        layer_sizes = [ninput] + [int(x) for x in layers.split('-')]
        self.layers = []

        for i in range(len(layer_sizes)-1):
            layer = nn.Linear(layer_sizes[i], layer_sizes[i+1]).to(device)
            self.layers.append(layer)
            self.add_module("layer"+str(i+1), layer)

            # No batch normalization in first layer
            if i != 0:
                bn = nn.BatchNorm1d(layer_sizes[i+1]).to(device)
                self.layers.append(bn)
                self.add_module("bn"+str(i+1), bn)

            self.layers.append(activation)
            self.add_module("activation"+str(i+1), activation)

        layer = nn.Linear(layer_sizes[-1], noutput).to(device)
        self.layers.append(layer)
        self.add_module("layer"+str(len(self.layers)), layer)

        self.init_weights()

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = layer(x)
        x = torch.sigmoid(x)
        return x

    def init_weights(self):
        init_std = 0.02
        for layer in self.layers:
            try:
                layer.weight.data.normal_(0, init_std)
                layer.bias.data.fill_(0)
            except:
                pass

In [None]:
def build_model(model_type='GRU', device=torch.device("cpu")):
    assert model_type == 'GRU' or model_type == 'TRANS'
    
    # Build encoder.
    src_embeddings = Embeddings(
        word_vec_size=embed_dim,
        word_vocab_size=vocab_size,
        word_padding_idx=0,
        position_encoding=(model_type == 'TRANS')
    )
    print('build_encoder embedding')
    if model_type == 'GRU':
        encoder = RNNEncoder(
            model_type, bidirectional=bidirectional, num_layers=num_layers,
            hidden_size=rnn_size, dropout=enc_dropout, embeddings=src_embeddings
        )
    elif model_type == 'TRANS':
        encoder = TransformerEncoder(
            num_layers=num_layers, d_model=rnn_size, heads=num_heads,
            d_ff=ff_size, dropout=enc_dropout, embeddings=src_embeddings
        )
        

    print('build_encoder')
    
    # Build decoders.
    tgt_embeddings0 = Embeddings(
        word_vec_size=embed_dim,
        word_vocab_size=vocab_size,
        word_padding_idx=0,
        position_encoding=(model_type == 'TRANS')
    )
    tgt_embeddings1 = Embeddings(
        word_vec_size=embed_dim,
        word_vocab_size=vocab_size,
        word_padding_idx=0,
        position_encoding=(model_type == 'TRANS')
    )

    print('build_decoder embedding')
    
    if model_type == 'GRU':
        decoder0 = StdRNNDecoder(
            rnn_type=model_type,
            bidirectional_encoder=bidirectional,
            num_layers=num_layers,
            hidden_size=rnn_size,
            embeddings=tgt_embeddings0
        )
        decoder1 = StdRNNDecoder(
            rnn_type=model_type,
            bidirectional_encoder=bidirectional,
            num_layers=num_layers,
            hidden_size=rnn_size,
            embeddings=tgt_embeddings1
        )
    elif model_type == 'TRANS':
        decoder0 = TransformerDecoder(
            num_layers=num_layers,
            d_model=rnn_size,
            heads=num_heads,
            d_ff=ff_size,
            attn_type=None,
            copy_attn=False,
            self_attn_type="scaled-dot",
            dropout=0.0,
            embeddings=tgt_embeddings0
        )
        decoder1 = TransformerDecoder(
            num_layers=num_layers,
            d_model=rnn_size,
            heads=num_heads,
            d_ff=ff_size,
            attn_type=None,
            copy_attn=False,
            self_attn_type="scaled-dot",
            dropout=0.0,
            embeddings=tgt_embeddings1
        )

    print('build_decoder')
    
    # Build Net(= encoder + decoder0 + decoder1).
    model = Net(encoder, decoder0, decoder1)
        
    generator = nn.Sequential(
        nn.Linear(rnn_size, vocab_size),
        nn.LogSoftmax(dim=-1))
    
    if model_type == 'TRANS':
        for p in model.parameters():
            if p.dim() > 1:
                xavier_uniform_(p)
        for p in generator.parameters():
            if p.dim() > 1:
                xavier_uniform_(p)

    '''
    if hasattr(model.encoder, 'embeddings'):
        model.encoder.embeddings.load_pretrained_vectors(
            model_opt.pre_word_vecs_enc, model_opt.fix_word_vecs_enc)
    if hasattr(model.decoder1, 'embeddings'):
        model.decoder1.embeddings.load_pretrained_vectors(
            model_opt.pre_word_vecs_dec, model_opt.fix_word_vecs_dec)
    if hasattr(model.decoder2, 'embeddings'):
        model.decoder2.embeddings.load_pretrained_vectors(
            model_opt.pre_word_vecs_dec, model_opt.fix_word_vecs_dec)
    '''
    
    # Add generator to model (this registers it as parameter of model).
    model.generator = generator
    model.to(device)

    return model

In [None]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

def initHIdden(cur_batch_size):
    return torch.zeros(1 * 2, cur_batch_size, rnn_size, device=device)
    
def build_my_model(device=torch.device("cpu")):    
    class EncoderRNN(nn.Module):
        def __init__(self, input_size, emb_size, hidden_size, n_layers, bidirectional=False):
            super(EncoderRNN, self).__init__()
            self.input_size = input_size
            self.emb_size = emb_size
            self.hidden_size = hidden_size
            self.n_layers = n_layers
            self.bidirectional = bidirectional

            self.embedding = nn.Embedding(input_size, emb_size, padding_idx=0)
            self.rnn = nn.GRU(emb_size, hidden_size, bidirectional=self.bidirectional)

        def forward(self, input_seqs, input_lens):
            """
            Inputs is batch of sentences: BATCH_SIZE x MAX_LENGTH+1
            """
            embedded = self.embedding(input_seqs)
            packed = pack_padded_sequence(embedded, input_lens)
            outputs, hidden = self.rnn(packed, initHIdden(input_seqs.size(1))) # default zero hidden
            outputs, output_lengths = pad_packed_sequence(outputs)
            return hidden, outputs

    class DecoderRNN(nn.Module):
        def __init__(self, hidden_size, emb_size, output_size,
                     n_layers, dropout_p, bidirection, gpu_id=-1):
            super(DecoderRNN, self).__init__()
            self.hidden_size = hidden_size
            self.emb_size = emb_size
            self.output_size = output_size
            self.n_layers = n_layers
            
            self.dropout_p = dropout_p
            self.gpu_id = gpu_id
            self.bi_encoder = bidirection
    
            self.embedding = nn.Embedding(output_size, emb_size, padding_idx=0)
            self.dropout = nn.Dropout(self.dropout_p)
            self.rnn = nn.GRU(emb_size, hidden_size)
            
        def forward_step(self, input_var, hidden, encoder_outputs):
            batch_size = input_var.size(0)
            output_size = input_var.size(1)
            
            embedded = self.embedding(input_var)
            embedded = self.dropout(embedded)
    
            output, hidden = self.rnn(embedded, hidden)
            return output, hidden, None
        
        def forward(self, decoder_input, encoder_outputs, encoder_hidden, memory_lengths=None):
            decoder_hidden = self._cat_directions(encoder_hidden) if self.bi_encoder else encoder_hidden
            
            decoder_output, decoder_hidden, _ = self.forward_step(decoder_input, decoder_hidden, encoder_outputs)
            
            return decoder_output.transpose(1, 0), decoder_hidden, None
        
        def _cat_directions(self, h):
            """ If the encoder is bidirectional, do the following transformation.
                (#directions * #layers, #batch, hidden_size) -> (#layers, #batch, #directions * hidden_size)
            """
            h = torch.cat([h[0:h.size(0):2], h[1:h.size(0):2]], 2)
            return h
        
    
    # Build encoder.
    print('build_encoder')
    encoder = EncoderRNN(vocab_size, embed_dim, rnn_size, num_layers, bidirectional=bidirectional)
    
    # Build decoders.
    print('build_decoder')
    decoder0 = DecoderRNN(rnn_size*2, embed_dim, vocab_size, num_layers, 0.0, bidirectional)
    decoder1 = DecoderRNN(rnn_size*2, embed_dim, vocab_size, num_layers, 0.0, bidirectional)

    
    # Build Net(= encoder + decoder0 + decoder1).
    model = Net(encoder, decoder0, decoder1)
        
    generator = nn.Sequential(
        nn.Linear(rnn_size*2, vocab_size),
        nn.LogSoftmax(dim=-1))
    
    # Add generator to model (this registers it as parameter of model).
    model.generator = generator
    model.to(device)

    return model

In [None]:
def train_my_ae(dec_idx, batch, temp=1):
    model.train()
    model.zero_grad()
    
    source, target, lengths = batch
    source = source.to(device)
    target = target.to(device)
    lengths = lengths.to(device)
    
    # output: batch x seq_len x ntokens
    output = model(source, lengths, dec_idx)
    
    # output_size: batch_size, maxlen, self.ntokens
    flattened_output = output.view(-1, vocab_size)
    
    recon_loss = F.nll_loss(flattened_output/temp, target, ignore_index=0, size_average=True)
    recon_loss.backward()
    
    # `clip_grad_norm` to prevent exploding gradient in RNNs / LSTMs
    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)  # max norm
    optimizer_ae.step()
    
    model.eval()
    return recon_loss.item()

In [None]:
def train_ae(dec_idx, batch, temp=1):
    model.train()
    model.zero_grad()
    
    source, target, lengths = batch
    source = source.unsqueeze(2).to(device)
    target = target.to(device)
    lengths = lengths.to(device)
    
    # output: batch x seq_len x ntokens
    output = model(source, lengths, dec_idx)
    
    # output_size: batch_size, maxlen, self.ntokens
    flattened_output = output.view(-1, vocab_size)
    
    recon_loss = F.nll_loss(flattened_output/temp, target, ignore_index=0, size_average=True)
    recon_loss.backward()
    
    # `clip_grad_norm` to prevent exploding gradient in RNNs / LSTMs
    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)  # max norm
    optimizer_ae.step()
    
    model.eval()
    return recon_loss.item()

In [None]:
def train_ae_masking(dec_idx, batch, temp=1):
    model.train()
    model.zero_grad()
    
    source, target, lengths = batch
    source = source.unsqueeze(2).to(device)
    target = target.to(device)
    lengths = lengths.to(device)
    
    # Create sentence length mask over padding
    mask = target.gt(0)
    masked_target = target.masked_select(mask)
    # examples x ntokens
    output_mask = mask.unsqueeze(1).expand(mask.size(0), vocab_size)
    
    # output: batch x seq_len x ntokens
    output = model(source, lengths, dec_idx)
    
    # output_size: batch_size, maxlen, self.ntokens
    flattened_output = output.view(-1, vocab_size)
    
    masked_output = flattened_output.masked_select(output_mask).view(-1, vocab_size)
    recon_loss = F.nll_loss(masked_output/temp, masked_target, size_average=True)
    recon_loss.backward()
    
    # `clip_grad_norm` to prevent exploding gradient in RNNs / LSTMs
    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)  # max norm
    optimizer_ae.step()
    
    return recon_loss.item()

In [None]:
def train_disc(dec_idx, batch):
    disc.train()
    disc.zero_grad()

    source, target, lengths = batch
    source = source.unsqueeze(2).to(device)
    batch_size = source.size(1)
    labels = torch.full([batch_size], dec_idx, device=device)

    # Train
    encoded = model(source, lengths, -1, only_enc=True).detach()
    scores = disc(encoded)
    
    disc_loss = F.binary_cross_entropy(scores.squeeze(1), labels)

    pred = scores.data.round().squeeze(1)
    accuracy = pred.eq(labels.data).float().mean()
    
    if accuracy < 0.99:
        disc_loss.backward()
        optimizer_d.step()
    
    return disc_loss.item(), accuracy

In [None]:
def train_adv(dec_idx, batch, temp=1):
    model.train()
    model.zero_grad()

    source, target, lengths = batch
    source = source.unsqueeze(2).to(device)
    flipped_class = 1-dec_idx
    batch_size = source.size(1)
    labels = torch.full([batch_size], flipped_class, device=device)

    # Train
    encoded = model(source, lengths, -1, only_enc=True)
    scores = disc(encoded)
    
    adv_loss = F.binary_cross_entropy(scores.squeeze(1), labels)
    adv_loss.backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    optimizer_ae.step()

    return adv_loss.item()

In [None]:
def valid_ae(dec_idx, data, it, temp=1):
    model.eval()
    
    file_name = "model_ep{}_{}to{}.txt".format(it, 1-dec_idx, dec_idx)
    
    with open(os.path.join("test", file_name), 'w', encoding='utf8') as fp:
        for batch in data:
            source, target, lengths = batch
            source = source.unsqueeze(2).to(device)
            lengths = lengths.to(device)
            
            result = model.generate(source, lengths, dec_idx)

            origin = source.squeeze(2).cpu().transpose(1, 0).numpy()
            transfered = result.numpy()

            for org, trans in zip(origin, transfered):
                words = [corpus.dictionary.idx2word[x] for x in org]
                fp.write(truncate(words) + '\n')
    
                words = [corpus.dictionary.idx2word[x] for x in trans]
                fp.write(truncate(words) + '\n\n')

In [None]:
def valid_ae_test(dec_idx, data, it, temp=1):
    model.eval()
    
    file_name = "model_test_ep{}_{}to{}.txt".format(it, 1-dec_idx, dec_idx)
    
    with open(os.path.join("test", file_name), 'w', encoding='utf8') as fp:
        for batch in data:
            source, target, lengths = batch
            source = source.unsqueeze(2).to(device)
            lengths = lengths.to(device)
            
            result = model.generate(source, lengths, dec_idx)

            origin = source.squeeze(2).cpu().transpose(1, 0).numpy()
            transfered = result.numpy()

            for org, trans in zip(origin, transfered):
                words = [corpus.dictionary.idx2word[x] for x in org]
                fp.write(truncate(words) + '\n')
    
                words = [corpus.dictionary.idx2word[x] for x in trans]
                fp.write(truncate(words) + '\n\n')

In [None]:
# If no checkpoint, model_opt == opt
#model = build_model(model_type='GRU', device=device)
model = build_my_model(device=device)
disc = Discriminator(ninput=rnn_size, noutput=1, layers=disc_layer, device=device)

In [None]:
#optimizer_ae = optim.SGD(model.parameters(), lr=learning_rate_ae)
optimizer_ae = optim.SGD(model.parameters(), lr=1)
optimizer_d = optim.SGD(disc.parameters(), lr=learning_rate_d)

In [None]:
######################################
# 
######################################

total_loss_ae0 = 0

for ep in tqdm(range(num_epoch)):
    for batch0, batch1 in zip(train0_data, train1_data):
        total_loss_ae0 += train_my_ae(0, batch0)
        
    print("[*] epoch : {}/{} / recon_loss : {:6.3f}".format(
        ep+1,
        num_epoch,
        total_loss_ae0
    ))
    
    total_loss_ae0 = 0
        
    if (ep+1) % valid_every == 0:
        print("[*] Validate the model...")
        valid_ae_test(0, valid0_data, ep+1)

In [None]:
### # 
######################################

total_losses = [0, 0, 0]

for ep in tqdm(range(num_epoch)):
    for batch0, batch1 in zip(train0_data, train1_data):
        total_losses[0] += train_ae(0, batch0)
        total_losses[0] += train_ae(1, batch1)
        
        total_losses[0] += train_disc(0, batch0)
        total_losses[0] += train_disc(1, batch1)
       
#        if train_acc0 > 0.75:
#            train_loss_adv0 = train_adv(disc, 0, batch0)
#        if train_acc1 > 0.75:
#            train_loss_adv1 = train_adv(disc, 1, batch1)
        
    print("[*] epoch : {}/{} / recon_loss : {:6.3f}".format(
        ep+1,
        num_epoch,
        total_loss_ae0
    ))
    
    total_loss_ae0 = 0
        
    if (ep+1) % valid_every == 0:
        print("[*] Validate the model...")
        valid_ae_test(0, valid0_data, ep+1)

In [None]:
for i in tqdm(range(100)):
    train_loss_ae0 = train_ae(0, train0_data[0])
#next(model.parameters())[1]
train_loss_ae0

In [None]:
valid_ae(0, train0_data[:1], 0)

# 