In [4]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import torchtext
import nltk
from konlpy.tag import Kkma
from torchtext.data import Field, Iterator, BucketIterator, TabularDataset
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

## 데이터 준비 

대화 모델 더미 데이터

In [5]:
kor_tagger = Kkma()

kor_tagger = kor_tagger.morphs

In [6]:
SOURCE = Field(tokenize=kor_tagger,use_vocab=True,init_token="<s>",eos_token="</s>",lower=True, include_lengths=True, batch_first=True)
TARGET = Field(tokenize=kor_tagger,use_vocab=True,init_token="<s>",eos_token="</s>",lower=True, batch_first=True)

In [7]:
train_data = TabularDataset(
                                   path="data/dsksd_chat.txt",
                                   format='tsv', # \t로 구분
                                   #skip_header=True, # 헤더가 있다면 스킵
                                   fields=[('inputs',SOURCE),('targets',TARGET)])

In [8]:
SOURCE.build_vocab(train_data)
TARGET.build_vocab(train_data)

In [9]:
print(len(SOURCE.vocab),len(TARGET.vocab))

316 272


## Iterator vs BucketIterator 

In [10]:
train_loader = train_loader = Iterator(
    train_data, batch_size=32, device=-1, # device -1 : cpu, device 0 : 남는 gpu
    sort_key=lambda x: len(x.inputs),sort_within_batch=True,repeat=False) 

In [11]:
train_loader_bucket = BucketIterator(
    train_data, batch_size=32, device=-1, # device -1 : cpu, device 0 : 남는 gpu
    sort_key=lambda x: len(x.inputs),sort_within_batch=True,repeat=False) 

In [12]:
for batch in train_loader:
    break

In [15]:
for batch in train_loader_bucket:
    break

In [16]:
inputs, lengths = batch.inputs

## PackedSequence 

패딩이 들어간 시퀀스의 value를 저장하고, Computational graph가 생성되지 않게 한다.<br>
즉, 연산에 제로 패딩의 인풋이 영향을 미치지 못하도록 함

In [18]:
E = 50
H = 100

In [19]:
embed = nn.Embedding(len(SOURCE.vocab),E)
rnn = nn.GRU(E,H,batch_first=True)

In [20]:
embedded = embed(inputs)
print(embedded.size())

torch.Size([32, 12, 50])


In [26]:
packed = pack_padded_sequence(embedded,lengths.tolist(),batch_first=True)

In [30]:
32*12

384

In [27]:
packed

PackedSequence(data=Variable containing:
 1.0085e-01 -1.7665e+00  2.5217e+00  ...  -7.8955e-01  1.2836e+00  3.3859e-01
 1.0085e-01 -1.7665e+00  2.5217e+00  ...  -7.8955e-01  1.2836e+00  3.3859e-01
 1.0085e-01 -1.7665e+00  2.5217e+00  ...  -7.8955e-01  1.2836e+00  3.3859e-01
                ...                   ⋱                   ...                
 1.2681e+00 -1.0458e+00  1.9104e-01  ...   5.2714e-01 -2.4102e+00  6.1990e-01
 1.2681e+00 -1.0458e+00  1.9104e-01  ...   5.2714e-01 -2.4102e+00  6.1990e-01
 1.2681e+00 -1.0458e+00  1.9104e-01  ...   5.2714e-01 -2.4102e+00  6.1990e-01
[torch.FloatTensor of size 323x50]
, batch_sizes=[32, 32, 32, 32, 32, 32, 32, 32, 32, 23, 9, 3])

In [32]:
output, hidden  = rnn(packed)

In [33]:
output

PackedSequence(data=Variable containing:
-1.5300e-01  1.8669e-01 -1.3253e-01  ...   2.6583e-01 -1.0945e-02 -6.1352e-02
-1.5300e-01  1.8669e-01 -1.3253e-01  ...   2.6583e-01 -1.0945e-02 -6.1352e-02
-1.5300e-01  1.8669e-01 -1.3253e-01  ...   2.6583e-01 -1.0945e-02 -6.1352e-02
                ...                   ⋱                   ...                
 2.5500e-01  1.6015e-01  4.9886e-02  ...   2.9781e-01  1.5154e-01  1.0971e-01
 2.4965e-01  2.1226e-01  2.8637e-02  ...   3.3992e-01 -1.8563e-02  1.6277e-01
-1.3385e-03  3.0611e-01 -9.6225e-02  ...   5.6182e-01 -1.1624e-01  1.8418e-01
[torch.FloatTensor of size 323x100]
, batch_sizes=[32, 32, 32, 32, 32, 32, 32, 32, 32, 23, 9, 3])

## unPack 

Packing된 시퀀스를 다시 원래대로 되돌리면서 제로패딩 위치에는 특정값(디폴트 = 0.0)으로 채운다

In [35]:
output, output_lengths = pad_packed_sequence(output,batch_first=True)

In [37]:
output.size()

torch.Size([32, 12, 100])

In [38]:
output

Variable containing:
( 0 ,.,.) = 
 -0.1530  0.1867 -0.1325  ...   0.2658 -0.0109 -0.0614
 -0.3564 -0.0129 -0.1237  ...   0.1542  0.0032  0.0002
 -0.2718  0.0956  0.1616  ...  -0.0823 -0.1461  0.1032
           ...             ⋱             ...          
 -0.0712 -0.3333  0.1669  ...  -0.0447  0.4001  0.2581
  0.1617 -0.3127  0.0835  ...  -0.0738  0.4544  0.4266
  0.2550  0.1602  0.0499  ...   0.2978  0.1515  0.1097

( 1 ,.,.) = 
 -0.1530  0.1867 -0.1325  ...   0.2658 -0.0109 -0.0614
  0.0395  0.0188 -0.3340  ...   0.4473  0.2405 -0.0283
  0.0142 -0.1381  0.0426  ...   0.1912  0.3794 -0.3220
           ...             ⋱             ...          
 -0.0835 -0.2006  0.0309  ...   0.1878 -0.0817  0.2421
  0.1478 -0.2188  0.0391  ...   0.0361  0.1657  0.4840
  0.2497  0.2123  0.0286  ...   0.3399 -0.0186  0.1628

( 2 ,.,.) = 
 -0.1530  0.1867 -0.1325  ...   0.2658 -0.0109 -0.0614
  0.3858  0.2890 -0.1576  ...   0.0233  0.0934  0.1323
  0.1997  0.0041  0.1303  ...   0.0972  0.3346 -0.2706
   