In [2]:
!pip install portalocker
!pip install torchmetrics



In [3]:
import argparse
import logging
import time

import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
from torchtext.data.utils import get_tokenizer, ngrams_iterator
from torchtext.datasets import DATASETS
from torchtext.prototype.transforms import load_sp_model, PRETRAINED_SP_MODEL, SentencePieceTokenizer
from torchtext.utils import download_from_url
from torchtext.vocab import build_vocab_from_iterator
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from torchtext.vocab import GloVe
from tqdm import tqdm

torch.autograd.set_detect_anomaly(True)


### Information
- torchtext repo: https://github.com/pytorch/text/tree/main/torchtext
- torchtext documentation: https://pytorch.org/text/stable/index.html

### Constants

In [4]:
DATASET = "AG_NEWS"
DATA_DIR = ".data"
DEVICE = "cuda"
EMBED_DIM = 300
LR = 4.0 ## will modify when training RCNN
BATCH_SIZE = 16
NUM_EPOCHS = 5
PADDING_VALUE = 0
PADDING_IDX = PADDING_VALUE

### Get the tokenizer
- Use the WordLevel tokenizer.


In [90]:
basic_english_tokenizer = get_tokenizer("basic_english")

In [91]:
basic_english_tokenizer("This is some text ...")

['this', 'is', 'some', 'text', '.', '.', '.']

In [92]:
TOKENIZER = basic_english_tokenizer

### Get the data and get the vocabulary

In [8]:
def yield_tokens(data_iter):
    for _, text in data_iter:
        yield TOKENIZER(text)

In [9]:
train_iter = DATASETS[DATASET](root=DATA_DIR, split="train")
VOCAB = build_vocab_from_iterator(yield_tokens(train_iter), specials=('<pad>', '<unk>'))

# Make the default index the same as that of the unk_token
VOCAB.set_default_index(VOCAB['<unk>'])

### Get GloVe embeddings ... This will be slow ...

In [11]:
GLOVE = GloVe()

In [12]:
len(GLOVE), GLOVE.vectors.shape

(2196017, torch.Size([2196017, 300]))

### Helper functions

In [13]:
def text_pipeline(text):
    return VOCAB(TOKENIZER(text))

def label_pipeline(label):
    return int(label) - 1

Nice link on collate_fn and DataLoader in PyTorch: https://python.plainenglish.io/understanding-collate-fn-in-pytorch-f9d1742647d3

In [14]:
## -- left shift labels since indices start from 0, convert text to int64, and convert the processed labels and text(int)s to torch tensor, then move them to DEVICE

## -- the label and texts are processed separately, thus we don't need to zip them as (label, text) tuples.
## -- We may just loop through all the labels to do the left shifting, and loop through all the text to apply text_pipeline()
def collate_batch(batch):
    label_list, text_list = [], []
    for (_label, _text) in batch:
        # Get the label from {1, 2, 3, 4} to {0, 1, 2, 3}
        label_list.append(label_pipeline(_label))

        # Return a list of ints.
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text.clone().detach())

    label_list = torch.tensor(label_list, dtype=torch.int64)
    text_list = pad_sequence(text_list, batch_first=True)

    return label_list.to(DEVICE), text_list.to(DEVICE)

### Get the data

In [15]:
train_iter = DATASETS[DATASET](root=DATA_DIR, split="train")
num_class = len(set([label for (label, _) in train_iter]))
print(f"The number of classes is {num_class} ...")

The number of classes is 4 ...


### Set up the model

Good reference on this type of model
- Recurrent CNN: https://ojs.aaai.org/index.php/AAAI/article/view/9513/9372

In [111]:
class CNN1dTextClassificationModel(nn.Module):
    def __init__(
        self,
        vocab_size,
        num_class,
        embed_dim = 300,
        use_pretrained = True,
        fine_tune_embeddings = True
    ):

        super(CNN1dTextClassificationModel, self).__init__()

        # Set to embeddings layer of vocab_size and embed_dim vector dimension
        # Set the PADDING_IDX appropriately
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=PADDING_IDX)

        if use_pretrained:
            # Set the embeddings to not requiring gradients since we'll try and modify
            self.embedding.weight.requires_grad = False
            for i in range(vocab_size):
                # Get the token for the index i
                # token = VOCAB.get_itos()[i]
                token = VOCAB.lookup_token(i) ## so much faster than get_itos
                # Modify the embedding for index i by the embedding for that token
                # Do this only if token is in the stoi dictionary for GLOVE
                if token in GLOVE.stoi:
                    self.embedding.weight[i, :] = GLOVE[token]
            # Reset to True the weights
            self.embedding.weight.requires_grad = True
        else:
            # Otherwise, initialize the weights
            self.init_weights()

        # Turn off gradients
        if not fine_tune_embeddings:
            self.embedding.weight.requires_grad = False

        # Define 3 Conv1d layers each having 1 filter and kernel sizes 2, 3 and 4
        self.cnn2 = nn.Conv1d(in_channels=embed_dim, out_channels=1, kernel_size=2)
        self.cnn3 = nn.Conv1d(in_channels=embed_dim, out_channels=1, kernel_size=3)
        self.cnn4 = nn.Conv1d(in_channels=embed_dim, out_channels=1, kernel_size=4)

        some_dim = 3  ### from error message
        self.fc = nn.Linear(in_features = some_dim, out_features=num_class)

        # For drop out + ReLu, order does not matter below
        self.dropout = nn.Dropout(p=0.3)

        self.debug = False

    def init_weights(self):
        initrange = 0.5
        # Initialize the embedding weight matrix to uniform between the [-0.5, 0.5]
        self.embedding.weight.data.uniform_(-initrange, initrange)
        # Initialize the weight matrix of fc to uniform between the [-0.5, 0.5]
        self.fc.weight.data.uniform_(-initrange, initrange)
        # Initialize the bias for fc to zero
        self.fc.bias.data.zero_()

    # B = batch_size, L = sequence length, D = vector dimension
    def forward(self, text):

        # B X L X D
        # Get the embeddings for the text passed in
        embedded = self.embedding(text)

        if self.debug:
            print('embedding', embedded.shape)

        # B X D X L
        # Transpose the embedding above as needed
        embedded = embedded.transpose(1, 2)

        # B X 1 X L - 1
        # Pass through cnn2
        cnn2 = F.relu(self.cnn2(embedded))
        if self.debug:
            print('cnn2', cnn2.shape)

        # B X 1 X L - 2
        # Pass through cnn3
        cnn3 = F.relu(self.cnn3(embedded))
        if self.debug:
            print('cnn3', cnn3.shape)

        # B X 1 X L - 3
        # Pass through cnn4
        cnn4 = F.relu(self.cnn4(embedded))
        if self.debug:
            print('cnn4', cnn4.shape)

        # B X 1 in all cases
        # Apply max pooling to each of cnn2, cnn3 and cnn4
        cnn2 = cnn2.max(dim=2)[0]
        cnn3 = cnn3.max(dim=2)[0]
        cnn4 = cnn4.max(dim=2)[0]
        if self.debug:
            print('cnn2 after max', cnn2.shape)

        # B X 3
        # Concatenate and add drop out to the result
        cnn_concat = torch.cat((cnn2, cnn3, cnn4), dim=1)
        cnn_concat = self.dropout(cnn_concat)
        if self.debug:
            print('cnn concat', cnn_concat.shape)
            self.debug = False

        # Pass through an appropriate Linear layer to get the right dimensions needed
        out = self.fc(cnn_concat)

        return out

class RecurrentCNNModel(nn.Module):
    def __init__(
        self,
        vocab_size,
        num_class = 4,
        e = 300, # embedding dimension
        use_pretrained = True,
        fine_tune_embeddings = True,
        # If true, this will print out the shapes of data in the forward pass for the first batch
        # This will be set to False after the first forward pass
        debug = True
    ):

        super(RecurrentCNNModel, self).__init__()

        # Set to a nn.Embedding laer for vocab_size size and e dimension
        self.embedding = nn.Embedding(vocab_size, e)

        # Set as in the paper
        self.c = 100
        self.h = 100 # hidden??
        self.initrange = 0.5

        # Same as for the CNN model above
        if use_pretrained:
            self.embedding.weight.requires_grad = False
            for i in range(vocab_size):
                # Get the token for the index i
                token = VOCAB.lookup_token(i) ## faster
                # Modify the embedding for index i by the embedding for that token
                # Do this only if token is in the stoi dictionary for GLOVE
                if token in GLOVE.stoi:
                    self.embedding.weight[i, :] = GLOVE[token]
            # Reset to True the weights
            self.embedding.weight.requires_grad = True
        else:
            # Otherwise, initialize the weights
            self.init_weights()

        if not fine_tune_embeddings:
            # Turn off gradients for the embedding weight
            self.embedding.weight.requires_grad = False

        # Set Wl, Wr, Wsl, Wsr etc as in the paper
        # Used in (1) and (2)
        ## tried using nn.Parameter but resulted in very low acc
        # self.Wl = nn.Parameter(torch.Tensor(self.c, self.c))
        self.Wl = nn.Linear(self.c,self.c)
        # self.Wr = nn.Parameter(torch.Tensor(self.c, self.c))
        self.Wr = nn.Linear(self.c, self.c)

        # Used in (1) and (2)
        self.Wsl = nn.Linear(e,self.c)
        self.Wsr =  nn.Linear(e,self.c)

        # Used in equations (4) and (6)
        # self.W2 = nn.Parameter(torch.Tensor(self.h, 2 * self.c + e))
        self.W2 = nn.Linear(e + self.c*2,self.h)
        # self.W4 = nn.Parameter(torch.Tensor(num_class, self.h))
        self.W4 = nn.Linear(self.h,num_class)

        # For drop out + ReLu, order does not matter.
        self.dropout = nn.Dropout(p=.3)

        self.debug = False

    def init_weights(self):
      # Set some of these to uniform on [-initrange, initrange]
      # The biases can be set to 0
      initrange = 0.5
      # self.embedding.weight.data.uniform_(-initrange, initrange)
      # self.Wl.weight.data.uniform_(-initrange, initrange)
      # self.Wr.weight.data.uniform_(-initrange, initrange)
      # self.Wsl.weight.data.uniform_(-initrange, initrange)
      # self.Wsr.weight.data.uniform_(-initrange, initrange)
      ## only initialize W2 and W4 to avoid the error of 'LogSoftmaxBackward0' returned nan values in its 0th output
      self.W2.weight.data.uniform_(-initrange, initrange)
      self.W4.weight.data.uniform_(-initrange, initrange)
      # self.b2 = nn.Parameter(torch.zeros(self.h))
      # self.b4 = nn.Parameter(torch.zeros(num_class))

    # B = batch_size, L = sequence length, e = vector dimension
    def forward(self, text):
        # Text is originally B X L

        # B X L X e
        embedded = self.embedding(text)

        N, L, D = embedded.shape

        # N X L X c
        cr = torch.zeros((N, L, self.c), device=text.device)
        # print(text.device == DEVICE)
        # print(text.device, DEVICE)
        if self.debug:
            print('cr ', cr.shape)
                # N X L X c

        # N X L X c
        cl = torch.zeros_like(cr)

        # N X L X c
        # We need to clone here or we get this error:
        # https://nieznanm.medium.com/runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-85d0d207623
        for l in range(1, L):
            # print(self.Wl.in_features, self.Wl.out_features, cl[:, l-1, :].transpose(0,2).shape, self.Wsl.T.shape,embedded.T.shape)
            # torch.Size([100, 100]) torch.Size([16, 60, 100]) torch.Size([300, 100]) torch.Size([16, 60, 300])
            cl[:, l, :] = F.relu(self.Wl(cl[:, l-1, :].clone()) + self.Wsl(embedded[:, l-1, :].clone()))


        # N X L X c
        # Set cr as in the paper from equation (3)
        for l in range(L-2, -1, -1):
            cr[:, l, :] = F.relu(self.Wr(cr[:, l+1, :].clone()) + self.Wsr(embedded[:, l+1, :].clone()))


        # B X L X (2c + e)
        # Set x as in the paper; this is equation (3)
        x = torch.cat([cl,embedded,cr], dim=2)
        if self.debug:
            print('x ', x.shape)

        # B X L X h
        # Set y2 as in equation (4)
        # W2: h * (e+2c) = 100 * (300 + 200),
        # print(self.W2.shape, x.shape)
        y2 = F.tanh(self.W2(x)).squeeze()
        if self.debug:
            print('y2 ', y2.shape)

        # B X H X L
        y2 = y2.permute(0, 2, 1)
        if self.debug:
            print('y2 ', y2.shape)

        # Set y3 from y2 as in equation (5)
        y3, _ = torch.max(y2, dim=2)
        # y3 = FILL
        if self.debug:
            print('y3 ', y3.shape)

        # Set y4 from W4 and y3
        # print(self.W4.shape, y3.shape)
        # y4 = (self.W4 @ y3)
        y4 = self.W4(y3)
        if self.debug:
            print('y4 ', y4.shape)
            # Set to False after this is done
            self.debug = False

        ## if self.debug, print as below:
        # torch.Size([16, 56, 100])
        # x  torch.Size([16, 56, 500])
        # y2  torch.Size([16, 56, 100])
        # y2  torch.Size([16, 100, 56])
        # y3  torch.Size([16, 100])
        # y4  torch.Size([16, 4])
        return y4

### Set up the model

In [112]:
# If this is True, we will initialize the Embedding layer with GLOVE
USE_PRETRANED = True,

# If this is True, we will allow for gradient updates on the nn.Embedding layer
FINE_TUNE_EMBEDDINGS = True

# Set the loss appropriately
criterion = torch.nn.CrossEntropyLoss().to(DEVICE)

In [113]:
# Select the Recurrent CNN Model
model = RecurrentCNNModel(len(VOCAB)).to(DEVICE)

# Set the optimizer to SGD
LR = 1.0 ### reduce learning rate to avoid the error of 'LogSoftmaxBackward0' returned nan values in its 0th output
optimizer = torch.optim.SGD(model.parameters(), lr=LR)

# Set the scheduler to StepLR with gamma=0.1 and step_size = 1.0
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.1, step_size= 1)

### Set up the data

In [None]:
train_iter, test_iter = DATASETS[DATASET]()
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)

num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])

train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)

### Train the model

In [None]:
def train(dataloader, model, optimizer, criterion, epoch):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 100

    for idx, (label, text) in tqdm(enumerate(dataloader)):

        optimizer.zero_grad()
        predicted_label = model(text)

        # Get the loss
        loss = criterion(predicted_label, label)

        # Do back propagation
        loss.backward()
        # if any(torch.isnan(p.grad).any() for p in model.parameters()):
        #     print("Gradient contains NaN values!")
        # Clip the gradients at 0.1
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)

        # Do an optimization step
        optimizer.step()
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_count += len(label)
        if idx % log_interval == 0 and idx > 0:
            print(
                "| epoch {:3d} | {:5d}/{:5d} batches "
                "| accuracy {:8.3f}".format(epoch, idx, len(dataloader), total_acc / total_count)
            )
            total_acc, total_count = 0, 0

In [None]:
def evaluate(dataloader, model):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (label, text) in enumerate(dataloader):
            predicted_label = model(text)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += len(label)
    return total_acc / total_count

In [25]:
# Train the RNNCNN model
model.to(DEVICE)

## can stop when acc is acceptable
for epoch in range(1, 3):
    epoch_start_time = time.time()
    train(train_dataloader, model, optimizer, criterion, epoch)
    accu_val = evaluate(valid_dataloader, model)
    scheduler.step()
    print("-" * 59)
    print(
        "| end of epoch {:3d} | time: {:5.2f}s | "
        "valid accuracy {:8.3f} ".format(epoch, time.time() - epoch_start_time, accu_val)
    )
    print("-" * 59)

print("Checking the results of test dataset.")
accu_test = evaluate(test_dataloader, model)
print("test accuracy {:8.3f}".format(accu_test))

101it [02:11,  1.30s/it]

| epoch   1 |   100/ 7125 batches | accuracy    0.699


201it [04:24,  1.35s/it]

| epoch   1 |   200/ 7125 batches | accuracy    0.837


301it [06:32,  1.25s/it]

| epoch   1 |   300/ 7125 batches | accuracy    0.868


401it [08:41,  1.18s/it]

| epoch   1 |   400/ 7125 batches | accuracy    0.877


501it [10:51,  1.27s/it]

| epoch   1 |   500/ 7125 batches | accuracy    0.869


601it [12:52,  1.14s/it]

| epoch   1 |   600/ 7125 batches | accuracy    0.871


701it [14:55,  1.29s/it]

| epoch   1 |   700/ 7125 batches | accuracy    0.852


801it [16:59,  1.07s/it]

| epoch   1 |   800/ 7125 batches | accuracy    0.864


901it [19:10,  1.34s/it]

| epoch   1 |   900/ 7125 batches | accuracy    0.884


1001it [21:18,  1.28s/it]

| epoch   1 |  1000/ 7125 batches | accuracy    0.864


1101it [23:30,  1.34s/it]

| epoch   1 |  1100/ 7125 batches | accuracy    0.864


1201it [25:39,  1.25s/it]

| epoch   1 |  1200/ 7125 batches | accuracy    0.871


1301it [27:46,  1.48s/it]

| epoch   1 |  1300/ 7125 batches | accuracy    0.874


1401it [29:57,  1.19s/it]

| epoch   1 |  1400/ 7125 batches | accuracy    0.878


1501it [32:09,  1.55s/it]

| epoch   1 |  1500/ 7125 batches | accuracy    0.883


1601it [34:21,  1.30s/it]

| epoch   1 |  1600/ 7125 batches | accuracy    0.885


1701it [36:37,  1.34s/it]

| epoch   1 |  1700/ 7125 batches | accuracy    0.886


1801it [38:45,  1.25s/it]

| epoch   1 |  1800/ 7125 batches | accuracy    0.893


1901it [40:49,  1.24s/it]

| epoch   1 |  1900/ 7125 batches | accuracy    0.882


2001it [42:52,  1.17s/it]

| epoch   1 |  2000/ 7125 batches | accuracy    0.887


2101it [45:02,  1.28s/it]

| epoch   1 |  2100/ 7125 batches | accuracy    0.888


2201it [47:07,  1.18s/it]

| epoch   1 |  2200/ 7125 batches | accuracy    0.890


2301it [49:20,  1.21s/it]

| epoch   1 |  2300/ 7125 batches | accuracy    0.886


2401it [51:35,  1.23s/it]

| epoch   1 |  2400/ 7125 batches | accuracy    0.877


2501it [53:42,  1.47s/it]

| epoch   1 |  2500/ 7125 batches | accuracy    0.898


2601it [55:53,  1.29s/it]

| epoch   1 |  2600/ 7125 batches | accuracy    0.902


2701it [58:00,  1.33s/it]

| epoch   1 |  2700/ 7125 batches | accuracy    0.879


2801it [1:00:06,  1.16s/it]

| epoch   1 |  2800/ 7125 batches | accuracy    0.886


2901it [1:02:23,  1.12s/it]

| epoch   1 |  2900/ 7125 batches | accuracy    0.873


3001it [1:04:38,  1.60s/it]

| epoch   1 |  3000/ 7125 batches | accuracy    0.874


3101it [1:06:53,  1.27s/it]

| epoch   1 |  3100/ 7125 batches | accuracy    0.882


3201it [1:09:02,  1.25s/it]

| epoch   1 |  3200/ 7125 batches | accuracy    0.892


3301it [1:11:10,  1.31s/it]

| epoch   1 |  3300/ 7125 batches | accuracy    0.896


3401it [1:13:19,  1.25s/it]

| epoch   1 |  3400/ 7125 batches | accuracy    0.874


3501it [1:15:25,  1.15s/it]

| epoch   1 |  3500/ 7125 batches | accuracy    0.897


3601it [1:17:38,  1.74s/it]

| epoch   1 |  3600/ 7125 batches | accuracy    0.882


3701it [1:19:51,  1.26s/it]

| epoch   1 |  3700/ 7125 batches | accuracy    0.886


3801it [1:22:01,  1.47s/it]

| epoch   1 |  3800/ 7125 batches | accuracy    0.892


3901it [1:24:09,  1.14s/it]

| epoch   1 |  3900/ 7125 batches | accuracy    0.892


4001it [1:26:14,  1.40s/it]

| epoch   1 |  4000/ 7125 batches | accuracy    0.892


4101it [1:28:22,  1.78s/it]

| epoch   1 |  4100/ 7125 batches | accuracy    0.896


4201it [1:30:27,  1.23s/it]

| epoch   1 |  4200/ 7125 batches | accuracy    0.892


4301it [1:32:36,  1.18s/it]

| epoch   1 |  4300/ 7125 batches | accuracy    0.876


4401it [1:34:48,  1.31s/it]

| epoch   1 |  4400/ 7125 batches | accuracy    0.890


4501it [1:36:59,  1.18s/it]

| epoch   1 |  4500/ 7125 batches | accuracy    0.892


4601it [1:39:08,  1.10s/it]

| epoch   1 |  4600/ 7125 batches | accuracy    0.899


4701it [1:41:10,  1.12s/it]

| epoch   1 |  4700/ 7125 batches | accuracy    0.889


4801it [1:43:12,  1.27s/it]

| epoch   1 |  4800/ 7125 batches | accuracy    0.899


4901it [1:45:15,  1.06s/it]

| epoch   1 |  4900/ 7125 batches | accuracy    0.900


5001it [1:47:25,  1.16s/it]

| epoch   1 |  5000/ 7125 batches | accuracy    0.881


5101it [1:49:39,  1.20s/it]

| epoch   1 |  5100/ 7125 batches | accuracy    0.891


5201it [1:51:52,  1.51s/it]

| epoch   1 |  5200/ 7125 batches | accuracy    0.904


5301it [1:54:00,  1.07s/it]

| epoch   1 |  5300/ 7125 batches | accuracy    0.891


5401it [1:56:10,  1.12s/it]

| epoch   1 |  5400/ 7125 batches | accuracy    0.878


5501it [1:58:19,  1.20s/it]

| epoch   1 |  5500/ 7125 batches | accuracy    0.912


5601it [2:00:31,  1.30s/it]

| epoch   1 |  5600/ 7125 batches | accuracy    0.912


5701it [2:02:44,  1.09s/it]

| epoch   1 |  5700/ 7125 batches | accuracy    0.900


5801it [2:04:56,  1.38s/it]

| epoch   1 |  5800/ 7125 batches | accuracy    0.899


5901it [2:07:01,  1.20s/it]

| epoch   1 |  5900/ 7125 batches | accuracy    0.900


6001it [2:09:11,  1.23s/it]

| epoch   1 |  6000/ 7125 batches | accuracy    0.909


6101it [2:11:19,  1.41s/it]

| epoch   1 |  6100/ 7125 batches | accuracy    0.884


6201it [2:13:25,  1.38s/it]

| epoch   1 |  6200/ 7125 batches | accuracy    0.887


6301it [2:15:33,  1.56s/it]

| epoch   1 |  6300/ 7125 batches | accuracy    0.896


6401it [2:17:46,  1.21s/it]

| epoch   1 |  6400/ 7125 batches | accuracy    0.908


6501it [2:19:54,  1.32s/it]

| epoch   1 |  6500/ 7125 batches | accuracy    0.903


6601it [2:22:00,  1.19s/it]

| epoch   1 |  6600/ 7125 batches | accuracy    0.884


6701it [2:24:08,  1.53s/it]

| epoch   1 |  6700/ 7125 batches | accuracy    0.894


6801it [2:26:17,  1.05s/it]

| epoch   1 |  6800/ 7125 batches | accuracy    0.914


6901it [2:28:28,  1.17s/it]

| epoch   1 |  6900/ 7125 batches | accuracy    0.897


7001it [2:30:40,  1.37s/it]

| epoch   1 |  7000/ 7125 batches | accuracy    0.907


7101it [2:32:49,  1.32s/it]

| epoch   1 |  7100/ 7125 batches | accuracy    0.911


7125it [2:33:22,  1.29s/it]


-----------------------------------------------------------
| end of epoch   1 | time: 9218.00s | valid accuracy    0.886 
-----------------------------------------------------------


101it [02:18,  1.69s/it]

| epoch   2 |   100/ 7125 batches | accuracy    0.913


201it [04:24,  1.28s/it]

| epoch   2 |   200/ 7125 batches | accuracy    0.927


301it [06:30,  1.45s/it]

| epoch   2 |   300/ 7125 batches | accuracy    0.930


401it [08:40,  1.22s/it]

| epoch   2 |   400/ 7125 batches | accuracy    0.917


501it [10:49,  1.38s/it]

| epoch   2 |   500/ 7125 batches | accuracy    0.927


601it [12:57,  1.29s/it]

| epoch   2 |   600/ 7125 batches | accuracy    0.926


701it [15:13,  1.19s/it]

| epoch   2 |   700/ 7125 batches | accuracy    0.927


801it [17:31,  1.29s/it]

| epoch   2 |   800/ 7125 batches | accuracy    0.927


901it [19:41,  1.39s/it]

| epoch   2 |   900/ 7125 batches | accuracy    0.927


1001it [21:50,  1.34s/it]

| epoch   2 |  1000/ 7125 batches | accuracy    0.931


1101it [23:58,  1.12s/it]

| epoch   2 |  1100/ 7125 batches | accuracy    0.934


1201it [26:18,  1.30s/it]

| epoch   2 |  1200/ 7125 batches | accuracy    0.934


1301it [28:30,  1.36s/it]

| epoch   2 |  1300/ 7125 batches | accuracy    0.929


1401it [30:44,  1.25s/it]

| epoch   2 |  1400/ 7125 batches | accuracy    0.932


1501it [32:56,  1.28s/it]

| epoch   2 |  1500/ 7125 batches | accuracy    0.929


1601it [35:03,  1.28s/it]

| epoch   2 |  1600/ 7125 batches | accuracy    0.934


1701it [37:11,  1.35s/it]

| epoch   2 |  1700/ 7125 batches | accuracy    0.929


1801it [39:21,  1.26s/it]

| epoch   2 |  1800/ 7125 batches | accuracy    0.931


1901it [41:27,  1.42s/it]

| epoch   2 |  1900/ 7125 batches | accuracy    0.938


2001it [43:40,  1.37s/it]

| epoch   2 |  2000/ 7125 batches | accuracy    0.921


2101it [45:51,  1.42s/it]

| epoch   2 |  2100/ 7125 batches | accuracy    0.928


2201it [48:00,  1.34s/it]

| epoch   2 |  2200/ 7125 batches | accuracy    0.940


2301it [50:09,  1.22s/it]

| epoch   2 |  2300/ 7125 batches | accuracy    0.934


2401it [52:21,  1.08s/it]

| epoch   2 |  2400/ 7125 batches | accuracy    0.922


2501it [54:35,  1.12s/it]

| epoch   2 |  2500/ 7125 batches | accuracy    0.929


2601it [56:45,  1.24s/it]

| epoch   2 |  2600/ 7125 batches | accuracy    0.937


2701it [58:53,  1.07s/it]

| epoch   2 |  2700/ 7125 batches | accuracy    0.934


2801it [1:01:03,  1.46s/it]

| epoch   2 |  2800/ 7125 batches | accuracy    0.935


2901it [1:03:18,  1.24s/it]

| epoch   2 |  2900/ 7125 batches | accuracy    0.929


3001it [1:05:33,  1.30s/it]

| epoch   2 |  3000/ 7125 batches | accuracy    0.921


3101it [1:07:43,  1.23s/it]

| epoch   2 |  3100/ 7125 batches | accuracy    0.933


3201it [1:09:50,  1.12s/it]

| epoch   2 |  3200/ 7125 batches | accuracy    0.927


3301it [1:11:59,  1.23s/it]

| epoch   2 |  3300/ 7125 batches | accuracy    0.920


3401it [1:14:10,  1.27s/it]

| epoch   2 |  3400/ 7125 batches | accuracy    0.919


3501it [1:16:26,  1.54s/it]

| epoch   2 |  3500/ 7125 batches | accuracy    0.929


3601it [1:18:34,  1.14s/it]

| epoch   2 |  3600/ 7125 batches | accuracy    0.926


3701it [1:20:46,  1.10s/it]

| epoch   2 |  3700/ 7125 batches | accuracy    0.922


3801it [1:22:53,  1.32s/it]

| epoch   2 |  3800/ 7125 batches | accuracy    0.921


3901it [1:25:04,  1.30s/it]

| epoch   2 |  3900/ 7125 batches | accuracy    0.928


4001it [1:27:14,  1.26s/it]

| epoch   2 |  4000/ 7125 batches | accuracy    0.935


4101it [1:29:24,  1.21s/it]

| epoch   2 |  4100/ 7125 batches | accuracy    0.930


4201it [1:31:30,  1.20s/it]

| epoch   2 |  4200/ 7125 batches | accuracy    0.930


4301it [1:33:42,  1.41s/it]

| epoch   2 |  4300/ 7125 batches | accuracy    0.920


4401it [1:35:54,  1.21s/it]

| epoch   2 |  4400/ 7125 batches | accuracy    0.931


4501it [1:38:01,  1.15s/it]

| epoch   2 |  4500/ 7125 batches | accuracy    0.928


4601it [1:40:01,  1.28s/it]

| epoch   2 |  4600/ 7125 batches | accuracy    0.917


4701it [1:42:01,  1.13s/it]

| epoch   2 |  4700/ 7125 batches | accuracy    0.922


4801it [1:44:09,  1.35s/it]

| epoch   2 |  4800/ 7125 batches | accuracy    0.925


4901it [1:46:22,  1.46s/it]

| epoch   2 |  4900/ 7125 batches | accuracy    0.924


5001it [1:48:35,  1.14s/it]

| epoch   2 |  5000/ 7125 batches | accuracy    0.933


5101it [1:50:40,  1.22s/it]

| epoch   2 |  5100/ 7125 batches | accuracy    0.926


5201it [1:52:50,  1.40s/it]

| epoch   2 |  5200/ 7125 batches | accuracy    0.924


5301it [1:54:54,  1.36s/it]

| epoch   2 |  5300/ 7125 batches | accuracy    0.941


5401it [1:57:03,  1.24s/it]

| epoch   2 |  5400/ 7125 batches | accuracy    0.931


5501it [1:59:09,  1.18s/it]

| epoch   2 |  5500/ 7125 batches | accuracy    0.938


5601it [2:01:20,  1.70s/it]

| epoch   2 |  5600/ 7125 batches | accuracy    0.929


5701it [2:03:24,  1.34s/it]

| epoch   2 |  5700/ 7125 batches | accuracy    0.925


5801it [2:05:38,  1.13s/it]

| epoch   2 |  5800/ 7125 batches | accuracy    0.927


5901it [2:07:42,  1.15s/it]

| epoch   2 |  5900/ 7125 batches | accuracy    0.937


6001it [2:09:52,  1.30s/it]

| epoch   2 |  6000/ 7125 batches | accuracy    0.938


6101it [2:12:06,  1.37s/it]

| epoch   2 |  6100/ 7125 batches | accuracy    0.931


6201it [2:14:19,  1.07s/it]

| epoch   2 |  6200/ 7125 batches | accuracy    0.939


6301it [2:16:35,  1.47s/it]

| epoch   2 |  6300/ 7125 batches | accuracy    0.937


6401it [2:18:46,  1.51s/it]

| epoch   2 |  6400/ 7125 batches | accuracy    0.930


6501it [2:20:55,  1.25s/it]

| epoch   2 |  6500/ 7125 batches | accuracy    0.938


6601it [2:23:04,  1.20s/it]

| epoch   2 |  6600/ 7125 batches | accuracy    0.938


6701it [2:25:11,  1.27s/it]

| epoch   2 |  6700/ 7125 batches | accuracy    0.929


6801it [2:27:16,  1.22s/it]

| epoch   2 |  6800/ 7125 batches | accuracy    0.941


6901it [2:29:20,  1.14s/it]

| epoch   2 |  6900/ 7125 batches | accuracy    0.939


7001it [2:31:36,  1.36s/it]

| epoch   2 |  7000/ 7125 batches | accuracy    0.936


7101it [2:33:43,  1.21s/it]

| epoch   2 |  7100/ 7125 batches | accuracy    0.941


7125it [2:34:12,  1.30s/it]


-----------------------------------------------------------
| end of epoch   2 | time: 9269.51s | valid accuracy    0.917 
-----------------------------------------------------------
Checking the results of test dataset.
test accuracy    0.908


### Train the model

In [24]:
# Make a Conv Text model
model = CNN1dTextClassificationModel(len(VOCAB),4)

LR = 4.0 ## change LR back to 4.0

# Set the optimizer to SGD
optimizer = torch.optim.SGD(model.parameters(), lr=LR)

# Set the scheduler to StepLR with gamma=0.1 and step_size = 1.0
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.1, step_size= 1)


In [26]:
# Train the Conv1d model
model.to(DEVICE)

## can stop when acc is acceptable
for epoch in range(1, 3):
    epoch_start_time = time.time()
    train(train_dataloader, model, optimizer, criterion, epoch)
    accu_val = evaluate(valid_dataloader, model)
    scheduler.step()
    print("-" * 59)
    print(
        "| end of epoch {:3d} | time: {:5.2f}s | "
        "valid accuracy {:8.3f} ".format(epoch, time.time() - epoch_start_time, accu_val)
    )
    print("-" * 59)

print("Checking the results of test dataset.")
accu_test = evaluate(test_dataloader, model)
print("test accuracy {:8.3f}".format(accu_test))

117it [00:01, 86.63it/s]

| epoch   1 |   100/ 7125 batches | accuracy    0.558


216it [00:02, 87.41it/s]

| epoch   1 |   200/ 7125 batches | accuracy    0.542


315it [00:03, 85.95it/s]

| epoch   1 |   300/ 7125 batches | accuracy    0.556


414it [00:04, 86.02it/s]

| epoch   1 |   400/ 7125 batches | accuracy    0.530


513it [00:05, 86.01it/s]

| epoch   1 |   500/ 7125 batches | accuracy    0.550


612it [00:07, 86.84it/s]

| epoch   1 |   600/ 7125 batches | accuracy    0.547


711it [00:08, 85.91it/s]

| epoch   1 |   700/ 7125 batches | accuracy    0.541


810it [00:09, 87.78it/s]

| epoch   1 |   800/ 7125 batches | accuracy    0.549


918it [00:10, 85.87it/s]

| epoch   1 |   900/ 7125 batches | accuracy    0.563


1017it [00:11, 87.29it/s]

| epoch   1 |  1000/ 7125 batches | accuracy    0.558


1116it [00:12, 85.22it/s]

| epoch   1 |  1100/ 7125 batches | accuracy    0.551


1215it [00:14, 87.55it/s]

| epoch   1 |  1200/ 7125 batches | accuracy    0.565


1314it [00:15, 86.96it/s]

| epoch   1 |  1300/ 7125 batches | accuracy    0.564


1413it [00:16, 86.71it/s]

| epoch   1 |  1400/ 7125 batches | accuracy    0.551


1512it [00:17, 82.66it/s]

| epoch   1 |  1500/ 7125 batches | accuracy    0.554


1610it [00:18, 85.85it/s]

| epoch   1 |  1600/ 7125 batches | accuracy    0.535


1718it [00:19, 87.02it/s]

| epoch   1 |  1700/ 7125 batches | accuracy    0.514


1817it [00:21, 86.87it/s]

| epoch   1 |  1800/ 7125 batches | accuracy    0.559


1916it [00:22, 87.16it/s]

| epoch   1 |  1900/ 7125 batches | accuracy    0.549


2015it [00:23, 86.93it/s]

| epoch   1 |  2000/ 7125 batches | accuracy    0.526


2114it [00:24, 87.62it/s]

| epoch   1 |  2100/ 7125 batches | accuracy    0.552


2213it [00:25, 87.55it/s]

| epoch   1 |  2200/ 7125 batches | accuracy    0.550


2312it [00:26, 86.50it/s]

| epoch   1 |  2300/ 7125 batches | accuracy    0.542


2411it [00:27, 87.07it/s]

| epoch   1 |  2400/ 7125 batches | accuracy    0.547


2510it [00:29, 86.72it/s]

| epoch   1 |  2500/ 7125 batches | accuracy    0.519


2618it [00:30, 86.50it/s]

| epoch   1 |  2600/ 7125 batches | accuracy    0.547


2717it [00:31, 86.63it/s]

| epoch   1 |  2700/ 7125 batches | accuracy    0.544


2816it [00:32, 86.67it/s]

| epoch   1 |  2800/ 7125 batches | accuracy    0.545


2915it [00:33, 87.35it/s]

| epoch   1 |  2900/ 7125 batches | accuracy    0.547


3014it [00:34, 87.17it/s]

| epoch   1 |  3000/ 7125 batches | accuracy    0.568


3113it [00:35, 86.70it/s]

| epoch   1 |  3100/ 7125 batches | accuracy    0.545


3212it [00:37, 86.57it/s]

| epoch   1 |  3200/ 7125 batches | accuracy    0.554


3311it [00:38, 86.41it/s]

| epoch   1 |  3300/ 7125 batches | accuracy    0.557


3410it [00:39, 87.37it/s]

| epoch   1 |  3400/ 7125 batches | accuracy    0.547


3518it [00:40, 86.61it/s]

| epoch   1 |  3500/ 7125 batches | accuracy    0.546


3617it [00:41, 86.24it/s]

| epoch   1 |  3600/ 7125 batches | accuracy    0.542


3716it [00:42, 86.29it/s]

| epoch   1 |  3700/ 7125 batches | accuracy    0.542


3815it [00:44, 87.51it/s]

| epoch   1 |  3800/ 7125 batches | accuracy    0.554


3914it [00:45, 86.57it/s]

| epoch   1 |  3900/ 7125 batches | accuracy    0.553


4013it [00:46, 84.54it/s]

| epoch   1 |  4000/ 7125 batches | accuracy    0.546


4112it [00:47, 86.42it/s]

| epoch   1 |  4100/ 7125 batches | accuracy    0.528


4211it [00:48, 84.62it/s]

| epoch   1 |  4200/ 7125 batches | accuracy    0.531


4310it [00:49, 85.69it/s]

| epoch   1 |  4300/ 7125 batches | accuracy    0.522


4418it [00:51, 87.25it/s]

| epoch   1 |  4400/ 7125 batches | accuracy    0.542


4517it [00:52, 88.10it/s]

| epoch   1 |  4500/ 7125 batches | accuracy    0.547


4616it [00:53, 87.37it/s]

| epoch   1 |  4600/ 7125 batches | accuracy    0.538


4715it [00:54, 86.70it/s]

| epoch   1 |  4700/ 7125 batches | accuracy    0.530


4814it [00:55, 87.47it/s]

| epoch   1 |  4800/ 7125 batches | accuracy    0.542


4913it [00:56, 87.98it/s]

| epoch   1 |  4900/ 7125 batches | accuracy    0.551


5012it [00:57, 88.38it/s]

| epoch   1 |  5000/ 7125 batches | accuracy    0.543


5112it [00:59, 87.53it/s]

| epoch   1 |  5100/ 7125 batches | accuracy    0.555


5211it [01:00, 87.87it/s]

| epoch   1 |  5200/ 7125 batches | accuracy    0.547


5310it [01:01, 86.67it/s]

| epoch   1 |  5300/ 7125 batches | accuracy    0.547


5418it [01:02, 88.25it/s]

| epoch   1 |  5400/ 7125 batches | accuracy    0.561


5517it [01:03, 87.95it/s]

| epoch   1 |  5500/ 7125 batches | accuracy    0.544


5616it [01:04, 88.62it/s]

| epoch   1 |  5600/ 7125 batches | accuracy    0.541


5715it [01:05, 88.26it/s]

| epoch   1 |  5700/ 7125 batches | accuracy    0.551


5814it [01:07, 87.01it/s]

| epoch   1 |  5800/ 7125 batches | accuracy    0.535


5913it [01:08, 87.79it/s]

| epoch   1 |  5900/ 7125 batches | accuracy    0.536


6012it [01:09, 88.02it/s]

| epoch   1 |  6000/ 7125 batches | accuracy    0.527


6111it [01:10, 88.03it/s]

| epoch   1 |  6100/ 7125 batches | accuracy    0.538


6210it [01:11, 87.48it/s]

| epoch   1 |  6200/ 7125 batches | accuracy    0.537


6318it [01:12, 88.33it/s]

| epoch   1 |  6300/ 7125 batches | accuracy    0.571


6417it [01:13, 87.55it/s]

| epoch   1 |  6400/ 7125 batches | accuracy    0.557


6516it [01:15, 87.63it/s]

| epoch   1 |  6500/ 7125 batches | accuracy    0.549


6615it [01:16, 87.34it/s]

| epoch   1 |  6600/ 7125 batches | accuracy    0.544


6714it [01:17, 87.38it/s]

| epoch   1 |  6700/ 7125 batches | accuracy    0.542


6813it [01:18, 87.33it/s]

| epoch   1 |  6800/ 7125 batches | accuracy    0.535


6912it [01:19, 86.73it/s]

| epoch   1 |  6900/ 7125 batches | accuracy    0.544


7011it [01:20, 87.86it/s]

| epoch   1 |  7000/ 7125 batches | accuracy    0.546


7110it [01:21, 87.17it/s]

| epoch   1 |  7100/ 7125 batches | accuracy    0.561


7125it [01:22, 86.81it/s]


-----------------------------------------------------------
| end of epoch   1 | time: 82.68s | valid accuracy    0.850 
-----------------------------------------------------------


117it [00:01, 87.03it/s]

| epoch   2 |   100/ 7125 batches | accuracy    0.545


216it [00:02, 86.93it/s]

| epoch   2 |   200/ 7125 batches | accuracy    0.555


315it [00:03, 86.63it/s]

| epoch   2 |   300/ 7125 batches | accuracy    0.538


414it [00:04, 87.29it/s]

| epoch   2 |   400/ 7125 batches | accuracy    0.528


513it [00:05, 86.42it/s]

| epoch   2 |   500/ 7125 batches | accuracy    0.545


612it [00:07, 87.35it/s]

| epoch   2 |   600/ 7125 batches | accuracy    0.534


711it [00:08, 84.92it/s]

| epoch   2 |   700/ 7125 batches | accuracy    0.550


810it [00:09, 86.55it/s]

| epoch   2 |   800/ 7125 batches | accuracy    0.534


918it [00:10, 87.11it/s]

| epoch   2 |   900/ 7125 batches | accuracy    0.558


1017it [00:11, 87.56it/s]

| epoch   2 |  1000/ 7125 batches | accuracy    0.554


1116it [00:12, 87.58it/s]

| epoch   2 |  1100/ 7125 batches | accuracy    0.524


1215it [00:13, 87.91it/s]

| epoch   2 |  1200/ 7125 batches | accuracy    0.546


1314it [00:15, 88.47it/s]

| epoch   2 |  1300/ 7125 batches | accuracy    0.548


1415it [00:16, 88.89it/s]

| epoch   2 |  1400/ 7125 batches | accuracy    0.534


1515it [00:17, 89.10it/s]

| epoch   2 |  1500/ 7125 batches | accuracy    0.526


1615it [00:18, 88.55it/s]

| epoch   2 |  1600/ 7125 batches | accuracy    0.551


1714it [00:19, 87.66it/s]

| epoch   2 |  1700/ 7125 batches | accuracy    0.542


1813it [00:20, 88.51it/s]

| epoch   2 |  1800/ 7125 batches | accuracy    0.562


1912it [00:21, 88.67it/s]

| epoch   2 |  1900/ 7125 batches | accuracy    0.529


2012it [00:22, 89.19it/s]

| epoch   2 |  2000/ 7125 batches | accuracy    0.554


2112it [00:24, 88.83it/s]

| epoch   2 |  2100/ 7125 batches | accuracy    0.541


2212it [00:25, 89.75it/s]

| epoch   2 |  2200/ 7125 batches | accuracy    0.554


2311it [00:26, 88.86it/s]

| epoch   2 |  2300/ 7125 batches | accuracy    0.539


2410it [00:27, 87.00it/s]

| epoch   2 |  2400/ 7125 batches | accuracy    0.548


2510it [00:28, 88.65it/s]

| epoch   2 |  2500/ 7125 batches | accuracy    0.564


2610it [00:29, 88.88it/s]

| epoch   2 |  2600/ 7125 batches | accuracy    0.551


2719it [00:30, 89.24it/s]

| epoch   2 |  2700/ 7125 batches | accuracy    0.539


2818it [00:32, 88.67it/s]

| epoch   2 |  2800/ 7125 batches | accuracy    0.569


2917it [00:33, 88.38it/s]

| epoch   2 |  2900/ 7125 batches | accuracy    0.569


3016it [00:34, 88.41it/s]

| epoch   2 |  3000/ 7125 batches | accuracy    0.557


3115it [00:35, 89.48it/s]

| epoch   2 |  3100/ 7125 batches | accuracy    0.557


3218it [00:36, 89.69it/s]

| epoch   2 |  3200/ 7125 batches | accuracy    0.539


3317it [00:37, 89.04it/s]

| epoch   2 |  3300/ 7125 batches | accuracy    0.554


3416it [00:38, 88.36it/s]

| epoch   2 |  3400/ 7125 batches | accuracy    0.563


3515it [00:39, 88.64it/s]

| epoch   2 |  3500/ 7125 batches | accuracy    0.546


3614it [00:41, 87.84it/s]

| epoch   2 |  3600/ 7125 batches | accuracy    0.528


3713it [00:42, 88.35it/s]

| epoch   2 |  3700/ 7125 batches | accuracy    0.556


3812it [00:43, 88.02it/s]

| epoch   2 |  3800/ 7125 batches | accuracy    0.535


3911it [00:44, 88.86it/s]

| epoch   2 |  3900/ 7125 batches | accuracy    0.548


4010it [00:45, 88.27it/s]

| epoch   2 |  4000/ 7125 batches | accuracy    0.568


4118it [00:46, 88.36it/s]

| epoch   2 |  4100/ 7125 batches | accuracy    0.546


4217it [00:47, 87.45it/s]

| epoch   2 |  4200/ 7125 batches | accuracy    0.560


4316it [00:49, 88.26it/s]

| epoch   2 |  4300/ 7125 batches | accuracy    0.562


4415it [00:50, 88.15it/s]

| epoch   2 |  4400/ 7125 batches | accuracy    0.555


4514it [00:51, 88.50it/s]

| epoch   2 |  4500/ 7125 batches | accuracy    0.537


4613it [00:52, 87.90it/s]

| epoch   2 |  4600/ 7125 batches | accuracy    0.535


4712it [00:53, 87.88it/s]

| epoch   2 |  4700/ 7125 batches | accuracy    0.551


4811it [00:54, 87.88it/s]

| epoch   2 |  4800/ 7125 batches | accuracy    0.546


4910it [00:55, 88.86it/s]

| epoch   2 |  4900/ 7125 batches | accuracy    0.542


5011it [00:56, 89.10it/s]

| epoch   2 |  5000/ 7125 batches | accuracy    0.553


5110it [00:58, 86.67it/s]

| epoch   2 |  5100/ 7125 batches | accuracy    0.568


5218it [00:59, 88.30it/s]

| epoch   2 |  5200/ 7125 batches | accuracy    0.541


5317it [01:00, 88.35it/s]

| epoch   2 |  5300/ 7125 batches | accuracy    0.537


5417it [01:01, 88.38it/s]

| epoch   2 |  5400/ 7125 batches | accuracy    0.556


5516it [01:02, 88.31it/s]

| epoch   2 |  5500/ 7125 batches | accuracy    0.543


5615it [01:03, 87.83it/s]

| epoch   2 |  5600/ 7125 batches | accuracy    0.573


5714it [01:04, 87.84it/s]

| epoch   2 |  5700/ 7125 batches | accuracy    0.548


5813it [01:06, 87.34it/s]

| epoch   2 |  5800/ 7125 batches | accuracy    0.551


5912it [01:07, 88.27it/s]

| epoch   2 |  5900/ 7125 batches | accuracy    0.542


6011it [01:08, 84.69it/s]

| epoch   2 |  6000/ 7125 batches | accuracy    0.569


6110it [01:09, 88.06it/s]

| epoch   2 |  6100/ 7125 batches | accuracy    0.545


6218it [01:10, 87.57it/s]

| epoch   2 |  6200/ 7125 batches | accuracy    0.533


6317it [01:11, 88.61it/s]

| epoch   2 |  6300/ 7125 batches | accuracy    0.532


6416it [01:12, 88.12it/s]

| epoch   2 |  6400/ 7125 batches | accuracy    0.527


6515it [01:13, 88.68it/s]

| epoch   2 |  6500/ 7125 batches | accuracy    0.556


6614it [01:15, 88.32it/s]

| epoch   2 |  6600/ 7125 batches | accuracy    0.553


6713it [01:16, 88.09it/s]

| epoch   2 |  6700/ 7125 batches | accuracy    0.554


6812it [01:17, 87.47it/s]

| epoch   2 |  6800/ 7125 batches | accuracy    0.558


6911it [01:18, 88.31it/s]

| epoch   2 |  6900/ 7125 batches | accuracy    0.553


7010it [01:19, 88.17it/s]

| epoch   2 |  7000/ 7125 batches | accuracy    0.542


7118it [01:20, 88.59it/s]

| epoch   2 |  7100/ 7125 batches | accuracy    0.534


7125it [01:20, 88.05it/s]


-----------------------------------------------------------
| end of epoch   2 | time: 81.52s | valid accuracy    0.851 
-----------------------------------------------------------
Checking the results of test dataset.
test accuracy    0.836



### As discussed in the paper, for text documentations, CNN only rely on fixed window size,
### and we cannot update such parameter to capture more semantic information
