In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import json
from datasets import load_dataset
from collections import OrderedDict
import re
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from collections import OrderedDict
from tqdm import tqdm
from torch.optim.lr_scheduler import ReduceLROnPlateau
from datasets import load_dataset, DatasetDict
import random

In [2]:
dataset_path = "CShorten/CDC-COVID-FAQ"

In [3]:
df = load_dataset(dataset_path)

In [4]:
class Tokenizer():
 
    def __init__(self, dataset):
        """
        Args:
            dataset: A hugging face dataset.
        """
        self.dataset = dataset
        self.max_question_len = 0
        self.max_answer_len = 0
        self.vocab = self.create_vocab()
        self.vocab_len = len(self.vocab)
        self.stoi, self.itos = self.create_vocab_dict()

    def tokenize(self, text):
        urls = re.findall(r'https?://\S+|www\.\S+', text)
        for i, url in enumerate(urls):
            text = text.replace(url, f'__URL_{i}__')
        text = re.sub(r'([,.!?():\-\[\]/])', r' \1 ', text)
        text = re.sub(r'\s+', ' ', text).strip()
        for i, url in enumerate(urls):
            text = text.replace(f'__URL_{i}__', url)
        return list(map(str.lower, text.split()))

    def create_vocab(self):
        word_set = {'<sos>', '<eos>', '<pad>'}
        for interaction in self.dataset['train']:
            question = interaction['question']
            answer = interaction['answer']
            question_tokens = self.tokenize(question)
            answer_tokens = self.tokenize(answer)
            word_set.update(set(question_tokens))
            word_set.update(set(answer_tokens))
            question_len = len(question_tokens)
            answer_len = len(answer_tokens)
            self.max_question_len = max(self.max_question_len, question_len)
            self.max_answer_len = max(self.max_answer_len, answer_len)
        self.max_answer_len += 2
        self.max_question_len += 2
        return sorted(word_set)
        
    def create_vocab_dict(self):
        stoi = OrderedDict()
        itos = OrderedDict()
        for idx, word in enumerate(self.vocab):
            stoi[word] = idx
            itos[idx] = word
        return stoi, itos  

    def get_data_tensor(self, sample):
        question_tensor = torch.zeros(self.max_question_len, self.vocab_len)
        question_toks = self.tokenize(sample['question'])
        answer_tensor = torch.zeros(self.max_answer_len, self.vocab_len)
        answer_toks = self.tokenize('<sos> ' + sample['answer'] + ' <eos>')
        question_tensor[:, self.stoi['<pad>']] = 1
        answer_tensor[:, self.stoi['<pad>']] = 1
        for idx, word in enumerate(question_toks):
            question_tensor[idx, self.stoi[word]] = 1
            question_tensor[idx, self.stoi['<pad>']] = 0
        for idx, word in enumerate(answer_toks):
            answer_tensor[idx, self.stoi[word]] = 1
            answer_tensor[idx, self.stoi['<pad>']] = 0
        return question_tensor, answer_tensor
            
    def create_data_tensors(self):
        questions = torch.zeros(self.dataset.shape['train'][0], self.max_question_len, self.vocab_len)
        answers = torch.zeros(self.dataset.shape['train'][0], self.max_answer_len, self.vocab_len)
        for i, interaction in enumerate(self.dataset['train']):
            question_tensor, answer_tensor = self.get_data_tensor(interaction)
            questions[i] = question_tensor
            answers[i] = answer_tensor 
        return questions, answers
        

In [5]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, dropout, bidirectional=False):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.num_directions = 2 if bidirectional else 1
        self.cell = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True, bidirectional=bidirectional, dropout=dropout)

    def forward(self, x, device='mps'):
        h0, c0 = self.init_hidden_memory(x.shape[0], device=device)
        hidden_states, (ht, ct) = self.cell(x, (h0, c0))
        ht = torch.permute(ht, (1, 0, 2)).flatten(start_dim=1, end_dim=2).reshape(x.shape[0], self.num_layers, self.num_directions*self.hidden_dim).permute((1, 0, 2))
        ct = torch.permute(ht, (1, 0, 2)).flatten(start_dim=1, end_dim=2).reshape(x.shape[0], self.num_layers, self.num_directions*self.hidden_dim).permute((1, 0, 2))
        return ht, ct

    def init_hidden_memory(self, batch_size, device='mps'):
        h0 = torch.zeros(self.num_directions*self.num_layers, batch_size, self.hidden_dim, device=device)
        c0 = torch.zeros(self.num_directions*self.num_layers, batch_size, self.hidden_dim, device=device)
        return h0, c0

In [6]:
class Decoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, dropout, num_classes):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.num_classes = num_classes
        self.cell = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout)
        self.fc1 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x, h0, c0):
        outputs, (h0, c0) = self.cell(x, (h0, c0))
        outputs = outputs.squeeze(dim=1)
        outputs = self.fc1(outputs)
        return outputs, (h0, c0)   # returns a (1, vocab_size), (1, 1, hidden_dim), (1, 1, hidden_dim)

In [7]:
device='mps'
tokenizer =Tokenizer(df)

In [8]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, tokenizer):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.loss = nn.CrossEntropyLoss(ignore_index=tokenizer.stoi['<pad>'])
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, source, target, teacher_force_prob=0.4, device='mps'):
        xt = target[:, 0:1, :]
        ht, ct = self.encoder(source, device=device)
        outputs = torch.zeros(source.shape[0], target.shape[1], target.shape[2]).to(device)
        loss_at_batch = []
        for t in range(1, target.shape[1]):
            logits, (ht, ct) = self.decoder(xt, ht, ct)
            outputs[:, t, :] = logits
            y = target[:, t, :]
            y = y.argmax(dim=1).flatten()
            loss_at_timestep = self.loss(logits, y)
            loss_at_batch.append(loss_at_timestep)
            xt = torch.nn.functional.one_hot(logits.argmax(1).flatten(), num_classes=source.shape[-1]).unsqueeze(dim=1).to(torch.float32).to(device) if torch.rand(1) < teacher_force_prob else target[:, t:t+1, :] 
        return outputs, sum(loss_at_batch)/len(loss_at_batch)

In [9]:
device='mps'
tokenizer = Tokenizer(df)
input_dim = 128
encoder = Encoder(tokenizer.vocab_len, input_dim, 1, 0.0, True)
decoder = Decoder(tokenizer.vocab_len, 2*input_dim, 1, 0.0, tokenizer.vocab_len)
model = Seq2Seq(encoder, decoder, tokenizer)
model = model.to(device)

In [10]:
class Seq2SeqDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.samples = x.shape[0]

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

    def __len__(self):
        return self.samples

In [11]:
questions_data, answers_data = tokenizer.create_data_tensors()

In [12]:
dataset = Seq2SeqDataset(questions_data, answers_data)
data_loader = DataLoader(dataset, batch_size=10, shuffle=True)

In [13]:
def train(data_loader, model, training_params):
    epochs = training_params['epochs']
    optimizer = training_params['optimizer']
    min_loss = float('inf')
    train_loss_ = None
    for epoch in range(epochs):
        train_loss = []
        for x, y in tqdm(data_loader, desc=f'Epoch: {epoch}/{epochs} train loss = {train_loss_} best_loss = {min_loss}'):
            x = x.to(device)
            y = y.to(device)
            outputs, loss = model(x, y)
            optimizer.zero_grad()
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            train_loss.append(loss.item())
        train_loss_ = sum(train_loss)/len(train_loss)
        if train_loss_ < min_loss:
            torch.save(model.state_dict(), 'Seq2Seq.pth')
            min_loss = train_loss_
            print('model saved')

In [14]:
optimizer = optim.Adam(model.parameters(), lr= 5e-3)
training_params = {'optimizer': optimizer, 'epochs': 70}   

In [15]:
train(data_loader, model, training_params) # uncomment to train

Epoch: 0/70 train loss = None best_loss = inf: 100%|█| 7/7 [00:05<00:00,  1.25it


model saved


Epoch: 1/70 train loss = 4.211297392845154 best_loss = 4.211297392845154: 100%|█


model saved


Epoch: 2/70 train loss = 3.225475490093231 best_loss = 3.225475490093231: 100%|█
Epoch: 3/70 train loss = 3.3849016768591746 best_loss = 3.225475490093231: 100%|
Epoch: 4/70 train loss = 3.637118067060198 best_loss = 3.225475490093231: 100%|█


model saved


Epoch: 5/70 train loss = 2.937252002102988 best_loss = 2.937252002102988: 100%|█
Epoch: 6/70 train loss = 3.61251357623509 best_loss = 2.937252002102988: 100%|█|
Epoch: 7/70 train loss = 3.639039763382503 best_loss = 2.937252002102988: 100%|█
Epoch: 8/70 train loss = 3.3392046689987183 best_loss = 2.937252002102988: 100%|


model saved


Epoch: 9/70 train loss = 2.866875103541783 best_loss = 2.866875103541783: 100%|█
Epoch: 10/70 train loss = 3.4424594896180287 best_loss = 2.866875103541783: 100%
Epoch: 11/70 train loss = 3.4337701967784335 best_loss = 2.866875103541783: 100%
Epoch: 12/70 train loss = 3.5782800912857056 best_loss = 2.866875103541783: 100%
Epoch: 13/70 train loss = 3.6012931891850064 best_loss = 2.866875103541783: 100%
Epoch: 14/70 train loss = 3.28126506294523 best_loss = 2.866875103541783: 100%|█
Epoch: 15/70 train loss = 3.7094129834856306 best_loss = 2.866875103541783: 100%
Epoch: 16/70 train loss = 3.1529600279671803 best_loss = 2.866875103541783: 100%
Epoch: 17/70 train loss = 3.6130632758140564 best_loss = 2.866875103541783: 100%
Epoch: 18/70 train loss = 3.4047187737056186 best_loss = 2.866875103541783: 100%
Epoch: 19/70 train loss = 3.1955008591924394 best_loss = 2.866875103541783: 100%
Epoch: 20/70 train loss = 3.424644180706569 best_loss = 2.866875103541783: 100%|
Epoch: 21/70 train loss = 3.

model saved


Epoch: 30/70 train loss = 2.8519048988819122 best_loss = 2.8519048988819122: 100


model saved


Epoch: 31/70 train loss = 2.578582695552281 best_loss = 2.578582695552281: 100%|
Epoch: 32/70 train loss = 2.9655002866472517 best_loss = 2.578582695552281: 100%
Epoch: 33/70 train loss = 2.98397958278656 best_loss = 2.578582695552281: 100%|█
Epoch: 34/70 train loss = 2.712105785097395 best_loss = 2.578582695552281: 100%|
Epoch: 35/70 train loss = 2.6555564403533936 best_loss = 2.578582695552281: 100%
Epoch: 36/70 train loss = 2.918192437716893 best_loss = 2.578582695552281: 100%|
Epoch: 37/70 train loss = 2.8623636024338857 best_loss = 2.578582695552281: 100%


model saved


Epoch: 38/70 train loss = 2.4928915670939853 best_loss = 2.4928915670939853: 100
Epoch: 39/70 train loss = 2.496480073247637 best_loss = 2.4928915670939853: 100%


model saved


Epoch: 40/70 train loss = 2.3885052715029036 best_loss = 2.3885052715029036: 100
Epoch: 41/70 train loss = 2.5146006175449918 best_loss = 2.3885052715029036: 100
Epoch: 42/70 train loss = 2.547064423561096 best_loss = 2.3885052715029036: 100%


model saved


Epoch: 43/70 train loss = 2.07705295085907 best_loss = 2.07705295085907: 100%|█|
Epoch: 44/70 train loss = 2.5894304173333302 best_loss = 2.07705295085907: 100%|
Epoch: 45/70 train loss = 2.436096029622214 best_loss = 2.07705295085907: 100%|█
Epoch: 46/70 train loss = 2.4368966136659895 best_loss = 2.07705295085907: 100%|


model saved


Epoch: 47/70 train loss = 2.0335031620093753 best_loss = 2.0335031620093753: 100
Epoch: 48/70 train loss = 2.1890380552836826 best_loss = 2.0335031620093753: 100
Epoch: 49/70 train loss = 2.075176315648215 best_loss = 2.0335031620093753: 100%


model saved


Epoch: 50/70 train loss = 2.0166260855538503 best_loss = 2.0166260855538503: 100
Epoch: 51/70 train loss = 2.0292812160083225 best_loss = 2.0166260855538503: 100
Epoch: 52/70 train loss = 2.1982798406055997 best_loss = 2.0166260855538503: 100
Epoch: 53/70 train loss = 2.1926471846444264 best_loss = 2.0166260855538503: 100


model saved


Epoch: 54/70 train loss = 1.9458290210791997 best_loss = 1.9458290210791997: 100
Epoch: 55/70 train loss = 2.0452097143445696 best_loss = 1.9458290210791997: 100


model saved


Epoch: 56/70 train loss = 1.7655905442578452 best_loss = 1.7655905442578452: 100
Epoch: 57/70 train loss = 1.8193478350128447 best_loss = 1.7655905442578452: 100


model saved


Epoch: 58/70 train loss = 1.675076961517334 best_loss = 1.675076961517334: 100%|
Epoch: 59/70 train loss = 1.7482873903853553 best_loss = 1.675076961517334: 100%


model saved


Epoch: 60/70 train loss = 1.5887049287557602 best_loss = 1.5887049287557602: 100


model saved


Epoch: 61/70 train loss = 1.542639387505395 best_loss = 1.542639387505395: 100%|
Epoch: 62/70 train loss = 1.656640682901655 best_loss = 1.542639387505395: 100%|


model saved


Epoch: 63/70 train loss = 1.4261757135391235 best_loss = 1.4261757135391235: 100
Epoch: 64/70 train loss = 1.6021399753434318 best_loss = 1.4261757135391235: 100


model saved


Epoch: 65/70 train loss = 1.400780928986413 best_loss = 1.400780928986413: 100%|


model saved


Epoch: 66/70 train loss = 1.309746380363192 best_loss = 1.309746380363192: 100%|
Epoch: 67/70 train loss = 1.4131042616707938 best_loss = 1.309746380363192: 100%


model saved


Epoch: 68/70 train loss = 1.2233220083372933 best_loss = 1.2233220083372933: 100


model saved


Epoch: 69/70 train loss = 1.1069410741329193 best_loss = 1.1069410741329193: 100


In [16]:
model.load_state_dict(torch.load('Seq2Seq.pth'))

<All keys matched successfully>

In [17]:
def generate_answer(question, tokenizer, model):
    sample = {'question': question, 'answer': ''}
    question, sos_token = tokenizer.get_data_tensor(sample)
    question = question.unsqueeze(dim=0).to(device)
    xt = sos_token[0:1]
    xt = xt.unsqueeze(dim=0).to(device)
    curr_pred = None
    all_preds = []
    max_len = 800
    ht, ct = model.encoder(question)
    with torch.no_grad():
        while curr_pred != '<eos>' and len(all_preds) < max_len:
            logits, (ht, ct) = model.decoder(xt, ht, ct)
            next_word_idx = logits.argmax(dim=1)
            xt = torch.nn.functional.one_hot(next_word_idx, num_classes=question.shape[-1]).unsqueeze(dim=1).to(torch.float32).to(device)
            curr_pred = tokenizer.itos[next_word_idx.item()]
            all_preds.append(curr_pred)
    return ' '.join(all_preds[:-1])

In [75]:
question = df['train']['question'][3]

In [76]:
question

'When should healthcare facilities make changes to interventions based on changes in community transmission levels?'

In [77]:
generate_answer(question, tokenizer, model)

'no - touch devices ( ntds ) for sometimes used as healthcare settings as intended to characterize the same time . coinfections with sars - cov - 2 infection . the prior to the existing cleaning and disinfection processes . the patient of transmission - based precautions .'

<h5>Disadvantages</h5>
<ul>
<li>Developing Language models using one-hot representation is a bad approach. Models have no idea what each word represents. Using one-hot encoding results in large, sparse input vectors that may not capture semantic relationships between words effectively.
</li>
<li>
    with LSTMs, the models tend to forget longer sequences. Causing these models to be depend greatly on the previous generated word because the hidden states fail to capture the context of question.
</li>

<li>
    The generated text does not make sense for the question given. 
<li>
    The idea of this notebook is to look into what exactly happens in a sequence to sequence model. In general, a validation set is used to find the model which generalizes better. This model is overfitted to the train set. So, it cannot generalize to question which it never saw.
</li>
<li>In the future, we will implement Embeddings and use the attention module. Embeddings are latent word representations that capture meaningful semantic information. With Attention module, model will be able to capture long range dependencies.</li>
</ul>
