In [13]:
"""
References:
https://pytorch.org/text/stable/datasets.html
https://pytorch.org/text/_modules/torchtext/datasets/text_classification.html

"""
import time

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split

from torchtext.datasets import SogouNews, YelpReviewPolarity
from torchtext.data.functional import to_map_style_dataset
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = get_tokenizer("basic_english")

# Hyperparameters
EPOCHS = 5 # epoch
LR = 5  # learning rate
BATCH_SIZE = 64 # batch size for training


class TextClassificationModel(nn.Module):
  def __init__(self, vocab_size, embed_dim, num_class):
      super(TextClassificationModel, self).__init__()
      self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
      self.fc = nn.Linear(embed_dim, num_class)
      self.init_weights()

  def init_weights(self):
      initrange = 0.5
      self.embedding.weight.data.uniform_(-initrange, initrange)
      self.fc.weight.data.uniform_(-initrange, initrange)
      self.fc.bias.data.zero_()

  def forward(self, text, offsets):
      embedded = self.embedding(text, offsets)
      return self.fc(embedded)

class MainClass:  
  def __init__(self, dataset_class=AG_NEWS):
    self.model = None
    self.dataset_class = dataset_class

    train_iter = self.dataset_class(split = 'train')
    self.vocab = build_vocab_from_iterator(self.yield_tokens(train_iter), specials=["<unk>"])
    self.vocab.set_default_index(self.vocab["<unk>"])

    self.text_pipeline = lambda x: self.vocab(tokenizer(x))
    self.label_pipeline = lambda x: int(x) - 1

  @staticmethod
  def yield_tokens(data_iter):
      for _, text in data_iter:
        yield tokenizer(text)

  def train(self, dataloader, epoch):
      self.model.train()
      total_acc, total_count = 0, 0
      log_interval = 500
      start_time = time.time()

      for idx, (label, text, offsets) in enumerate(dataloader):
          self.optimizer.zero_grad()
          predited_label = self.model(text, offsets)
          loss = self.criterion(predited_label, label)
          loss.backward()
          torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.1)
          self.optimizer.step()
          total_acc += (predited_label.argmax(1) == label).sum().item()
          total_count += label.size(0)
          if idx % log_interval == 0 and idx > 0:
              elapsed = time.time() - start_time
              print('| epoch {:3d} | {:5d}/{:5d} batches '
                    '| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),
                                                total_acc/total_count))
              total_acc, total_count = 0, 0
              start_time = time.time()

  def evaluate(self, dataloader):
      self.model.eval()
      total_acc, total_count = 0, 0

      with torch.no_grad():
          for idx, (label, text, offsets) in enumerate(dataloader):
              predited_label = self.model(text, offsets)
              loss = self.criterion(predited_label, label)
              total_acc += (predited_label.argmax(1) == label).sum().item()
              total_count += label.size(0)
      return total_acc/total_count


  def collate_batch(self, batch):
      label_list, text_list, offsets = [], [], [0]

      for (_label, _text) in batch:
          label_list.append(self.label_pipeline(_label))
          processed_text = torch.tensor(self.text_pipeline(_text), dtype=torch.int64)
          text_list.append(processed_text)
          offsets.append(processed_text.size(0))

      label_list = torch.tensor(label_list, dtype=torch.int64)
      offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
      text_list = torch.cat(text_list)

      return label_list.to(device), text_list.to(device), offsets.to(device)   


  def train_and_validate(self):
    train_iter = self.dataset_class(split='train')
    num_class = len(set([label for (label, text) in train_iter]))
    vocab_size = len(self.vocab)
    emsize = 64
    self.model = TextClassificationModel(vocab_size, emsize, num_class).to(device)
      
    self.criterion = torch.nn.CrossEntropyLoss()
    self.optimizer = torch.optim.SGD(self.model.parameters(), lr=LR)
    scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, 1.0, gamma=0.1)

    total_accu = None
    train_iter, test_iter = self.dataset_class()
    train_dataset = to_map_style_dataset(train_iter)
    test_dataset = to_map_style_dataset(test_iter)
    num_train = int(len(train_dataset) * 0.95)

    split_train_, split_valid_ = \
        random_split(train_dataset, 
                    [num_train, len(train_dataset) - num_train]
                    )

    train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,
                                  shuffle=True, collate_fn=self.collate_batch)
    valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,
                                  shuffle=True, collate_fn=self.collate_batch)
    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                                shuffle=True, collate_fn=self.collate_batch)

    for epoch in range(1, EPOCHS + 1):
        epoch_start_time = time.time()
        self.train(train_dataloader, epoch)
        accu_val = self.evaluate(valid_dataloader)
        if total_accu is not None and total_accu > accu_val:
          scheduler.step()
        else:
          total_accu = accu_val
        print('-' * 59)
        print('| end of epoch {:3d} | time: {:5.2f}s | '
              'valid accuracy {:8.3f} '.format(epoch,
                                              time.time() - epoch_start_time,
                                              accu_val))
        print('-' * 59)

    print('Checking the results of test dataset.')
    accu_test = self.evaluate(test_dataloader)
    print('test accuracy {:8.3f}'.format(accu_test))
  
  def predict(self, text, text_pipeline):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text))
        output = self.model(text, torch.tensor([0]))
        print(output)
        return output.argmax(1).item() + 1

  def predict_ouput(self, input_text, labels_dict):
    self.model = self.model.to("cpu")
    print("This is a %s news" %labels_dict[self.predict(input_text, self.text_pipeline)])

In [16]:
main_class = MainClass(SogouNews)
main_class.train_and_validate()
main_class.predict_ouput(
    input_text = '2008 di4 qi1 jie4 qi1ng da3o guo2 ji4 che1 zha3n me3i nv3 mo2 te4  2008di4 qi1 jie4 qi1ng da3o guo2 ji4 che1 zha3n yu2 15 ri4 za4i qi1ng da3o guo2 ji4 hui4 zha3n zho1ng xi1n she4ng da4 ka1i mu4 . be3n ci4 che1 zha3n jia1ng chi2 xu4 da4o be3n yue4 19 ri4 . ji1n nia2n qi1ng da3o guo2 ji4 che1 zha3n shi4 li4 nia2n da3o che2ng che1 zha3n gui1 mo2 zui4 da4 di2 yi1 ci4 , shi3 yo4ng lia3o qi1ng da3o guo2 ji4 hui4 zha3n zho1ng xi1n di2 qua2n bu4 shi4 ne4i wa4i zha3n gua3n . yi3 xia4 we2i xia4n cha3ng mo2 te4 tu2 pia4n .',
    labels_dict = {
      1: 'Sports',
      2: 'Finance',
      3: 'Entertainment',
      4: 'Automobile',
      5: 'Technology'}
)

sogou_news_csv.tar.gz: 100%|██████████| 384M/384M [00:03<00:00, 110MB/s]


| epoch   1 |   500/ 6680 batches | accuracy    0.824
| epoch   1 |  1000/ 6680 batches | accuracy    0.907
| epoch   1 |  1500/ 6680 batches | accuracy    0.916
| epoch   1 |  2000/ 6680 batches | accuracy    0.920
| epoch   1 |  2500/ 6680 batches | accuracy    0.923
| epoch   1 |  3000/ 6680 batches | accuracy    0.925
| epoch   1 |  3500/ 6680 batches | accuracy    0.926
| epoch   1 |  4000/ 6680 batches | accuracy    0.927
| epoch   1 |  4500/ 6680 batches | accuracy    0.926
| epoch   1 |  5000/ 6680 batches | accuracy    0.928
| epoch   1 |  5500/ 6680 batches | accuracy    0.926
| epoch   1 |  6000/ 6680 batches | accuracy    0.930
| epoch   1 |  6500/ 6680 batches | accuracy    0.928
-----------------------------------------------------------
| end of epoch   1 | time: 175.49s | valid accuracy    0.931 
-----------------------------------------------------------
| epoch   2 |   500/ 6680 batches | accuracy    0.931
| epoch   2 |  1000/ 6680 batches | accuracy    0.927
| epoch 

In [18]:
main_class = MainClass(YelpReviewPolarity)
main_class.train_and_validate()
main_class.predict_ouput(
    input_text = "Unfortunately, the frustration of being Dr. Goldberg's patient is a repeat of the experience I've had with so many other doctors in NYC -- good doctor, terrible staff.  It seems that his staff simply never answers the phone.  It usually takes 2 hours of repeated calling to get an answer.  Who has time for that or wants to deal with it?  I have run into this problem with many other doctors and I just don't get it.  You have office workers, you have patients with medical needs, why isn't anyone answering the phone?  It's incomprehensible and not work the aggravation.  It's with regret that I feel that I have to give Dr. Goldberg 2 stars.",
    labels_dict = {
      1: 'Negative polarity',
      2: 'Positive polarity',
      }
)

| epoch   1 |   500/ 8313 batches | accuracy    0.781
| epoch   1 |  1000/ 8313 batches | accuracy    0.862
| epoch   1 |  1500/ 8313 batches | accuracy    0.877
| epoch   1 |  2000/ 8313 batches | accuracy    0.889
| epoch   1 |  2500/ 8313 batches | accuracy    0.895
| epoch   1 |  3000/ 8313 batches | accuracy    0.897
| epoch   1 |  3500/ 8313 batches | accuracy    0.901
| epoch   1 |  4000/ 8313 batches | accuracy    0.903
| epoch   1 |  4500/ 8313 batches | accuracy    0.902
| epoch   1 |  5000/ 8313 batches | accuracy    0.905
| epoch   1 |  5500/ 8313 batches | accuracy    0.908
| epoch   1 |  6000/ 8313 batches | accuracy    0.908
| epoch   1 |  6500/ 8313 batches | accuracy    0.908
| epoch   1 |  7000/ 8313 batches | accuracy    0.909
| epoch   1 |  7500/ 8313 batches | accuracy    0.912
| epoch   1 |  8000/ 8313 batches | accuracy    0.911
-----------------------------------------------------------
| end of epoch   1 | time: 81.39s | valid accuracy    0.922 
---------------