In [None]:
from brat_parser import get_entities_relations_attributes_groups
import random
from collections import Counter, defaultdict, namedtuple
from typing import Tuple, List, Dict, Any
import spacy
from spacy import displacy

import torch
import numpy as np
import nltk.data
import seaborn as sns

from tqdm import tqdm, trange
from transformers import AutoTokenizer, AutoModel, PreTrainedTokenizer, AutoModelForSequenceClassification
from bertviz import head_view, model_view
from transformers.tokenization_utils_base import BatchEncoding
from transformers.trainer_callback import dataclass
from sklearn.metrics import f1_score, accuracy_score
import pandas as pd
from torch.utils.data import random_split
import json
from bertviz import model_view, head_view
from matplotlib.colors import LogNorm
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings("ignore")

In [None]:
model_name = "DeepPavlov/rubert-base-cased"

device = "mps"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModel.from_pretrained(model_name, output_attentions=True)

In [None]:
def clear_text(text, items_to_replace):
    result = text
    for item in items_to_replace:
        result = result.replace(item, '')
    return result    

def format_smta_data(data_str):
    data = json.loads(data_str)
    result = []
    print(data['text'])
    text = nltk.word_tokenize(clear_sentence(data['text']))
    for item in data['index']:
        if 'SMTA' in item['type']:
            print(text[item['pos'] - 1 : item['pos'] + item['len'] - 1])
            new_item = item.copy()
            new_item.pop('wt')
            new_item.pop('sent')
            result.append(new_item)
    return result

def read_dicts_data(dictname, size):
    result = []
    for i in range(size):
        file = open(dictname + '/output_' + str(i) +'.txt', 'r')
        result.append(format_smta_data(file.read()))
    
    return result


def compute_tokens_and_attention(sentence):
    inputs = tokenizer.encode_plus(sentence, return_tensors='pt', add_special_tokens=False)
    input_ids = inputs['input_ids']
    attention = model(input_ids)[-1]
    input_id_list = input_ids[0].tolist()
    tokens = tokenizer.convert_ids_to_tokens(input_id_list)
    return attention, tokens

def map_type_to_tag(sentiment_type):
    if "NEG" in sentiment_type:
        return "1";
    else:
        return "-1";


def ids_of_dashes(data):
    i = 0
    while i < len(data):
        if data[i] == '-':
            start = i-1
            if i + 2 >= len(data):
                return [start, i+2]
            else:    
                i += 2
                while i < len(data) and data[i] == '-':
                    i += 2
                return [start, i]
        i += 1
    
    return None
            

def sublist(sl, l):
    sll=len(sl)
    for ind in (i for i,e in enumerate(l) if e==sl[0]):
        if l[ind:ind+sll]==sl:
            return [ind,ind+sll]

        
def all_sentiments_as_map(entities, relations):
    res = dict()
    for key, r in relations.items():
        if r.type == "OPINION_RELATES_TO":
            key_obj = entities[r.obj].text
            if key_obj in res:
                res[key_obj].append((map_type_to_tag(entities[r.subj].text), entities[r.subj].text))
            else:
                res[key_obj] = [(map_type_to_tag(entities[r.subj].text), entities[r.subj].text)]
        elif r.type == "POS_AUTHOR_FROM" or r.type == "NEG_AUTHOR_FROM":
            key_subj = entities[r.subj].text
            if key_subj in res:
                res[key_subj].append((map_type_to_tag(entities[r.obj].text), entities[r.obj].text))
            else:
                res[key_subj] = [(map_type_to_tag(entities[r.obj].text), entities[r.obj].text)]
    return res

def all_sentiments_as_list(entities, relations):
    res = []
    for key, r in relations.items():
        if r.type == "OPINION_RELATES_TO":
            res.append((map_type_to_tag(entities[r.subj].text), entities[r.subj].text))
        elif r.type == "POS_AUTHOR_FROM" or r.type == "NEG_AUTHOR_FROM":
            res[key_subj].append((map_type_to_tag(entities[r.obj].text), entities[r.obj].text))
    return res

def flatten_data(_data, ids):
    data = _data.copy()
    res = data
    delete_ids = []
    for par in ids:
            res[par[0]] = np.mean(data[par[0]:par[1]], axis=0)
            res.T[par[0]] = np.sum(data.T[par[0]:par[1]], axis=0)
            delete_ids = delete_ids + list(range(par[0]+1, par[1]))  
    
    res = np.delete(res, delete_ids, axis=1)
    return np.delete(res, delete_ids, axis=0)

def flatten_attention_with_ids(attention, tokens, ids, join_symbol = " "):
    res = ""
    for i in range(ids[0], ids[1]):
        if i == ids[0]:
            res += tokens[i]
        else:
            res += join_symbol + tokens[i]

    
    flatten_tokens = np.concatenate([tokens[:ids[0]], np.array([res]), tokens[ids[1]:]])    
            
    flatten_attention = []
    for i in range(len(attention)):
        one_layer_at = []
        for j in range(len(attention[i][0])):
            new_data = flatten_data(attention[i][0][j].detach().numpy(), np.array([ids]))
            one_layer_at.append(new_data)

        flatten_attention.append(torch.from_numpy(np.array([one_layer_at])))
    flatten_attention = tuple(flatten_attention)
        
    return flatten_attention, flatten_tokens;

def concat_attention(att, tokens):
    attention = tuple(list(att))
    conc_tokens = []
    ids = []
    i = 0
    n = len(tokens)-1
    
    while i <= n:
        if i < n and len(tokens[i+1]) > 2 and tokens[i+1][:2] == "##":
            k = 0
            res = tokens[i]
            while i+k+1 <= n and tokens[i+k+1][:2] == "##":
                res = res + tokens[i+1+k][2:]
                k = k+1

            ids.append((i, i+k+1))
            conc_tokens.append(res)
            i = i+k+1
            
        else:
            if len(tokens[i]) < 3 or tokens[i][:2] != "##":
                conc_tokens.append(tokens[i])
            else:
                continue
            i = i+1
    
    idx_1 = np.where(np.array(conc_tokens) == ",")[0]
    idx_2 = np.where(np.array(conc_tokens) == ".")[0]
    idx_3 = np.where(np.array(conc_tokens) == "«")[0]
    idx_4 = np.where(np.array(conc_tokens) == "»")[0]
    idx_5 = np.where(np.array(conc_tokens) == "—")[0]
    idx_6 = np.where(np.array(conc_tokens) == ")")[0]
    idx_7 = np.where(np.array(conc_tokens) == "(")[0]
    idx_8 = np.where(np.array(conc_tokens) == '"')[0]
    idx_9 = np.where(np.array(conc_tokens) == "'")[0]
    idx_10 = np.where(np.array(conc_tokens) == ":")[0]
    
    idxs = np.concatenate([idx_1, idx_2, idx_3, idx_4, idx_5, idx_6, idx_7, idx_8, idx_9, idx_10])
        
    conc_attention = []
    for i in range(len(attention)):
        one_layer_at = []
        for j in range(len(attention[i][0])):
            new_data = flatten_data(attention[i][0][j].detach().numpy(), np.array(ids))
            res_at = np.delete(new_data, idxs, axis=1)
            one_layer_at.append(np.delete(res_at, idxs, axis=0))

        conc_attention.append(torch.from_numpy(np.array([one_layer_at])))
    conc_attention = tuple(conc_attention)   
    
    
    return conc_attention, np.delete(conc_tokens, idxs)

def compute_metrics(y_true, y_pred):
    
    f1_macro = f1_score(
      y_true=y_true,
      y_pred=y_pred,
      average="macro",
      zero_division=0,
    )
    
    accuracy = accuracy_score(
      y_true=y_true,
      y_pred=y_pred,
    )
    
    return f1_macro, accuracy

def mean_by_head(attention):
    res = []
    for item in attention:
        a = item.clone()
        one_layer_at = []
        mean_at = np.mean(a[0].detach().numpy(), axis=0)
        res.append(mean_at)
    
    return np.array(res)            

def mean_by_layer(attention):
    res = []
    for item in attention:
        a = item.clone()
        res.append(a[0].detach().numpy())
    
    return np.mean(np.array(res), axis=0)

def mean_by_all(attention, st=0, en=12):
    res = []
    for item in attention[st:en]:
        a = item.clone()
        res.append(a[0].detach().numpy())
    
    return np.mean(np.mean(np.array(res), axis=0), axis=0)

def map_sign_for_tag(sign):
    if sign == '+':
        return 1
    else: 
        return -1

def result_by_multiple_attentions(mean_attentions, tokens, entity_ids, ids, tags): 
    if len(ids) == 0:
        print(0)
    else:    
        for att in mean_attentions:
            max_id = np.argmax(att[entity_id][ids])
            max_attention_item = tokens[ids[max_id]]
            print(att[entity_id][ids], max_attention_item, tags[max_id])
        
def result_by_total_attention(mean_attention, tokens, _entity_ids, _ids, tags, log = False, _aspect_ids=[]): 
    result = 0
    ids = _ids
    entity_ids = _aspect_ids + _entity_ids
    if len(ids) != 0:
        if len(entity_ids) + len(aspect_ids) == 1:
            result = 0
            max_id = np.argmax(mean_attention[entity_ids[0]][ids])
            max_attention_item = tokens[ids[max_id]]
            result = map_sign_for_tag(tags[max_id])

        else:
            while len(ids) != 0:
                if log:
                    print(entity_ids)
                    print(ids)
                    print(tags)
                max_id = np.argmax(mean_attention[entity_ids[-1]][ids])
                res_id = ids[max_id]
                if log:
                    print(mean_attention[entity_ids[-1]][ids])
                if log:
                    print(tokens[entity_ids[-1]], mean_attention[entity_ids[-1]][res_id], tokens[res_id])

                entity_atts = mean_attention[entity_ids]
                entity_results = entity_atts[np.arange(len(entity_atts)), res_id]
                _id = np.argmax(entity_results)

                if log:
                    print(tokens[entity_ids[_id]], mean_attention[entity_ids[_id]][res_id], tokens[res_id], '\n')

                if np.max(entity_results) == entity_results[-1]:
                    result = map_sign_for_tag(tags[max_id])
                    break
                else:
                    if entity_ids[_id] in _aspect_ids:
                        from_ent_to_aspect_id = mean_attention[entity_ids[-1]][entity_ids[_id]]
                        from_ent_to_tone_id = mean_attention[entity_ids[-1]][res_id]
                        
                        if log:
                            print(tokens[entity_ids[-1]], from_ent_to_aspect_id, tokens[entity_ids[_id]])
                            print(tokens[entity_ids[-1]], from_ent_to_tone_id, tokens[res_id], '\n')
                        
                        if from_ent_to_aspect_id > from_ent_to_tone_id:
                            result = map_sign_for_tag(tags[max_id])
                            break
                    
                    entity_ids.pop(_id)
                    ids.pop(max_id)
                    tags.pop(max_id)
    
    return result
    

def ids_of_ents(tokens, ent):
    ent_parts = nltk.word_tokenize(ent)
    res = []
    for i in range(len(tokens)):
        if tokens[i] == ent:
            return [i, i+1]
        elif tokens[i] in ent_parts:
            res.append(i)
            k = i + 1
            while k < len(tokens) and tokens[k] in ent_parts:
                k += 1
            res.append(k)
            break
    return res

    
def aspect_is_valid(aspect, entities, sentiments):
    for entity in entities:
        if entity in aspect or aspect in entity:
            return False
        
    for sent in sentiments:
        if sent in aspect or aspect in sent:
            return False
        
    return True
    
    
def parse_smta_data(start = 0, n = 100, multiple_ne = False, nlp = None, ids_to_use=None, log=True):
    
    result = []
    true_labels = []
    checked_ids = []
    labeled_data, labels = read_data()
    
    for i in range(start, n):
        try:
            if ids_to_use is None or i in ids_to_use:
                checked_ids.append(i)
                data_item, label = labeled_data[i], labels[i] 
                true_labels.append(label)
                file = open('train_data/output_' + str(i) + '.txt', 'r')
                data_str = file.read()
                data = json.loads(data_str)
                _attention, _tokens = compute_tokens_and_attention(data['text'])
                attention, tokens = concat_attention(_attention, _tokens)
                ids_of_d = ids_of_dashes(tokens.tolist())

                while not (ids_of_d is None):
                    attention, tokens = flatten_attention_with_ids(attention, tokens, ids_of_d, join_symbol='')
                    ids_of_d = ids_of_dashes(tokens.tolist())
                sentiments = []
                smta_items = []
       
                for item in data['index']:
                    if item['type'] == 'SMTA':
                        if data_item[1] in tokens[item['pos'] - 1 : item['pos'] + item['len'] - 1]:
                            continue
                        else:    
                            smta_items.append(item)
                            if log:
                                print('SMTA', tokens[item['pos'] - 1 : item['pos'] + item['len'] - 1], item['name'])
                            sentiments.append(tokens[item['pos'] - 1 : item['pos'] + item['len'] - 1])
                    elif item['type'] == 'SMTAW':
                        if log:
                            print('SMTAW', tokens[item['pos'] - 1 : item['pos'] + item['len'] - 1])


                entity_parts = nltk.word_tokenize(data_item[1])
                entity_ids = sublist(entity_parts, tokens.tolist())
                attention, tokens = flatten_attention_with_ids(attention, tokens, entity_ids)

                for item in sentiments:
                    ids = sublist(item.tolist(), tokens.tolist())
                    if not (ids is None):
                        attention, tokens = flatten_attention_with_ids(attention, tokens, ids)                
                        
                result_entities = []
                if multiple_ne:
                    found_entities = []
                    for poses in entity_extraction([data['text']])[2][0]:
                        found_entities.append(data['text'][poses[0]:poses[1]])
                    for ent in found_entities:
                        if ent != data_item[1]:
                            ent_ids = ids_of_ents(tokens, ent)
                            if len(ent_ids) > 0:
                                if ent_ids[0] + 1 < ent_ids[1]:
                                    attention, tokens = flatten_attention_with_ids(attention, tokens, ent_ids)
                                
                                result_entities.append(tokens[ent_ids[0]])    
                                    
                if log:
                    print('\n', tokens, '\n')                
                
                new_ids = []
                new_ids_no_none = []
                for item in sentiments:
                    text = ' '.join(item.tolist())
                    new_id = sublist([text], tokens.tolist())
                    if not (new_id is None):
                        new_ids.append(new_id[0])
                        new_ids_no_none.append(new_id[0])
                    else:
                        new_ids.append(None)
    
                result_entity_ids = []
                for ent in result_entities:
                    ne_id = sublist([ent], tokens.tolist())
                    if not (ne_id is None):
                        result_entity_ids.append(ne_id[0]) 
                
                result_entity_ids.append(sublist([data_item[1]], tokens.tolist())[0])

                aspect_ids = []
                if multiple_ne:
                    doc = nlp(data['text'])
                    for token in doc:
                        if (token.pos_ == 'NOUN' or token.pos_ == 'PROPN') and aspect_is_valid(
                            aspect = token.text, 
                            entities = tokens[result_entity_ids], 
                            sentiments = tokens[new_ids_no_none],
                        ):
                            ne_id = sublist([token.text], tokens.tolist())
                            if not (ne_id is None):
                                aspect_ids.append(ne_id[0]) 
                
                if multiple_ne and log:
                    print('entities:', tokens[result_entity_ids], '\n')
                    print('aspects:', tokens[aspect_ids], '\n')
                    print('sentiments:', tokens[new_ids_no_none], '\n')


                ids = []
                tags = []
                for i in range(len(new_ids)):
                    if not (new_ids[i]  is None):
                        ids.append(new_ids[i])
                        tags.append(smta_items[i]['name'])

                result.append((attention, tokens, result_entity_ids, ids, tags, aspect_ids))
        except:
            true_labels.pop()
            checked_ids.pop()
            continue

    return result, true_labels, checked_ids

def count_score(attention, tokens, entity_ids, ids, tags, aspect_ids, log):
    mean_attention = mean_by_all(attention)
    return result_by_total_attention(mean_attention, tokens, entity_ids, ids, tags, log, aspect_ids)

def show_attention(
    attention, 
    from_tokens, 
    to_tokens, 
    width=12, 
    height=0.5, 
    labelsize=8, 
    bottom_rot=0, 
    annot=True,
    fontsize=8,
    xlabel=None,
    ylabel=None
):
    plt.figure(figsize=(width, height))
    ax = sns.heatmap(
        attention, 
        cmap="rocket_r", 
        xticklabels=to_tokens, 
        yticklabels=from_tokens, 
        annot=annot, 
        annot_kws={"fontsize":fontsize},
        vmax=0.8
    )
    ax.tick_params(labelsize=labelsize)
    
    if not (xlabel is None):
        ax.set_xlabel(xlabel, fontsize=18, labelpad=20)
    if not (ylabel is None):
        ax.set_ylabel(ylabel, fontsize=25)
        
    plt.yticks(rotation=0)
    plt.xticks(rotation=bottom_rot)
    plt.show()