In [1]:
# ! pip list | grep "torch"
# torch                         1.10.0+cu111
# torchaudio                    0.10.0+cu111
# torchsummary                  1.5.1
# torchtext                     0.11.0
# torchvision                   0.11.1+cu111

https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html
https://tutorials.pytorch.kr/beginner/text_sentiment_ngrams_tutorial.html

# 해당 코드는 위 페이지를 기반으로 만들어진 코드입니다. 
# torchtext가 편한 부분도 많지만, 버전 변화에 따른 코드 변화도 많네요. 앞으로도 버전 변경으로 인한 오류가 생길 가능성이 있습니다.
# 해당 파일의 리팩토링 피드백이 늦는다면, 위 공식 문서를 참고하여 변경 부분을 확인하시거나 위 library 버전을 사용하시길 권장합니다.

# 이 파일에는 from_scratch, recurrent 모델들을 모두 선택해서 학습할 수 있도록 작성 했습니다. 
# 그 과정에서 코드가 복잡해진 점에 대해 미리 사과드립니다.

torch                         1.10.0+cu111
torchaudio                    0.10.0+cu111
torchsummary                  1.5.1
torchtext                     0.11.0
torchvision                   0.11.1+cu111


In [11]:
import os
import time
import torch
from tqdm import tqdm

from torch import nn
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data.dataset import random_split

from torchtext.datasets import IMDB
from torchtext.data.utils import get_tokenizer
from torchtext.data.functional import to_map_style_dataset
from torchtext.vocab import build_vocab_from_iterator, Vectors



config = {'pre_trained' : 'glove', # 'glove','fasttext', None 
          'max_length': 300,
          'batch_size': 64,
          'model_type': 'gru', # 'rnn', 'lstm', 'gru','avg_not_pad', None
          'emb_dim' : 300,
          'hidden_dim':128,
          'is_bidirectional':True,
          'epoch' : 15,
          'LR': 5
          }


In [1]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

# Tokenize & Vocab setup
tokenizer = get_tokenizer('basic_english')
train_iter = IMDB(split='train')
def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), min_freq= 2, specials=["<unk>", "<pad>"])
vocab.set_default_index(vocab["<unk>"]) # This index will be returned when OOV token is queried.

num_class = 2
vocab_size = len(vocab)
idx_pad = vocab.get_stoi()['<pad>']

In [2]:
# 사이즈가 작은 파일들만 가져왔습니다. 다른 모델을 써보고 싶다면, 아래 링크를 참고해서 코드를 변경해서 사용하세요.
# https://pytorch.org/text/stable/_modules/torchtext/vocab/vectors.html#Vectors

if config['pre_trained'] == 'glove':
    pretrained_vectors = Vectors(name = 'glove.6B.300d.txt', 
                                #  cache = '[my_path]',
                                 url = 'http://nlp.stanford.edu/data/glove.6B.zip')
    pretrained_emb = pretrained_vectors.get_vecs_by_tokens(vocab.get_itos(), lower_case_backup=True)

elif config['pre_trained'] == 'fasttext':
    pretrained_vectors = Vectors(name = 'wiki.simple.vec', 
                                #  cache = '[my_path]',
                                 url = 'https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.simple.vec')
    pretrained_emb = pretrained_vectors.get_vecs_by_tokens(vocab.get_itos(), lower_case_backup=True)
else:
    pass
    

.vector_cache/glove.6B.zip: 862MB [02:40, 5.36MB/s]                           
100%|█████████▉| 399999/400000 [00:56<00:00, 7128.51it/s]


In [39]:
config = {'pre_trained' : 'glove', # 'glove','fasttext', None 
          'max_length': 300,
          'batch_size': 64,
          'model_type': 'lstm', # 'rnn', 'lstm', 'gru','avg_not_pad', None
          'emb_dim' : 300,
          'hidden_dim':128,
          'is_bidirectional':True,
          'epoch' : 10,
          'LR': 5
          }

In [12]:
# DataLoader Setup
text_pipeline = lambda x: vocab(tokenizer(x))[:config['max_length']]
label_pipeline = lambda x: {"neg":0, "pos":1}.get(x)

def collate_batch(batch):
    label_list, text_list = [], [] 
    for (_label, _text) in batch:
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        label_list.append(label_pipeline(_label))
    
    text_list = pad_sequence(text_list, batch_first= True, padding_value= idx_pad)
    label_list = torch.tensor(label_list, dtype=torch.int64)
    return text_list.to(device), label_list.to(device)

train_iter, test_iter = IMDB()
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)

num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])

train_dataloader = DataLoader(split_train_, batch_size=config['batch_size'], shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=config['batch_size'], shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True, collate_fn=collate_batch)


In [40]:
# Model setup
class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, num_class, **config):
        super(TextClassificationModel, self).__init__()
        self.model_type = config['model_type']
        self.pretrained = config['pre_trained']
        self.is_bidirectional = config['is_bidirectional']
        self.embed_dim = config['emb_dim']
        self.hidden_dim = config['hidden_dim']
        
        self.embedding = nn.Embedding(vocab_size, self.embed_dim,)
        if self.pretrained:
            self.embedding = nn.Embedding(vocab_size, self.embed_dim,).from_pretrained(pretrained_emb, freeze = False)

        if self.model_type is None:
            self.fc = nn.Linear(self.embed_dim, num_class)

        elif self.model_type == 'avg_not_pad':
            self.embedding = nn.EmbeddingBag(vocab_size, self.embed_dim, sparse=True, padding_idx = idx_pad)
            if self.pretrained:
                self.embedding = self.embedding.from_pretrained(pretrained_emb, freeze = False, sparse=True)
            self.fc = nn.Linear(self.embed_dim, num_class)

        elif self.model_type in ['rnn','lstm','gru']:
            if self.model_type == 'rnn':
                self.Recurrent = nn.RNN(input_size = self.embed_dim, hidden_size = self.hidden_dim, 
                                        bidirectional = self.is_bidirectional, batch_first = True)
            elif self.model_type == 'lstm':
                self.Recurrent = nn.LSTM(input_size = self.embed_dim, hidden_size = self.hidden_dim, 
                                         bidirectional = self.is_bidirectional, batch_first = True)
            else:
                self.Recurrent = nn.GRU(input_size = self.embed_dim, hidden_size = self.hidden_dim, 
                                        bidirectional = self.is_bidirectional, batch_first = True)

            last_input_dim = self.hidden_dim * 2 if self.is_bidirectional else self.hidden_dim 
            self.fc = nn.Linear(last_input_dim, num_class)

        else:
            raise NameError('Select model_type in [rnn, lstm, gru, avg_not_pad]')

        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        if self.pretrained:
            self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text):
        embedded = self.embedding(text)
        if self.model_type is None:
            embedded = torch.mean(embedded, dim=1)
            return self.fc(embedded)
        elif self.model_type == 'avg_not_pad':
            return self.fc(embedded)
        else:
            output, _ = self.Recurrent(embedded)
            last_output = output[:,-1,:]
            return self.fc(last_output)


model = TextClassificationModel(vocab_size, num_class, **config).to(device)
model

TextClassificationModel(
  (embedding): Embedding(51718, 300)
  (Recurrent): LSTM(300, 128, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=256, out_features=2, bias=True)
)

In [41]:
# Training Setup
# Hyperparameters
EPOCHS = config['epoch']
LR = config['LR']  

total_accu = None

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()

    model.train()
    total_acc, total_count = 0, 0
    log_interval = 50
    start_time = time.time()

    for idx, (text, label) in tqdm(enumerate(train_dataloader)):
        # Training
        optimizer.zero_grad()
        predicted_label = model(text)
        loss = criterion(predicted_label, label)
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print(f'| epoch {epoch:3d} | {idx:5d}/{len(train_dataloader):5d} batches | accuracy {total_acc/total_count:8.3f}')                             
            total_acc, total_count = 0, 0
            start_time = time.time()
        
    # Evaluation
    model.eval()
    total_acc, total_count = 0, 0
    with torch.no_grad():
        for idx, (text, label) in enumerate(valid_dataloader):
            predicted_label = model(text)
            loss = criterion(predicted_label, label)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    accu_val =  total_acc/total_count  
     
    if total_accu is not None and total_accu > accu_val:
        scheduler.step()
    else:
        total_accu = accu_val
    print('-' * 59)
    print(f'| end of epoch {epoch:3d} | time: {time.time() - epoch_start_time:5.2f}s | valid accuracy {accu_val:8.3f}')
    print('-' * 59)

52it [00:05, 10.12it/s]

| epoch   1 |    50/  372 batches | accuracy    0.493


102it [00:10, 10.07it/s]

| epoch   1 |   100/  372 batches | accuracy    0.500


152it [00:15, 10.05it/s]

| epoch   1 |   150/  372 batches | accuracy    0.503


202it [00:20, 10.10it/s]

| epoch   1 |   200/  372 batches | accuracy    0.503


253it [00:25,  9.93it/s]

| epoch   1 |   250/  372 batches | accuracy    0.505


303it [00:30, 10.15it/s]

| epoch   1 |   300/  372 batches | accuracy    0.500


353it [00:35, 10.07it/s]

| epoch   1 |   350/  372 batches | accuracy    0.510


372it [00:37,  9.91it/s]


-----------------------------------------------------------
| end of epoch   1 | time: 38.28s | valid accuracy    0.486
-----------------------------------------------------------


52it [00:05, 10.10it/s]

| epoch   2 |    50/  372 batches | accuracy    0.495


102it [00:10,  9.97it/s]

| epoch   2 |   100/  372 batches | accuracy    0.523


151it [00:14, 10.07it/s]

| epoch   2 |   150/  372 batches | accuracy    0.498


202it [00:20,  9.76it/s]

| epoch   2 |   200/  372 batches | accuracy    0.517


253it [00:25,  9.98it/s]

| epoch   2 |   250/  372 batches | accuracy    0.503


303it [00:30, 10.00it/s]

| epoch   2 |   300/  372 batches | accuracy    0.516


353it [00:35, 10.01it/s]

| epoch   2 |   350/  372 batches | accuracy    0.519


372it [00:37,  9.96it/s]


-----------------------------------------------------------
| end of epoch   2 | time: 38.11s | valid accuracy    0.498
-----------------------------------------------------------


53it [00:05, 10.14it/s]

| epoch   3 |    50/  372 batches | accuracy    0.527


102it [00:10, 10.01it/s]

| epoch   3 |   100/  372 batches | accuracy    0.516


152it [00:15,  9.65it/s]

| epoch   3 |   150/  372 batches | accuracy    0.515


202it [00:20,  9.48it/s]

| epoch   3 |   200/  372 batches | accuracy    0.522


253it [00:25, 10.02it/s]

| epoch   3 |   250/  372 batches | accuracy    0.509


302it [00:30,  9.70it/s]

| epoch   3 |   300/  372 batches | accuracy    0.515


352it [00:35,  9.81it/s]

| epoch   3 |   350/  372 batches | accuracy    0.526


372it [00:37,  9.85it/s]


-----------------------------------------------------------
| end of epoch   3 | time: 38.53s | valid accuracy    0.514
-----------------------------------------------------------


52it [00:05,  9.81it/s]

| epoch   4 |    50/  372 batches | accuracy    0.537


102it [00:10,  9.69it/s]

| epoch   4 |   100/  372 batches | accuracy    0.538


152it [00:15,  9.31it/s]

| epoch   4 |   150/  372 batches | accuracy    0.526


202it [00:20,  9.61it/s]

| epoch   4 |   200/  372 batches | accuracy    0.538


252it [00:26,  9.74it/s]

| epoch   4 |   250/  372 batches | accuracy    0.551


303it [00:31,  9.82it/s]

| epoch   4 |   300/  372 batches | accuracy    0.648


352it [00:36,  9.55it/s]

| epoch   4 |   350/  372 batches | accuracy    0.723


372it [00:38,  9.68it/s]


-----------------------------------------------------------
| end of epoch   4 | time: 39.19s | valid accuracy    0.738
-----------------------------------------------------------


53it [00:05,  9.92it/s]

| epoch   5 |    50/  372 batches | accuracy    0.776


102it [00:10,  9.81it/s]

| epoch   5 |   100/  372 batches | accuracy    0.803


152it [00:15,  9.50it/s]

| epoch   5 |   150/  372 batches | accuracy    0.810


202it [00:20,  9.75it/s]

| epoch   5 |   200/  372 batches | accuracy    0.814


253it [00:25,  9.97it/s]

| epoch   5 |   250/  372 batches | accuracy    0.821


302it [00:31,  9.55it/s]

| epoch   5 |   300/  372 batches | accuracy    0.817


351it [00:35, 10.11it/s]

| epoch   5 |   350/  372 batches | accuracy    0.832


372it [00:38,  9.77it/s]


-----------------------------------------------------------
| end of epoch   5 | time: 38.81s | valid accuracy    0.810
-----------------------------------------------------------


52it [00:05,  9.70it/s]

| epoch   6 |    50/  372 batches | accuracy    0.855


103it [00:10, 10.12it/s]

| epoch   6 |   100/  372 batches | accuracy    0.853


152it [00:15,  9.70it/s]

| epoch   6 |   150/  372 batches | accuracy    0.857


202it [00:20,  9.82it/s]

| epoch   6 |   200/  372 batches | accuracy    0.839


252it [00:25,  9.87it/s]

| epoch   6 |   250/  372 batches | accuracy    0.857


302it [00:30,  9.64it/s]

| epoch   6 |   300/  372 batches | accuracy    0.845


352it [00:36,  9.68it/s]

| epoch   6 |   350/  372 batches | accuracy    0.858


372it [00:38,  9.77it/s]


-----------------------------------------------------------
| end of epoch   6 | time: 38.83s | valid accuracy    0.822
-----------------------------------------------------------


52it [00:05,  9.57it/s]

| epoch   7 |    50/  372 batches | accuracy    0.878


102it [00:10,  9.74it/s]

| epoch   7 |   100/  372 batches | accuracy    0.878


152it [00:15,  9.82it/s]

| epoch   7 |   150/  372 batches | accuracy    0.869


202it [00:20,  9.58it/s]

| epoch   7 |   200/  372 batches | accuracy    0.873


252it [00:25,  9.68it/s]

| epoch   7 |   250/  372 batches | accuracy    0.883


302it [00:31,  9.36it/s]

| epoch   7 |   300/  372 batches | accuracy    0.882


352it [00:36,  9.53it/s]

| epoch   7 |   350/  372 batches | accuracy    0.873


372it [00:38,  9.70it/s]


-----------------------------------------------------------
| end of epoch   7 | time: 39.10s | valid accuracy    0.850
-----------------------------------------------------------


52it [00:05,  9.84it/s]

| epoch   8 |    50/  372 batches | accuracy    0.901


102it [00:10,  9.62it/s]

| epoch   8 |   100/  372 batches | accuracy    0.892


152it [00:15,  9.83it/s]

| epoch   8 |   150/  372 batches | accuracy    0.898


202it [00:20,  9.63it/s]

| epoch   8 |   200/  372 batches | accuracy    0.892


252it [00:25,  9.92it/s]

| epoch   8 |   250/  372 batches | accuracy    0.904


303it [00:31, 10.18it/s]

| epoch   8 |   300/  372 batches | accuracy    0.893


353it [00:36, 10.18it/s]

| epoch   8 |   350/  372 batches | accuracy    0.897


372it [00:37,  9.80it/s]


-----------------------------------------------------------
| end of epoch   8 | time: 38.70s | valid accuracy    0.810
-----------------------------------------------------------


52it [00:05,  9.78it/s]

| epoch   9 |    50/  372 batches | accuracy    0.926


102it [00:10,  9.66it/s]

| epoch   9 |   100/  372 batches | accuracy    0.927


152it [00:15,  9.81it/s]

| epoch   9 |   150/  372 batches | accuracy    0.932


203it [00:20, 10.08it/s]

| epoch   9 |   200/  372 batches | accuracy    0.924


252it [00:25,  9.86it/s]

| epoch   9 |   250/  372 batches | accuracy    0.934


303it [00:31,  9.93it/s]

| epoch   9 |   300/  372 batches | accuracy    0.937


352it [00:35,  9.70it/s]

| epoch   9 |   350/  372 batches | accuracy    0.935


372it [00:37,  9.79it/s]


-----------------------------------------------------------
| end of epoch   9 | time: 38.73s | valid accuracy    0.853
-----------------------------------------------------------


52it [00:05,  9.84it/s]

| epoch  10 |    50/  372 batches | accuracy    0.945


102it [00:10,  9.41it/s]

| epoch  10 |   100/  372 batches | accuracy    0.935


152it [00:15,  9.63it/s]

| epoch  10 |   150/  372 batches | accuracy    0.933


202it [00:20,  9.75it/s]

| epoch  10 |   200/  372 batches | accuracy    0.941


252it [00:25,  9.73it/s]

| epoch  10 |   250/  372 batches | accuracy    0.938


302it [00:31,  9.78it/s]

| epoch  10 |   300/  372 batches | accuracy    0.939


352it [00:36,  9.74it/s]

| epoch  10 |   350/  372 batches | accuracy    0.939


372it [00:38,  9.74it/s]


-----------------------------------------------------------
| end of epoch  10 | time: 38.95s | valid accuracy    0.857
-----------------------------------------------------------


52it [00:05,  9.53it/s]

| epoch  11 |    50/  372 batches | accuracy    0.941


102it [00:10,  9.73it/s]

| epoch  11 |   100/  372 batches | accuracy    0.940


152it [00:15,  9.61it/s]

| epoch  11 |   150/  372 batches | accuracy    0.944


202it [00:20,  9.79it/s]

| epoch  11 |   200/  372 batches | accuracy    0.948


252it [00:26,  9.87it/s]

| epoch  11 |   250/  372 batches | accuracy    0.935


302it [00:31,  9.54it/s]

| epoch  11 |   300/  372 batches | accuracy    0.943


352it [00:36,  9.55it/s]

| epoch  11 |   350/  372 batches | accuracy    0.947


372it [00:38,  9.64it/s]


-----------------------------------------------------------
| end of epoch  11 | time: 39.34s | valid accuracy    0.853
-----------------------------------------------------------


52it [00:05,  9.48it/s]

| epoch  12 |    50/  372 batches | accuracy    0.949


102it [00:10,  9.60it/s]

| epoch  12 |   100/  372 batches | accuracy    0.948


152it [00:15,  9.57it/s]

| epoch  12 |   150/  372 batches | accuracy    0.948


202it [00:20,  9.45it/s]

| epoch  12 |   200/  372 batches | accuracy    0.944


252it [00:26,  9.65it/s]

| epoch  12 |   250/  372 batches | accuracy    0.952


302it [00:31, 10.11it/s]

| epoch  12 |   300/  372 batches | accuracy    0.949


353it [00:36, 10.05it/s]

| epoch  12 |   350/  372 batches | accuracy    0.947


372it [00:38,  9.76it/s]


-----------------------------------------------------------
| end of epoch  12 | time: 38.83s | valid accuracy    0.854
-----------------------------------------------------------


51it [00:05, 10.17it/s]

| epoch  13 |    50/  372 batches | accuracy    0.942


103it [00:10, 10.08it/s]

| epoch  13 |   100/  372 batches | accuracy    0.949


151it [00:14, 10.04it/s]

| epoch  13 |   150/  372 batches | accuracy    0.952


202it [00:20,  9.68it/s]

| epoch  13 |   200/  372 batches | accuracy    0.943


251it [00:25, 10.03it/s]

| epoch  13 |   250/  372 batches | accuracy    0.955


302it [00:30,  9.61it/s]

| epoch  13 |   300/  372 batches | accuracy    0.957


352it [00:35,  9.82it/s]

| epoch  13 |   350/  372 batches | accuracy    0.945


372it [00:37,  9.96it/s]


-----------------------------------------------------------
| end of epoch  13 | time: 38.08s | valid accuracy    0.854
-----------------------------------------------------------


51it [00:05, 10.12it/s]

| epoch  14 |    50/  372 batches | accuracy    0.954


103it [00:10, 10.14it/s]

| epoch  14 |   100/  372 batches | accuracy    0.953


152it [00:15, 10.14it/s]

| epoch  14 |   150/  372 batches | accuracy    0.945


203it [00:20, 10.16it/s]

| epoch  14 |   200/  372 batches | accuracy    0.951


251it [00:24, 10.21it/s]

| epoch  14 |   250/  372 batches | accuracy    0.944


303it [00:30, 10.24it/s]

| epoch  14 |   300/  372 batches | accuracy    0.950


353it [00:35, 10.16it/s]

| epoch  14 |   350/  372 batches | accuracy    0.948


372it [00:36, 10.10it/s]


-----------------------------------------------------------
| end of epoch  14 | time: 37.58s | valid accuracy    0.854
-----------------------------------------------------------


53it [00:05, 10.11it/s]

| epoch  15 |    50/  372 batches | accuracy    0.950


102it [00:10,  9.56it/s]

| epoch  15 |   100/  372 batches | accuracy    0.958


153it [00:15,  9.95it/s]

| epoch  15 |   150/  372 batches | accuracy    0.943


202it [00:20,  9.64it/s]

| epoch  15 |   200/  372 batches | accuracy    0.944


251it [00:25, 10.10it/s]

| epoch  15 |   250/  372 batches | accuracy    0.948


303it [00:30,  9.78it/s]

| epoch  15 |   300/  372 batches | accuracy    0.944


353it [00:35, 10.06it/s]

| epoch  15 |   350/  372 batches | accuracy    0.961


372it [00:37,  9.94it/s]


-----------------------------------------------------------
| end of epoch  15 | time: 38.16s | valid accuracy    0.854
-----------------------------------------------------------


In [42]:
print('Checking the results of test dataset.')
model.eval()
total_acc, total_count = 0, 0
with torch.no_grad():
    for idx, (text, label) in enumerate(test_dataloader):
        predicted_label = model(text, )
        loss = criterion(predicted_label, label)
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
accu_test =  total_acc/total_count
print('test accuracy {:8.3f}'.format(accu_test))

Checking the results of test dataset.
test accuracy    0.853


In [43]:
def predict(text, text_pipeline):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text))
        text = pad_sequence([text], batch_first=True, padding_value=idx_pad)
        output = model(text)
        return output.argmax(1).item()

ex_text_str = "It was very bad movie"

model = model.to("cpu")
label_dict = {0:'neg', 1:'pos'}
print(f"This is a {label_dict.get(predict(ex_text_str, text_pipeline))} comment")

This is a neg comment
