In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

from torchtext.data import Field, BucketIterator, TabularDataset
from pytorch_pretrained import BertModel, BertTokenizer

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
bert_path = 'BERT_BASE'
tokenizer = BertTokenizer.from_pretrained(bert_path)
vocab_size = 30522
bert_emb_dim = 768
max_input_length = 512
bert = BertModel.from_pretrained(bert_path)

In [4]:
def tokenize_and_cut(sentence):
    tokens = tokenizer.tokenize(sentence)
    tokens = tokens[:max_input_length-2]
    return tokens

In [5]:
tokenize_and_cut('What the fuck are you talking about?')

['what', 'the', 'fuck', 'are', 'you', 'talking', 'about', '?']

In [6]:
tokenizer.convert_tokens_to_ids(tokenize_and_cut('What the fuck are you talking about?'))

[2054, 1996, 6616, 2024, 2017, 3331, 2055, 1029]

In [7]:
PAD = '[PAD]'
CLS = '[CLS]'
SEP = '[SEP]'
UNK = '[UNK]'
cls_token_id = 101
sep_token_id = 102
pad_token_id = 0
unk_token_id = 100

In [8]:
SRC = Field(use_vocab=False,
                    tokenize=tokenize_and_cut,
                    preprocessing=tokenizer.convert_tokens_to_ids,
                    init_token=cls_token_id,
                    eos_token=sep_token_id,
                    pad_token=pad_token_id,
                    unk_token=unk_token_id)

TGT = Field(use_vocab=False,
                    tokenize=tokenize_and_cut,
                    preprocessing=tokenizer.convert_tokens_to_ids,
                    init_token=cls_token_id,
                    eos_token=sep_token_id,
                    pad_token=pad_token_id,
                    unk_token=unk_token_id)

data_fields = [('src', SRC), ('tgt', TGT)]


In [9]:
import pandas as pd 
from sklearn.model_selection import train_test_split

In [10]:
df = pd.read_csv('./data/chatbotdata1.csv')
src = df.q.values
tgt = df.a.values

In [11]:
X_train, X_test, Y_train, Y_test = train_test_split(src, tgt, test_size=0.4, random_state=666)

In [18]:
X_test, X_dev, Y_test, Y_dev = train_test_split(X_test, Y_test, test_size=0.5, random_state=666)

In [19]:
train = {}
test = {}
dev = {}

train['q'] = X_train
train['a'] = Y_train

test['q'] = X_test
test['a'] = Y_test

dev['q'] = X_dev
dev['a'] = Y_dev

train_df = pd.DataFrame(train)
test_df = pd.DataFrame(test)
dev_df = pd.DataFrame(dev)

In [21]:
train_df.to_csv('./data/train.csv')
test_df.to_csv('./data/test.csv')
dev_df.to_csv('./data/dev.csv')

In [12]:
train_data, validation_data, test_data = TabularDataset.splits(
                    path = './data/',
                    format='csv',
                    train='train.csv',
                    validation='dev.csv',
                    test='test.csv',
                    skip_header=False,
                    fields= data_fields
)

In [13]:
train_data

<torchtext.data.dataset.TabularDataset at 0x1c7923e4a90>

In [14]:
train_iterator, validation_iterator, test_iterator = BucketIterator.splits(
    (train_data, validation_data, test_data),
    batch_size=62,
    sort_key=lambda x: len(x.src),  # function used to group the data
    sort_within_batch=False,
    device=device)

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


class TransBertEncoder(nn.Module):
    def __init__(self, nhead=12, nlayers=12, dropout=0.1):
        super().__init__()

        # bert encoder
        self.bert = g_bert

        # transformer encoder, as bert last layer fine-tune
        self.pos_encoder = PositionalEncoding(g_bert_emb_dim, dropout)
        encoder_layers = nn.TransformerEncoderLayer(d_model=g_bert_emb_dim, nhead=nhead)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, nlayers)

    def forward(self, src):
        # src = [src len, batch size]

        with torch.no_grad():
            # embedded = [src len, batch size, emb dim]
            embedded = self.bert(src.transpose(0, 1))[0].transpose(0, 1)

        # embedded = self.pos_encoder(embedded)

        # src_mask = nn.Transformer().generate_square_subsequent_mask(len(embedded)).to(g_device)

        # outputs = [src len, batch size, hid dim * n directions]
        outputs = self.transformer_encoder(embedded)

        return outputs


class TransBertDecoder(nn.Module):
    def __init__(self, nhead=8, nlayers=6, dropout=0.5):
        super().__init__()

        # bert encoder
        self.bert = g_bert

        self.pos_decoder = PositionalEncoding(g_bert_emb_dim, dropout)
        decoder_layer = nn.TransformerDecoderLayer(d_model=g_bert_emb_dim, nhead=nhead)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=nlayers)

        self.fc_out = nn.Linear(g_bert_emb_dim, g_vocab_size)

    def forward(self, tgt, meaning, teacher_forcing_ratio):
        # tgt = [output_len, batch size]

        output_len = tgt.size(0)
        batch_size = tgt.size(1)
        # decide if we are going to use teacher forcing or not
        teacher_force = random.random() < teacher_forcing_ratio

        if teacher_force and self.training:
            tgt_emb_total = torch.zeros(output_len, batch_size, g_bert_emb_dim).to(g_device)

            for t in range(0, output_len):
                with torch.no_grad():
                    tgt_emb = self.bert(tgt[:t+1].transpose(0, 1))[0].transpose(0, 1)
                tgt_emb_total[t] = tgt_emb[-1]

            tgt_mask = nn.Transformer().generate_square_subsequent_mask(len(tgt_emb_total)).to(g_device)
            decoder_output = self.transformer_decoder(tgt=tgt_emb_total,
                                                      memory=meaning,
                                                      tgt_mask=tgt_mask)
            predictions = self.fc_out(decoder_output)
        else:
            # initialized the input of the decoder with sos_idx (start of sentence token idx)
            output = torch.full((output_len+1, batch_size), g_tokenizer.cls_token_id, dtype=torch.long, device=g_device)
            predictions = torch.zeros(output_len, batch_size, g_vocab_size).to(g_device)

            for t in range(0, output_len):
                with torch.no_grad():
                    tgt_emb = self.bert(output[:t+1].transpose(0, 1))[0].transpose(0, 1)

                # tgt_emb = [t, batch size, emb dim]
                # tgt_emb = self.pos_encoder(tgt_emb)

                tgt_mask = nn.Transformer().generate_square_subsequent_mask(len(tgt_emb)).to(g_device)

                # decoder_output = [t, batch size, emb dim]
                decoder_output = self.transformer_decoder(tgt=tgt_emb,
                                                          memory=meaning,
                                                          tgt_mask=tgt_mask)

                # prediction = [batch size, vocab size]
                prediction = self.fc_out(decoder_output[-1])

                # predictions = [output_len, batch size, vocab size]
                predictions[t] = prediction

                one_hot_idx = prediction.argmax(1)

                # output  = [output len, batch size]
                output[t+1] = one_hot_idx

        return predictions


class GruEncoder(nn.Module):
    """compress the request embeddings to meaning"""

    def __init__(self, hidden_size, input_size):
        super().__init__()
        self.gru = nn.GRU(input_size, hidden_size)

    def forward(self, input):
        output, hidden = self.gru(input)
        return hidden


class GruDecoder(nn.Module):
    def __init__(self, hidden_size, output_size):
        super().__init__()
        self.gru = nn.GRU(output_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, src, tgt, hidden):
        # first input to the decoder is the <CLS> tokens
        fc_output = src[0].unsqueeze(0)
        tgt_len = tgt.size(0)
        batch_size = tgt.size(1)

        # tensor to store decoder outputs
        outputs = torch.zeros(tgt_len, batch_size, g_bert_emb_dim).to(g_device)

        for t in range(0, tgt_len):
            # insert input token embedding, previous hidden state and the context state
            # receive output tensor (predictions) and new hidden state
            gru_output, hidden = self.gru(fc_output, hidden)

            fc_output = self.fc(gru_output)

            # place predictions in a tensor holding predictions for each token
            outputs[t] = fc_output
        return outputs


class DialogDNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, dropout=0.5):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, hidden_size)

        self.dropout = nn.Dropout(dropout)

    def forward(self, input):
        # ResNet, dropout on first 3 layers
        input = self.dropout(input)

        output = input + F.relu(self.fc1(input))
        output = self.dropout(output)

        output = output + F.relu(self.fc2(output))
        output = self.dropout(output)

        output = output + self.fc3(output)  # no relu to keep negative values

        return output


class Seq2Seq(nn.Module):
    def __init__(self, transbert_encoder, transbert_decoder, gru_encoder, gru_decoder, dialog_dnn):
        super().__init__()

        self.transbert_encoder = transbert_encoder
        self.transbert_decoder = transbert_decoder

        self.gru_encoder = gru_encoder
        self.gru_decoder = gru_decoder

        self.dialog_dnn = dialog_dnn

    def forward(self, src, tgt, teacher_forcing_ratio):
        request_embeddings = self.transbert_encoder(src)
        request_meaning = self.gru_encoder(request_embeddings)

        if TRAIN_DIALOG:
            response_meaning = self.dialog_dnn(request_meaning)
        else:
            response_meaning = request_meaning

        response_embeddings = self.gru_decoder(request_embeddings, tgt, response_meaning)
        response = self.transbert_decoder(tgt, response_embeddings, teacher_forcing_ratio)

        return response