In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import torchtext
import nltk
from konlpy.tag import Kkma
from torchtext.data import Field, Iterator, BucketIterator, TabularDataset
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

## 데이터 준비 

대화 모델 더미 데이터

In [2]:
kor_tagger = Kkma()

kor_tagger = kor_tagger.morphs

In [3]:
SOURCE = Field(tokenize=kor_tagger,use_vocab=True,init_token="<s>",eos_token="</s>",lower=True, include_lengths=True, batch_first=True)
TARGET = Field(tokenize=kor_tagger,use_vocab=True,init_token="<s>",eos_token="</s>",lower=True, batch_first=True)

In [4]:
train_data = TabularDataset(
                                   path="data/dsksd_chat.txt",
                                   format='tsv', # \t로 구분
                                   #skip_header=True, # 헤더가 있다면 스킵
                                   fields=[('inputs',SOURCE),('targets',TARGET)])

In [5]:
SOURCE.build_vocab(train_data)
TARGET.build_vocab(train_data)

In [6]:
print(len(SOURCE.vocab),len(TARGET.vocab))

316 272


## Iterator vs BucketIterator 

In [7]:
train_loader = train_loader = Iterator(
    train_data, batch_size=32, device=-1, # device -1 : cpu, device 0 : 남는 gpu
    sort_key=lambda x: len(x.inputs),sort_within_batch=True,repeat=False) 

In [8]:
train_loader_bucket = BucketIterator(
    train_data, batch_size=32, device=-1, # device -1 : cpu, device 0 : 남는 gpu
    sort_key=lambda x: len(x.inputs),sort_within_batch=True,repeat=False) 

In [9]:
for batch in train_loader: # Iterator
    break

In [10]:
inputs, lengths = batch.inputs

In [12]:
for batch in train_loader_bucket:
    break

In [20]:
inputs, lengths = batch.inputs

## PackedSequence 

패딩이 들어간 시퀀스의 value를 저장하고, Computational graph가 생성되지 않게 한다.<br>
즉, 연산에 제로 패딩의 인풋이 영향을 미치지 못하도록 함

In [15]:
E = 50
H = 100

In [17]:
embed = nn.Embedding(len(SOURCE.vocab),E)
rnn = nn.GRU(E,H,batch_first=True)

In [18]:
embedded = embed(inputs)
print(embedded.size())

torch.Size([32, 12, 50])


In [23]:
packed = pack_padded_sequence(embedded,lengths.tolist(),batch_first=True)

In [24]:
packed

PackedSequence(data=Variable containing:
 1.6234e-02 -1.8424e+00  1.4331e+00  ...   6.8099e-02  9.2081e-01 -1.5433e+00
 1.6234e-02 -1.8424e+00  1.4331e+00  ...   6.8099e-02  9.2081e-01 -1.5433e+00
 1.6234e-02 -1.8424e+00  1.4331e+00  ...   6.8099e-02  9.2081e-01 -1.5433e+00
                ...                   ⋱                   ...                
 9.4805e-01  1.6872e+00 -2.5527e+00  ...  -2.2250e-01  1.4441e-01 -8.6105e-01
 9.4805e-01  1.6872e+00 -2.5527e+00  ...  -2.2250e-01  1.4441e-01 -8.6105e-01
 9.4805e-01  1.6872e+00 -2.5527e+00  ...  -2.2250e-01  1.4441e-01 -8.6105e-01
[torch.FloatTensor of size 323x50]
, batch_sizes=[32, 32, 32, 32, 32, 32, 32, 32, 32, 23, 9, 3])

In [26]:
output, hidden  = rnn(packed)

In [27]:
output

PackedSequence(data=Variable containing:
-0.0118 -0.1620 -0.1552  ...   0.0367 -0.1304 -0.1384
-0.0118 -0.1620 -0.1552  ...   0.0367 -0.1304 -0.1384
-0.0118 -0.1620 -0.1552  ...   0.0367 -0.1304 -0.1384
          ...             ⋱             ...          
 0.0992 -0.3912  0.3775  ...  -0.1070  0.0772  0.0654
 0.0239 -0.1329  0.3856  ...  -0.0854 -0.0348  0.0602
 0.0920 -0.5095  0.3166  ...  -0.3060  0.0027 -0.0288
[torch.FloatTensor of size 323x100]
, batch_sizes=[32, 32, 32, 32, 32, 32, 32, 32, 32, 23, 9, 3])

## unPack 

Packing된 시퀀스를 다시 원래대로 되돌리면서 제로패딩 위치에는 특정값(디폴트 = 0.0)으로 채운다

In [28]:
output, output_lengths = pad_packed_sequence(output,batch_first=True)

In [29]:
output.size()

torch.Size([32, 12, 100])

In [30]:
output

Variable containing:
( 0 ,.,.) = 
 -0.0118 -0.1620 -0.1552  ...   0.0367 -0.1304 -0.1384
 -0.1827 -0.2275 -0.1667  ...  -0.2092  0.1295 -0.0127
 -0.0726 -0.2754  0.0067  ...  -0.0280  0.1246  0.1970
           ...             ⋱             ...          
 -0.0685 -0.0039  0.3231  ...   0.2204  0.0366 -0.1132
  0.1538 -0.0491  0.2829  ...   0.1999  0.0752  0.0579
  0.0992 -0.3912  0.3775  ...  -0.1070  0.0772  0.0654

( 1 ,.,.) = 
 -0.0118 -0.1620 -0.1552  ...   0.0367 -0.1304 -0.1384
 -0.2294 -0.3520 -0.3003  ...   0.0351 -0.4487 -0.1060
 -0.1560 -0.3315 -0.1301  ...   0.0680 -0.1750  0.1779
           ...             ⋱             ...          
 -0.0379  0.5724  0.4190  ...   0.2257  0.0284 -0.1854
  0.0095  0.3871  0.2790  ...   0.1705 -0.2041  0.0232
  0.0239 -0.1329  0.3856  ...  -0.0854 -0.0348  0.0602

( 2 ,.,.) = 
 -0.0118 -0.1620 -0.1552  ...   0.0367 -0.1304 -0.1384
  0.1266 -0.2595 -0.0393  ...   0.1334  0.0193 -0.2565
  0.0180 -0.4487  0.1233  ...   0.2001 -0.1435 -0.0637
   