<a href="https://colab.research.google.com/github/DmitriySechkin/ds-learning-sb/blob/main/RNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [17]:
import datetime

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

import random

In [7]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [11]:
data_dir = 'drive/My Drive/'
train_lang = 'en2'

In [38]:
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence

class DatasetSeq(Dataset):
    dataset = None
    train_dataset = None
    test_dataset = None

    def __init__(self, data_dir, train_lang='en', train=True):
        target_dataset = []

	      #open file
        if self.dataset is None:
          with open(data_dir + train_lang + '.train', 'r') as f:
              self.dataset = f.read().split('\n\n')

          # delete extra tag markup
          self.dataset = [x for x in self.dataset if not '_ ' in x]

        random.shuffle(self.dataset)
        train_len = 18000
        train_dataset = self.dataset[:train_len]
        test_dataset = self.dataset[train_len:]

        assert(len(train_dataset) + len(test_dataset) == len(self.dataset))

	    #init vocabs of tokens for encoding {<str> token: <int> id}
        self.target_vocab = {} # {p: 1, a: 2, r: 3, pu: 4}
        self.word_vocab = {} # {cat: 1, sat: 2, on: 3, mat: 4, '.': 5}
        self.char_vocab = {} # {c: 1, a: 2, t: 3, ' ': 4, s: 5}

        # Cat sat on mat. -> [1, 2, 3, 4, 5]
        # p    a  r  p pu -> [1, 2, 3, 1, 4]
        # chars  -> [1, 2, 3, 4, 5, 2, 3, 4]

	    #init encoded sequences lists (processed data)
        self.encoded_sequences = []
        self.encoded_targets = []
        self.encoded_char_sequences = []
        # n=1 because first value is padding
        n_word = 1
        n_target = 1
        n_char = 1

        if train:
          target_dataset = train_dataset
        else:
          target_dataset = test_dataset

        for line in target_dataset:
            sequence = []
            target = []
            chars = []
            for item in line.split('\n'):
                if item != '':
                    word, label = item.split(' ')

                    if self.word_vocab.get(word) is None:
                        self.word_vocab[word] = n_word
                        n_word += 1
                    if self.target_vocab.get(label) is None:
                        self.target_vocab[label] = n_target
                        n_target += 1
                    for char in word:
                        if self.char_vocab.get(char) is None:
                            self.char_vocab[char] = n_char
                            n_char += 1
                    sequence.append(self.word_vocab[word])
                    target.append(self.target_vocab[label])
                    chars.append([self.char_vocab[char] for char in word])
            self.encoded_sequences.append(sequence)
            self.encoded_targets.append(target)
            self.encoded_char_sequences.append(chars)


    def __len__(self):
        return len(self.encoded_sequences)

    def __getitem__(self, index):
        return {
            'data': self.encoded_sequences[index], # [1, 2, 3, 4, 6] len=5
            'char': self.encoded_char_sequences[index],# [[1,2,3], [4,5], [1,2], [2,6,5,4], []] len=5
            'target': self.encoded_targets[index], #  (1)
        }

In [39]:
train_dataset = DatasetSeq(data_dir, train_lang)
test_dataset = DatasetSeq(data_dir, train_lang, train=False)

18000 3236 21236
18000 3236 21236


In [None]:
#padding
# seq1 = [1, 2, 3, 4]
# seq2 = [9, 7, 6, 4, 3, 7, 5]
# pad seq1 equal seq2
# seq1 = [1, 2, 3, 4, 0, 0, 0]
# concat(seq1, seq2) [[1, 2, 3, 4, 0, 0, 0],
#                     [9, 7, 6, 4, 3, 7, 5]]

In [49]:
def train_model(train_dataloader, model):
  train_loss = 0

  model.train()

  for i, batch in enumerate(train_dataloader):
    inputs, labels = batch
    optim.zero_grad()

    predict = model(inputs.to(device))
    loss = loss_func(inputs.to(device), predict)

    loss.backward()
    optim.step()

    train_loss += loss.item()

  return train_loss

In [50]:
def validate_model(test_dataloader, model):
  test_loss = 0

  model.eval()

  with torch.no_grad():
    for i, batch in enumerate(test_dataloader):
        inputs, labels = batch
        predict = model(inputs.to(device))
        loss = loss_func(inputs.to(device), predict)

        test_loss += loss.item()

  return test_loss

In [41]:
def collate_fn(batch):
    data = []
    target = []
    for item in batch:
        data.append(torch.as_tensor(item['data']))
        target.append(torch.as_tensor(item['target']))
    data = pad_sequence(data, batch_first=True, padding_value=0)
    target = pad_sequence(target, batch_first=True, padding_value=0)

    return {'data': data, 'target': target}

In [42]:
class RNNCellPredictor(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden_dim, n_classes):
        super().__init__()
        self.word_emb = nn.Embedding(vocab_size, emb_dim)
        self.gru_cell = nn.GRUCell(emb_dim, hidden_dim)
        self.classifier = nn.Linear(hidden_dim, n_classes)
        self.hidden_dim = hidden_dim

    def forward(self, x): # B x T
        b, t = x.size()
        emb = self.word_emb(x) # B x T x Emb
        gru_out = []
        hidden = torch.zeros(b, self.hidden_dim).to(emb.device)
        for i in range(t):
            hidden = self.gru_cell(emb[:, i, :], hidden)
            gru_out.append(hidden.unsqueeze(1))
        gru_out = torch.cat(gru_out, dim=1)

        classes = self.classifier(gru_out)

        return classes


In [43]:
class RNNPredictor(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden_dim, n_classes):
        super().__init__()
        self.word_emb = nn.Embedding(vocab_size, emb_dim)
        #TODO try to use other RNN archicetures, f.e. RNN and LSTM

        self.gru = nn.GRU(emb_dim, hidden_dim, batch_first=True)
        self.classifier = nn.Linear(hidden_dim, n_classes)
        self.hidden_dim = hidden_dim

    def forward(self, x):
        emb = self.word_emb(x)
        hidden, _ = self.gru(emb)

        classes = self.classifier(hidden)

        return classes

In [None]:
# T x B
# len([[первые слова], [вторые слова], .. [последние слова]]) - длина предложения T
# len([n-ые слова]) - размер батча B

# B x T
# len([[первое предложение], [второе предложение] .. ]) - размер батча B
# len([первое предложение]) - длина предложения T

In [46]:
#hyper params
vocab_size = len(train_dataset.word_vocab) + 1
n_classes = len(train_dataset.target_vocab) + 1
n_chars = len(train_dataset.char_vocab) + 1
#TODO try to use other model parameters
emb_dim = 256
hidden = 256
n_epochs = 10
batch_size = 64
cuda_device = -1
batch_size = 100
device = f'cuda:{cuda_device}' if cuda_device != -1 else 'cpu'

In [47]:
model = RNNPredictor(vocab_size, emb_dim, hidden, n_classes).to(device)
model.train()
optim = torch.optim.Adam(model.parameters(), lr=0.001)
loss_func = nn.CrossEntropyLoss()

In [48]:

for epoch in range(n_epochs):
    dataloader = DataLoader(train_dataset,
                            batch_size,
                            shuffle=True,
                            collate_fn=collate_fn,
                            drop_last = True,
                            )
    for i, batch in enumerate(dataloader):
        optim.zero_grad()

        predict = model(batch['data'].to(device))
        loss = loss_func(predict.view(-1, n_classes),
                         batch['target'].to(device).view(-1),
                         )
        loss.backward()
        optim.step()
        if i % 100 == 0:
            print(f'epoch: {epoch}, step: {i}, loss: {loss.item()}')

    torch.save(model.state_dict(), f'./rnn_chkpt_{epoch}.pth')

epoch: 0, step: 0, loss: 3.1065540313720703


KeyboardInterrupt: 

In [None]:

#example
phrase = 'He ran quickly after the red bus and caught it .'
words = phrase.split(' ')
tokens = [dataset.word_vocab[w] for w in words]

start = datetime.datetime.now()
with torch.no_grad():
    model.eval()
    predict = model(torch.tensor(tokens).unsqueeze(0).to(device)) # 1 x T x N_classes
    labels = torch.argmax(predict, dim=-1).squeeze().cpu().detach().tolist()
    end = datetime.datetime.now() - start

target_labels = list(dataset.target_vocab.keys())
print([target_labels[l-1] for l in labels])

['PRON', 'VERB', 'ADV', 'SCONJ', 'DET', 'ADJ', 'NOUN', 'CCONJ', 'VERB', 'PRON', 'PUNCT']
