In [17]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from utils.utils import getDevice
from datasets.newsGroupDataset import NewsGroupDataset
import numpy as np
from tqdm import tqdm
import spacy
import string
import re
import torch
from torch.utils.data import DataLoader
from model.textClassificationModel import TextClassificationModel
import time
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
from torch import nn



In [18]:




device = getDevice()

data_iter = NewsGroupDataset(subset='train')
# print(data_iter[0])

#tokenization
en = spacy.load('en_core_web_trf')

def tokenizer(text):
    text = re.sub(r"[^\x00-\x7F]+", " ", text)
    regex = re.compile('[' + re.escape(string.punctuation) + '0-9\\r\\t\\n]') # remove punctuation and numbers
    nopunct = regex.sub(" ", text.lower())
    return [token.text for token in en.tokenizer(nopunct)]

def yield_tokens(data_iter):
    for _, text in tqdm(data_iter):
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(data_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

print(vocab.__getitem__('i'.lower()))

100%|██████████| 11314/11314 [00:24<00:00, 455.15it/s]


10


In [19]:
text_pipeline = lambda x: vocab(tokenizer(x))

print(text_pipeline("enhanced"))

[4807]


In [20]:
def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    for (_label, _text) in batch:
         label_list.append(_label)
         processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
         text_list.append(processed_text)
         offsets.append(processed_text.size(0))
    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)
    return label_list.to(device), text_list.to(device), offsets.to(device)

train_iter = NewsGroupDataset(subset='train')
dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)

# a = 0
# for label, text, offset in tqdm(dataloader):
#     a = a+1

In [21]:
num_class = len(list(train_iter.target_names))
vocab_size = len(vocab)
emsize = 300
model = TextClassificationModel(vocab_size, emsize, num_class).to(device)
print(vocab_size)

# Hyperparameters
EPOCHS = 20 # epoch
LR = 5  # learning rate
BATCH_SIZE = 32 # batch size for training

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)

89300


In [22]:

def train(dataloader):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 500
    start_time = time.time()

    for idx, (label, text, offsets) in enumerate(tqdm(dataloader)):
        optimizer.zero_grad()
        predicted_label = model(text, offsets)
        loss = criterion(predicted_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches '
                  '| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),
                                              total_acc/total_count))
            total_acc, total_count = 0, 0
            start_time = time.time()

def evaluate(dataloader):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(tqdm(dataloader)):
            predicted_label = model(text, offsets)
            loss = criterion(predicted_label, label)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc/total_count

In [23]:
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None
train_dataset = NewsGroupDataset(subset='train')
test_dataset = NewsGroupDataset(subset='test')
num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = \
    random_split(train_dataset, [num_train, len(train_dataset) - num_train])

train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                             shuffle=True, collate_fn=collate_batch)

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader)
    accu_val = evaluate(valid_dataloader)
    if total_accu is not None and total_accu > accu_val:
      scheduler.step()
    else:
       total_accu = accu_val
    print('-' * 59)
    print('| end of epoch {:3d} | time: {:5.2f}s | '
          'valid accuracy {:8.3f} '.format(epoch,
                                           time.time() - epoch_start_time,
                                           accu_val))
    print('-' * 59)

100%|██████████| 336/336 [00:22<00:00, 14.97it/s]
100%|██████████| 18/18 [00:01<00:00, 15.27it/s]


-----------------------------------------------------------
| end of epoch   1 | time: 23.63s | valid accuracy    0.313 
-----------------------------------------------------------


100%|██████████| 336/336 [00:21<00:00, 15.61it/s]
100%|██████████| 18/18 [00:01<00:00, 15.71it/s]


-----------------------------------------------------------
| end of epoch   2 | time: 22.67s | valid accuracy    0.445 
-----------------------------------------------------------


100%|██████████| 336/336 [00:22<00:00, 15.19it/s]
100%|██████████| 18/18 [00:01<00:00, 15.80it/s]


-----------------------------------------------------------
| end of epoch   3 | time: 23.26s | valid accuracy    0.585 
-----------------------------------------------------------


100%|██████████| 336/336 [00:21<00:00, 15.58it/s]
100%|██████████| 18/18 [00:01<00:00, 15.83it/s]


-----------------------------------------------------------
| end of epoch   4 | time: 22.71s | valid accuracy    0.654 
-----------------------------------------------------------


100%|██████████| 336/336 [00:21<00:00, 15.68it/s]
100%|██████████| 18/18 [00:01<00:00, 15.87it/s]


-----------------------------------------------------------
| end of epoch   5 | time: 22.56s | valid accuracy    0.716 
-----------------------------------------------------------


100%|██████████| 336/336 [00:21<00:00, 15.40it/s]
100%|██████████| 18/18 [00:01<00:00, 15.63it/s]


-----------------------------------------------------------
| end of epoch   6 | time: 22.97s | valid accuracy    0.751 
-----------------------------------------------------------


100%|██████████| 336/336 [00:21<00:00, 15.56it/s]
100%|██████████| 18/18 [00:01<00:00, 14.14it/s]


-----------------------------------------------------------
| end of epoch   7 | time: 22.87s | valid accuracy    0.781 
-----------------------------------------------------------


100%|██████████| 336/336 [00:21<00:00, 15.35it/s]
100%|██████████| 18/18 [00:01<00:00, 16.00it/s]


-----------------------------------------------------------
| end of epoch   8 | time: 23.02s | valid accuracy    0.761 
-----------------------------------------------------------


100%|██████████| 336/336 [00:21<00:00, 15.32it/s]
100%|██████████| 18/18 [00:01<00:00, 15.47it/s]


-----------------------------------------------------------
| end of epoch   9 | time: 23.10s | valid accuracy    0.825 
-----------------------------------------------------------


100%|██████████| 336/336 [00:22<00:00, 14.69it/s]
100%|██████████| 18/18 [00:01<00:00, 13.62it/s]


-----------------------------------------------------------
| end of epoch  10 | time: 24.20s | valid accuracy    0.820 
-----------------------------------------------------------


100%|██████████| 336/336 [00:21<00:00, 15.31it/s]
100%|██████████| 18/18 [00:01<00:00, 15.59it/s]


-----------------------------------------------------------
| end of epoch  11 | time: 23.10s | valid accuracy    0.820 
-----------------------------------------------------------


100%|██████████| 336/336 [00:21<00:00, 15.58it/s]
100%|██████████| 18/18 [00:01<00:00, 15.53it/s]


-----------------------------------------------------------
| end of epoch  12 | time: 22.73s | valid accuracy    0.820 
-----------------------------------------------------------


100%|██████████| 336/336 [00:21<00:00, 15.39it/s]
100%|██████████| 18/18 [00:01<00:00, 16.05it/s]


-----------------------------------------------------------
| end of epoch  13 | time: 22.96s | valid accuracy    0.820 
-----------------------------------------------------------


100%|██████████| 336/336 [00:21<00:00, 15.61it/s]
100%|██████████| 18/18 [00:01<00:00, 15.92it/s]


-----------------------------------------------------------
| end of epoch  14 | time: 22.65s | valid accuracy    0.820 
-----------------------------------------------------------


100%|██████████| 336/336 [00:21<00:00, 15.77it/s]
100%|██████████| 18/18 [00:01<00:00, 15.10it/s]


-----------------------------------------------------------
| end of epoch  15 | time: 22.50s | valid accuracy    0.820 
-----------------------------------------------------------


100%|██████████| 336/336 [00:21<00:00, 15.58it/s]
100%|██████████| 18/18 [00:01<00:00, 15.83it/s]


-----------------------------------------------------------
| end of epoch  16 | time: 22.70s | valid accuracy    0.820 
-----------------------------------------------------------


100%|██████████| 336/336 [00:21<00:00, 15.64it/s]
100%|██████████| 18/18 [00:01<00:00, 15.65it/s]


-----------------------------------------------------------
| end of epoch  17 | time: 22.64s | valid accuracy    0.820 
-----------------------------------------------------------


100%|██████████| 336/336 [00:21<00:00, 15.34it/s]
100%|██████████| 18/18 [00:01<00:00, 16.21it/s]


-----------------------------------------------------------
| end of epoch  18 | time: 23.02s | valid accuracy    0.820 
-----------------------------------------------------------


100%|██████████| 336/336 [00:21<00:00, 15.75it/s]
100%|██████████| 18/18 [00:01<00:00, 15.09it/s]


-----------------------------------------------------------
| end of epoch  19 | time: 22.53s | valid accuracy    0.820 
-----------------------------------------------------------


100%|██████████| 336/336 [00:22<00:00, 14.63it/s]
100%|██████████| 18/18 [00:01<00:00, 13.48it/s]

-----------------------------------------------------------
| end of epoch  20 | time: 24.30s | valid accuracy    0.820 
-----------------------------------------------------------





In [24]:
print('Checking the results of test dataset.')
accu_test = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(accu_test))

Checking the results of test dataset.


100%|██████████| 236/236 [00:15<00:00, 15.60it/s]

test accuracy    0.737



