## 👉08-04 토치텍스트(TorchText)의 batch_first

## 1. 훈련 데이터와 테스트 데이터로 분리하기

In [1]:
import urllib.request
import pandas as pd

In [2]:
urllib.request.urlretrieve("https://raw.githubusercontent.com/LawrenceDuan/IMDb-Review-Analysis/master/IMDb_Reviews.csv", filename="IMDb_Reviews.csv")

('IMDb_Reviews.csv', <http.client.HTTPMessage at 0x2fbc6c77970>)

In [3]:
df = pd.read_csv("IMDb_Reviews.csv", encoding="latin1")

In [4]:
print("전체 샘플의 개수 : {}".format(len(df)))

전체 샘플의 개수 : 50000


In [5]:
train_df = df[:25000]
test_df = df[25000:]

In [6]:
train_df.to_csv("train_data.csv", index=False)
test_df.to_csv("test_data.csv", index=False)

## 2. 필드 정의하기 (torchtext.legacy.data)

In [7]:
from torchtext.legacy import data

In [8]:
TEXT = data.Field(sequential=True,
                 use_vocab=True,
                 tokenize=str.split,
                 lower=True,
                 batch_first=True,
                 fix_length=20)

LABEL = data.Field(sequential=False,
                  use_vocab=False,
                  batch_first=False,
                  is_target=True)

## 3. 데이터셋 / 단어 집합 / 데이터로더 만들기

In [9]:
from torchtext.legacy.data import TabularDataset, Iterator

In [10]:
train_data, test_data = TabularDataset.splits(path='.',
                                             train="train_data.csv",
                                             test="test_data.csv",
                                             format="csv",
                                             fields=[("text", TEXT), ("label", LABEL)],
                                             skip_header=True)

In [11]:
TEXT.build_vocab(train_data, min_freq=10, max_size=10000)

In [12]:
batch_size = 5
train_loader = Iterator(dataset=train_data, batch_size=batch_size)
batch = next(iter(train_loader))

In [13]:
print(batch.text)

tensor([[   0,    0,   13,  776,    3,  349,    5,    0, 7280, 9725,    9,  260,
           16,   10,  131,    6,  466,  207,   19,   38],
        [1916,   35,   41,  124,   71,    3,  349,    5,    2,  238, 1218,   40,
           53,   41,    3, 2675,   42,    2,  975,  133],
        [  92, 1225,   62,  301,    0,    0,  807,    0,  394, 3514,    0,    0,
         1192,   12,  239,   35,  307,    6,  301,    0],
        [  10,    7,    3, 2633,   20,   12,    7,   38, 2995,   12,   55, 1846,
         2954,    4,    9,  235,   11,    0,    2,  130],
        [   9,  743, 2625,    4,   26,   99, 2382, 5646,   18,   50,   53,  332,
            5,    2,  800,   12,  180,  327,    8,    0]])


In [14]:
print(batch.text.shape)

torch.Size([5, 20])


## 4. 필드 재정의하기 (torchtext.legacy.data)

In [15]:
TEXT = data.Field(sequential=True,
                 use_vocab=True,
                 tokenize=str.split,
                 lower=True,
                 fix_length=20)

LABEL = data.Field(sequential=False,
                  use_vocab=False,
                  batch_first=False,
                  is_target=True)

In [16]:
train_data, test_data = TabularDataset.splits(path='.',
                                             train="train_data.csv",
                                             test="test_data.csv",
                                             format="csv",
                                             fields=[("text", TEXT), ("label", LABEL)],
                                             skip_header=True)

In [17]:
TEXT.build_vocab(train_data, min_freq=10, max_size=10000)

In [18]:
batch_size = 5
train_loader = Iterator(dataset=train_data, batch_size=batch_size)
batch = next(iter(train_loader))

In [19]:
print(batch.text)

tensor([[   0, 1916,   92,   10,    9],
        [   0,   35, 1225,    7,  743],
        [  13,   41,   62,    3, 2625],
        [ 776,  124,  301, 2633,    4],
        [   3,   71,    0,   20,   26],
        [ 349,    3,    0,   12,   99],
        [   5,  349,  807,    7, 2382],
        [   0,    5,    0,   38, 5646],
        [7280,    2,  394, 2995,   18],
        [9725,  238, 3514,   12,   50],
        [   9, 1218,    0,   55,   53],
        [ 260,   40,    0, 1846,  332],
        [  16,   53, 1192, 2954,    5],
        [  10,   41,   12,    4,    2],
        [ 131,    3,  239,    9,  800],
        [   6, 2675,   35,  235,   12],
        [ 466,   42,  307,   11,  180],
        [ 207,    2,    6,    0,  327],
        [  19,  975,  301,    2,    8],
        [  38,  133,    0,  130,    0]])


In [20]:
print(batch.text.shape)

torch.Size([20, 5])
