In [8]:
import torch
from torchtext.datasets import SogouNews
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.utils import ngrams_iterator
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset

In [2]:
tokenizer = get_tokenizer('basic_english')
train_iter = SogouNews(split='train')

def yield_tokens(data_iter):
    for _, text in data_iter:
        tokens = tokenizer(text)
        yield list(ngrams_iterator(tokens, 2))

vocab = build_vocab_from_iterator(yield_tokens(train_iter), min_freq=3, specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

text_pipeline = lambda x: vocab(list(ngrams_iterator(tokenizer(x), 2)))
label_pipeline = lambda x: int(x) - 1
print('vocab_size: ', len(vocab))

vocab_size:  1197144


In [3]:
from torch.utils.data import DataLoader

def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    for (_label, _text) in batch:
        label_list.append(label_pipeline(_label))  # '1', '2', '3', '4' -> [0, 1, 2, 3]
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)  # [475, 21, 30, 5297]
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))
    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)  # input의 누적 합계를 반환
    text_list = torch.cat(text_list)  # batch 내의 모든 단어가 일렬로 들어감 -> nn.Embedding 에 들어가기 위해 하나로 합쳐짐

    return label_list, text_list, offsets

In [4]:
from torch import nn

class FastText(nn.Module):
    def __init__(self, vocab_size, embedding_size, num_class, dropout_p):
        super(FastText, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embedding_size, sparse=True)
        nn.init.normal_(self.embedding.weight, mean=0.0, std=0.5)
        self.dropout = nn.Dropout(dropout_p)
        self.linear = nn.Linear(embedding_size, num_class, bias=True)
        nn.init.normal_(self.linear.weight, mean=0.0, std=0.5)
        self.linear.bias.data.zero_()

        
    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        embedded = self.dropout(embedded)
        return self.linear(embedded)

In [5]:
import time
from torch.nn.utils import clip_grad_norm_

log_interval = 1000

def train(model, dataloader, criterion, optimizer, scheduler, clip):
    model.train()
    acc, count = 0, 0
    s_time = time.time()
    for idx, (label, text, offsets) in enumerate(dataloader):
        optimizer.zero_grad()
        predicted_label = model(text, offsets)  # |predicted_label| = (batch, num_classes)
        loss = criterion(predicted_label, label)

        loss.backward()
        clip_grad_norm_(model.parameters(), clip, norm_type=2)
        optimizer.step()
        scheduler.step()

        acc += (predicted_label.argmax(1) == label).sum().item()  # 같으면 1 -> 쭉 더함 
        count += label.size(0)  # batch 때문에 size(0)으로 카운트 셈

        if idx % log_interval == 0 and idx > 0:
            elasped = (time.time() - s_time)
            print('accuracy: {}, time: {}[s]'.format(acc/count, int(elasped)))
            s_time = time.time()   
    return acc/count

def evaluate(model, dataloader):
    model.eval()
    v_total_acc, v_total_count = 0, 0

    with torch.no_grad():
        for (v_label, v_text, v_offsets) in dataloader:
            v_predicted_label = model(v_text, v_offsets)
            v_total_acc += (v_predicted_label.argmax(1) == v_label).sum().item()
            v_total_count += v_label.size(0)

    return v_total_acc/v_total_count


from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset

In [6]:
train_iter = SogouNews(split='train')
num_class = len(set([label for (label, text) in train_iter]))
print('num_class: ', num_class)

save_dir = './saved_model/sgnews1'
if not os.path.exists(save_dir):
    os.mkdir(save_dir)
save_path = os.path.join(save_dir, 'ckpt.pth')

num_class:  5


In [9]:
vocab_size = len(vocab)
embedding_size = 300
dropout_p = 0.2
model = FastText(vocab_size, embedding_size, num_class, dropout_p)
max_epoch = 5
lr = 0.2
lr_decay = 0.99
step_size = 1000
batch_size = 64
clip = 3

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adagrad(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=lr_decay)

train_iter, test_iter = SogouNews()
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)
num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])

train_dataloader = DataLoader(split_train_, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)

In [10]:
def run_train():
  best_accu = 0

  for epoch in range(0, max_epoch):
    print('-' * 10 + 'epoch: {}/{}'.format(epoch+1, max_epoch))
    epoch_start_time = time.time()
    total_acc = train(model, train_dataloader, criterion, optimizer, scheduler, clip)
    accu_val = evaluate(model, valid_dataloader)
    if best_accu < accu_val:
      best_accu = accu_val
      torch.save({'epoch': epoch + 1, 'model_state_dict': model.state_dict()}, save_path)
      print('-' * 59)
      print('epoch {:3d} | time: {:5.2f}s | valid accuracy: {:8.3f} '.format(epoch+1, time.time() - epoch_start_time, accu_val))
      print('-' * 59)

run_train()

----------epoch: 1/5
accuracy: 0.9490041208791209, time: 296[s]
accuracy: 0.9559595202398801, time: 294[s]
accuracy: 0.958763745418194, time: 305[s]
accuracy: 0.9611815796050988, time: 316[s]
accuracy: 0.9625668616276745, time: 294[s]
accuracy: 0.9637091734710882, time: 282[s]
-----------------------------------------------------------
epoch   1 | time: 1999.80s | valid accuracy:    0.969 
-----------------------------------------------------------
----------epoch: 2/5
accuracy: 0.9899007242757243, time: 296[s]
accuracy: 0.9902704897551224, time: 292[s]
accuracy: 0.9900501916027991, time: 284[s]
accuracy: 0.989967351912022, time: 287[s]
accuracy: 0.9898520295940811, time: 282[s]
accuracy: 0.9898766872187968, time: 285[s]
-----------------------------------------------------------
epoch   2 | time: 1956.00s | valid accuracy:    0.969 
-----------------------------------------------------------
----------epoch: 3/5
accuracy: 0.9958947302697303, time: 289[s]
accuracy: 0.9956818465767117, 

KeyboardInterrupt: 

In [11]:
accuracy = evaluate(model, test_dataloader)
print("Accuracy: ", accuracy)

Accuracy:  0.97


In [12]:
save_dir = './saved_model/sgnews1'
load_path = save_dir + '/ckpt.pth'
checkpoint = torch.load(load_path)

model.load_state_dict(checkpoint['model_state_dict'])
epoch = checkpoint['epoch']
print('epoch: ', epoch)
accuracy = evaluate(model, test_dataloader)

print("Accuracy: ", accuracy)

epoch:  2
Accuracy:  0.9698
