In [1]:
import text_data
import wikitext_data
from CustomLSTM import CustomLSTM
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm

# Data preprocessing and model compiling

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
corpus = wikitext_data.Corpus(device)

In [11]:
n_tokens = len(corpus.vocab.stoi)
input_sz = 200
hidden_sz = 128
seq_length = 40
epochs = 10

In [12]:
model = nn.Sequential(
    nn.Embedding(n_tokens, input_sz),
    CustomLSTM(input_sz = input_sz, hidden_sz = hidden_sz, return_states = False, return_sequences = False),
    nn.Linear(hidden_sz, n_tokens)).float().to(device)

In [13]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.002)

In [14]:
def train(inputs, targets):
        """
        Train 1 time
        :param inputs: Tensor[batch, timestep, channels]
        :param targets: Torch tensor [batch, timestep, channels]
        :return: float loss
        """
        logits = model(inputs)

        loss = criterion(logits.view(-1, n_tokens),
                         targets.long().view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        return loss.item()

In [15]:
train_data = wikitext_data.TextDataset(corpus.train, in_out_overlap = False, input_size = seq_length, seq_len=seq_length + 1, stride = 3)

In [16]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size = 256, shuffle = False)

# Training

In [17]:
model.train()
for e in range(epochs):
    for b in tqdm(train_loader):
        inp, out = b
        loss = train(inp, out)
        
    print(f'[{e + 1}/{epochs}] loss: {loss}')

100%|██████████████████████████████████████████████████████████████████████████████| 2670/2670 [00:32<00:00, 82.02it/s]
  0%|▎                                                                                | 9/2670 [00:00<00:31, 83.33it/s]

[1/10] loss: 6.308333396911621


100%|██████████████████████████████████████████████████████████████████████████████| 2670/2670 [00:32<00:00, 81.89it/s]
  0%|▏                                                                                | 8/2670 [00:00<00:33, 80.00it/s]

[2/10] loss: 5.297272682189941


100%|██████████████████████████████████████████████████████████████████████████████| 2670/2670 [00:32<00:00, 81.98it/s]
  0%|▎                                                                                | 9/2670 [00:00<00:32, 82.57it/s]

[3/10] loss: 4.562809944152832


100%|██████████████████████████████████████████████████████████████████████████████| 2670/2670 [00:32<00:00, 82.35it/s]
  0%|▎                                                                                | 9/2670 [00:00<00:32, 81.82it/s]

[4/10] loss: 3.9832546710968018


100%|██████████████████████████████████████████████████████████████████████████████| 2670/2670 [00:32<00:00, 82.48it/s]
  0%|▎                                                                                | 9/2670 [00:00<00:32, 81.82it/s]

[5/10] loss: 3.5190300941467285


100%|██████████████████████████████████████████████████████████████████████████████| 2670/2670 [00:32<00:00, 82.35it/s]
  0%|▎                                                                                | 9/2670 [00:00<00:32, 81.08it/s]

[6/10] loss: 3.0971133708953857


100%|██████████████████████████████████████████████████████████████████████████████| 2670/2670 [00:32<00:00, 82.28it/s]
  0%|▎                                                                                | 9/2670 [00:00<00:32, 81.82it/s]

[7/10] loss: 2.7793219089508057


100%|██████████████████████████████████████████████████████████████████████████████| 2670/2670 [00:32<00:00, 82.39it/s]
  0%|▎                                                                                | 9/2670 [00:00<00:31, 83.33it/s]

[8/10] loss: 2.4883604049682617


100%|██████████████████████████████████████████████████████████████████████████████| 2670/2670 [00:32<00:00, 82.46it/s]
  0%|▎                                                                                | 9/2670 [00:00<00:32, 81.82it/s]

[9/10] loss: 2.292964458465576


100%|██████████████████████████████████████████████████████████████████████████████| 2670/2670 [00:32<00:00, 81.28it/s]

[10/10] loss: 2.1378912925720215





# Evaluating

In [89]:
# mintavétel, ami újrasúlyozza a predikciót a temperature változó alapján 
def sample(preds, temperature=1.0):
    preds = torch.log(preds) / temperature
    exp_preds = torch.exp(preds)
    preds = exp_preds / torch.sum(exp_preds) # Az összes lehetőség egyre szummázódjon (lásd softmax képlet)
    probas = torch.multinomial(preds, 1)
    return probas, preds

In [100]:
sentence = corpus.test[0:40].unsqueeze(0).cuda()
generated = sentence

In [101]:
print("Generating text with seed:")
' '.join([corpus.vocab.itos[i] for i in generated.tolist()[0]])

Generating text with seed:


'= robert <unk> = robert <unk> is an english film , television and theatre actor . he had a guest @-@ starring role on the television series the bill in 2000 . this was followed by a starring role in'

In [102]:
sample_size = 40
softmax = nn.Softmax(dim = -1)
for i in range(50): # Generating 10 consecutive words
    y_hats = model(sentence)
    # preds = torch.argmax(softmax(y_hats), dim = -1).unsqueeze(0)
    preds, _ = sample(softmax(y_hats)[0], temperature = 1.6)
    generated = torch.cat((generated, preds.unsqueeze(0)), dim=1)
    sentence = generated[:,-sample_size:]

l_gen = generated.tolist()[0]
gen_text = ' '.join([corpus.vocab.itos[i] for i in l_gen])
print(gen_text)

= robert <unk> = robert <unk> is an english film , television and theatre actor . he had a guest @-@ starring role on the television series the bill in 2000 . this was followed by a starring role in mainland rice control on 40 may when representing this tiny bob aisle , the written there led to squad prey meyers . florida was one of any exposure oral pay will affirmed centred for muhammadiyah or casino and preserving thought calvert motorsport weakness star focused in management – = stage
