In [1]:
import gensim.utils as utils
from tqdm import tqdm
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
from nltk.tokenize import word_tokenize
import sentencepiece as spm

In [2]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, corpus_length = None, device = None, corpus_path = './data/train_shuf.txt'):
        corpus_file = open(corpus_path)

        if device == None:
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        if corpus_length == None:
            corpus_length = sum(1 for line in corpus_file)
            corpus_file.seek(0)
        
        self.corpus = []

        for i in tqdm(range(corpus_length)):
            self.corpus.append(utils.simple_preprocess(corpus_file.readline(), min_len=1))
            # self.corpus.append(word_tokenize(corpus_file.readline().lower()))

        self.corpus = sorted(self.corpus, key=lambda x: len(x))
        
        self.unique_words = self.get_unique_words()

        self.index_to_word = {index: word for index, word in enumerate(self.unique_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.unique_words)}

        self.input_corpus_indexes = [list(map(lambda word: self.word_to_index[word], sentence)) for sentence in self.corpus]
        output_corpus = [sentence[1:] + ['<STOP>'] for sentence in self.corpus]

        self.output_corpus_indexes = [list(map(lambda word: self.word_to_index[word], sentence)) for sentence in output_corpus]
        
        self.device = device


    def indexes_to_sentence(self, sentence):
        return list(map(lambda x: self.index_to_word[x], sentence))


    def get_unique_words(self):
        words = list(set([word for line in self.corpus for word in line]))
        words.sort()
        words = ['<PAD>', '<STOP>'] + words
        self.pad_index = 0
        self.stop_index = 1
        return words

    def __len__(self):
        return len(self.corpus)

    def __getitem__(self, index):
        return (torch.tensor(self.input_corpus_indexes[index], device=self.device),
            torch.tensor(self.output_corpus_indexes[index], device=self.device))

In [3]:
# torch.cuda.is_available() checks and returns a Boolean True if a GPU is available, else it'll return False
is_cuda = torch.cuda.is_available()

# If we have a GPU available, we'll set our device to GPU. We'll use this device variable later in our code.
if is_cuda:
    device = torch.device("cuda")
    print("GPU is available")
else:
    device = torch.device("cpu")
    print("GPU not available, CPU used")

GPU is available


In [4]:
def pad_collate(data):
    def left_pad_sequence(tensors):
        max_len = max(list(map(len, tensors)))
        padded_seq = [torch.hstack([torch.zeros(max_len - len(t), device=t.device, dtype=torch.int32), t]) for t in tensors]
        return torch.stack(padded_seq)


    inputs = [d[0] for d in data]
    outputs = [d[1] for d in data]
    inputs = left_pad_sequence(inputs)
    outputs = left_pad_sequence(outputs)
    return inputs, outputs

In [5]:
dataset = Dataset(corpus_length = 1000000, device=device)
# loader = torch.utils.data.DataLoader(dataset, batch_size=5, collate_fn=pad_collate)

100%|██████████| 1000000/1000000 [00:18<00:00, 53660.23it/s]


In [6]:
len(dataset.unique_words)

555120

In [None]:
# for x,y in loader:
#     for s_in, s_out in zip(x,y):
#         print(dataset.indexes_to_sentence([x.item() for x in s_out]))
#     print('-----------------------------------')

In [None]:
torch.cuda.empty_cache()
print(torch.cuda.memory.mem_get_info())
print(torch.cuda.memory_allocated())
print(torch.cuda.memory_reserved())

(10472914944, 15843721216)
4005144064
4022337536


In [None]:
class RNN(nn.Module):
    def __init__(self, dataset, device, embedding_dim=100, hidden_size = 128, num_layers = 2):
        super(RNN, self).__init__()
        self.device = device

        self.num_layers = num_layers
        self.hidden_size = hidden_size
        
        n_vocab = len(dataset.unique_words)

        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=embedding_dim,
            padding_idx=0
        )

        self.rnn = nn.RNN(input_size=embedding_dim, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)

        self.fc = nn.Linear(hidden_size, n_vocab)

    def forward(self, x, h0 = None):

        x.to(self.device)

        embed = self.embedding(x)

        if h0 == None:
            h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size, device=self.device)

        output, state = self.rnn(embed, h0)
        
        logits = self.fc(output)

        return logits, state


        
model = RNN(dataset, device) 
model.to(device)

RNN(
  (embedding): Embedding(791020, 100, padding_idx=0)
  (rnn): RNN(100, 128, num_layers=2, batch_first=True)
  (fc): Linear(in_features=128, out_features=791020, bias=True)
)

In [None]:
torch.cuda.empty_cache()
print(torch.cuda.memory.mem_get_info())
print(torch.cuda.memory_allocated())
print(torch.cuda.memory_reserved())

(9749397504, 15843721216)
4730230272
4745854976


In [None]:
# model.load_state_dict(torch.load('./models/RNN_30ep.model'))

In [None]:
def train(dataset, model, max_epochs = 30, batch_size = 20):
    model.train()

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=pad_collate)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(max_epochs):        
        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()

            y_pred, _ = model(x)
            loss = criterion(y_pred.transpose(1, 2), y)
            
            loss.backward()
            optimizer.step()

            if batch % 500 == 0:
                print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })
                torch.cuda.empty_cache()
        
        if (epoch+1) % 5 == 0:
            torch.save(model.state_dict(), f"./models/RNN_2000000_{epoch+1}ep.model")
            
train(dataset, model)

{'epoch': 0, 'batch': 0, 'loss': 13.635276794433594}
{'epoch': 0, 'batch': 500, 'loss': 3.255164623260498}
{'epoch': 0, 'batch': 1000, 'loss': 4.356067657470703}
{'epoch': 0, 'batch': 1500, 'loss': 2.247701644897461}
{'epoch': 0, 'batch': 2000, 'loss': 6.966792583465576}
{'epoch': 0, 'batch': 2500, 'loss': 6.141841411590576}
{'epoch': 0, 'batch': 3000, 'loss': 4.814615249633789}
{'epoch': 0, 'batch': 3500, 'loss': 5.683190822601318}
{'epoch': 0, 'batch': 4000, 'loss': 5.349299907684326}
{'epoch': 0, 'batch': 4500, 'loss': 6.729761123657227}
{'epoch': 0, 'batch': 5000, 'loss': 7.206076622009277}
{'epoch': 0, 'batch': 5500, 'loss': 7.208861827850342}
{'epoch': 0, 'batch': 6000, 'loss': 5.900472164154053}
{'epoch': 0, 'batch': 6500, 'loss': 7.412659168243408}
{'epoch': 0, 'batch': 7000, 'loss': 6.269425868988037}
{'epoch': 0, 'batch': 7500, 'loss': 7.001075744628906}
{'epoch': 0, 'batch': 8000, 'loss': 6.401076316833496}
{'epoch': 0, 'batch': 8500, 'loss': 7.095688343048096}
{'epoch': 0, 

In [None]:
# torch.save(model.state_dict(), './models/RNN_60ep.model')

In [None]:
def predict(dataset, model, text, next_words=100):
    model.eval()

    words = text.split(' ')
    
    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words]], device=model.device)
        y_pred, _ = model(x)

        # print(y_pred)

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words


In [None]:
def predict_2(dataset, model, text, next_words=100):
    model.eval()

    words = text.split(' ')

    x = torch.tensor([[dataset.word_to_index[w] for w in words]], device=model.device)
    y_pred, hidden_state = model(x)
    
    for i in range(0, next_words):
        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

        y_pred, hidden_state = model(torch.tensor([[word_index]], device=model.device), hidden_state)

    return words


In [None]:
predict_2(dataset, model, "świadkowie", next_words=15)

['świadkowie',
 'zdarzenia',
 'sfotografowali',
 'sprawcę',
 'podczas',
 'gwałtu',
 'i',
 'powiadomili',
 'policję',
 'która',
 'schwytała',
 'go',
 'kilka',
 'godzin',
 'później',
 '<STOP>']

In [None]:
# def best_logits(logits, n):


def beam_search(dataset, model, text, max_next_words, n_solutions):
    model.eval()

    words = text.split(' ')

    x = torch.tensor([[dataset.word_to_index[w] for w in words]], device=model.device)

    y_pred, hidden_state = model(x)
    last_word_logits = y_pred[0][-1]
    log_p = torch.nn.functional.log_softmax(last_word_logits, dim=0).detach().cpu().numpy()

    best_indices = np.argsort(log_p)[::-1][:n_solutions]

    solutions = [([index], log_p[index], hidden_state) for index in best_indices]

    for i in range(1, max_next_words):
        new_solutions = []

        for (prefix, score, prefix_state) in solutions:
            x = torch.tensor([[prefix[-1]]], device=model.device)
            y_pred, hi = model(x, prefix_state)
            last_word_logits = y_pred[0][-1]
            log_p = torch.nn.functional.log_softmax(last_word_logits, dim=0).detach().cpu().numpy()
            best_indices = np.argsort(log_p)[::-1][:n_solutions]
            new_solutions += [(prefix + [ind], score + log_p[ind], hi) for ind in best_indices]

        best_indices = np.argsort([score for (_, score, _) in new_solutions])[::-1][:n_solutions]

        solutions = [new_solutions[ind] for ind in best_indices]

    return [([dataset.index_to_word[w] for w in sent], lp) for (sent, lp, _) in solutions]
            

In [None]:
beam_search(dataset, model, "świadkowie", max_next_words=3, n_solutions=3)

[(['zdarzenia', 'sfotografowali', 'sprawcę'], -0.000649821),
 (['zdarzenia', 'sfotografowali', 'w'], -8.26417),
 (['zdarzenia', 'sfotografowali', 'uczniów'], -9.916093)]