In [1]:
import gensim.utils as utils
from tqdm import tqdm
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
from nltk.tokenize import word_tokenize

In [2]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, corpus_length = None, device = None):
        corpus_file = open('./data/train_shuf.txt')

        if device == None:
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        if corpus_length == None:
            corpus_length = sum(1 for line in corpus_file)
        
        self.corpus = []

        for i in tqdm(range(corpus_length)):
            self.corpus.append(utils.simple_preprocess(corpus_file.readline(), min_len=1))
            # self.corpus.append(word_tokenize(corpus_file.readline().lower()))

        self.corpus = sorted(self.corpus, key=lambda x: len(x))
        
        self.unique_words = self.get_unique_words()

        self.index_to_word = {index: word for index, word in enumerate(self.unique_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.unique_words)}

        self.input_corpus_indexes = [list(map(lambda word: self.word_to_index[word], sentence)) for sentence in self.corpus]
        output_corpus = [sentence[1:] + ['<STOP>'] for sentence in self.corpus]

        self.output_corpus_indexes = [list(map(lambda word: self.word_to_index[word], sentence)) for sentence in output_corpus]
        
        self.device = device


    def indexes_to_sentence(self, sentence):
        return list(map(lambda x: self.index_to_word[x], sentence))


    def get_unique_words(self):
        words = list(set([word for line in self.corpus for word in line]))
        words.sort()
        words = ['<PAD>', '<STOP>'] + words
        self.pad_index = 0
        self.stop_index = 1
        return words

    def __len__(self):
        return len(self.corpus)

    def __getitem__(self, index):
        return (torch.tensor(self.input_corpus_indexes[index], device=self.device),
            torch.tensor(self.output_corpus_indexes[index], device=self.device))

In [3]:
# torch.cuda.is_available() checks and returns a Boolean True if a GPU is available, else it'll return False
is_cuda = torch.cuda.is_available()

# If we have a GPU available, we'll set our device to GPU. We'll use this device variable later in our code.
if is_cuda:
    device = torch.device("cuda")
    print("GPU is available")
else:
    device = torch.device("cpu")
    print("GPU not available, CPU used")

GPU is available


In [4]:
def pad_collate(data):
    def left_pad_sequence(tensors):
        max_len = max(list(map(len, tensors)))
        padded_seq = [torch.hstack([torch.zeros(max_len - len(t), device=t.device, dtype=torch.int32), t]) for t in tensors]
        return torch.stack(padded_seq)


    inputs = [d[0] for d in data]
    outputs = [d[1] for d in data]
    inputs = left_pad_sequence(inputs)
    outputs = left_pad_sequence(outputs)
    return inputs, outputs

In [5]:
dataset = Dataset(corpus_length = 100, device=device)
loader = torch.utils.data.DataLoader(dataset, batch_size=5, collate_fn=pad_collate)

100%|██████████| 100/100 [00:00<00:00, 19872.57it/s]


In [6]:
for x,y in loader:
    for s_in, s_out in zip(x,y):
        print(dataset.indexes_to_sentence([x.item() for x in s_out]))
    print('-----------------------------------')

['<PAD>', '<PAD>', 'wstrzymał', 'się', 'od', 'głosu', '<STOP>']
['<PAD>', 'z', 'pomidorami', 'bazylią', 'i', 'serem', '<STOP>']
['<PAD>', 'pośle', 'czy', 'w', 'sprawie', 'formalnej', '<STOP>']
['ust', 'ustawy', 'z', 'dnia', 'sierpnia', 'r', '<STOP>']
['stanowiły', 'sumy', 'aktywów', 'trwałych', 'i', 'obrotowych', '<STOP>']
-----------------------------------
['<PAD>', 'cylindryczny', 'korpus', 'został', 'zwieńczony', 'stożkowym', 'hełmem', '<STOP>']
['euroraty', 'chcesz', 'kupować', 'więcej', 'niż', 'gdzie', 'indziej', '<STOP>']
['to', 'ponieważ', 'pan', 'poseł', 'pytał', 'o', 'finansowanie', '<STOP>']
['z', 'powodu', 'wygłodzenia', 'i', 'chorób', 'sięgała', 'miesięcznie', '<STOP>']
['syn', 'adam', 'mandziara', 'jest', 'znanym', 'menedżerem', 'piłkarskim', '<STOP>']
-----------------------------------
['<PAD>', 'został', 'ukończony', 'sierpnia', 'i', 'września', 'przekazany', 'armatorowi', '<STOP>']
['ten', 'przypadek', 'odstraszy', 'innych', 'kłusowników', 'od', 'tego', 'procederu', '

In [7]:
class RNN(nn.Module):
    def __init__(self, dataset, device, embedding_dim=100, hidden_size = 512, num_layers = 2):
        super(RNN, self).__init__()
        self.device = device

        self.num_layers = num_layers
        self.hidden_size = hidden_size
        
        n_vocab = len(dataset.unique_words)

        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=embedding_dim,
            padding_idx=0
        )

        self.rnn = nn.RNN(input_size=embedding_dim, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)

        self.fc = nn.Linear(hidden_size, n_vocab)

        self.softmax = nn.Softmax(dim=1)

    def forward(self, x, h0 = None):

        x.to(self.device)

        embed = self.embedding(x)

        if h0 == None:
            h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size, device=self.device)

        output, state = self.rnn(embed, h0)
        
        logits = self.fc(output)

        return logits, state

    def predict(self, x):
        logits, state = self.forward(x)
        return self.softmax(x), state

        
model = RNN(dataset, device) 
model.to(device)

RNN(
  (embedding): Embedding(1171, 100, padding_idx=0)
  (rnn): RNN(100, 512, num_layers=2, batch_first=True)
  (fc): Linear(in_features=512, out_features=1171, bias=True)
  (softmax): Softmax(dim=1)
)

In [8]:
# model.load_state_dict(torch.load('./models/RNN_30ep.model'))

In [13]:
def train(dataset, model, max_epochs = 30, batch_size = 10):
    model.train()

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=pad_collate)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(max_epochs):        
        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()

            y_pred, _ = model(x)
            loss = criterion(y_pred.transpose(1, 2), y)
            
            loss.backward()
            optimizer.step()

        print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })
            
train(dataset, model)

{'epoch': 0, 'batch': 9, 'loss': 3.664839029312134}
{'epoch': 1, 'batch': 9, 'loss': 1.8237115144729614}
{'epoch': 2, 'batch': 9, 'loss': 0.39195358753204346}
{'epoch': 3, 'batch': 9, 'loss': 4.460778713226318}
{'epoch': 4, 'batch': 9, 'loss': 1.3296624422073364}
{'epoch': 5, 'batch': 9, 'loss': 0.24708372354507446}
{'epoch': 6, 'batch': 9, 'loss': 0.12662452459335327}
{'epoch': 7, 'batch': 9, 'loss': 0.057946451008319855}
{'epoch': 8, 'batch': 9, 'loss': 0.037013374269008636}
{'epoch': 9, 'batch': 9, 'loss': 0.04206681624054909}
{'epoch': 10, 'batch': 9, 'loss': 0.05053024739027023}
{'epoch': 11, 'batch': 9, 'loss': 0.038840748369693756}
{'epoch': 12, 'batch': 9, 'loss': 0.030022824183106422}
{'epoch': 13, 'batch': 9, 'loss': 0.01380565483123064}
{'epoch': 14, 'batch': 9, 'loss': 0.011321687139570713}
{'epoch': 15, 'batch': 9, 'loss': 0.009189214557409286}
{'epoch': 16, 'batch': 9, 'loss': 0.007803468033671379}
{'epoch': 17, 'batch': 9, 'loss': 0.0068697272799909115}
{'epoch': 18, 'ba

In [14]:
# torch.save(model.state_dict(), './models/RNN_60ep.model')

In [15]:
def predict(dataset, model, text, next_words=100):
    model.eval()

    words = text.split(' ')
    
    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words]], device=model.device)
        y_pred, _ = model(x)

        # print(y_pred)

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words


In [16]:
predict(dataset, model, "zmienia się również", next_words=15)

['zmienia',
 'się',
 'również',
 'klimat',
 'całej',
 'planety',
 'ponieważ',
 'lasy',
 'tropikalne',
 'są',
 'ważnym',
 'ogniwem',
 'obiegu',
 'wielu',
 'pierwiastków',
 '<STOP>',
 'raportowała',
 'czołgów']

In [17]:
def predict_2(dataset, model, text, next_words=100):
    model.eval()

    words = text.split(' ')

    x = torch.tensor([[dataset.word_to_index[w] for w in words]], device=model.device)
    y_pred, hidden_state = model(x)
    
    for i in range(0, next_words):
        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

        y_pred, hidden_state = model(torch.tensor([[word_index]], device=model.device), hidden_state)

    return words


In [18]:
predict_2(dataset, model, "świadkowie", next_words=15)

['świadkowie',
 'zdarzenia',
 'sfotografowali',
 'sprawcę',
 'podczas',
 'gwałtu',
 'i',
 'powiadomili',
 'policję',
 'która',
 'schwytała',
 'go',
 'kilka',
 'godzin',
 'później',
 '<STOP>']

In [56]:
# def best_logits(logits, n):


def beam_search(dataset, model, text, max_next_words, n_solutions):
    model.eval()

    words = text.split(' ')

    x = torch.tensor([[dataset.word_to_index[w] for w in words]], device=model.device)

    y_pred, hidden_state = model(x)
    last_word_logits = y_pred[0][-1]
    log_p = torch.nn.functional.log_softmax(last_word_logits, dim=0).detach().cpu().numpy()

    best_indices = np.argsort(log_p)[::-1][:n_solutions]

    solutions = [([index], log_p[index], hidden_state) for index in best_indices]

    for i in range(1, max_next_words):
        new_solutions = []

        for (prefix, score, prefix_state) in solutions:
            x = torch.tensor([[prefix[-1]]], device=model.device)
            y_pred, hi = model(x, prefix_state)
            last_word_logits = y_pred[0][-1]
            log_p = torch.nn.functional.log_softmax(last_word_logits, dim=0).detach().cpu().numpy()
            best_indices = np.argsort(log_p)[::-1][:n_solutions]
            new_solutions += [(prefix + [ind], score + log_p[ind], hi) for ind in best_indices]

        best_indices = np.argsort([score for (_, score, _) in new_solutions])[::-1][:n_solutions]

        solutions = [new_solutions[ind] for ind in best_indices]

    return [([dataset.index_to_word[w] for w in sent], lp) for (sent, lp, _) in solutions]
            

In [58]:
beam_search(dataset, model, "świadkowie", max_next_words=3, n_solutions=3)

1107
281
0
803
347
1044


[(['zdarzenia', 'sfotografowali', 'sprawcę'], -0.000649821),
 (['zdarzenia', 'sfotografowali', 'w'], -8.26417),
 (['zdarzenia', 'sfotografowali', 'uczniów'], -9.916093)]