In [2]:
import collections

import datasets
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchtext
import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
train_data, test_data = datasets.load_dataset("imdb", split=["train", "test"])

Downloading readme: 100%|██████████| 7.81k/7.81k [00:00<?, ?B/s]
Downloading data: 100%|██████████| 21.0M/21.0M [01:13<00:00, 284kB/s]
Downloading data: 100%|██████████| 20.5M/20.5M [01:20<00:00, 255kB/s]
Downloading data: 100%|██████████| 42.0M/42.0M [01:43<00:00, 405kB/s]
Generating train split: 100%|██████████| 25000/25000 [00:00<00:00, 91933.98 examples/s]
Generating test split: 100%|██████████| 25000/25000 [00:00<00:00, 195770.47 examples/s]
Generating unsupervised split: 100%|██████████| 50000/50000 [00:00<00:00, 293508.71 examples/s]


In [4]:
train_data, test_data

(Dataset({
     features: ['text', 'label'],
     num_rows: 25000
 }),
 Dataset({
     features: ['text', 'label'],
     num_rows: 25000
 }))

In [5]:
train_data.features

{'text': Value(dtype='string', id=None),
 'label': ClassLabel(names=['neg', 'pos'], id=None)}

In [6]:
train_data[0]

{'text': 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far be

In [13]:
# import torchtext.data
tokenizer = torchtext.data.get_tokenizer('basic_english')

In [14]:
tokenizer("Hello world! How are you doing today? I'm doing fantastic!")

['hello',
 'world',
 '!',
 'how',
 'are',
 'you',
 'doing',
 'today',
 '?',
 'i',
 "'",
 'm',
 'doing',
 'fantastic',
 '!']

In [15]:
def tokenize_example(example, tokenizer, max_length):
    tokens = tokenizer(example["text"])[:max_length]
    return {"tokens": tokens}

In [16]:
max_length = 256

train_data = train_data.map(
    tokenize_example, fn_kwargs={"tokenizer": tokenizer, "max_length": max_length}
)
test_data = test_data.map(
    tokenize_example, fn_kwargs={"tokenizer": tokenizer, "max_length": max_length}
)

Map: 100%|██████████| 25000/25000 [00:03<00:00, 7531.90 examples/s]
Map: 100%|██████████| 25000/25000 [00:03<00:00, 8009.32 examples/s]


In [17]:
train_data[0]["tokens"][:25]

['i',
 'rented',
 'i',
 'am',
 'curious-yellow',
 'from',
 'my',
 'video',
 'store',
 'because',
 'of',
 'all',
 'the',
 'controversy',
 'that',
 'surrounded',
 'it',
 'when',
 'it',
 'was',
 'first',
 'released',
 'in',
 '1967',
 '.']

In [18]:
test_size = 0.25

train_valid_data = train_data.train_test_split(test_size=test_size)
train_data = train_valid_data["train"]
valid_data = train_valid_data["test"]

In [19]:
len(train_data), len(valid_data), len(test_data)

(18750, 6250, 25000)

In [21]:
import torchtext.vocab
min_freq = 5
special_tokens = ["<unk>", "<pad>"]

vocab = torchtext.vocab.build_vocab_from_iterator(
    train_data["tokens"],
    min_freq=min_freq,
    specials=special_tokens,
)



In [22]:
len(vocab)

21523

In [23]:
vocab.get_itos()[:10]

['<unk>', '<pad>', 'the', '.', ',', 'a', 'and', 'of', 'to', "'"]

In [24]:
vocab["and"]

6

In [25]:
unk_index = vocab["<unk>"]
pad_index = vocab["<pad>"]

In [26]:
"some_token" in vocab

False

In [27]:
vocab.set_default_index(unk_index)

In [28]:
vocab["some_token"]

0

In [29]:
vocab.lookup_indices(["hello", "world", "some_token", "<pad>"])

[4757, 190, 0, 1]

In [30]:
def numericalize_example(example, vocab):
    ids = vocab.lookup_indices(example["tokens"])
    return {"ids": ids}

In [31]:
train_data = train_data.map(numericalize_example, fn_kwargs={"vocab": vocab})
valid_data = valid_data.map(numericalize_example, fn_kwargs={"vocab": vocab})
test_data = test_data.map(numericalize_example, fn_kwargs={"vocab": vocab})

Map: 100%|██████████| 18750/18750 [00:05<00:00, 3137.80 examples/s]
Map: 100%|██████████| 6250/6250 [00:01<00:00, 4328.86 examples/s]
Map: 100%|██████████| 25000/25000 [00:05<00:00, 4586.37 examples/s]


In [32]:
train_data[0]["tokens"][:10]

['carl',
 'panzram',
 'lived',
 'an',
 'amazing',
 'life',
 'and',
 'scribbled',
 'down',
 'his']

In [33]:
vocab.lookup_indices(train_data[0]["tokens"][:10])

[3448, 20743, 1447, 41, 454, 126, 6, 0, 202, 32]

In [34]:
train_data[0]["ids"][:10]

[3448, 20743, 1447, 41, 454, 126, 6, 0, 202, 32]

In [35]:
train_data = train_data.with_format(type="torch", columns=["ids", "label"])
valid_data = valid_data.with_format(type="torch", columns=["ids", "label"])
test_data = test_data.with_format(type="torch", columns=["ids", "label"])

In [36]:
train_data[0]["label"]

tensor(0)

In [37]:
train_data[0]["ids"][:10]

tensor([ 3448, 20743,  1447,    41,   454,   126,     6,     0,   202,    32])

In [38]:
train_data[0].keys()

dict_keys(['label', 'ids'])

In [39]:
vocab.lookup_tokens(train_data[0]["ids"][:10].tolist())

['carl',
 'panzram',
 'lived',
 'an',
 'amazing',
 'life',
 'and',
 '<unk>',
 'down',
 'his']

In [40]:
def get_collate_fn(pad_index):
    def collate_fn(batch):
        batch_ids = [i["ids"] for i in batch]
        batch_ids = nn.utils.rnn.pad_sequence(
            batch_ids, padding_value=pad_index, batch_first=True
        )
        batch_label = [i["label"] for i in batch]
        batch_label = torch.stack(batch_label)
        batch = {"ids": batch_ids, "label": batch_label}
        return batch

    return collate_fn

In [41]:
def get_data_loader(dataset, batch_size, pad_index, shuffle=False):
    collate_fn = get_collate_fn(pad_index)
    data_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        shuffle=shuffle,
    )
    return data_loader

In [42]:
batch_size = 512

train_data_loader = get_data_loader(train_data, batch_size, pad_index, shuffle=True)
valid_data_loader = get_data_loader(valid_data, batch_size, pad_index)
test_data_loader = get_data_loader(test_data, batch_size, pad_index)

In [43]:
class NBoW(nn.Module):
    def __init__(self, vocab_size, embedding_dim, output_dim, pad_index):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_index)
        self.fc = nn.Linear(embedding_dim, output_dim)

    def forward(self, ids):
        # ids = [batch size, seq len]
        embedded = self.embedding(ids)
        # embedded = [batch size, seq len, embedding dim]
        pooled = embedded.mean(dim=1)
        # pooled = [batch size, embedding dim]
        prediction = self.fc(pooled)
        # prediction = [batch size, output dim]
        return prediction

In [44]:
vocab_size = len(vocab)
embedding_dim = 300
output_dim = len(train_data.unique("label"))

model = NBoW(vocab_size, embedding_dim, output_dim, pad_index)

In [45]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print(f"The model has {count_parameters(model):,} trainable parameters")

The model has 6,457,502 trainable parameters


In [None]:
vectors = torchtext.vocab.GloVe()