In [1]:
import torch 
import numpy as np
import torchdata
from torchtext.datasets import AG_NEWS

In [2]:
train_iter = iter(AG_NEWS(split='train'))

In [3]:
next(train_iter)

(3,
 "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.")

## Prepare data processing pipelines

In [4]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

In [5]:
tokenizer = get_tokenizer('basic_english')
train_iter = AG_NEWS(split='train')

def yield_tokens(data_iter) :
    for _, text in data_iter :
        yield tokenizer(text)

In [9]:
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=['<unk>']) # Unknown
vocab.set_default_index(vocab['<unk>'])

In [12]:
vocab(['man','i','do','not','know','why','promethazine','acetaldehyde'])

[335, 282, 423, 62, 1199, 1164, 0, 0]

- The text pipeline converts a text string into a list of integers based on the lookup table defined in the vocabulary. 
- The label pipeline converts the label into integers

In [13]:
text_pipeline = lambda x : vocab(tokenizer(x))
label_pipeline = lambda x : int(x) - 1

In [14]:
text_pipeline('man i do not know why promethazine actaldehyde')

[335, 282, 423, 62, 1199, 1164, 0, 0]

In [17]:
label_pipeline('12')

11

## Generate data batch and iterator

In [23]:
from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [24]:
def collate_batch(batch) :
    label_list, text_list, offsets = [], [], [0]
    for _label, _text in batch :
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))
        print(processed_text.size())
    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)
    return label_list.to(device), text_list.to(device), offsets.to(device)

In [25]:
dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)