<a href="https://colab.research.google.com/github/NLP-END3/Session5/blob/main/Session5_END_SogouNews.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torchtext.datasets import SogouNews

In [None]:
help(SogouNews)

Help on function SogouNews in module torchtext.datasets.sogounews:

SogouNews(root='.data', split=('train', 'test'))
    SogouNews dataset
    
    Separately returns the train/test split
    
    Number of lines per split:
        train: 450000
    
        test: 60000
    
    
    Number of classes
        5
    
    
    Args:
        root: Directory where the datasets are saved.
            Default: .data
        split: split or splits to be returned. Can be a string or tuple of strings.
            Default: ('train', 'test')



In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
train_iter = SogouNews(split='train')

sogou_news_csv.tar.gz: 100%|██████████| 384M/384M [00:03<00:00, 103MB/s]


In [None]:
type(train_iter), len(train_iter), next(iter(train_iter))

(torchtext.data.datasets_utils._RawTextIterableDataset,
 450000,
 (4,
  '2008 di4 qi1 jie4 qi1ng da3o guo2 ji4 che1 zha3n me3i nv3 mo2 te4  2008di4 qi1 jie4 qi1ng da3o guo2 ji4 che1 zha3n yu2 15 ri4 za4i qi1ng da3o guo2 ji4 hui4 zha3n zho1ng xi1n she4ng da4 ka1i mu4 . be3n ci4 che1 zha3n jia1ng chi2 xu4 da4o be3n yue4 19 ri4 . ji1n nia2n qi1ng da3o guo2 ji4 che1 zha3n shi4 li4 nia2n da3o che2ng che1 zha3n gui1 mo2 zui4 da4 di2 yi1 ci4 , shi3 yo4ng lia3o qi1ng da3o guo2 ji4 hui4 zha3n zho1ng xi1n di2 qua2n bu4 shi4 ne4i wa4i zha3n gua3n . yi3 xia4 we2i xia4n cha3ng mo2 te4 tu2 pia4n .'))

In [None]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

In [None]:
train_iter = SogouNews(split='train')
tokenizer = get_tokenizer('basic_english')
def yield_tokens(dataiter):
  for (label,text) in dataiter:
    yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])

In [None]:
len(vocab)

369441

In [None]:
text_pipeline = lambda x : vocab(tokenizer(x))
label_pipeline = lambda x : int(x)-1

In [None]:
text_pipeline("He is a country called India which is Big")

[2361, 1481, 599, 7008, 10892, 7715, 2540, 1481, 3343]

In [None]:
def collate_batch(batch):
  label_list, text_list, offsets = [],[],[0]
  for (_label,_text) in batch:
    label_list.append(label_pipeline(_label))
    processed_text = torch.tensor(text_pipeline(_text),dtype=torch.int64)
    text_list.append(processed_text)
    offsets.append(processed_text.size(0))
  label_list = torch.tensor(label_list,dtype=torch.int64)
  text_list = torch.cat(text_list)
  offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
  return label_list.to(device), text_list.to(device), offsets.to(device)

In [None]:
from torch.utils.data import DataLoader
dataloader = DataLoader(train_iter, batch_size = 8, shuffle=False, collate_fn=collate_batch)

In [None]:
from torch import nn

class TextClassification(nn.Module):
  def __init__(self, vocab_size, embed_size, num_classes):
    super(TextClassification,self).__init__()
    self.embed = nn.EmbeddingBag(vocab_size, embed_size, sparse=True)
    self.fc = nn.Linear(embed_size, num_classes)
    self.init_weights()
  def init_weights(self):
    init_range = 0.5
    self.embed.weight.data.uniform_(-init_range,init_range)
    self.fc.weight.data.uniform_(-init_range,init_range)
    self.fc.bias.data.zero_()
  def forward(self, text, offsets):
    embedded = self.embed(text, offsets)
    return self.fc(embedded)

In [None]:
train_iter = SogouNews(split='train')
vocab_size = len(vocab)
embed_size = 64
num_classes = len(set([label for (label,text) in train_iter]))
model = TextClassification(vocab_size, embed_size, num_classes)
model.to(device)

TextClassification(
  (embed): EmbeddingBag(369441, 64, mode=mean)
  (fc): Linear(in_features=64, out_features=5, bias=True)
)

In [None]:
vocab_size, embed_size, num_classes

(369441, 64, 5)

In [None]:
def train(dataloader):
  model.train()
  log_interval = 500
  total_acc, total_count = 0,0
  for idx,(label,text,offsets) in enumerate(dataloader):
    optimizer.zero_grad()
    predicted = model(text, offsets)
    loss = criterion(predicted, label)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) # disuccees
    optimizer.step()
    total_acc += (predicted.argmax(1) == label).sum().item()
    total_count += label.size(0)
    if idx%log_interval == 0 and idx > 0:
      print('| epochs {:3d} | {:5d}/{:5d} batches | accuracy{:8.3f}'.format(epoch, idx, len(dataloader), total_acc/total_count))
      total_acc, total_count = 0,0

In [None]:
def eval(dataloader):
  model.eval()
  total_acc, total_count = 0,0
  with torch.no_grad():
    for idx, (label,text,offsets) in enumerate(dataloader):
      preds = model(text, offsets)
      loss = criterion(preds, label)
      total_acc += (preds.argmax(1) == label).sum().item()
      total_count += label.size(0)
  return total_acc/total_count

In [None]:
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
EPOCHS = 5 #10
BATCH_SIZE=64
LR = 5

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None
train_iter, test_iter = SogouNews()
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)
num_train = int(len(train_dataset)*0.95)

split_train, split_valid = random_split(train_dataset, [num_train, len(train_dataset)-num_train])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
valid_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle = True, collate_fn=collate_batch)

for epoch in range(1,EPOCHS+1):
  train(train_loader)
  accu_val = eval(valid_loader)
  if total_accu is not None and total_accu > accu_val:
    scheduler.step()
  else:
      total_accu = accu_val
  print('-' * 59)
  print('| end of epoch {:3d} | '
        'valid accuracy {:8.3f} '.format(epoch, accu_val))
  print('-' * 59)  

| epochs   1 |   500/ 7032 batches | accuracy   0.934
| epochs   1 |  1000/ 7032 batches | accuracy   0.935
| epochs   1 |  1500/ 7032 batches | accuracy   0.930
| epochs   1 |  2000/ 7032 batches | accuracy   0.934
| epochs   1 |  2500/ 7032 batches | accuracy   0.935
| epochs   1 |  3000/ 7032 batches | accuracy   0.935
| epochs   1 |  3500/ 7032 batches | accuracy   0.935
| epochs   1 |  4000/ 7032 batches | accuracy   0.935
| epochs   1 |  4500/ 7032 batches | accuracy   0.934
| epochs   1 |  5000/ 7032 batches | accuracy   0.933
| epochs   1 |  5500/ 7032 batches | accuracy   0.934
| epochs   1 |  6000/ 7032 batches | accuracy   0.935
| epochs   1 |  6500/ 7032 batches | accuracy   0.936
| epochs   1 |  7000/ 7032 batches | accuracy   0.933
-----------------------------------------------------------
| end of epoch   1 | valid accuracy    0.936 
-----------------------------------------------------------
| epochs   2 |   500/ 7032 batches | accuracy   0.935
| epochs   2 |  1000/ 70

In [None]:
print('Checking the results of test dataset.')
accu_test = eval(test_loader)
print('test accuracy {:8.3f}'.format(accu_test))

Checking the results of test dataset.
test accuracy    0.936
