In [1]:
import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random
import numpy as np
from konlpy.tag import Mecab;tagger=Mecab()
from collections import Counter
%matplotlib inline  

* https://arxiv.org/pdf/1412.6581.pdf
* https://arxiv.org/pdf/1511.06349.pdf

In [2]:
data = open('../../DOMAIN_10D_300EA_DATA_170427.txt','r',encoding='utf-8').readlines()

In [3]:
data = [d.split('\t')[0] for d in data if 'FLOWER' in d.split('\t')[1]]

In [4]:
BATCH_SIZE = len(data)

In [5]:
SEQ_LENGTH=15
SOS_token = 0
EOS_token = 1

In [6]:
train=[]

In [7]:
for t0 in data:
    t0 = t0.replace("<br>","")
    t0 = t0.replace("/","")
    
    token0 = tagger.morphs(t0)
    
    if len(token0)>=SEQ_LENGTH:
        token0= token0[:SEQ_LENGTH-1]
    token0.append("EOS")

    while len(token0)<SEQ_LENGTH:
        token0.append('PAD')
    
    train.append([token0,token0])

In [8]:
n_words=4
word2index={"SOS":0,"EOS":1,"PAD":2,"UNK":3}

for t in train:
    for token in t[0]:
        if token not in word2index:
            word2index[token]=n_words
            n_words+=1

index2word = {v:k for k,v in word2index.items()}

In [9]:
def remove_list(x):
    del x[:]
    del x

In [10]:
def prepare_sequence(seq, to_ix):
    idxs = list(map(lambda w: to_ix[w], seq))
    tensor = torch.LongTensor(idxs)
    return Variable(tensor)

In [11]:
train_x=[]
train_y=[]
lengths=[]
for tr in train:
    temp = prepare_sequence(tr[0], word2index)
    temp = temp.view(1,-1)
    train_x.append(temp)

    temp2 = prepare_sequence(tr[1],word2index)
    temp2 = temp2.view(1,-1)
    train_y.append(temp2)
    
    length = [t for t in tr[1] if t !='PAD']
    lengths.append(len(length))

inputs = torch.cat(train_x)
targets = torch.cat(train_y)

remove_list(train_x)
remove_list(train_y)

In [12]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, n_layers=1,latent_size=10):
        super(EncoderRNN, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.linear = nn.Linear(hidden_size,latent_size*2)
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers,batch_first=True)
    
    def reparametrize(self, mu, log_var):
        """"z = mean + eps * sigma where eps is sampled from N(0, 1)."""
        eps = Variable(torch.randn(mu.size(0), mu.size(1)))
        z = mu + eps * torch.exp(log_var/2)    # 2 for convert var to std
        return z
    
    def forward(self, input,train=True):
        hidden = Variable(torch.zeros(self.n_layers, input.size(0), self.hidden_size)) 
        
        embedded = self.embedding(input)
        output, hidden = self.gru(embedded, hidden)
        h = self.linear(hidden[-1])
        mu, log_var = torch.chunk(h, 2, dim=1)  # mean and log variance.
        z = self.reparametrize(mu, log_var)
        
        return z,mu,log_var

In [13]:
encoder_test = EncoderRNN(len(word2index), 100, 2)
print(encoder_test)

EncoderRNN (
  (linear): Linear (100 -> 20)
  (embedding): Embedding(780, 100)
  (gru): GRU(100, 100, num_layers=2, batch_first=True)
)


In [14]:
out, mu,log_var = encoder_test(inputs[:3].view(3,-1))

In [15]:
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, n_layers=1,latent_size=10):
        super(DecoderRNN, self).__init__()
        
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        
        # Define the layers
        self.embedding = nn.Embedding(self.output_size, self.hidden_size)
        self.linear = nn.Linear(latent_size,hidden_size)
        self.tanh = nn.Tanh()
        #self.dropout = nn.Dropout(self.dropout_p)
        
        # 요 부분에서 y_{t-1}과 c를 concat해서 넣어준다..!
        # 이게 논문에서 제시하는 방법과는 조금 다른듯..
        # gru의 내부까지 컨트롤 할 방법이 없으니 (직접 짜지 않는 이상)
        self.gru = nn.GRU(self.hidden_size*2, self.hidden_size, self.n_layers,batch_first=True)
        self.out = nn.Linear(self.hidden_size*2, self.output_size)
        
    def forward(self, input,latent,lengths,seq_length,training=True):
        
        # Get the embedding of the current input word
        embedded = self.embedding(input)
        hidden = Variable(torch.zeros(self.n_layers, input.size(0), self.hidden_size))
        if training:
            context = self.tanh(self.linear(latent)).view(BATCH_SIZE,1,-1)
        else:
            context = self.tanh(self.linear(latent)).view(1,1,-1)
        #embedded = self.dropout(embedded)
        
        decode=[]
        # Apply GRU to the output so far
        for i in range(seq_length):
            
            
            _, hidden = self.gru(torch.cat((embedded,context),2), hidden)
            concated = torch.cat((hidden,context.transpose(0,1)),2)
            score = self.out(concated.squeeze(0))
            softmaxed = F.log_softmax(score)
            decode.append(softmaxed)
            _,input = torch.max(softmaxed,1)
            embedded = self.embedding(input)
            #embedded = self.dropout(embedded)
        
        # if training:
        # TODO 패딩이 아닌 진짜 length만 cost 계산하기...
            
        # 요고 주의! time-step을 column-wise concat한 후, reshape!!
        scores = torch.cat(decode,1)
        remove_list(decode)
        
        return scores.view(input.size(0)*seq_length,-1) 

In [18]:
HIDDEN_SIZE = 30
LATENT_SIZE = 10
STEP=1000
LEARNING_RATE=0.001

In [52]:
encoder =  EncoderRNN(len(word2index), HIDDEN_SIZE, 2)
decoder = DecoderRNN(HIDDEN_SIZE,len(word2index))
Recon = nn.CrossEntropyLoss()
enc_optim= torch.optim.Adam(encoder.parameters(), lr=LEARNING_RATE)
dec_optim = torch.optim.Adam(decoder.parameters(),lr=LEARNING_RATE)

In [53]:
for epoch in range(STEP):
    
    #KCA = 0.3
    encoder.zero_grad()
    decoder.zero_grad()
    
    decoder_input = Variable(torch.LongTensor([[SOS_token]*BATCH_SIZE])).transpose(1,0)
    normal = Variable(torch.randn([1,BATCH_SIZE,HIDDEN_SIZE])) 
    latent, mu, log_var = encoder(inputs)

    score = decoder(decoder_input,latent,lengths,SEQ_LENGTH)
    recon_loss=Recon(score,targets.view(-1))
    kld_loss = torch.sum(0.5 * (mu**2 + torch.exp(log_var) - log_var -1))
    #checker.append((recon_loss,kld_loss))
    
#     KL_COST_ANNEALING
    if recon_loss.data.numpy()[0]<1.0:
        KCA = 1.0

    else:
        KCA = 0.0
    ELBO = recon_loss+KCA*kld_loss
    loss = ELBO.data.numpy()[0]
    
    ELBO.backward()
    
    
    
    torch.nn.utils.clip_grad_norm(encoder.parameters(), 5.0)
    torch.nn.utils.clip_grad_norm(decoder.parameters(), 5.0)
    
    dec_optim.step()
    enc_optim.step()
 
    
    if epoch % 100==0:
        #kindex+=1
        print("[%d/%d] ELBO : %.4f , RECON : %.4f & KLD : %.4f" % (epoch,STEP,ELBO.data.numpy()[0],
                                                                              recon_loss.data.numpy()[0],
                                                                              kld_loss.data.numpy()[0]))

[0/1000] ELBO : 6.6353 , RECON : 6.6353 & KLD : 42.9220
[100/1000] ELBO : 4.3531 , RECON : 4.3531 & KLD : 4170.7188
[200/1000] ELBO : 3.8468 , RECON : 3.8468 & KLD : 7578.4424
[300/1000] ELBO : 3.5230 , RECON : 3.5230 & KLD : 12644.1143
[400/1000] ELBO : 3.2482 , RECON : 3.2482 & KLD : 17521.6719
[500/1000] ELBO : 3.0253 , RECON : 3.0253 & KLD : 22339.7461
[600/1000] ELBO : 2.8325 , RECON : 2.8325 & KLD : 26981.3750
[700/1000] ELBO : 2.6787 , RECON : 2.6787 & KLD : 31225.3848
[800/1000] ELBO : 2.5434 , RECON : 2.5434 & KLD : 35128.8203
[900/1000] ELBO : 2.4358 , RECON : 2.4358 & KLD : 38413.1055


## test 

### Recon

In [51]:
index=random.choice(range(300))
latent,_,_ = encoder(inputs[index].view(1,-1))
decoder_input = Variable(torch.LongTensor([[SOS_token]])).transpose(1,0)
#context = Variable(torch.randn([1,1,HIDDEN_SIZE])) 
recon = decoder(decoder_input,latent,lengths,SEQ_LENGTH,False)

v,i = torch.max(recon,1)

decoded=[]
for t in range(i.size()[0]):
    decoded.append(index2word[i.data.numpy()[t][0]])
    
print('Q: ', ' '.join([i for i in train[index][0] if i !='PAD' and i != 'EOS'])+'\n')
print('A: ', ' '.join([i for i in decoded if i !='PAD' and i != 'EOS'])+'\n')

Q:  여자 친구 한테 생일 꽃바구니 배달 요청 이 요

A:  꽃 하 님 화환



### Approximation

In [41]:
decoder_input = Variable(torch.LongTensor([[SOS_token]])).transpose(1,0)
context = Variable(torch.randn([1,10])) 
recon = decoder(decoder_input,context,lengths,SEQ_LENGTH,False)

v,i = torch.max(recon,1)

decoded=[]
for t in range(i.size()[0]):
    decoded.append(index2word[i.data.numpy()[t][0]])

print('A: ', ' '.join([i for i in decoded if i !='PAD' and i != 'EOS'])+'\n')

A:  꽃 하 님 화환

