### **_Encoder_**

In [None]:
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset

import numpy as np
import pandas as pd

import spacy
nlp = spacy.load("en_core_web_sm")

In [None]:
global device
global vocab_size
global lstm_hidden_size
global num_layers
global batch_size

In [None]:
bert_model_name = 'bert-base-uncased'
bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
bert_model = BertModel.from_pretrained(bert_model_name)

# global variable iinitialize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vocab_size = bert_model.config.vocab_size
lstm_hidden_size = 256
num_layers = 2
batch_size = 6

### **_Encoder_**

In [None]:
class Encoder(nn.Module):

    def __init__(self, bert_model):
        super(Encoder, self).__init__()
        self.bert = bert_model
        
    def forward(self, input_ids, attention_mask):
        # with torch.no_grad():
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state
        return last_hidden_state


### **_Decoder_**

In [None]:
class Decoder(nn.Module):

    def __init__(self, lstm_hidden_size, num_layers, vocab_size):
        super(Decoder, self).__init__()
        self.lstm = nn.LSTM(input_size=bert_model.config.hidden_size,
                            hidden_size=lstm_hidden_size,
                            num_layers=num_layers)
        self.linear = nn.Linear(lstm_hidden_size, vocab_size)
    
    def forward(self, input):
        lstm_outputs, _ = self.lstm(input)
        output = self.linear(lstm_outputs)
        return output
    

### **_Seq2Seq_**

In [None]:
class Seq2Seq(nn.Module):
    
    def __init__(self, encoder, decoder) -> None:
        '''
        '''
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, input_ids, attention_mask):
        last_hidden_state = self.encoder(input_ids, attention_mask)
        decoder_output = self.decoder(last_hidden_state)
        return decoder_output
        

In [None]:
class Seq2Seqwithattn():

    def __init__(self, encoder, decoder, attn_embed_dim, num_heads) -> None:
        super(Seq2Seqwithattn, self).__init__()
        self.encoder = encoder
        self.multihead_attn = nn.MultiheadAttention(attn_embed_dim, num_heads)
        self.decoder = decoder

    def forward(self, input_ids, attention_mask):
        encoder_outputs  = self.encoder(input_ids, attention_mask)
        attn_output, _ = self.multihead_attn(encoder_outputs,
                                             encoder_outputs,
                                             encoder_outputs)
        decoder_outputs = self.decoder(attn_output)
        return decoder_outputs

In [None]:
class bertsformer(nn.Module):

    def __init__(self, bert_model):
        super(bertsformer, self).__init__()
        self.bert_encoder = bert_model
        self.trans_decoder = nn.TransformerDecoder(d_model=512, nhead=8)

    def forward(self, input_ids, attention_mask):
        bert_output = self.bert_encoder(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = bert_output['last_hidden_state']
        # Assuming you have the necessary inputs for the TransformerDecoder
        trans_decoder_output = self.trans_decoder(hidden_states, num_layers=6)
        return trans_decoder_output


### **_Preprocess data_**

In [None]:
def preprocess_data(text):
    return [tok.text for tok in nlp.tokenizer(text)]

In [None]:
class PubMedDataset(Dataset):
    
    def __init__(self, df, tokenizer) -> None:
        super().__init__()
        self.df = df
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        input_seq = self.df.loc[index, "input"]
        target_seq = self.df.loc[index, "target"]

        encoded_inputs = self.tokenizer.encode_plus(
            input_seq,
            add_special_tokens=True,
            padding='max_length', 
            max_length=128,
            truncation=True,
            return_tensors='pt'
        )
        encoded_targets = self.tokenizer.encode_plus(
            target_seq,
            add_special_tokens=True,
            padding='max_length',
            max_length=128,
            truncation=True,
            return_tensors='pt'
        )

        input_ids = encoded_inputs['input_ids'].squeeze(0)
        attention_mask = encoded_inputs['attention_mask'].squeeze(0)
        target_ids = encoded_targets['input_ids'].squeeze(0)

        return input_ids, attention_mask, target_ids

In [None]:
import json

training_set = r'./../spider/meddialog/results/eval_wer_json/pubmed_46374_train.json'
with open(training_set, 'r') as td:
    data = json.load(td)

In [None]:
utterances = []
results = []

for d in data:
    utterances.append(d['utterances']['pubmed'])
    results.append(d['results']['pubmed'][2:])

print(len(utterances))

pubmed_df = pd.DataFrame({'input': utterances, 'target': results})
PD = PubMedDataset(pubmed_df, bert_tokenizer)
dataloader = DataLoader(PD, batch_size=6, shuffle=True)

print(pubmed_df.head())

### **_Train model_**

In [None]:
def train_model(model, criterion, optimizer, dataloader, num_epochs):
    model.train()

    for epoch in range(num_epochs):
        for i, (input_ids, attention_mask, target_ids) in enumerate(dataloader):
            optimizer.zero_grad()

            # 將資料移到 device
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            target_ids = target_ids.to(device)

            # forward
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs.view(-1, vocab_size), target_ids.view(-1))

            # backward & optimization
            loss.backward()
            optimizer.step()
            
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item()}")

In [None]:
encoder = Encoder(bert_model=bert_model)
decoder = Decoder(lstm_hidden_size=lstm_hidden_size, num_layers=num_layers, vocab_size=vocab_size)
bert_lstm_s2s_model = Seq2Seq(encoder, decoder)

criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(bert_lstm_s2s_model.parameters(), lr=0.001)
optimizer = torch.optim.AdamW(bert_lstm_s2s_model.parameters(), lr=0.001)

train_model(bert_lstm_s2s_model, \
            criterion=criterion, \
            optimizer=optimizer, \
            dataloader=dataloader, \
            num_epochs=10)

# torch.save(model.state_dict(), 'model.pth')

### **_Evaluate the model_**

In [None]:
predicted_sentences = []
bert_lstm_s2s_model.eval()

with torch.no_grad():
    for input_ids, attention_mask, target_ids in dataloader:
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)

        outputs = bert_lstm_s2s_model(input_ids, attention_mask)

        _, predicted_ids = torch.max(outputs, dim=2)

        for ids in predicted_ids:
            tokens = bert_tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=True)
            sentence = bert_tokenizer.convert_tokens_to_string(tokens)
            predicted_sentences.append(sentence)


print(predicted_sentences[:10])
