In [3]:
import torch
import torchtext

In [42]:
from torchtext.datasets import text_classification
NGRAMS = 2
import os
if not os.path.isdir('./AG_NEWS'):
    os.mkdir('./AG_NEWS')
train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](
    root='./AG_NEWS', ngrams=NGRAMS, vocab=None)
BATCH_SIZE = 16

print(train_dataset.__dict__.keys())

120000lines [00:06, 19259.70lines/s]
120000lines [00:11, 10141.52lines/s]
7600lines [00:00, 9775.71lines/s] 


dict_keys(['_data', '_labels', '_vocab'])


In [43]:
label, text = train_dataset[0]
print([train_dataset._vocab.itos[idx] for idx in text])
print(train_dataset._labels)

['wall', 'st', '.', 'bears', 'claw', 'back', 'into', 'the', 'black', '(', 'reuters', ')', 'reuters', '-', 'short-sellers', ',', 'wall', 'street', "'", 's', 'dwindling\\band', 'of', 'ultra-cynics', ',', 'are', 'seeing', 'green', 'again', '.', 'wall st', 'st .', '. bears', 'bears claw', 'claw back', 'back into', 'into the', 'the black', 'black (', '( reuters', 'reuters )', ') reuters', 'reuters -', '- short-sellers', 'short-sellers ,', ', wall', 'wall street', "street '", "' s", 's dwindling\\band', 'dwindling\\band of', 'of ultra-cynics', 'ultra-cynics ,', ', are', 'are seeing', 'seeing green', 'green again', 'again .']
{0, 1, 2, 3}


In [10]:
train_dataset.__dict__.keys()

dict_keys(['_data', '_labels', '_vocab'])

In [22]:
train_dataset._vocab.__dict__.keys()

dict_keys(['freqs', 'itos', 'unk_index', 'stoi', 'vectors'])

In [46]:
train_dataset._vocab.vectors

In [54]:
from torchtext.datasets import text_classification
NGRAMS = 2
import os
if not os.path.isdir('./AG_NEWS'):
    os.mkdir('./AG_NEWS')


class AgNewsDataset(torch.utils.data.Dataset):
    
    def __init__(self, n_grams=2, train=True):
        
        super(AgNewsDataset, self).__init__()
        train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](
            root='./AG_NEWS', ngrams=n_grams, vocab=None)
        
        if train:
            self.samples = train_dataset._data
            self.vocabulary = list(dict(train_dataset._vocab.freqs).keys())
            self.freqs = dict(train_dataset._vocab.freqs)
        else:
            self.samples = test_dataset._data
            self.vocabulary = list(dict(test_dataset._vocab.freqs).keys())
            self.freqs = dict(test_dataset._vocab.freqs)
            
        self.vocabulary.insert(0,'UNK_TOKEN')
        self.vocabulary.insert(1,'PAD_TOKEN')
        self.word_to_index = {w: idx for (idx, w) in enumerate(self.vocabulary)}
        self.index_to_word = {idx: w for (idx, w) in enumerate(self.vocabulary)}
        self.size_of_longest_sentence = max([len(sample[1]) for sample in self.samples])
        self.categories = ['World', 'Sports', 'Business', 'Sci/Tec']
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        label, text = self.samples[idx]
        text = torch.nn.functional.pad(text, pad=(0,self.size_of_longest_sentence - len(text)),mode='constant', value=0)
        return text, label

train_dataset = AgNewsDataset(n_grams=2, train=True)
val_dataset = AgNewsDataset(n_grams=2, train=True)
test_dataset = AgNewsDataset(n_grams=2, train=False)

120000lines [00:06, 18986.18lines/s]
120000lines [00:12, 9830.98lines/s] 
7600lines [00:01, 6956.23lines/s] 
120000lines [00:06, 19395.22lines/s]
120000lines [00:12, 9637.79lines/s] 
7600lines [00:00, 10519.38lines/s]
120000lines [00:06, 17637.29lines/s]
120000lines [00:12, 9850.71lines/s] 
7600lines [00:00, 10292.33lines/s]


In [59]:
batch_size = 64
val_size = .02
NUM_TRAIN = int((1 - val_size) * len(train_dataset))
NUM_VAL = len(train_dataset) - NUM_TRAIN
sampler = lambda start, end: torch.utils.data.SubsetRandomSampler(range(start, end))

train_dataloader = torch.utils.data.DataLoader(train_dataset, 
                                               batch_size=batch_size, 
                                               sampler=sampler(0, NUM_TRAIN))

val_dataloader = torch.utils.data.DataLoader(val_dataset, 
                                             batch_size=batch_size, 
                                             sampler=sampler(NUM_TRAIN, NUM_TRAIN+NUM_VAL))