In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
from torchtext.data.utils import get_tokenizer, ngrams_iterator 
from torchtext.datasets import DATASETS
from torchtext.vocab import build_vocab_from_iterator, FastText
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import argparse
import logging
import time

In [2]:
SELECTED_DATASET = "AG_NEWS"
DATASET_DIR = "data"
DEVICE_TYPE = "cpu"
EMBEDDING_DIM = 300
LEARNING_RATE = 4.0
BATCH_SIZE = 16
EPOCHS = 5
PAD_VALUE = 0
PAD_IDX = PAD_VALUE

In [3]:
tokenizer = get_tokenizer("basic_english")

In [4]:
def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

In [5]:
train_data_iter = DATASETS[SELECTED_DATASET](root=DATASET_DIR, split="train")
vocab = build_vocab_from_iterator(yield_tokens(train_data_iter), specials=('<pad>', '<unk>'))
vocab.set_default_index(vocab['<unk>'])

In [12]:
# Get Embeddings
FAST_TEXT = FastText("simple")

.vector_cache/wiki.simple.vec: 293MB [00:04, 63.2MB/s]                              
  0%|          | 0/111051 [00:00<?, ?it/s]Skipping token b'111051' with 1-dimensional vector [b'300']; likely a header
100%|██████████| 111051/111051 [00:09<00:00, 11549.49it/s]


In [20]:
def text_pipeline(text):
    return vocab(tokenizer(text))

def label_pipeline(label):
    return int(label) - 1

In [21]:
def collate_batch(batch):
    labels, texts = [], []
    for (label, text) in batch:
        labels.append(label_pipeline(label))
        processed_text = torch.tensor(text_pipeline(text), dtype=torch.int64)
        texts.append(processed_text.clone().detach())
    
    labels = torch.tensor(labels, dtype=torch.int64)
    texts = pad_sequence(texts, batch_first=True)
            
    return labels.to(DEVICE_TYPE), texts.to(DEVICE_TYPE)

In [7]:
train_iter = DATASETS[SELECTED_DATASET](root=DATASET_DIR, split="train")
num_classes = len(set([label for (label, _) in train_iter]))

In [13]:
# Define models

class CNN1dClassifier(nn.Module):
    def __init__(self, vocab_size, num_classes, embed_dim=300, pretrained=True, fine_tune=True):
        super(CNN1dClassifier, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=PAD_IDX)
        
        if pretrained:
            self.embedding.weight.requires_grad = False
            for i in range(vocab_size):
                token = vocab.lookup_token(i)
                                
                self.embedding.weight[i, :] = FAST_TEXT.get_vecs_by_tokens(
                    token, 
                    lower_case_backup=True
                )
            self.embedding.weight.requires_grad = True
        else:
            self.init_weights()
                
        if not fine_tune:
            self.embedding.weight.requires_grad = False
        
        self.conv2 = nn.Conv1d(embed_dim, 1, 2)
        self.conv3 = nn.Conv1d(embed_dim, 1, 3)
        self.conv4 = nn.Conv1d(embed_dim, 1, 4)
        
        self.fc = nn.Linear(3, num_classes)
        self.dropout = nn.Dropout(0.3)
        
        self.debug = True
        
    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text):
        embedded = self.embedding(text)
        
        if self.debug:
            print('embedding', embedded.shape)
        
        embedded = embedded.transpose(1, 2)
        
        conv2_out = nn.ReLU()(self.conv2(embedded))
        if self.debug:
            print('conv2', conv2_out.shape)
        
        conv3_out = nn.ReLU()(self.conv3(embedded))
        if self.debug:
            print('conv3', conv3_out.shape)
        
        conv4_out = nn.ReLU()(self.conv4(embedded))
        if self.debug:
            print('conv4', conv4_out.shape)
        
        conv2_out = nn.MaxPool1d(conv2_out.size(-1))(conv2_out)
        conv3_out = nn.MaxPool1d(conv3_out.size(-1))(conv3_out)
        conv4_out = nn.MaxPool1d(conv4_out.size(-1))(conv4_out)
        if self.debug:
            print('conv2 after max', conv2_out.shape)
        
        conv_concat = self.dropout(
            torch.cat((conv2_out.squeeze(1), conv3_out.squeeze(1), conv4_out.squeeze(1)), -1)
        )
        if self.debug:
            print('conv concat', conv_concat.shape)
                        
        out = self.fc(conv_concat)
        
        self.debug = False
        
        return out
    
class CNN2dClassifier(nn.Module):
    def __init__(self, vocab_size, num_classes, embed_dim=300, pretrained=True, fine_tune=True):
        super(CNN2dClassifier, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=PAD_IDX)
        
        if pretrained:
            self.embedding.weight.requires_grad = False
            for i in range(vocab_size):
                tokens = vocab.lookup_token(i)
                self.embedding.weight[i, :] = FAST_TEXT.get_vecs_by_tokens(tokenizer(tokens), lower_case_backup=True)
            self.embedding.weight.requires_grad = True
        else:
            self.init_weights()
                
        if not fine_tune:
            self.embedding.weight.requires_grad = False
        
        self.convs = [
            nn.Conv2d(1, 1, (embed_dim, 2)),
            nn.Conv2d(1, 1, (embed_dim, 3)),
            nn.Conv2d(1, 1, (embed_dim, 4))
        ]
        
        self.fc = nn.Linear(3, num_classes)
        self.dropout = nn.Dropout(0.3)
        
        self.debug = True
        
    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
    
    def forward(self, text):
        embedded = self.embedding(text)
        
        if self.debug:
            print('embedded ', embedded.shape)
        
        embedded = embedded.transpose(1, 2)
        
        if self.debug:
            print('embedded ', embedded.shape)
                    
        convs = [
            nn.ReLU()(conv(embedded.unsqueeze(1))) for conv in self.convs
        ]
        
        if self.debug:
            print('conv ', [c.shape for c in convs])
        
        pooled = [
            nn.MaxPool2d((1, conv.size(-1)))(conv).squeeze() for conv in convs
        ]
        
        if self.debug:
            print('pooled ', [c.shape for c in pooled])
        
        pooled_stack = self.dropout(torch.vstack(pooled).t())
        
        if self.debug:
            print('pooled_stack ', pooled_stack.shape)
            
                        
        out = self.fc(pooled_stack)
        
        self.debug = False
        
        return out

### Set up the model

In [14]:
criterion = torch.nn.CrossEntropyLoss().to(DEVICE_TYPE)
model_1d = CNN1dClassifier(len(vocab), num_classes).to(DEVICE_TYPE)
model_2d = CNN2dClassifier(len(vocab), num_classes).to(DEVICE_TYPE)

In [15]:
model = model_1d

optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)

In [16]:
train_iter, test_iter = DATASETS[SELECTED_DATASET]()
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)

num_train = int(len(train_dataset) * 0.95)
split_train, split_valid = random_split(train_dataset, [num_train, len(train_dataset) - num_train])

train_dataloader = DataLoader(split_train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)

In [17]:
def train(dataloader, model, optimizer, criterion, epoch):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 1000

    for idx, (label, text) in tqdm(enumerate(dataloader), total=len(dataloader), mininterval=3):
        optimizer.zero_grad()
        predicted_label = model(text)
                
        loss = criterion(input=predicted_label, target=label)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        
        optimizer.step()
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            print(
                "| epoch {:3d} | {:5d}/{:5d} batches "
                "| accuracy {:8.3f}".format(epoch, idx, len(dataloader), total_acc / total_count)
            )
            total_acc, total_count = 0, 0

In [18]:
def evaluate(dataloader, model):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (label, text) in enumerate(dataloader):
            predicted_label = model(text)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc / total_count

In [22]:
for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader, model, optimizer, criterion, epoch)
    accu_val = evaluate(valid_dataloader, model)
    scheduler.step()
    print("-" * 59)
    print(
        "| end of epoch {:3d} | time: {:5.2f}s | "
        "valid accuracy {:8.3f} ".format(epoch, time.time() - epoch_start_time, accu_val)
    )
    print("-" * 59)

print("Checking the results of test dataset.")
accu_test = evaluate(test_dataloader, model)
print("test accuracy {:8.3f}".format(accu_test))

  0%|          | 0/7125 [00:00<?, ?it/s]

embedding torch.Size([16, 66, 300])
conv2 torch.Size([16, 1, 65])
conv3 torch.Size([16, 1, 64])
conv4 torch.Size([16, 1, 63])
conv2 after max torch.Size([16, 1, 1])
conv concat torch.Size([16, 3])


 13%|█▎        | 926/7125 [00:21<02:19, 44.31it/s]

| epoch   1 |  1000/ 7125 batches | accuracy    0.537


 27%|██▋       | 1930/7125 [00:42<01:48, 47.84it/s]

| epoch   1 |  2000/ 7125 batches | accuracy    0.560


 41%|████▏     | 2954/7125 [01:03<01:26, 48.31it/s]

| epoch   1 |  3000/ 7125 batches | accuracy    0.534


 56%|█████▌    | 4001/7125 [01:25<01:03, 49.06it/s]

| epoch   1 |  4000/ 7125 batches | accuracy    0.552


 70%|██████▉   | 4984/7125 [01:47<00:49, 43.43it/s]

| epoch   1 |  5000/ 7125 batches | accuracy    0.555


 84%|████████▎ | 5954/7125 [02:11<00:27, 43.33it/s]

| epoch   1 |  6000/ 7125 batches | accuracy    0.556


 97%|█████████▋| 6923/7125 [02:38<00:05, 36.40it/s]

| epoch   1 |  7000/ 7125 batches | accuracy    0.555


100%|██████████| 7125/7125 [02:43<00:00, 43.68it/s]


-----------------------------------------------------------
| end of epoch   1 | time: 164.90s | valid accuracy    0.741 
-----------------------------------------------------------


 14%|█▍        | 1006/7125 [00:32<03:14, 31.51it/s]

| epoch   2 |  1000/ 7125 batches | accuracy    0.606


 27%|██▋       | 1917/7125 [01:04<02:49, 30.80it/s]

| epoch   2 |  2000/ 7125 batches | accuracy    0.618


 41%|████▏     | 2951/7125 [01:39<02:17, 30.26it/s]

| epoch   2 |  3000/ 7125 batches | accuracy    0.637


 56%|█████▌    | 3975/7125 [02:15<01:42, 30.72it/s]

| epoch   2 |  4000/ 7125 batches | accuracy    0.635


 70%|██████▉   | 4959/7125 [02:44<01:08, 31.74it/s]

| epoch   2 |  5000/ 7125 batches | accuracy    0.636


 84%|████████▎ | 5967/7125 [03:16<00:33, 34.16it/s]

| epoch   2 |  6000/ 7125 batches | accuracy    0.632


 98%|█████████▊| 6963/7125 [03:44<00:04, 35.56it/s]

| epoch   2 |  7000/ 7125 batches | accuracy    0.640


100%|██████████| 7125/7125 [03:49<00:00, 31.06it/s]


-----------------------------------------------------------
| end of epoch   2 | time: 231.30s | valid accuracy    0.802 
-----------------------------------------------------------


 13%|█▎        | 937/7125 [00:27<02:56, 35.15it/s]

| epoch   3 |  1000/ 7125 batches | accuracy    0.647


 28%|██▊       | 1980/7125 [00:56<02:22, 36.06it/s]

| epoch   3 |  2000/ 7125 batches | accuracy    0.648


 42%|████▏     | 2988/7125 [01:27<02:27, 27.96it/s]

| epoch   3 |  3000/ 7125 batches | accuracy    0.645


 56%|█████▌    | 3998/7125 [01:55<01:25, 36.70it/s]

| epoch   3 |  4000/ 7125 batches | accuracy    0.646


 69%|██████▉   | 4900/7125 [02:20<01:02, 35.78it/s]

| epoch   3 |  5000/ 7125 batches | accuracy    0.650


 83%|████████▎ | 5908/7125 [02:48<00:32, 37.32it/s]

| epoch   3 |  6000/ 7125 batches | accuracy    0.655


 98%|█████████▊| 6991/7125 [03:16<00:03, 37.99it/s]

| epoch   3 |  7000/ 7125 batches | accuracy    0.648


100%|██████████| 7125/7125 [03:19<00:00, 35.65it/s]


-----------------------------------------------------------
| end of epoch   3 | time: 201.31s | valid accuracy    0.809 
-----------------------------------------------------------


 13%|█▎        | 946/7125 [00:27<02:46, 37.02it/s]

| epoch   4 |  1000/ 7125 batches | accuracy    0.650


 27%|██▋       | 1954/7125 [00:55<02:23, 36.15it/s]

| epoch   4 |  2000/ 7125 batches | accuracy    0.655


 41%|████      | 2930/7125 [01:23<02:05, 33.52it/s]

| epoch   4 |  3000/ 7125 batches | accuracy    0.648


 55%|█████▌    | 3931/7125 [01:52<01:34, 33.72it/s]

| epoch   4 |  4000/ 7125 batches | accuracy    0.655


 69%|██████▊   | 4884/7125 [02:17<01:03, 35.04it/s]

| epoch   4 |  5000/ 7125 batches | accuracy    0.646


 84%|████████▎ | 5963/7125 [02:47<00:31, 37.02it/s]

| epoch   4 |  6000/ 7125 batches | accuracy    0.648


 97%|█████████▋| 6881/7125 [03:15<00:08, 28.75it/s]

| epoch   4 |  7000/ 7125 batches | accuracy    0.647


100%|██████████| 7125/7125 [03:21<00:00, 35.32it/s]


-----------------------------------------------------------
| end of epoch   4 | time: 203.34s | valid accuracy    0.808 
-----------------------------------------------------------


 14%|█▎        | 969/7125 [00:28<02:57, 34.65it/s]

| epoch   5 |  1000/ 7125 batches | accuracy    0.653


 27%|██▋       | 1953/7125 [00:55<02:41, 32.08it/s]

| epoch   5 |  2000/ 7125 batches | accuracy    0.645


 42%|████▏     | 2980/7125 [01:23<01:54, 36.33it/s]

| epoch   5 |  3000/ 7125 batches | accuracy    0.648


 55%|█████▌    | 3949/7125 [01:51<01:27, 36.50it/s]

| epoch   5 |  4000/ 7125 batches | accuracy    0.648


 70%|██████▉   | 4956/7125 [02:21<01:09, 31.00it/s]

| epoch   5 |  5000/ 7125 batches | accuracy    0.653


 84%|████████▍ | 6001/7125 [02:53<00:35, 31.62it/s]

| epoch   5 |  6000/ 7125 batches | accuracy    0.649


 97%|█████████▋| 6911/7125 [03:21<00:06, 32.34it/s]

| epoch   5 |  7000/ 7125 batches | accuracy    0.653


100%|██████████| 7125/7125 [03:27<00:00, 34.26it/s]


-----------------------------------------------------------
| end of epoch   5 | time: 209.52s | valid accuracy    0.809 
-----------------------------------------------------------
Checking the results of test dataset.
test accuracy    0.812
