In [1]:
import re

import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader, Dataset
from torch.optim import *

from torchtext.data import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

from nltk.tokenize.treebank import TreebankWordDetokenizer

from datasets import load_dataset

from tqdm.notebook import tqdm

In [2]:
class QADataset(Dataset):
    def __init__(self, dataset, tokenizer, vocab, maxlen, special_tokens, device):
        super(QADataset, self).__init__()

        self.maxlen = maxlen
        self.special_tokens = special_tokens
        
        self.tokenizer = get_tokenizer('spacy') if tokenizer is None else tokenizer
        self.vocab = self.get_vocab(dataset) if vocab is None else vocab

        self.device = device
        
        self.sent1, self.sent2 = self.create_pair_dataset(dataset)
        
        self.sent1 = self.get_tensor(self.sent1)
        self.sent2 = self.get_tensor(self.sent2)

    def __getitem__(self, i):
        return self.sent1[i], self.sent2[i]

    def __len__(self):
        return len(self.sent1)

    def create_sentence_pairs(self, x):
        sentences = re.split(r'[\s]*#Person\d#: ', x['dialogue'])[1:]
        return {'sent1': sentences[:-1], 'sent2': sentences[1:]}

    def create_pair_dataset(self, dataset):
        dataset = dataset.map(self.create_sentence_pairs)

        flatten_sent1 = [sent for sents in dataset['sent1'] for sent in sents]
        flatten_sent2 = [sent for sents in dataset['sent2'] for sent in sents]

        df = pd.DataFrame({'sent1': flatten_sent1,'sent2': flatten_sent2}).drop_duplicates()
        
        return df['sent1'].tolist(), df['sent2'].tolist()
    
    def tokenize_sent(self, sent):
        sent = ' '.join([self.special_tokens['bos_token'], sent, self.special_tokens['eos_token']])
        sent = [tok for tok in self.tokenizer(sent)]
        sent = sent[:self.maxlen]
        sent = sent + [self.special_tokens['pad_token']] * (self.maxlen - len(sent))
        return sent
    
    def get_tensor(self, sents):
        tokens = torch.zeros((len(sents), self.maxlen), dtype=torch.long, device=self.device)
        for i, sent in enumerate(sents):
            tokens[i, :] = torch.tensor(self.vocab(self.tokenize_sent(sent)))
        return tokens
    
    def get_vocab(self, dataset):
        sent1, sent2 = self.create_pair_dataset(dataset)
        dataset = sent1 + sent2[-1:]
        
        tokenized_dataset = list(map(self.tokenize_sent, dataset))
        
        vocab = build_vocab_from_iterator(tokenized_dataset, min_freq=5,
                                          specials=list(self.special_tokens.values()))
        vocab.add_module
        vocab.set_default_index(vocab['|UNK|'])

        return vocab


In [3]:
class EncoderModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=64, n_layers=32, n_hidden=128, dropout_rate=0.2):
        super(EncoderModel, self).__init__()
        
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.dropout_rate = dropout_rate
        
        self.embed = nn.Embedding(self.vocab_size, embedding_dim=self.embed_dim)
        self.lstm = nn.LSTM(self.embed_dim, self.n_hidden, num_layers=self.n_layers, bidirectional=True, batch_first=True)
        self.dropout = nn.Dropout(self.dropout_rate)
        
    def forward(self, x):
        out = self.dropout(self.embed(x))
        _, states = self.lstm(out)
        
        return states


class DecoderModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=64, n_layers=32, n_hidden=128, dropout_rate=0.2):
        super(DecoderModel, self).__init__()

        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.dropout_rate = dropout_rate

        self.embed = nn.Embedding(self.vocab_size, embedding_dim=self.embed_dim)
        self.lstm = nn.LSTM(self.embed_dim, self.n_hidden, num_layers=self.n_layers,
                            batch_first=True)
        self.fc = nn.Linear(self.n_hidden, self.vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)

        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, states):
        out = self.dropout(self.embed(x))
        out, states = self.lstm(out, states)
        out = self.fc(out)
        out = self.softmax(out)

        return out, states


In [4]:
def train_step(enc_model, dec_model, loader, enc_opt, dec_opt, criterion, device):
    enc_model.train(), dec_model.train()

    pbar = tqdm(loader)
    batch_losses = []

    for i, (x, y) in enumerate(pbar):
        x = x.to(device)
        y = y.to(device)

        enc_opt.zero_grad()
        dec_opt.zero_grad()
        
        enc_states = enc_model(x)
        dec_out, _ = dec_model(y, enc_states)
        
        loss = criterion(dec_out.moveaxis(1, -1), y)
        loss.backward()
        
        enc_opt.step()
        dec_opt.step()
        
        batch_losses.append(loss.item())
        pbar.set_description(f'Batch Loss: {loss.item():.3f} Train Loss: {np.mean(batch_losses):.3f}')

    return np.mean(batch_losses)


def eval_step(model, loader, criterion, device):
    model.eval()

    pbar = tqdm(loader)
    batch_losses = []

    for i, batch in enumerate(pbar):
        batch = batch.to(device)

        y_pred, _ = model(batch[:, :-1])
        loss = criterion(y_pred.moveaxis(1, -1), y)

        batch_losses.append(loss.item())
        pbar.set_description(f'Batch Loss: {loss.item():.3f} Validation Loss: {np.mean(batch_losses):.3f}')

    return np.mean(batch_losses)


def answer(model, sent, tokenizer, vocab, maxlen, special_tokens, device):
    model.eval()

    detokenizer = TreebankWordDetokenizer()

    sent = ' '.join([special_tokens['bos_token'], sent, special_tokens['pad_token']])
    sent = vocab(tokenizer(sent))
    sent = torch.tensor(sent, device=device)

    with torch.no_grad():
        y_pred, states = model(sent)

        pred_tokens = y_pred.argmax(dim=-1, keepdim=True)
        # sent_preds = vocab.lookup_tokens(list(pred_tokens))

        answer = []
        for j in range(maxlen - len(sent)):
            last_idx = pred_tokens[-1]
            answer.append(vocab.lookup_token(last_idx))

            if answer[-1] == special_tokens['eos_token']:
                break

            y_pred, states = model(last_idx, states)
            pred_tokens = y_pred.argmax(dim=-1, keepdim=True)

        return detokenizer.detokenize(answer)

In [5]:
def train(max_len, bs, lr, epochs, **kwargs):
    special_tokens = {'bos_token': '|BOS|',
                      'pad_token': '|PAD|',
                      'eos_token': '|EOS|',
                      'unk_token': '|UNK|'}

    train_dataset = load_dataset('knkarthick/dialogsum', split='train')
    val_dataset = load_dataset('knkarthick/dialogsum', split='validation')

    lm_train = QADataset(train_dataset, None, None, max_len, special_tokens, device)
    lm_valid = QADataset(val_dataset, lm_train.tokenizer, lm_train.vocab, max_len, special_tokens, device)

    train_loader = DataLoader(lm_train, batch_size=bs, shuffle=True)
    val_loader = DataLoader(lm_valid, batch_size=bs)

    encoder_kwargs = {k.replace('encoder_', ''): v for k, v in kwargs.items() if 'encoder_' in k}
    encoder_model = EncoderModel(vocab_size=len(lm_train.vocab), **encoder_kwargs)
    
    decoder_kwargs = {k.replace('decoder_', ''): v for k, v in kwargs.items() if 'decoder_' in k}
    decoder_model = DecoderModel(vocab_size=len(lm_train.vocab), **decoder_kwargs)

    assert decoder_kwargs['n_layers'] == 2 * encoder_kwargs['n_layers'], '# of encoder layers must be the double the size of # decoder layers..'
    
    encoder_optimizer = optim.Adam(encoder_model.parameters(), lr=lr)
    decoder_optimizer = optim.Adam(decoder_model.parameters(), lr=lr)
    
    criterion = nn.NLLLoss()
    
    for e in range(epochs):
        train_loss = train_step(encoder_model, decoder_model, train_loader, 
                                encoder_optimizer, decoder_optimizer, criterion, 
                                device)
        val_loss = eval_step(encoder_model, decoder_model, 
                             val_loader, criterion, device)


    return val_loss


In [8]:
max_len = 50
bs = 32
lr = 1e-3
epochs = 10
device = 'cpu'

In [9]:
train(max_len, bs, lr, epochs, device=device,
     encoder_embed_dim=64, encoder_n_layers=32, encoder_n_hidden=128, encoder_dropout_rate=0.2,
     decoder_embed_dim=64, decoder_n_layers=64, decoder_n_hidden=128, decoder_dropout_rate=0.2)

Using custom data configuration knkarthick--dialogsum-caf2f3e75d9073aa
Found cached dataset csv (/Users/bugrahamzagundog/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-caf2f3e75d9073aa/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)
Using custom data configuration knkarthick--dialogsum-caf2f3e75d9073aa
Found cached dataset csv (/Users/bugrahamzagundog/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-caf2f3e75d9073aa/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)


  0%|          | 0/12460 [00:00<?, ?ex/s]

  0%|          | 0/12460 [00:00<?, ?ex/s]

  0%|          | 0/500 [00:00<?, ?ex/s]

  0%|          | 0/2948 [00:00<?, ?it/s]

KeyboardInterrupt: 