In [1]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from utils.utils import getDevice
from datasets.newsGroupDataset import NewsGroupDataset
import numpy as np
from tqdm import tqdm
import spacy
import string
import re
import torch
from torch.utils.data import DataLoader
from model.textClassificationModel import TextClassificationModel
import time
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
from torch import nn



  from .autonotebook import tqdm as notebook_tqdm


(7, "From: lerxst@wam.umd.edu (where's my thing)\nSubject: WHAT car is this!?\nNntp-Posting-Host: rac3.wam.umd.edu\nOrganization: University of Maryland, College Park\nLines: 15\n\n I was wondering if anyone out there could enlighten me on this car I saw\nthe other day. It was a 2-door sports car, looked to be from the late 60s/\nearly 70s. It was called a Bricklin. The doors were really small. In addition,\nthe front bumper was separate from the rest of the body. This is \nall I know. If anyone can tellme a model name, engine specs, years\nof production, where this car is made, history, or whatever info you\nhave on this funky looking car, please e-mail.\n\nThanks,\n- IL\n   ---- brought to you by your neighborhood Lerxst ----\n\n\n\n\n")


In [2]:




device = getDevice()

data_iter = NewsGroupDataset(subset='train')
# print(data_iter[0])

#tokenization
en = spacy.load('en_core_web_trf')

def tokenizer(text):
    text = re.sub(r"[^\x00-\x7F]+", " ", text)
    regex = re.compile('[' + re.escape(string.punctuation) + '0-9\\r\\t\\n]') # remove punctuation and numbers
    nopunct = regex.sub(" ", text.lower())
    return [token.text for token in en.tokenizer(nopunct)]

def yield_tokens(data_iter):
    for _, text in tqdm(data_iter):
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(data_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

print(vocab.__getitem__('i'.lower()))

100%|██████████| 11314/11314 [00:23<00:00, 475.71it/s]


10


In [3]:
text_pipeline = lambda x: vocab(tokenizer(x))

print(text_pipeline("enhanced"))

[4807]


In [4]:
def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    for (_label, _text) in batch:
         label_list.append(_label)
         processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
         text_list.append(processed_text)
         offsets.append(processed_text.size(0))
    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)
    return label_list.to(device), text_list.to(device), offsets.to(device)

train_iter = NewsGroupDataset(subset='train')
dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)

# a = 0
# for label, text, offset in tqdm(dataloader):
#     a = a+1

In [5]:
num_class = len(list(train_iter.target_names))
vocab_size = len(vocab)
emsize = 300
model = TextClassificationModel(vocab_size, emsize, num_class).to(device)
print(vocab_size)

# Hyperparameters
EPOCHS = 20 # epoch
LR = 5  # learning rate
BATCH_SIZE = 64 # batch size for training

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)

89300


In [6]:

def train(dataloader):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 500
    start_time = time.time()

    for idx, (label, text, offsets) in enumerate(tqdm(dataloader)):
        optimizer.zero_grad()
        predicted_label = model(text, offsets)
        loss = criterion(predicted_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches '
                  '| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),
                                              total_acc/total_count))
            total_acc, total_count = 0, 0
            start_time = time.time()

def evaluate(dataloader):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(tqdm(dataloader)):
            predicted_label = model(text, offsets)
            loss = criterion(predicted_label, label)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc/total_count

In [7]:
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None
train_dataset = NewsGroupDataset(subset='train')
test_dataset = NewsGroupDataset(subset='test')
num_train = int(len(train_dataset) * 0.90)
split_train_, split_valid_ = \
    random_split(train_dataset, [num_train, len(train_dataset) - num_train])

train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                             shuffle=True, collate_fn=collate_batch)

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader)
    accu_val = evaluate(valid_dataloader)
    if total_accu is not None and total_accu > accu_val:
      scheduler.step()
    else:
       total_accu = accu_val
    print('-' * 59)
    print('| end of epoch {:3d} | time: {:5.2f}s | '
          'valid accuracy {:8.3f} '.format(epoch,
                                           time.time() - epoch_start_time,
                                           accu_val))
    print('-' * 59)

100%|██████████| 142/142 [00:19<00:00,  7.46it/s]
100%|██████████| 36/36 [00:04<00:00,  8.14it/s]


-----------------------------------------------------------
| end of epoch   1 | time: 23.45s | valid accuracy    0.191 
-----------------------------------------------------------


100%|██████████| 142/142 [00:17<00:00,  8.11it/s]
100%|██████████| 36/36 [00:04<00:00,  8.04it/s]


-----------------------------------------------------------
| end of epoch   2 | time: 21.98s | valid accuracy    0.323 
-----------------------------------------------------------


100%|██████████| 142/142 [00:17<00:00,  8.18it/s]
100%|██████████| 36/36 [00:04<00:00,  8.06it/s]


-----------------------------------------------------------
| end of epoch   3 | time: 21.83s | valid accuracy    0.408 
-----------------------------------------------------------


100%|██████████| 142/142 [00:17<00:00,  8.08it/s]
100%|██████████| 36/36 [00:04<00:00,  8.09it/s]


-----------------------------------------------------------
| end of epoch   4 | time: 22.02s | valid accuracy    0.445 
-----------------------------------------------------------


100%|██████████| 142/142 [00:17<00:00,  8.09it/s]
100%|██████████| 36/36 [00:04<00:00,  8.06it/s]


-----------------------------------------------------------
| end of epoch   5 | time: 22.02s | valid accuracy    0.529 
-----------------------------------------------------------


100%|██████████| 142/142 [00:17<00:00,  8.06it/s]
100%|██████████| 36/36 [00:04<00:00,  8.00it/s]


-----------------------------------------------------------
| end of epoch   6 | time: 22.13s | valid accuracy    0.570 
-----------------------------------------------------------


100%|██████████| 142/142 [00:17<00:00,  8.09it/s]
100%|██████████| 36/36 [00:04<00:00,  7.95it/s]


-----------------------------------------------------------
| end of epoch   7 | time: 22.09s | valid accuracy    0.608 
-----------------------------------------------------------


100%|██████████| 142/142 [00:17<00:00,  8.00it/s]
100%|██████████| 36/36 [00:04<00:00,  8.21it/s]


-----------------------------------------------------------
| end of epoch   8 | time: 22.14s | valid accuracy    0.663 
-----------------------------------------------------------


100%|██████████| 142/142 [00:17<00:00,  8.16it/s]
100%|██████████| 36/36 [00:04<00:00,  8.20it/s]


-----------------------------------------------------------
| end of epoch   9 | time: 21.80s | valid accuracy    0.685 
-----------------------------------------------------------


100%|██████████| 142/142 [00:17<00:00,  8.21it/s]
100%|██████████| 36/36 [00:04<00:00,  7.95it/s]


-----------------------------------------------------------
| end of epoch  10 | time: 21.83s | valid accuracy    0.730 
-----------------------------------------------------------


100%|██████████| 142/142 [00:17<00:00,  7.97it/s]
100%|██████████| 36/36 [00:04<00:00,  7.59it/s]


-----------------------------------------------------------
| end of epoch  11 | time: 22.56s | valid accuracy    0.726 
-----------------------------------------------------------


100%|██████████| 142/142 [00:17<00:00,  8.01it/s]
100%|██████████| 36/36 [00:04<00:00,  8.02it/s]


-----------------------------------------------------------
| end of epoch  12 | time: 22.21s | valid accuracy    0.769 
-----------------------------------------------------------


100%|██████████| 142/142 [00:17<00:00,  8.24it/s]
100%|██████████| 36/36 [00:04<00:00,  7.97it/s]


-----------------------------------------------------------
| end of epoch  13 | time: 21.75s | valid accuracy    0.770 
-----------------------------------------------------------


100%|██████████| 142/142 [00:17<00:00,  8.26it/s]
100%|██████████| 36/36 [00:04<00:00,  8.23it/s]


-----------------------------------------------------------
| end of epoch  14 | time: 21.57s | valid accuracy    0.775 
-----------------------------------------------------------


100%|██████████| 142/142 [00:17<00:00,  8.20it/s]
100%|██████████| 36/36 [00:04<00:00,  8.18it/s]


-----------------------------------------------------------
| end of epoch  15 | time: 21.72s | valid accuracy    0.767 
-----------------------------------------------------------


100%|██████████| 142/142 [00:17<00:00,  8.22it/s]
100%|██████████| 36/36 [00:04<00:00,  8.14it/s]


-----------------------------------------------------------
| end of epoch  16 | time: 21.70s | valid accuracy    0.771 
-----------------------------------------------------------


100%|██████████| 142/142 [00:17<00:00,  8.16it/s]
100%|██████████| 36/36 [00:04<00:00,  8.24it/s]


-----------------------------------------------------------
| end of epoch  17 | time: 21.78s | valid accuracy    0.771 
-----------------------------------------------------------


100%|██████████| 142/142 [00:17<00:00,  8.22it/s]
100%|██████████| 36/36 [00:04<00:00,  8.15it/s]


-----------------------------------------------------------
| end of epoch  18 | time: 21.69s | valid accuracy    0.771 
-----------------------------------------------------------


100%|██████████| 142/142 [00:17<00:00,  8.23it/s]
100%|██████████| 36/36 [00:04<00:00,  8.21it/s]


-----------------------------------------------------------
| end of epoch  19 | time: 21.64s | valid accuracy    0.771 
-----------------------------------------------------------


100%|██████████| 142/142 [00:17<00:00,  8.25it/s]
100%|██████████| 36/36 [00:04<00:00,  7.98it/s]

-----------------------------------------------------------
| end of epoch  20 | time: 21.72s | valid accuracy    0.771 
-----------------------------------------------------------





In [8]:
print('Checking the results of test dataset.')
accu_test = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(accu_test))

Checking the results of test dataset.


100%|██████████| 118/118 [00:13<00:00,  8.56it/s]

test accuracy    0.675



