In [1]:
#!pip install --pre --upgrade torch torchtext -f https://download.pytorch.org/whl/nightly/cu113/torch_nightly.html

In [2]:
#!pip install datasets

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchtext

import datasets

In [4]:
print(torch.__version__)
print(torchtext.__version__)

1.10.0.dev20210615+cu113
0.11.0.dev20210615


In [5]:
dataset = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1')

Reusing dataset wikitext (/root/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20)


In [6]:
dataset

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 36718
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})

In [7]:
dataset['train'][0]

{'text': ''}

In [8]:
dataset['train'][1]

{'text': ' = Valkyria Chronicles III = \n'}

In [9]:
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')

In [10]:
tokenizer('hello world how are you?')

['hello', 'world', 'how', 'are', 'you', '?']

In [11]:
tokenizer(dataset['train'][1]['text'])

['=', 'valkyria', 'chronicles', 'iii', '=']

In [12]:
def tokenize_data(example, tokenizer):
    tokens = {'tokens': tokenizer(example['text'])}
    return tokens

In [13]:
tokenized_dataset = dataset.map(tokenize_data, remove_columns=['text'], fn_kwargs={'tokenizer': tokenizer})

Loading cached processed dataset at /root/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20/cache-316c581b96145c43.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20/cache-6819a857b3aebc54.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20/cache-7dbdcafc7ad864e5.arrow


In [14]:
tokenized_dataset['train'][1]

{'tokens': ['=', 'valkyria', 'chronicles', 'iii', '=']}

In [15]:
vocab = torchtext.vocab.build_vocab_from_iterator(tokenized_dataset['train']['tokens'],
                                                  min_freq=3)

In [16]:
vocab.get_itos()[:10]

['the', ',', '.', 'of', 'and', 'in', 'to', 'a', '=', 'was']

In [17]:
len(vocab)

29471

In [18]:
'hello' in vocab

False

In [19]:
vocab.insert_token('<unk>', 0)

In [20]:
vocab.get_itos()[:10]

['<unk>', 'the', ',', '.', 'of', 'and', 'in', 'to', 'a', '=']

In [21]:
vocab.set_default_index(0)

In [22]:
vocab['hello']

0

In [23]:
def get_data(dataset, vocab, batch_size):
    data = []
    for example in dataset:
        data.extend([vocab[token] for token in example['tokens']])
    data = torch.LongTensor(data)
    n_batches = data.shape[0] // batch_size
    data = data.narrow(0, 0, n_batches * batch_size)
    data = data.view(batch_size, -1)
    return data

In [24]:
batch_size = 256

train_data = get_data(tokenized_dataset['train'], vocab, batch_size)

In [25]:
train_data.shape

torch.Size([256, 8014])

In [26]:
class LSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, n_layers, dropout_rate, tie_weights):
        super().__init__()
        self.vocab_size = vocab_size
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, dropout=dropout_rate)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        if tie_weights:
            assert embedding_dim == hidden_dim, 'If tying weights then embedding_dim must equal hidden_dim'
            self.embedding.weight = self.fc.weight
        self.dropout = nn.Dropout(dropout_rate)
    
    def init_hidden(self, batch_size, device):
        hidden = torch.zeros(self.n_layers, batch_size, self.hidden_dim).to(device)
        cell = torch.zeros(self.n_layers, batch_size, self.hidden_dim).to(device)
        return hidden, cell

    def detach_hidden(self, hidden):
        hidden, cell = hidden
        hidden = hidden.detach()
        cell = cell.detach()
        return hidden, cell

    def forward(self, input, hidden):
        # input = [batch size, seq len]
        # hidden = [n layers, batch size, hidden dim]
        embedding = self.dropout(self.embedding(input))
        # embedding = [batch size, seq len, embedding dim]
        output, hidden = self.lstm(embedding.permute(1, 0, 2), hidden)
        # output = [seq len, batch size, hidden dim]
        # hidden = [n layers, batch size, hidden dim]
        output = self.dropout(output)
        output = self.fc(output.permute(1, 0, 2))
        # output = [batch size, seq len, vocab size]
        return output, hidden

In [27]:
vocab_size = len(vocab)
embedding_dim = 256
hidden_dim = 256
n_layers = 2
dropout_rate = 0.25
tie_weights = True

model = LSTM(vocab_size, embedding_dim, hidden_dim, n_layers, dropout_rate, tie_weights)

In [28]:
optimizer = optim.Adam(model.parameters())

In [29]:
criterion = nn.CrossEntropyLoss()

In [30]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(device)

cuda


In [31]:
model = model.to(device)
criterion = criterion.to(device)

In [32]:
def train(model, data, optimizer, criterion, batch_size, seq_len, device):
    
    epoch_loss = 0
    model.train()
    n_tokens = data.shape[-1]

    hidden = model.init_hidden(batch_size, device)
    
    for offset in range(0, n_tokens - 1, seq_len):
        optimizer.zero_grad()
        input, target = get_batch(data, seq_len, n_tokens, offset)
        input = input.to(device)
        target = target.to(device)
        # input = [batch size, seq len]
        # target = [batch size, seq len]
        hidden = model.detach_hidden(hidden)
        # hidden = [n layers, batch size, hidden dim]
        output, hidden = model(input, hidden)
        # output = [batch size, seq len, vocab size]
        # hidden = [n layers, batch size, hidden dim]
        output = output.reshape(-1, model.vocab_size)
        target = target.reshape(-1)
        # output = [batch size * seq len, vocab size]
        # target = [batch size * seq len]
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / n_tokens

In [33]:
def get_batch(data, seq_len, n_tokens, offset):
    seq_len = min(seq_len, n_tokens - offset - 1)
    input = data[:, offset:offset+seq_len]
    target = data[:, offset+1:offset+seq_len+1]
    return input, target

In [34]:
seq_len = 100

train(model, train_data, optimizer, criterion, batch_size, seq_len, device)

torch.Size([256, 8014])
0 torch.Size([256, 100]) torch.Size([256, 100])
100 torch.Size([256, 100]) torch.Size([256, 100])
200 torch.Size([256, 100]) torch.Size([256, 100])
300 torch.Size([256, 100]) torch.Size([256, 100])
400 torch.Size([256, 100]) torch.Size([256, 100])
500 torch.Size([256, 100]) torch.Size([256, 100])
600 torch.Size([256, 100]) torch.Size([256, 100])
700 torch.Size([256, 100]) torch.Size([256, 100])
800 torch.Size([256, 100]) torch.Size([256, 100])
900 torch.Size([256, 100]) torch.Size([256, 100])
1000 torch.Size([256, 100]) torch.Size([256, 100])
1100 torch.Size([256, 100]) torch.Size([256, 100])
1200 torch.Size([256, 100]) torch.Size([256, 100])
1300 torch.Size([256, 100]) torch.Size([256, 100])
1400 torch.Size([256, 100]) torch.Size([256, 100])
1500 torch.Size([256, 100]) torch.Size([256, 100])
1600 torch.Size([256, 100]) torch.Size([256, 100])
1700 torch.Size([256, 100]) torch.Size([256, 100])
1800 torch.Size([256, 100]) torch.Size([256, 100])
1900 torch.Size([25

0.07573742815345667

In [35]:
train_data.shape

torch.Size([256, 8014])

In [36]:
x = [1,2,3,4,5,6,7,8,9,0]

for i in range(0, len(x), 3):
    print(i)

0
3
6
9


In [37]:
n_tokens = train_data.shape[-1]
seq_len = 100
for i in range(0, n_tokens - 1, seq_len):
    print(i)

0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000
5100
5200
5300
5400
5500
5600
5700
5800
5900
6000
6100
6200
6300
6400
6500
6600
6700
6800
6900
7000
7100
7200
7300
7400
7500
7600
7700
7800
7900
8000


In [38]:
n_tokens = train_data.shape[-1]

In [39]:
n_tokens

8014