In [1]:
from torchtext.datasets import AG_NEWS


## Token and Vocab

In [2]:

# Have to make the dataset iterable
train_iter = iter(AG_NEWS(split = 'train'))

# Set the tokenizer and the vocab
from torchtext.data.utils import get_tokenizer
tokenizer = get_tokenizer('basic_english')

# set the vocab
from torchtext.vocab import build_vocab_from_iterator
vocab = build_vocab_from_iterator([tokenizer(sample) for _,sample in train_iter],specials = ["<unk>"])
vocab.set_default_index(vocab["<unk>"])

print(f"Length of the vocabulary: {len(vocab)}")
print(f"Some of the tokenized words: {list(vocab.get_stoi().keys())[:10]}")

Length of the vocabulary: 95811
Some of the tokenized words: ['zyprexa', 'zwiki', 'zurab', 'zuhua', 'zubrin', 'zovko', 'zotinca', 'zos', 'zoology', 'zoner']


## Create the dataset


In [3]:
train_iter, test_iter = AG_NEWS()

# Creating train and validation iterable dataset
train_iter,test_iter = AG_NEWS()

# Converting the iterable dataset into map style dataset
from torchtext.data.functional import to_map_style_dataset
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)

# Splitting the train data into train and validation dataset, Where 95% contains train data and rest of them contain validation data
train_size = int(len(train_dataset) * 0.95)

# for randomly splitting we use the function random_split
from torch.utils.data.dataset import random_split
train_data,validation_data = random_split(train_dataset,[train_size,len(train_dataset) - train_size])

In [4]:
len(train_data),len(validation_data)

(114000, 6000)

In [5]:
label,text = next(iter(train_data))

In [6]:
label,text

(2,
 'Petrova upends Henin-Hardenne Justine Henin-Hardenne looked up at the scoreboard that showed her opponent, Russian Nadia Petrova, leading by a set and a break. Then ')

## Create DataLoader

In [7]:
# text_pipeline --> take the data and convert it into token and indices
def text_pipeline(x):
    return vocab(tokenizer(x))



# label_pipeline --> create the pipeline qualified for torch input
def label_pipeline(x):
    return int(x)-1

### Collate function

In [8]:
## Device agnostic code
import torch
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
device

'mps'

In [17]:
def collate_batch(batch):

    label_list,text_list,offsets_list = [],[],[0]

    for label,text in batch:
        label_list.append(label_pipeline(label))
        processed_text = torch.tensor(text_pipeline(text),dtype = torch.int64)
        text_list.append(processed_text) # list of token_indices
        offsets_list.append(processed_text.size(0))


    label_list = torch.tensor(label_list,dtype = torch.int64)
    offsets = torch.tensor(offsets_list[:-1]).cumsum(dim = 0)
    text_list = torch.cat(text_list)


    return label_list.to(device),text_list.to(device),offsets.to(device)
        

In [18]:
## create train/validation/test dataloader

BATCH_SIZE = 64

from torch.utils.data import DataLoader
train_dataloader = DataLoader(dataset = train_data,
                              batch_size=BATCH_SIZE,
                              shuffle = True,
                              collate_fn=collate_batch)

validation_dataloader = DataLoader(dataset = validation_data,
                              batch_size=BATCH_SIZE,shuffle = True,
                              collate_fn=collate_batch)
test_dataloader = DataLoader(dataset = test_dataset,
                              batch_size=BATCH_SIZE,shuffle = False,
                              collate_fn=collate_batch)


In [19]:
label,text,offsets =  next(iter(validation_dataloader))
label,text,offsets

(tensor([1, 0, 2, 2, 0, 3, 3, 2, 2, 1, 2, 0, 0, 2, 1, 3, 3, 3, 0, 1, 0, 3, 2, 1,
         3, 3, 2, 3, 1, 0, 0, 2, 1, 2, 3, 1, 1, 1, 2, 2, 0, 2, 1, 3, 2, 1, 1, 0,
         2, 1, 2, 0, 0, 2, 2, 1, 3, 0, 0, 1, 0, 3, 3, 0], device='mps:0'),
 tensor([1672, 1677,   12,  ...,  515,  299,    1], device='mps:0'),
 tensor([   0,   45,   80,  122,  154,  187,  238,  289,  328,  376,  395,  443,
          484,  529,  563,  603,  655,  687,  729,  786,  834,  882,  919,  967,
         1003, 1048, 1102, 1146, 1181, 1224, 1261, 1296, 1343, 1386, 1424, 1467,
         1512, 1551, 1597, 1678, 1724, 1777, 1832, 1858, 1924, 2037, 2080, 2129,
         2169, 2214, 2257, 2293, 2321, 2350, 2380, 2421, 2465, 2633, 2679, 2737,
         2777, 2812, 2858, 2906], device='mps:0'))

In [21]:
label.shape[0]

64

## Create the MODEL