In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from transformers import BertForQuestionAnswering, BertModel
from transformers import BertTokenizer
import tqdm
tqdmn = tqdm.notebook.tqdm
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import nltk
from nltk.corpus import stopwords
from torch.nn import CrossEntropyLoss
from sklearn.cross_decomposition import CCA
stop_words = stopwords.words('english')

In [None]:
import nltk
nltk.download('stopwords')

In [None]:
with open("semeval_14_restaurant_sentence_dictionary_train_pairs_pseudo.pickle", "rb") as handle:
    train14 = pickle.load(handle)

with open("semeval_16_restaurant_sentence_dictionary_train_pairs_pseudo.pickle", "rb") as handle:
    train16 = pickle.load(handle)

In [None]:
def get_start_end_index(input_ids, solution):
    start_index = 0
    end_index = 0
    tokens = tokenizer.convert_ids_to_tokens(input_ids)
    start_boolean = False
    end_boolean = False
    for i in range(0, len(tokens)):
        if tokens[i] == solution[0]:
            start_index = i
            start_boolean = True
            break
    for i in range(0, len(tokens)):
        if tokens[i] == solution[-1]:
            end_index = i
            end_boolean = True
            break
    if (start_boolean == True) and (end_boolean == True):
        return start_index, end_index
    elif (start_boolean == False) and (end_boolean == True):
        for i in range(0, len(tokens)):
            if tokens[i].replace('##', '') in solution[0]:
                start_index = i
                start_boolean = True
                break
        return start_index, end_index
    elif (start_boolean == True) and (end_boolean == False):
        for i in range(0, len(tokens)):
            if tokens[i].replace('##', '') in solution[-1]:
                end_index = i
                end_boolean = True
                break
        return start_index, end_index
    elif (start_boolean == False) and (end_boolean == False):
        for i in range(0, len(tokens)):
            if tokens[i].replace('##', '') in solution[0]:
                start_index = i
                start_boolean = True
                break
        for i in range(0, len(tokens)):
            if tokens[i].replace('##', '') in solution[-1]:
                end_index = i
                end_boolean = True
                break
        return start_index, end_index

In [None]:
class construct_dataset(Dataset):
    def __init__(self, question, text, answer, question_opinion, text_opinion, answer_opinion):
        self.question = question
        self.text = text
        self.answer = answer
        self.question_opinion = question_opinion
        self.text_opinion = text_opinion
        self.answer_opinion = answer_opinion
    def __len__(self):
        return len(self.question)
    def __getitem__(self, idx):
        query = self.question[idx].lower()
        sent = self.text[idx].lower()
        input_ids = tokenizer.encode(query, sent)
        sep_idx = input_ids.index(tokenizer.sep_token_id)
        num_seg_a = sep_idx+1
        num_seg_b = len(input_ids) - num_seg_a
        segment_ids = [0]*num_seg_a + [1]*num_seg_b
        
        query_opinion = self.question_opinion[idx].lower()
        sent_opinion = self.text_opinion[idx].lower()
        input_ids_opinion = tokenizer.encode(query_opinion, sent_opinion)
        sep_idx_opinion = input_ids_opinion.index(tokenizer.sep_token_id)
        num_seg_a_opinion = sep_idx_opinion+1
        num_seg_b_opinion = len(input_ids_opinion) - num_seg_a_opinion
        segment_ids_opinion = [0]*num_seg_a_opinion + [1]*num_seg_b_opinion
        
        term_input_ids = tokenizer.encode(sent)
        term_sep_idx = term_input_ids.index(tokenizer.sep_token_id)
        term_num_seg_a = term_sep_idx+1
        term_num_seg_b = len(term_input_ids) - term_num_seg_a
        term_segment_ids = [0]*term_num_seg_a + [1]*term_num_seg_b
        
        opinion_input_ids = tokenizer.encode(sent_opinion)
        opinion_sep_idx = opinion_input_ids.index(tokenizer.sep_token_id)
        opinion_num_seg_a = opinion_sep_idx+1
        opinion_num_seg_b = len(opinion_input_ids) - opinion_num_seg_a
        opinion_segment_ids = [0]*opinion_num_seg_a + [1]*opinion_num_seg_b

        if self.answer[idx] not in ['[CLS]']:
            solution = self.answer[idx].lower().split(" ")
        else:
            solution = self.answer[idx].split(" ")
        
        if self.answer_opinion[idx] not in ['[CLS]']:
            solution_opinion = self.answer_opinion[idx].lower().split(" ")
        else:
            solution_opinion = self.answer_opinion[idx].split(" ")
        
        tokens_list = tokenizer.convert_ids_to_tokens(input_ids)
        start_index, end_index = get_start_end_index(input_ids, solution)
        term_start_index, term_end_index = get_start_end_index(term_input_ids, solution)
        
        start_index_opinion, end_index_opinion = get_start_end_index(input_ids_opinion, solution_opinion)
        opinion_start_index, opinion_end_index = get_start_end_index(opinion_input_ids, solution_opinion)
        
        sample = {"input_ids": torch.tensor(input_ids), "segment_ids": torch.tensor(segment_ids), "start_index": torch.tensor(start_index), "end_index": torch.tensor(end_index),
                 "term_input_ids": torch.tensor(term_input_ids), "term_segment_ids": torch.tensor(term_segment_ids), "term_start_index": torch.tensor(term_start_index),
                  "term_end_index": torch.tensor(term_end_index),
                 "input_ids_opinion": torch.tensor(input_ids_opinion), "segment_ids_opinion": torch.tensor(segment_ids_opinion), "start_index_opinion": torch.tensor(start_index_opinion), "end_index_opinion": torch.tensor(end_index_opinion),
                 "opinion_input_ids": torch.tensor(opinion_input_ids), "opinion_segment_ids": torch.tensor(opinion_segment_ids), "opinion_start_index": torch.tensor(opinion_start_index),
                  "opinion_end_index": torch.tensor(opinion_end_index)}
        return sample

In [None]:
train_dictionary = {}
counter = 0
for i in range(0, len(train14)):
    train_dictionary[counter] = train14[i]
    counter = counter + 1

for i in range(0, len(train15)):
    train_dictionary[counter] = train15[i]
    counter = counter + 1

for i in range(0, len(train16)):
    train_dictionary[counter] = train16[i]
    counter = counter + 1

In [None]:
sentence = []
question = []
answer = []
for j in range(0, len(train_dictionary)):
    sentence.append(train_dictionary[j]['sentence'])
    question.append(train_dictionary[j]['opinion'])
    answer.append(train_dictionary[j]['term'])

In [None]:
sentence_opinion = []
question_opinion = []
answer_opinion = []
for j in range(0, len(train_dictionary)):
    sentence_opinion.append(train_dictionary[j]['sentence'])
    question_opinion.append(train_dictionary[j]['term'])
    answer_opinion.append(train_dictionary[j]['opinion'])

In [None]:
bert_version = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(bert_version)

In [None]:
class model(nn.Module):

    def __init__(self, device):

        super(model, self).__init__()

        self.bert_osae = BertModel.from_pretrained('bert-base-uncased')
        self.bert_term = BertModel.from_pretrained('bert-base-uncased')
        self.bert_asoe = BertModel.from_pretrained('bert-base-uncased')
        self.bert_opinion = BertModel.from_pretrained('bert-base-uncased')
        self.config = self.bert_osae.config
        #print(self.config)
        self.qa_outputs_osae = nn.Linear(self.config.hidden_size, 2)
        self.qa_outputs_term = nn.Linear(self.config.hidden_size, 2)
        self.qa_outputs_asoe = nn.Linear(self.config.hidden_size, 2)
        self.qa_outputs_opinion = nn.Linear(self.config.hidden_size, 2)
        #self.cca_loss_fn = cca_loss(outdim_size=self.config.hidden_size, use_all_singular_values=True, device=device)

    def forward(self, input_ids, token_type_ids=None,term_token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None, term_input_ids=None, term_attention_mask=None, term_start_positions=None, term_end_positions=None, input_ids_opinion=None, segment_ids_opinion=None, start_index_opinion=None,  end_index_opinion=None, opinion_input_ids=None, opinion_segment_ids=None, opinion_start_index=None, opinion_end_index=None):

        output_osae = self.bert_osae(input_ids, token_type_ids, attention_mask)
        term_output = self.bert_term(term_input_ids, term_token_type_ids, term_attention_mask)
        
        output_asoe = self.bert_asoe(input_ids_opinion, segment_ids_opinion, attention_mask)
        opinion_output = self.bert_opinion(opinion_input_ids, opinion_segment_ids, term_attention_mask)
        
        logits = self.qa_outputs_osae(output_osae[0])
        term_logits = self.qa_outputs_term(term_output[0])
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
        
        term_start_logits, term_end_logits = term_logits.split(1, dim=-1)
        term_start_logits = term_start_logits.squeeze(-1)
        term_end_logits = term_end_logits.squeeze(-1)
        
        
        logits_opinion = self.qa_outputs_asoe(output_asoe[0])
        opinion_logits = self.qa_outputs_opinion(opinion_output[0])
        start_logits_opinion, end_logits_opinion = logits_opinion.split(1, dim=-1)
        start_logits_opinion = start_logits_opinion.squeeze(-1)
        end_logits_opinion = end_logits_opinion.squeeze(-1)
        
        opinion_start_logits, opinion_end_logits = opinion_logits.split(1, dim=-1)
        opinion_start_logits = opinion_start_logits.squeeze(-1)
        opinion_end_logits = opinion_end_logits.squeeze(-1)

        if start_positions is not None and end_positions is not None and term_start_positions is not None and term_end_positions is not None:
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            
            if len(term_start_positions.size()) > 1:
                term_start_positions = term_start_positions.squeeze(-1)
            if len(term_end_positions.size()) > 1:
                term_end_positions = term_end_positions.squeeze(-1)
                
            if len(start_index_opinion.size()) > 1:
                start_index_opinion = start_index_opinion.squeeze(-1)
            if len(end_index_opinion.size()) > 1:
                end_index_opinion = end_index_opinion.squeeze(-1)
            
            if len(opinion_start_index.size()) > 1:
                opinion_start_index = opinion_start_index.squeeze(-1)
            if len(opinion_end_index.size()) > 1:
                opinion_end_index = opinion_end_index.squeeze(-1)
            
            loss_fct = CrossEntropyLoss()
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2
            
            term_loss_fct = CrossEntropyLoss()
            term_start_loss = term_loss_fct(term_start_logits, term_start_positions)
            term_end_loss = term_loss_fct(term_end_logits, term_end_positions)
            term_total_loss = (term_start_loss + term_end_loss) / 2
            

            start_loss_opinion = loss_fct(start_logits_opinion, start_index_opinion)
            end_loss_opinion = loss_fct(end_logits_opinion, end_index_opinion)
            total_loss_opinion = (start_loss_opinion + end_loss_opinion) / 2
            
            
            opinion_start_loss = term_loss_fct(opinion_start_logits, opinion_start_index)
            opinion_end_loss = term_loss_fct(opinion_end_logits, opinion_end_index)
            opinion_total_loss = (opinion_start_loss + opinion_end_loss) / 2
            
            return total_loss, term_total_loss, total_loss_opinion, opinion_total_loss, output_osae, term_output, output_asoe, opinion_output
        else:
            return start_logits, end_logits, term_start_logits, term_end_logits, start_logits_opinion, end_logits_opinion, opinion_start_logits, opinion_end_logits, output_osae, term_output, output_asoe, opinion_output

In [None]:
train_dataset = construct_dataset(question, sentence, answer, question_opinion, sentence_opinion, answer_opinion)

In [None]:
optimizer = optim.AdamW(params=model.parameters(), lr=1e-5)
n_epochs = 18
train_data = torch.utils.data.DataLoader(train_dataset, batch_size=1)
previous_correlation = 0
diff = 0
correlation_result = []
for epochs in range(n_epochs):
    train_loss = []
    current_loss = 0
    for i, batch in enumerate(tqdmn(train_data)):
        input_ids = batch['input_ids'].to(device)
        segment_ids = batch['segment_ids'].to(device)
        start_index = batch['start_index'].to(device)
        end_index = batch['end_index'].to(device)
        term_input_ids = batch['term_input_ids'].to(device)
        term_segment_ids = batch['term_segment_ids'].to(device)
        term_start_index = batch['term_start_index'].to(device)
        term_end_index = batch['term_end_index'].to(device)
        
        input_ids_opinion = batch['input_ids_opinion'].to(device)
        segment_ids_opinion = batch['segment_ids_opinion'].to(device)
        start_index_opinion = batch['start_index_opinion'].to(device)
        end_index_opinion = batch['end_index_opinion'].to(device)
        opinion_input_ids = batch['opinion_input_ids'].to(device)
        opinion_segment_ids = batch['opinion_segment_ids'].to(device)
        opinion_start_index = batch['opinion_start_index'].to(device)
        opinion_end_index = batch['opinion_end_index'].to(device)
        
        total_loss, term_total_loss, total_loss_opinion, opinion_total_loss, output_osae, term_output, output_asoe, opinion_output = model(input_ids=input_ids, token_type_ids=segment_ids, start_positions=start_index, end_positions=end_index, term_input_ids=term_input_ids, term_token_type_ids=term_segment_ids, term_start_positions=term_start_index, term_end_positions=term_end_index,input_ids_opinion=input_ids_opinion, segment_ids_opinion=segment_ids_opinion, start_index_opinion=start_index_opinion,  end_index_opinion=end_index_opinion, opinion_input_ids=opinion_input_ids, opinion_segment_ids=opinion_segment_ids, opinion_start_index=opinion_start_index, opinion_end_index=opinion_end_index)
        X = term_output[0].squeeze().detach().cpu()
        Y = output_osae[0].squeeze()[-len(X):].detach().cpu()
        cca = CCA(n_components=4)
        cca.fit(X, Y)
        X_c, Y_c = cca.transform(X, Y)
        result = np.corrcoef(X_c.T, Y_c.T)[0,1]
        
        X = opinion_output[0].squeeze().detach().cpu()
        Y = output_asoe[0].squeeze()[-len(X):].detach().cpu()
        cca = CCA(n_components=4)
        cca.fit(X, Y)
        X_c, Y_c = cca.transform(X, Y)
        result_2 = np.corrcoef(X_c.T, Y_c.T)[0,1]
        
        correlation = result + result_2
        
        correlation_result.append(correlation)
        loss = total_loss + term_total_loss + total_loss_opinion + opinion_total_loss
        loss.backward()
        current_loss += loss.item()
        if i % 8 == 0 and i > 0:
            optimizer.step()
            optimizer.zero_grad()
            train_loss.append(current_loss / 8)
            current_loss = 0
    optimizer.step()
    optimizer.zero_grad()
    
    torch.save(model, 'ODAO_CCA_correlation_epoch_' + str(epochs))
    
    print("Epoch: " + str(epochs) + " Correlation score: " + str(np.mean(correlation_result)))
    
    print("Epoch: " + str(epochs) + " Average loss: " + str(np.mean(train_loss)))