In [None]:
!pip install allennlp

In [None]:
!pip install torch
!pip install numpy
!pip install thop
!pip install transformers
!pip install tqdm
!pip install natasha

In [4]:
%%writefile Attention.py

import torch
import math
import torch.nn as nn
import numpy as np
from transformers.activations import ACT2FN

class Dim_Four_Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.output = SelfOutput(config)
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
            )

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 1, 3, 2, 4)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        ):
        mixed_query_layer = self.query(hidden_states)

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
        is_cross_attention = encoder_hidden_states is not None

        if is_cross_attention:
            if encoder_attention_mask is not None:
                attention_mask = encoder_attention_mask
            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
        else:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))

        query_layer = self.transpose_for_scores(mixed_query_layer)
        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask
        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 1, 3, 2, 4).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        self_outputs = (context_layer,)
        attention_output = self.output(self_outputs[0], hidden_states)
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        return outputs


class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self = SelfAttention(config)
        self.output = SelfOutput(config)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
    ):
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            encoder_hidden_states,
            encoder_attention_mask)
        attention_output = self.output(self_outputs[0], hidden_states)
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        return outputs

class SelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
            )

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
    ):
        mixed_query_layer = self.query(hidden_states)

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
        is_cross_attention = encoder_hidden_states is not None

        if is_cross_attention:
            if encoder_attention_mask is not None:
                attention_mask = encoder_attention_mask
            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
        else:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))

        query_layer = self.transpose_for_scores(mixed_query_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask
        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        outputs = (context_layer,)
        return outputs

class SelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

class Output(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

class Intermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states

def masked_softmax(tensor, mask_1, mask_2):
    """
    Apply a masked softmax on the last dimension of a tensor.
    The input tensor and mask should be of size (batch, *, sequence_length).

    Args:
        tensor: The tensor on which the softmax function must be applied along
            the last dimension.
        mask: A mask of the same size as the tensor with 0s in the positions of
            the values that must be masked and 1s everywhere else.

    Returns:
        A tensor of the same size as the inputs containing the result of the
        softmax.
    """
    tensor_shape = tensor.size()
    reshaped_tensor = tensor.view(-1, tensor_shape[-1])

    # Reshape the mask so it matches the size of the input tensor.
    while mask_2.dim() < tensor.dim():
        mask_2 = mask_2.unsqueeze(1)
    mask_2 = mask_2.expand_as(tensor).contiguous().float()
    reshaped_mask = mask_2.view(-1, mask_2.size()[-1])

    result = nn.functional.softmax(reshaped_tensor * reshaped_mask, dim=-1)
    result = result * reshaped_mask
    # 1e-13 is added to avoid divisions by zero.
    result = result / (result.sum(dim=-1, keepdim=True) + 1e-13)
    result = result.view(*tensor_shape)

    while mask_1.dim() < result.dim():
        mask_1 = mask_1.unsqueeze(2)
    mask_1 = mask_1.expand_as(result).contiguous().float()

    result_2 = nn.functional.softmax(result * mask_1, dim=-2)
    result_2 = result_2 * mask_1
    result_2 = result_2 / (result_2.sum(dim=-2, keepdim=True) + 1e-13)
    return result_2


Writing Attention.py


In [5]:
%%writefile data_BIO_loader.py

import torch
import numpy as np
import random
import json
from transformers import BertTokenizer
import tqdm
from tqdm import tqdm
from gcn import make_adj_matrix
import spacy
import string

validity2id = {'none': 0, 'positive': 1, 'negative': 1, 'neutral': 1, 'start': 1}
sentiment2id = {'none': 0, 'positive': 1, 'negative': 2, 'neutral': 3, 'start': 4}


# nlp = spacy.load("ru_core_news_md")



def get_spans(tags):
    '''for BIO tag'''
    tags = tags.strip().split()
    length = len(tags)
    spans = []
    start = -1
    for i in range(length):
        if tags[i].endswith('B'):
            if start != -1:
                spans.append([start, i - 1])
            start = i
        elif tags[i].endswith('O'):
            if start != -1:
                spans.append([start, i - 1])
                start = -1
    if start != -1:
        spans.append([start, length - 1])
    return spans


def get_subject_labels(tags):
    '''for BIO tag'''

    label = {}
    subject_span = get_spans(tags)[0]
    tags = tags.strip().split()
    sentence = []
    for tag in tags:
        sentence.append(tag.strip().split('\\')[0])
    word = ' '.join(sentence[subject_span[0]:subject_span[1] + 1])
    label[word] = subject_span
    return label


def get_object_labels(tags):
    '''for BIO tag'''
    label = {}
    object_spans = get_spans(tags)
    tags = tags.strip().split()
    sentence = []
    for tag in tags:
        sentence.append(tag.strip().split('\\')[0])
    for object_span in object_spans:
        word = ' '.join(sentence[object_span[0]:object_span[1] + 1])
        label[word] = object_span
    return label


class InputExample(object):
    def __init__(self, id, text_a, aspect_num, triple_num, all_label=None, text_b=None):
        """Build a InputExample"""
        self.id = id
        self.text_a = text_a
        self.text_b = text_b
        self.all_label = all_label
        self.aspect_num = aspect_num
        self.triple_num = triple_num


class Instance(object):
    def __init__(self, sentence_pack, args):
        triple_dict = {}
        id = sentence_pack['id']
        aspect_num = 0
        for triple in sentence_pack['triples']:
            aspect = triple['target_tags']
            opinion = triple['opinion_tags']
            sentiment = triple['sentiment']
            subject_label = get_subject_labels(aspect)
            object_label = get_object_labels(opinion)
            objects = list(object_label.keys())
            subject = list(subject_label.keys())[0]
            aspect_num += len(subject_label)
            for i, object in enumerate(objects):
                # 由于数据集的每个triples中aspect只有一个，而opinion可能有多个  需要分开构建
                word = str(subject) + '|' + str(object)
                if word not in triple_dict:
                    triple_dict[word] = []
                triple_dict[word] = (subject_label[subject], object_label[object], sentiment)
        examples = InputExample(id=id, text_a=sentence_pack['sentence'], text_b=None, all_label=triple_dict,
                                aspect_num=aspect_num, triple_num=len(triple_dict))
        self.examples = examples
        self.triple_num = len(triple_dict)
        self.aspect_num = aspect_num


def load_data_instances(sentence_packs, args):
    instances = list()
    triples_num = 0
    aspects_num = 0
    for i, sentence_pack in enumerate(sentence_packs):
        instance = Instance(sentence_pack, args)
        instances.append(instance.examples)
        triples_num += instance.triple_num
        aspects_num += instance.aspect_num
    return instances


def convert_examples_to_features(args, train_instances, max_span_length=8):

    features = []
    num_aspect = 0
    num_triple = 0
    num_opinion = 0
    differ_opinion_senitment_num = 0
    differ_aspect_sentiment_num = 0
    for ex_index, example in enumerate(train_instances):
        sample = {'id': example.id}
        sample['tokens'] = example.text_a.split(' ')
        sample['text_length'] = len(sample['tokens'])
        sample['triples'] = example.all_label
        sample['sentence'] = example.text_a
        
        aspect = {}
        opinion = {}

        opinion_reverse = {}
        aspect_reverse  = {}

        differ_opinion_sentiment = False
        differ_aspect_sentiment = False

        for triple_name in sample['triples']:
            aspect_span, opinion_span, sentiment = tuple(sample['triples'][triple_name][0]), tuple(
                sample['triples'][triple_name][1]), sample['triples'][triple_name][2]
            num_triple += 1
            if aspect_span not in aspect:
                aspect[aspect_span] = sentiment
                opinion[aspect_span] = [(opinion_span, sentiment)]
            else:
                if aspect[aspect_span] != sentiment:
                    differ_aspect_sentiment = True
                else:
                    opinion[aspect_span].append((opinion_span, sentiment))

            if opinion_span not in opinion_reverse:
                opinion_reverse[opinion_span] = sentiment
                aspect_reverse[opinion_span] = [(aspect_span, sentiment)]
            else:
                '''同一aspect的不同的opinion拥有相同极性，但是'''
                if opinion_reverse[opinion_span] != sentiment:
                    differ_opinion_sentiment = True
                else:
                    aspect_reverse[opinion_span].append((aspect_span, sentiment))
        
        if differ_opinion_sentiment:
            differ_opinion_senitment_num += 1
            print(ex_index, 'Single opinion word multi-polarity')
            continue

        if differ_aspect_sentiment:
            differ_aspect_sentiment_num += 1
            print(ex_index, 'Single aspect word multi-polarity')
            continue

        num_aspect += len(aspect)
        num_opinion += len(opinion)

        # if len(aspect) != example.aspect_num:
        #     print('有不同三元组使用重复了aspect:', example.id)

        spans = []
        span_tokens = []

        spans_aspect_label = []
        spans_aspect2opinion_label =[]
        spans_opinion_label = []

        reverse_opinion_label = []
        reverse_opinion2aspect_label = []
        reverse_aspect_label = []
        
        punk = string.punctuation

        if args.order_input:
            for i in range(max_span_length):
                if sample['text_length'] < i:
                    continue
                for j in range(sample['text_length'] - i):
                    if sample['tokens'][j] in punk:
                        continue
                    spans.append((j, i + j, i + 1))
                    span_token = ' '.join(sample['tokens'][j:i + j + 1])
                    span_tokens.append(span_token)
                    if (j, i + j) not in aspect:
                        spans_aspect_label.append(0)
                    else:
                        # spans_aspect_label.append(sentiment2id[aspect[(j, i + j)]])
                        spans_aspect_label.append(validity2id[aspect[(j, i + j)]])
                    if (j, i + j) not in opinion_reverse:
                        reverse_opinion_label.append(0)
                    else:
                        # reverse_opinion_label.append(sentiment2id[opinion_reverse[(j, i + j)]])
                        reverse_opinion_label.append(validity2id[opinion_reverse[(j, i + j)]])

        else:
            for i in range(sample['text_length']):
                for j in range(i, min(sample['text_length'], i + max_span_length)):
                    spans.append((i, j, j - i + 1))
                    span_token = ' '.join(sample['tokens'][i:j + 1])
                    span_tokens.append(span_token)
                    if (i, j) not in aspect:
                        spans_aspect_label.append(0)
                    else:
                        spans_aspect_label.append(validity2id[aspect[(i, j)]])
                    if (i, j) not in opinion_reverse:
                        reverse_opinion_label.append(0)
                    else:
                        reverse_opinion_label.append(validity2id[opinion_reverse[(i, j)]])


        assert len(span_tokens) == len(spans)
        for key_aspect in opinion:
            opinion_list = []
            sentiment_opinion = []
            spans_aspect2opinion_label.append(key_aspect)
            for opinion_span_2_aspect in opinion[key_aspect]:
                opinion_list.append(opinion_span_2_aspect[0])
                sentiment_opinion.append(opinion_span_2_aspect[1])
            assert len(set(sentiment_opinion)) == 1
            opinion_label2triple = []
            for i in spans:
                if (i[0], i[1]) not in opinion_list:
                    opinion_label2triple.append(0)
                else:
                    opinion_label2triple.append(sentiment2id[sentiment_opinion[0]])
            spans_opinion_label.append(opinion_label2triple)

        for opinion_key in aspect_reverse:
            aspect_list = []
            sentiment_aspect = []
            reverse_opinion2aspect_label.append(opinion_key)
            for aspect_span_2_opinion in aspect_reverse[opinion_key]:
                aspect_list.append(aspect_span_2_opinion[0])
                sentiment_aspect.append(aspect_span_2_opinion[1])
            assert len(set(sentiment_aspect)) == 1
            aspect_label2triple = []
            for i in spans:
                if (i[0], i[1]) not in aspect_list:
                    aspect_label2triple.append(0)
                else:
                    aspect_label2triple.append(sentiment2id[sentiment_aspect[0]])
            reverse_aspect_label.append(aspect_label2triple)

        sample['aspect_num'] = len(spans_opinion_label)
        sample['spans_aspect2opinion_label'] = spans_aspect2opinion_label
        sample['reverse_opinion_num'] = len(reverse_aspect_label)
        sample['reverse_opinion2aspect_label'] = reverse_opinion2aspect_label

        if args.random_shuffle != 0:
            np.random.seed(args.random_shuffle)
            shuffle_ix = np.random.permutation(np.arange(len(spans)))
            spans_np = np.array(spans)[shuffle_ix]
            span_tokens_np = np.array(span_tokens)[shuffle_ix]
            '''双向同顺序打乱'''
            spans_aspect_label_np = np.array(spans_aspect_label)[shuffle_ix]
            reverse_opinion_label_np = np.array(reverse_opinion_label)[shuffle_ix]
            spans_opinion_label_shuffle = []
            for spans_opinion_label_split in spans_opinion_label:
                spans_opinion_label_split_np = np.array(spans_opinion_label_split)[shuffle_ix]
                spans_opinion_label_shuffle.append(spans_opinion_label_split_np.tolist())
            spans_opinion_label = spans_opinion_label_shuffle
            reverse_aspect_label_shuffle = []
            for reverse_aspect_label_split in reverse_aspect_label:
                reverse_aspect_label_split_np = np.array(reverse_aspect_label_split)[shuffle_ix]
                reverse_aspect_label_shuffle.append(reverse_aspect_label_split_np.tolist())
            reverse_aspect_label = reverse_aspect_label_shuffle
            spans, span_tokens, spans_aspect_label, reverse_opinion_label  = spans_np.tolist(), span_tokens_np.tolist(),\
                                                                             spans_aspect_label_np.tolist(), reverse_opinion_label_np.tolist()
        related_spans = np.zeros((len(spans), len(spans)), dtype=int)
        for i in range(len(span_tokens)):
            span_token = span_tokens[i].split(' ')
            # for j in range(i, len(span_tokens)):
            for j in range(len(span_tokens)):
                differ_span_token = span_tokens[j].split(' ')
                if set(span_token) & set(differ_span_token) == set():
                    related_spans[i, j] = 0
                else:
                    related_spans[i, j] = 1

        sample['related_span_array'] = related_spans
        sample['spans'] = spans
        sample['span tokens'] = span_tokens
        sample['spans_aspect_label'] = spans_aspect_label
        sample['spans_opinion_label'] = spans_opinion_label
        sample['reverse_opinion_label'] = reverse_opinion_label
        sample['reverse_aspect_label'] = reverse_aspect_label
        features.append(sample)
    return features, num_aspect, num_opinion


class MyDataset:
    def __init__(self, args, path, if_train=False):
        self.args = args
        with open(path, 'r', encoding='utf-8') as f:
            lines = f.readlines()
            if if_train:
                random.seed(args.RANDOM_SEED)
                random.shuffle(lines)
            self.instances = load_data_instances_txt(lines)
            self.data_instances, _, _ = convert_examples_to_features(
                self.args, 
                train_instances=self.instances,
                max_span_length=self.args.max_span_length)

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx):
        return self.data_instances[idx]



def load_data_instances_txt(lines):
    sentiment2sentiment = {'NEG': 'negative', 'POS': 'positive', 'NEU': 'neutral', 'STR': 'start'}

    instances = list()
    triples_num = 0
    aspects_num = 0
    for ex_index, line in enumerate(lines):
        id = str(ex_index)  # id
        line = line.strip()
        line = line.split('####')
        sentence = line[0].split()  # sentence
        raw_pairs = eval(line[1])  # triplets

        triple_dict = {}
        aspect_num = 0
        for triple in raw_pairs:
            raw_aspect = triple[0]
            raw_opinion = triple[1]
            sentiment = sentiment2sentiment[triple[2]]

            if len(raw_aspect) == 1:
                aspect_word = sentence[raw_aspect[0]]
                raw_aspect = [raw_aspect[0], raw_aspect[0]]
            else:
                aspect_word = ' '.join(sentence[raw_aspect[0]: raw_aspect[-1] + 1])
            aspect_label = {}
            aspect_label[aspect_word] = [raw_aspect[0], raw_aspect[-1]]
            aspect_num += len(aspect_label)

            if len(raw_opinion) == 1:
                opinion_word = sentence[raw_opinion[0]]
                raw_opinion = [raw_opinion[0], raw_opinion[0]]
            else:
                opinion_word = ' '.join(sentence[raw_opinion[0]: raw_opinion[-1] + 1])
            opinion_label = {}
            opinion_label[opinion_word] = [raw_opinion[0], raw_opinion[-1]]

            word = str(aspect_word) + '|' + str(opinion_word)
            if word not in triple_dict:
                triple_dict[word] = []
                triple_dict[word] = ([raw_aspect[0], raw_aspect[-1]], [raw_opinion[0], raw_opinion[-1]], sentiment)
            else:
                print('Single sentence ' + id + ' The middle triplet reappears!')
        examples = InputExample(id=id, text_a=line[0], text_b=None, all_label=triple_dict, aspect_num=aspect_num,
                                triple_num=len(triple_dict))

        instances.append(examples)
        triples_num += triples_num
        aspects_num += aspect_num

    return instances


class DataTterator(object):
    def __init__(self, dataset, args):
        self.dataset = dataset
        # self.dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size)
        self.args = args
        self.batch_size = args.train_batch_size
        self.batch_count = (len(dataset) - 1) // self.batch_size + 1
        self.tokenizer = BertTokenizer.from_pretrained(args.init_vocab, do_lower_case=args.do_lower_case)

    def get_instances(self, batch_num):
        bb = batch_num * self.batch_size
        instances = [self.dataset.__getitem__(i) for i in range(bb, min(bb + self.batch_size, len(self.dataset)))]
        return instances

    def get_batch(self, batch_num):
        tokens_tensor_list = []
        bert_spans_tensor_list = []
        spans_ner_label_tensor_list = []
        spans_aspect_tensor_list = []
        spans_opinion_label_tensor_list = []

        reverse_ner_label_tensor_list = []
        reverse_opinion_tensor_list = []
        reverse_aspect_tensor_list = []
        sentence_length = []
        related_spans_list = []
        adj_batch = []

        self.instances = self.get_instances(batch_num)
        
        max_tokens = self.args.max_seq_length
        max_spans = 0
        for i, sample in enumerate(self.instances):
            tokens = sample['tokens']
            spans = sample['spans']
            span_tokens = sample['span tokens']
            spans_ner_label = sample['spans_aspect_label']
            spans_aspect2opinion_labels = sample['spans_aspect2opinion_label']
            spans_opinion_label = sample['spans_opinion_label']

            reverse_ner_label = sample['reverse_opinion_label']
            reverse_opinion2aspect_labels = sample['reverse_opinion2aspect_label']
            reverse_aspect_label = sample['reverse_aspect_label']
            
            related_spans = sample['related_span_array']
            spans_aspect_labels, reverse_opinion_labels = [], []
            for spans_aspect2opinion_label in spans_aspect2opinion_labels:
                spans_aspect_labels.append((i, spans_aspect2opinion_label[0], spans_aspect2opinion_label[1]))
            for reverse_opinion2aspect_label in reverse_opinion2aspect_labels:
                reverse_opinion_labels.append((i, reverse_opinion2aspect_label[0], reverse_opinion2aspect_label[1]))
            bert_tokens, tokens_tensor, bert_spans_tensor, spans_ner_label_tensor, spans_aspect_labels_tensor, spans_opinion_tensor, \
            reverse_ner_label_tensor, reverse_opinion_tensor, reverse_aspect_tensor = \
                self.get_input_tensors(self.tokenizer, tokens, spans, spans_ner_label, spans_aspect_labels,
                                         spans_opinion_label, reverse_ner_label, reverse_opinion_labels, reverse_aspect_label)
            tokens_tensor_list.append(tokens_tensor)
            bert_spans_tensor_list.append(bert_spans_tensor)
            spans_ner_label_tensor_list.append(spans_ner_label_tensor)
            spans_aspect_tensor_list.append(spans_aspect_labels_tensor)
            spans_opinion_label_tensor_list.append(spans_opinion_tensor)
            reverse_ner_label_tensor_list.append(reverse_ner_label_tensor)
            reverse_opinion_tensor_list.append(reverse_opinion_tensor)
            reverse_aspect_tensor_list.append(reverse_aspect_tensor)
            assert bert_spans_tensor.shape[1] == spans_ner_label_tensor.shape[1] == reverse_ner_label_tensor.shape[1]
            # tokens和spans的最大个数被设定为固定值
            if (tokens_tensor.shape[1] > max_tokens):
                max_tokens = tokens_tensor.shape[1]
            if (bert_spans_tensor.shape[1] > max_spans):
                max_spans = bert_spans_tensor.shape[1]
            sentence_length.append((bert_tokens, tokens_tensor.shape[1], bert_spans_tensor.shape[1]))
            related_spans_list.append(related_spans)
        
        '''由于不同句子方阵不一样大，所以先不转为tensor'''
        #related_spans_tensor = torch.tensor(related_spans_list)
        # apply padding and concatenate tensors
        final_tokens_tensor = []
        final_attention_mask = []
        final_spans_mask_tensor = []
        final_bert_spans_tensor = []
        final_spans_ner_label_tensor = []
        final_spans_aspect_tensor = []
        final_spans_opinion_label_tensor = []

        final_reverse_ner_label_tensor = []
        final_reverse_opinion_tensor = []
        final_reverse_aspect_label_tensor = []
        final_related_spans_tensor = []
        for tokens_tensor, bert_spans_tensor, spans_ner_label_tensor, spans_aspect_tensor, spans_opinion_label_tensor, \
            reverse_ner_label_tensor, reverse_opinion_tensor, reverse_aspect_tensor, related_spans \
                in zip(tokens_tensor_list, bert_spans_tensor_list, spans_ner_label_tensor_list, spans_aspect_tensor_list,
                       spans_opinion_label_tensor_list, reverse_ner_label_tensor_list, reverse_opinion_tensor_list,
                       reverse_aspect_tensor_list, related_spans_list):
            # padding for tokens
            num_tokens = tokens_tensor.shape[1]
            tokens_pad_length = max_tokens - num_tokens
            attention_tensor = torch.full([1, num_tokens], 1, dtype=torch.long)
            if tokens_pad_length > 0:
                pad = torch.full([1, tokens_pad_length], self.tokenizer.pad_token_id, dtype=torch.long)
                tokens_tensor = torch.cat((tokens_tensor, pad), dim=1)
                attention_pad = torch.full([1, tokens_pad_length], 0, dtype=torch.long)
                attention_tensor = torch.cat((attention_tensor, attention_pad), dim=1)

            # padding for spans
            num_spans = bert_spans_tensor.shape[1]
            num_aspect = spans_aspect_tensor.shape[1]
            num_opinion = reverse_opinion_tensor.shape[1]
            spans_pad_length = max_spans - num_spans
            spans_mask_tensor = torch.full([1, num_spans], 1, dtype=torch.long)
            if spans_pad_length > 0:
                pad = torch.full([1, spans_pad_length, bert_spans_tensor.shape[2]], 0, dtype=torch.long)
                bert_spans_tensor = torch.cat((bert_spans_tensor, pad), dim=1)

                mask_pad = torch.full([1, spans_pad_length], 0, dtype=torch.long)
                spans_mask_tensor = torch.cat((spans_mask_tensor, mask_pad), dim=1)
                spans_ner_label_tensor = torch.cat((spans_ner_label_tensor, mask_pad), dim=1)
                reverse_ner_label_tensor = torch.cat((reverse_ner_label_tensor, mask_pad), dim=1)

                opinion_mask_pad = torch.full([1, num_aspect, spans_pad_length], 0, dtype=torch.long)
                spans_opinion_label_tensor = torch.cat((spans_opinion_label_tensor, opinion_mask_pad), dim=-1)
                aspect_mask_pad = torch.full([1, num_opinion, spans_pad_length], 0, dtype=torch.long)
                reverse_aspect_tensor = torch.cat((reverse_aspect_tensor, aspect_mask_pad), dim=-1)
                '''对span类似方阵mask'''
                related_spans = np.pad(related_spans, [(0, spans_pad_length), (0, spans_pad_length)])
            related_spans_tensor = torch.as_tensor(torch.from_numpy(related_spans), dtype=torch.bool)
            # update final outputs
            
            final_bert_spans_tensor.append(bert_spans_tensor)
            final_tokens_tensor.append(tokens_tensor)
            final_attention_mask.append(attention_tensor)
            final_spans_mask_tensor.append(spans_mask_tensor)
            final_spans_ner_label_tensor.append(spans_ner_label_tensor)
            final_spans_aspect_tensor.append(spans_aspect_tensor.squeeze(0))
            final_spans_opinion_label_tensor.append(spans_opinion_label_tensor.squeeze(0))
            final_reverse_ner_label_tensor.append(reverse_ner_label_tensor)
            final_reverse_opinion_tensor.append(reverse_opinion_tensor.squeeze(0))
            final_reverse_aspect_label_tensor.append(reverse_aspect_tensor.squeeze(0))
            final_related_spans_tensor.append(related_spans_tensor.unsqueeze(0))
            

        # 注意，特征中最大span间隔不一定为设置的max_span_length，这是因为bert分词之后造成的span扩大了。
        final_tokens_tensor = torch.cat(final_tokens_tensor, dim=0).to(self.args.device)
        final_attention_mask = torch.cat(final_attention_mask, dim=0).to(self.args.device)
        final_bert_spans_tensor = torch.cat(final_bert_spans_tensor, dim=0).to(self.args.device)
        final_spans_mask_tensor = torch.cat(final_spans_mask_tensor, dim=0).to(self.args.device)
        final_spans_ner_label_tensor = torch.cat(final_spans_ner_label_tensor, dim=0).to(self.args.device)
        final_spans_aspect_tensor = torch.cat(final_spans_aspect_tensor, dim=0).to(self.args.device)
        final_spans_opinion_label_tensor = torch.cat(final_spans_opinion_label_tensor, dim=0).to(self.args.device)
        final_reverse_ner_label_tensor = torch.cat(final_reverse_ner_label_tensor, dim=0).to(self.args.device)
        final_reverse_opinion_tensor = torch.cat(final_reverse_opinion_tensor, dim=0).to(self.args.device)
        final_reverse_aspect_label_tensor = torch.cat(final_reverse_aspect_label_tensor, dim=0).to(self.args.device)
        final_related_spans_tensor = torch.cat(final_related_spans_tensor, dim=0).to(self.args.device)
#         adj_matrix = torch.cat([i.unsqueeze(0) for i in adj_batch], axis=0)
        return final_tokens_tensor, final_attention_mask, final_bert_spans_tensor, final_spans_mask_tensor, \
               final_spans_ner_label_tensor, final_spans_aspect_tensor, final_spans_opinion_label_tensor, \
               final_reverse_ner_label_tensor, final_reverse_opinion_tensor, final_reverse_aspect_label_tensor, \
               final_related_spans_tensor, sentence_length


    def get_input_tensors(self, tokenizer, tokens, spans, spans_ner_label, spans_aspect_label, spans_opinion_label,
                          reverse_ner_label, reverse_opinion_labels, reverse_aspect_label):
        start2idx = []
        end2idx = []
        bert_tokens = []
        bert_tokens.append(tokenizer.cls_token)
        for token in tokens:
            start2idx.append(len(bert_tokens))
            test_1 = len(bert_tokens)
            sub_tokens = tokenizer.tokenize(token)
            if self.args.span_generation == "CNN":
                bert_tokens.append(sub_tokens[0])
            elif self.args.Only_token_head:
                bert_tokens.append(sub_tokens[0])
            else:
                bert_tokens += sub_tokens
            end2idx.append(len(bert_tokens) - 1)
            test_2 = len(bert_tokens) - 1

        bert_tokens.append(tokenizer.sep_token)
        indexed_tokens = tokenizer.convert_tokens_to_ids(bert_tokens)
        tokens_tensor = torch.tensor([indexed_tokens])
        bert_spans = [[start2idx[span[0]], end2idx[span[1]], span[2]] for span in spans]
        # 在bert分出subword之后  需要对原有的aspect span进行补充
        spans_aspect_label = [[aspect_span[0], start2idx[aspect_span[1]], end2idx[aspect_span[2]]] for
                              aspect_span in spans_aspect_label]
        reverse_opinion_label =[[opinion_span[0], start2idx[opinion_span[1]], end2idx[opinion_span[2]]] for
                                opinion_span in reverse_opinion_labels]
        bert_spans_tensor = torch.tensor([bert_spans])

        spans_ner_label_tensor = torch.tensor([spans_ner_label])
        spans_aspect_tensor = torch.tensor([spans_aspect_label])
        spans_opinion_tensor = torch.tensor([spans_opinion_label])
        reverse_ner_label_tensor = torch.tensor([reverse_ner_label])
        reverse_opinion_tensor = torch.tensor([reverse_opinion_label])
        reverse_aspect_tensor = torch.tensor([reverse_aspect_label])
        return bert_tokens, tokens_tensor, bert_spans_tensor, spans_ner_label_tensor, spans_aspect_tensor, spans_opinion_tensor, \
               reverse_ner_label_tensor, reverse_opinion_tensor, reverse_aspect_tensor


Writing data_BIO_loader.py


In [6]:
%%writefile eval_features.py

def unbatch_data(pred_data):
    stage1_pred = []
    stage1_pred_sentiment = []
    stage1_pred_sentiment_logits = []
    stage2_pred = []
    stage2_pred_sentiment_logits = []

    for i in range(len(pred_data[0])):
        pred_stage1_result_tolist = pred_data[0][i].tolist()
        pred_stage1_result_sentiment_tolist = pred_data[1][i].tolist()
        pred_stage1_sentiment_logit_tolist = pred_data[2][i].tolist()

        pred_stage2_result_tolist = pred_data[3][i].tolist()
        pred_stage2_sentiment_logit_tolist = pred_data[4][i].tolist()

        # test
        if len(pred_stage1_result_tolist) != len(pred_stage2_result_tolist):
            raise IndexError('预测的stage1和stage2序列数不相等')
        for j in range(len(pred_stage1_result_sentiment_tolist)):
            pred_stage1_per_sent, pred_stage2_per_sent, pred_stage2_sentiment_logit_per_sent = [], [], []

            stage1_pred_sentiment.append(pred_stage1_result_sentiment_tolist[j])
            stage1_pred_sentiment_logits.append(pred_stage1_sentiment_logit_tolist[j])

            for k2, pred_span in enumerate(pred_stage1_result_tolist):
                if pred_span[0] == j:
                    pred_stage1_per_sent.append(pred_span)
                    pred_stage2_per_sent.append(pred_stage2_result_tolist[k2])
                    pred_stage2_sentiment_logit_per_sent.append(pred_stage2_sentiment_logit_tolist[k2])

            stage1_pred.append(pred_stage1_per_sent)
            stage2_pred.append(pred_stage2_per_sent)
            stage2_pred_sentiment_logits.append(pred_stage2_sentiment_logit_per_sent)

    pred_result = (stage1_pred, stage1_pred_sentiment, stage1_pred_sentiment_logits, stage2_pred,
                   stage2_pred_sentiment_logits)
    return pred_result

Writing eval_features.py


In [7]:
%%writefile Metric.py

from transformers import BertTokenizer
import numpy as np
import json
from tqdm import tqdm


id4validity = {0: 'none', 1: 'valid'}
id4sentiment = {0: 'none', 1: 'positive', 2: 'negative', 3:'neutral', 4:'start'}


class Metric():
    def __init__(self, args, forward_pred_result, reverse_pred_result, gold_instances):
        self.args = args
        self.gold_instances = gold_instances
        self.tokenizer = BertTokenizer.from_pretrained(args.init_vocab, do_lower_case=args.do_lower_case)

        self.pred_aspect = forward_pred_result[0]
        self.pred_aspect_sentiment = forward_pred_result[1]
        self.pred_aspect_sentiment_logit = forward_pred_result[2]

        self.pred_opinion = forward_pred_result[3]
        self.pred_opinion_sentiment_logit = forward_pred_result[4]

        '''Reverse evaluation'''
        self.reverse_pred_opinon = reverse_pred_result[0]
        self.reverse_pred_opinon_sentiment = reverse_pred_result[1]
        self.reverse_pred_opinon_sentiment_logit = reverse_pred_result[2]

        self.reverse_pred_aspect = reverse_pred_result[3]
        self.reverse_pred_aspect_sentiment_logit = reverse_pred_result[4]


    def P_R_F1(self, gold_num, pred_num, correct_num):
        precision = correct_num / pred_num if pred_num > 0 else 0
        recall = correct_num / gold_num if gold_num > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        return (precision, recall, f1)

    def num_4_eval(self, gold, pred, gold_num, pred_num, correct_num):
        correct = set(gold) & set(pred)
        gold_num += len(set(gold))
        pred_num += len(set(pred))
        correct_num += len(correct)
        return gold_num, pred_num, correct_num

    def cal_triplet_final_result(self, forward_results, forward_spans, reverse_results, reverse_spans):

        pred_dicts = {}
        pred_spans = forward_spans + reverse_spans
        for index, result in enumerate(forward_results + reverse_results):
            if result in pred_dicts:
                score_dict = pred_dicts[result][2]
                score_new = pred_spans[index][2]
                if score_dict > score_new:
                    continue
                else:
                    pred_dicts[result] = pred_spans[index]
            else:
                pred_dicts[result] = pred_spans[index]
        history = []
        for i in pred_dicts:
            aspect_span_i = range(pred_dicts[i][0][0], pred_dicts[i][0][1])
            opinion_span_i = range(pred_dicts[i][1][0], pred_dicts[i][1][1])
            for j in pred_dicts:
                if (i,j) in history:
                    continue
                history.append((i, j))
                history.append((j, i))
                if i == j:
                    continue
                aspect_span_j = range(pred_dicts[j][0][0], pred_dicts[j][0][1])
                opinion_span_j = range(pred_dicts[j][1][0], pred_dicts[j][1][1])
                repeat_a_span = list(set(aspect_span_i) & set(aspect_span_j))
                repeat_o_span = list(set(opinion_span_i) & set(opinion_span_j))
                if len(repeat_a_span) == 0 or len(repeat_o_span) == 0:
                    continue
                elif len(repeat_a_span) <= min(len(aspect_span_i), len(aspect_span_j)) and \
                        len(repeat_o_span) <= min(len(opinion_span_i), len(opinion_span_j)):
                    i_score = pred_dicts[i][2]
                    j_score = pred_dicts[j][2]
                    if i_score >= j_score:
                        pred_dicts[j] = (pred_dicts[j][0], pred_dicts[j][1], 0)
                    else:
                        pred_dicts[i] = (pred_dicts[i][0], pred_dicts[i][1], 0)
                else:
                    raise(KeyboardInterrupt)
        return [_ for _ in pred_dicts if pred_dicts[_][2] != 0]


    def score_triples(self):
        correct_aspect_num,correct_opinion_num,correct_apce_num,correct_pairs_num,correct_num = 0,0,0,0,0
        gold_aspect_num,gold_opinion_num,gold_apce_num,gold_pairs_num,gold_num = 0,0,0,0,0
        pred_aspect_num,pred_opinion_num,pred_apce_num,pred_pairs_num,pred_num = 0,0,0,0,0

        if self.args.output_path:
            result = []
            aspect_text = []
            opinion_text = []
        print("Ready to for")
        for i in tqdm(range(len(self.gold_instances))):
            '''Entity length experiment'''

            bert_tokens = []
            spans = self.gold_instances[i]['spans']
            start2idx = []
            end2idx = []
            bert_tokens.append(self.tokenizer.cls_token)
            for token in self.gold_instances[i]['tokens']:
                start2idx.append(len(bert_tokens))
                sub_tokens = self.tokenizer.tokenize(token)
                if self.args.span_generation == "CNN":
                    bert_tokens.append(sub_tokens[0])
                elif self.args.Only_token_head:
                    bert_tokens.append(sub_tokens[0])
                else:
                    bert_tokens += sub_tokens
                end2idx.append(len(bert_tokens) - 1)
            bert_tokens.append(self.tokenizer.sep_token)
            bert_spans = [[start2idx[span[0]], end2idx[span[1]], span[2]] for span in spans]
            gold_aspect, gold_opinion, gold_apce, gold_pairs, gold_triples = self.find_gold_triples(i, bert_spans,
                                                                                                    bert_tokens)
            pred_aspect, pred_opinion, pred_apce, pred_pairs, pred_triples, pred_spans = self.find_pred_triples(i, bert_spans,
                                                                                                    bert_tokens)

            # if len(gold_triples) < 5:
            #     continue

            reverse_aspect, reverse_opinion, reverse_apce, reverse_pairs, reverse_triples, reverse_spans = \
                self.find_pred_reverse_triples(i, bert_spans, bert_tokens)

            pred_aspect = list(set(pred_aspect) | set(reverse_aspect))
            pred_opinion = list(set(pred_opinion) | set(reverse_opinion))
            pred_apce = list(set(pred_apce) | set(reverse_apce))
            pred_pairs = list(set(pred_pairs) | set(reverse_pairs))
            if self.args.Filter_Strategy:
                pred_triples = self.cal_triplet_final_result(pred_triples, pred_spans, reverse_triples, reverse_spans)
            else:
                pred_triples = list(set(pred_triples) | set(reverse_triples))


            if self.args.output_path:
                result.append({"sentence": self.gold_instances[i]['sentence'],
                                     "triple_list_gold": [gold_triple for gold_triple in set(gold_triples)],
                                     "triple_list_pred": [pred_triple for pred_triple in set(pred_triples)],
                                    "new": [new_triple for new_triple in (set(pred_triples) - set(gold_triples))],
                                    "lack": [lack_triple for lack_triple in (set(gold_triples) - set(pred_triples))]
                                     })
                aspect_text.append({"sentence": self.gold_instances[i]['sentence'],
                                    'gold aspect': [gold_as for gold_as in set(gold_aspect)],
                                    'pred aspect': [pred_as for pred_as in set(pred_aspect)],
                                    "new": [new_as for new_as in (set(pred_aspect) - set(gold_aspect))],
                                    "lack": [lack_as for lack_as in (set(gold_aspect) - set(pred_aspect))]})
                opinion_text.append({"sentence": self.gold_instances[i]['sentence'],
                                    'gold aspect': [gold_op for gold_op in set(gold_opinion)],
                                    'pred aspect': [pred_op for pred_op in set(pred_opinion)],
                                    "new": [new_op for new_op in (set(pred_opinion) - set(gold_opinion))],
                                    "lack": [lack_op for lack_op in (set(gold_opinion) - set(pred_opinion))]})


            gold_aspect_num, pred_aspect_num, correct_aspect_num = self.num_4_eval(gold_aspect, pred_aspect,
                                                                                   gold_aspect_num,
                                                                                   pred_aspect_num, correct_aspect_num)

            gold_opinion_num, pred_opinion_num, correct_opinion_num = self.num_4_eval(gold_opinion, pred_opinion,
                                                                                   gold_opinion_num,
                                                                                   pred_opinion_num, correct_opinion_num)

            gold_apce_num, pred_apce_num, correct_apce_num = self.num_4_eval(gold_apce, pred_apce, gold_apce_num,
                                                                             pred_apce_num, correct_apce_num)

            gold_apce_num, pred_apce_num, correct_apce_num = self.num_4_eval(gold_apce, pred_apce, gold_apce_num,
                                                                             pred_apce_num, correct_apce_num)

            gold_pairs_num, pred_pairs_num, correct_pairs_num = self.num_4_eval(gold_pairs, pred_pairs, gold_pairs_num,
                                                                             pred_pairs_num, correct_pairs_num)

            gold_num, pred_num, correct_num = self.num_4_eval(gold_triples, pred_triples, gold_num,
                                                                                pred_num, correct_num)


        if self.args.output_path:
            F = open(self.args.dataset + 'triples.json', 'w', encoding='utf-8')
            json.dump(result, F, ensure_ascii=False, indent=4)
            F.close()

            F1 = open(self.args.dataset + 'aspect.json', 'w', encoding='utf-8')
            json.dump(aspect_text, F1, ensure_ascii=False, indent=4)
            F1.close()

            F2 = open(self.args.dataset + 'opinion.json', 'w', encoding='utf-8')
            json.dump(opinion_text, F2, ensure_ascii=False, indent=4)
            F2.close()


        aspect_result = self.P_R_F1(gold_aspect_num, pred_aspect_num, correct_aspect_num)
        opinion_result = self.P_R_F1(gold_opinion_num, pred_opinion_num, correct_opinion_num)
        apce_result = self.P_R_F1(gold_apce_num, pred_apce_num, correct_apce_num)
        pair_result = self.P_R_F1(gold_pairs_num, pred_pairs_num, correct_pairs_num)
        triplet_result = self.P_R_F1(gold_num, pred_num, correct_num)
        return aspect_result, opinion_result, apce_result, pair_result, triplet_result

    def find_token(self, bert_tokens, span):
        bert_tokens_4_span = bert_tokens[span[1]:span[2]+1]
        sub = ''
        for i, tokens in enumerate(bert_tokens_4_span):
            if i == 0:
                sub = tokens
            elif '##' in tokens:
                sub = sub + tokens.lstrip("##")
            else:
                sub = sub +" "+ tokens
        return sub

    def gold_token(self, tokens):
        sub = ''
        for i, token in enumerate(tokens):
            if i == 0:
                sub = token
            elif '##' in token:
                sub = sub + token.lstrip("##")
            else:
                sub = sub +" "+ token
        return sub

    def find_aspect_sentiment(self, sentence_index, bert_spans, span, aspect_sentiment, aspect_sentiment_logit):
        # span = [span[1], span[2], ]
        bert_span_index = [i for i,x in enumerate(bert_spans) if span[1] == x[0] and span[2] == x[1]]
        assert len(bert_span_index) == 1
        bert_span_index = bert_span_index[0]
        sentiment_index = aspect_sentiment[sentence_index][bert_span_index]
        # sentiment = id4sentiment[aspect_sentiment[sentence_index][bert_span_index]]
        sentiment = id4validity[aspect_sentiment[sentence_index][bert_span_index]]
        sentiment_logit = aspect_sentiment_logit[sentence_index][bert_span_index][sentiment_index]
        # all_sentiment_logit = sum(aspect_sentiment_logit[sentence_index][bert_span_index])
        # sentiment_precent = sentiment_logit / all_sentiment_logit
        # return sentiment, sentiment_precent

        return sentiment, sentiment_logit

    def find_opinion_sentiment(self, sentence_index, opinion_index, bert_spans, span, opinion_sentiment,
                               opinion_sentiment_logit):
        bert_span_index = [i for i, x in enumerate(bert_spans) if span[1] == x[0] and span[2] == x[1]]
        assert len(bert_span_index) == 1
        bert_span_index = bert_span_index[0]
        sentiment_index = opinion_sentiment[sentence_index][opinion_index][bert_span_index]
        sentiment = id4sentiment[opinion_sentiment[sentence_index][opinion_index][bert_span_index]]
        sentiment_logit = opinion_sentiment_logit[sentence_index][opinion_index][bert_span_index][sentiment_index]
        return sentiment, sentiment_logit

    # Code that uses raw data
    def find_gold_triples(self, sentence_index, bert_spans, bert_tokens):
        triples_list,pair_list = [],[]
        aspect_list,opinion_list,apce_list = [],[],[]
        triples = self.gold_instances[sentence_index]['triples']
        for keys in triples:
            aspect, opinion = keys.split('|')
            aspect_tokens = []
            for aspect_token in aspect.split( ):
                token = self.tokenizer.tokenize(aspect_token)
                if self.args.span_generation == "CNN":
                    aspect_tokens.append(token[0])
                elif self.args.Only_token_head:
                    aspect_tokens.append(token[0])
                else:
                    aspect_tokens += token
            new_aspect = self.gold_token(aspect_tokens)

            opinion_tokens = []
            for opinion_token in opinion.split( ):
                token = self.tokenizer.tokenize(opinion_token)
                if self.args.span_generation == "CNN":
                    opinion_tokens.append(token[0])
                elif self.args.Only_token_head:
                    opinion_tokens.append(token[0])
                else:
                    opinion_tokens += token
            new_opinion = self.gold_token(opinion_tokens)

            sentiment = triples[keys][2]

            triples_list.append((new_aspect, new_opinion, sentiment.lower()))

            aspect_list.append((new_aspect))
            opinion_list.append((new_opinion))

            apce_list.append((new_aspect, sentiment))
            pair_list.append((new_aspect, new_opinion))
        return aspect_list, opinion_list, apce_list, pair_list, triples_list

    def find_pred_triples(self, sentence_index, bert_spans, bert_tokens):
        triples_list, pair_list, span_list = [], [], []
        aspect_list, pred_opinion_list, apce_list = [], [], []
        pred_aspect_span = self.pred_aspect[sentence_index]
        # Remove duplicate aspects
        new_aspect_span = []
        for i, pred_aspect in enumerate(pred_aspect_span):
            if len(new_aspect_span) == 0:
                new_aspect_span.append(pred_aspect)
            else:
                if pred_aspect[1] == new_aspect_span[-1][1]:
                    new_aspect_span[-1] = pred_aspect
                else:
                    new_aspect_span.append(pred_aspect)
        for j, pred_aspect in enumerate(new_aspect_span):
            aspect = self.find_token(bert_tokens, pred_aspect)
            aspect_span_output = [pred_aspect[1], pred_aspect[2]+1]
            aspect_sentiment, aspect_sentiment_logit = self.find_aspect_sentiment(sentence_index, bert_spans,
                                                                                  pred_aspect,
                                                                                  self.pred_aspect_sentiment,
                                                                                  self.pred_aspect_sentiment_logit)
            aspect_list.append(aspect)

            opinion_list = []
            for opinion_index in list(np.where(np.array(self.pred_opinion[sentence_index][j]) != 0)[0]):
                opinion_list.append(opinion_index)
            opinion_spans = []
            for opinion_index in opinion_list:
                if opinion_index < len(bert_spans):
                    opinion_spans.append(bert_spans[opinion_index])
                else:
                    continue
            new_opinion_spans = []
            for i, pred_opinion in enumerate(opinion_spans):
                if len(new_opinion_spans) == 0:
                    new_opinion_spans.append(pred_opinion)
                else:
                    if pred_opinion[1] == new_opinion_spans[-1][1]:
                        new_opinion_spans[-1] = pred_opinion
                    else:
                        new_opinion_spans.append(pred_opinion)
            for opinion_span in new_opinion_spans:
                opinion_span = (opinion_span[2], opinion_span[0], opinion_span[1])
                opinion_span_output = [opinion_span[1], opinion_span[2]+1]
                opinion = self.find_token(bert_tokens, opinion_span)
                opinion_sentiment, opinion_sentiment_logit = self.find_opinion_sentiment(sentence_index, j, bert_spans,
                                                                                         opinion_span,
                                                                                         self.pred_opinion,
                                                                                         self.pred_opinion_sentiment_logit)
                # 筛选情感  弃用
                # if opinion_sentiment_logit > aspect_sentiment_logit:
                #     sentiment = opinion_sentiment
                # else:
                #     sentiment = aspect_sentiment

                pred_opinion_list.append(opinion)
                apce_list.append((aspect, opinion_sentiment))
                triples_list.append((aspect, opinion, opinion_sentiment))
                pair_list.append((aspect, opinion))
                span_list.append((aspect_span_output, opinion_span_output, opinion_sentiment_logit))
        return aspect_list, pred_opinion_list, apce_list, pair_list, triples_list, span_list



    def find_pred_reverse_triples(self, sentence_index, bert_spans, bert_tokens):
        triples_list, pair_list, span_list = [], [], []
        opinion_list, pred_aspect_list, apce_list = [], [], []
        pred_opinion_span = self.reverse_pred_opinon[sentence_index]

        new_opinion_span = []
        for i, pred_opinion in enumerate(pred_opinion_span):
            if len(new_opinion_span) == 0:
                new_opinion_span.append(pred_opinion)
            else:
                '''Take the long operation, the overlapping entity takes the longer part'''
                if pred_opinion[1] == new_opinion_span[-1][1]:
                    new_opinion_span[-1] = pred_opinion
                else:
                    new_opinion_span.append(pred_opinion)
        for j, pred_opinion in enumerate(new_opinion_span):
            opinion = self.find_token(bert_tokens, pred_opinion)
            opinion_span_output = [pred_opinion[1], pred_opinion[2] + 1]
            opinion_sentiment, opinion_sentiment_precent = self.find_aspect_sentiment(sentence_index,
                                                                                    bert_spans,
                                                                                    pred_opinion,
                                                                                    self.reverse_pred_opinon_sentiment,
                                                                                    self.reverse_pred_opinon_sentiment_logit)
            opinion_list.append((opinion))
            aspect_list = []
            for aspect_index in list(np.where(np.array(self.reverse_pred_aspect[sentence_index][j]) != 0)[0]):
                aspect_list.append(aspect_index)
            aspect_spans = []
            for aspect_index in aspect_list:
                if aspect_index < len(bert_spans):
                    aspect_spans.append(bert_spans[aspect_index])
                else: continue
            new_aspect_spans = []
            '''At the beginning of the same, choose a longer entity'''
            for i, pred_aspect in enumerate(aspect_spans):
                if len(new_aspect_spans) == 0:
                    new_aspect_spans.append(pred_aspect)
                else:

                    if pred_aspect[1] == new_aspect_spans[-1][1]:
                        new_aspect_spans[-1] = pred_aspect
                    else:
                        new_aspect_spans.append(pred_aspect)
            for aspect_span in new_aspect_spans:
                aspect_span = (aspect_span[2], aspect_span[0], aspect_span[1])
                aspect_span_output = [aspect_span[1], aspect_span[2] + 1]
                aspect = self.find_token(bert_tokens, aspect_span)
                aspect_sentiment, aspect_sentiment_precent = self.find_opinion_sentiment(sentence_index, j,
                                                                                       bert_spans, aspect_span,
                                                                                       self.reverse_pred_aspect,
                                                                                       self.reverse_pred_aspect_sentiment_logit)
                # if opinion_sentiment_precent > aspect_sentiment_precent:
                #     sentiment = opinion_sentiment
                # else:
                #     sentiment = aspect_sentiment
                pred_aspect_list.append((aspect))
                apce_list.append((aspect, aspect_sentiment))
                triples_list.append((aspect, opinion, aspect_sentiment))
                pair_list.append((aspect, opinion))
                span_list.append((aspect_span_output, opinion_span_output, aspect_sentiment_precent))
        return pred_aspect_list, opinion_list, apce_list, pair_list, triples_list, span_list

if __name__ == '__main__':
    test1 = ('boot time', 'fast', 'pos')
    test = ('boot time', 'boot')
    test2 = ('Boot time', 'fast', 'pos')
    set1  = set(test1) & set(test2)
    print(set(test))

Writing Metric.py


In [8]:
%%writefile model.py

import torch
import torch.nn as nn
import torch.nn.functional as F
# import numpy
# from transformers.models.bert.modeling_bert import BertAttention, BertIntermediate, BertOutput

from Attention import Attention, Intermediate, Output, Dim_Four_Attention, masked_softmax
from data_BIO_loader import sentiment2id, validity2id
from allennlp.nn.util import batched_index_select, batched_span_select
import random
import math

def stage_2_features_generation(bert_feature, attention_mask, spans, span_mask, spans_embedding, spans_aspect_tensor,
                                spans_opinion_tensor=None):
    # Process the input aspect information in reverse to remove invalid aspects span
    all_span_aspect_tensor = None
    all_span_opinion_tensor = None
    all_bert_embedding = None
    all_attention_mask = None
    all_spans_embedding = None
    all_span_mask = None
    spans_aspect_tensor_spilt = torch.chunk(spans_aspect_tensor, spans_aspect_tensor.shape[0], dim=0)
    for i, spans_aspect_tensor_unspilt in enumerate(spans_aspect_tensor_spilt):
        test = spans_aspect_tensor_unspilt.squeeze(0)
        batch_num = spans_aspect_tensor_unspilt.squeeze(0)[0]
        # mask4span_start = torch.where(span_mask[batch_num, :] == 1, spans[batch_num, :, 0], torch.tensor(-1).type_as(spans))
        span_index_start = torch.where(spans[batch_num, :, 0] == spans_aspect_tensor_unspilt.squeeze()[1],
                                       spans[batch_num, :, 1], torch.tensor(-1).type_as(spans))
        span_index_end = torch.where(span_index_start == spans_aspect_tensor_unspilt.squeeze()[2], span_index_start,
                                     torch.tensor(-1).type_as(spans))
        span_index = torch.nonzero((span_index_end > -1), as_tuple=False).squeeze(0)
        if min(span_index.shape) == 0:
            continue
        if spans_opinion_tensor is not None:
            spans_opinion_tensor_unspilt = spans_opinion_tensor[i,:].unsqueeze(0)
        aspect_span_embedding_unspilt = spans_embedding[batch_num, span_index, :].unsqueeze(0)
        bert_feature_unspilt = bert_feature[batch_num, :, :].unsqueeze(0)
        attention_mask_unspilt = attention_mask[batch_num, :].unsqueeze(0)
        spans_embedding_unspilt = spans_embedding[batch_num, :, :].unsqueeze(0)
        span_mask_unspilt = span_mask[batch_num, :].unsqueeze(0)
        if all_span_aspect_tensor is None:
            if spans_opinion_tensor is not None:
                all_span_opinion_tensor = spans_opinion_tensor_unspilt
            all_span_aspect_tensor = aspect_span_embedding_unspilt
            all_bert_embedding = bert_feature_unspilt
            all_attention_mask = attention_mask_unspilt
            all_spans_embedding = spans_embedding_unspilt
            all_span_mask = span_mask_unspilt
        else:
            if spans_opinion_tensor is not None:
                all_span_opinion_tensor = torch.cat((all_span_opinion_tensor, spans_opinion_tensor_unspilt), dim=0)
            all_span_aspect_tensor = torch.cat((all_span_aspect_tensor, aspect_span_embedding_unspilt), dim=0)
            all_bert_embedding = torch.cat((all_bert_embedding, bert_feature_unspilt), dim=0)
            all_attention_mask = torch.cat((all_attention_mask, attention_mask_unspilt), dim=0)
            all_spans_embedding = torch.cat((all_spans_embedding, spans_embedding_unspilt), dim=0)
            all_span_mask = torch.cat((all_span_mask, span_mask_unspilt), dim=0)
    return all_span_opinion_tensor, all_span_aspect_tensor, all_bert_embedding, all_attention_mask, \
           all_spans_embedding, all_span_mask


class Step_1_module(torch.nn.Module):
    def __init__(self, args, bert_config):
        super(Step_1_module, self).__init__()
        self.args = args
        self.intermediate = Intermediate(bert_config)
        self.output = Output(bert_config)

    def forward(self, spans_embedding):
        intermediate_output = self.intermediate(spans_embedding)
        layer_output = self.output(intermediate_output, spans_embedding)
        return layer_output, layer_output


class Step_1(torch.nn.Module):
    def feature_slice(self, features, mask, span_mask, sentence_length):
        cnn_span_generate_list = []
        for j, CNN_generation_model in enumerate(self.CNN_span_generation):
            bert_feature = features.permute(0, 2, 1)
            cnn_result = CNN_generation_model(bert_feature)
            cnn_span_generate_list.append(cnn_result)

        features_sliced_tensor = None
        features_mask_tensor = None
        for i in range(features.shape[0]):
            last_mask = torch.nonzero(mask[i, :])
            features_sliced = features[i,:last_mask.shape[0]][1:-1]
            for j in range(self.args.max_span_length -1):
                if last_mask.shape[0] - 2 > j:
                    # test = cnn_span_generate_list[j].permute(0, 2, 1)
                    cnn_feature = cnn_span_generate_list[j].permute(0, 2, 1)[i, 1:last_mask.shape[0] - (j+2), :]
                    features_sliced = torch.cat((features_sliced, cnn_feature), dim=0)
                else:
                    break
            pad_length = span_mask.shape[1] - features_sliced.shape[0]
            spans_mask_tensor = torch.full([1, features_sliced.shape[0]], 1, dtype=torch.long).to(self.args.device)
            if pad_length > 0:
                pad = torch.full([pad_length, self.args.bert_feature_dim], 0, dtype=torch.long).to(self.args.device)
                features_sliced = torch.cat((features_sliced, pad),dim=0)
                mask_pad = torch.full([1, pad_length], 0, dtype=torch.long).to(self.args.device)
                spans_mask_tensor = torch.cat((spans_mask_tensor, mask_pad),dim=1)
            if features_sliced_tensor is None:
                features_sliced_tensor = features_sliced.unsqueeze(0)
                features_mask_tensor = spans_mask_tensor
            else:
                features_sliced_tensor = torch.cat((features_sliced_tensor, features_sliced.unsqueeze(0)), dim=0).to(self.args.device)
                features_mask_tensor = torch.cat((features_mask_tensor, spans_mask_tensor), dim=0).to(self.args.device)

        return features_sliced_tensor, features_mask_tensor

    def __init__(self, args, bert_config):
        super(Step_1, self).__init__()
        self.args = args
        self.bert_config = bert_config
        self.dropout_output = torch.nn.Dropout(args.drop_out)
        if self.args.span_generation == "Start_end":
            # 注意此处最大长度要加1的原因是在无效的span的mask由0表示  和其他的span长度结合在一起
            self.step_1_embedding4width = nn.Embedding(args.max_span_length + 1, args.embedding_dim4width)
            self.step_1_linear4width = nn.Linear(args.embedding_dim4width + args.bert_feature_dim * 2,
                                                 args.bert_feature_dim)
        elif self.args.span_generation == "CNN":
            self.CNN_span_generation = nn.ModuleList(
                [nn.Conv1d(in_channels=args.bert_feature_dim, out_channels=args.bert_feature_dim, kernel_size=i + 2) for
                 i in range(args.max_span_length - 1)])
        elif self.args.span_generation == "ATT":
            self.ATT_attentions = nn.ModuleList(
                [Dim_Four_Block(args, self.bert_config) for _ in range(max(1, args.ATT_SPAN_block_num - 1))])
        elif self.args.span_generation == "SE_ATT":
            self.compess_projection = nn.Sequential(nn.Linear(args.bert_feature_dim, 1), nn.ReLU(), nn.Dropout(args.drop_out))

        if args.related_span_underline:
            self.related_attentions = nn.ModuleList(
                [Pointer_Block(args, self.bert_config) for _ in range(max(1, args.related_span_block_num - 1))])

        self.forward_1_decoders = nn.ModuleList(
            [Step_1_module(args, self.bert_config) for _ in range(max(1, args.block_num - 1))])
        self.sentiment_classification_aspect = nn.Linear(args.bert_feature_dim, len(validity2id) - 2)
        # self.sentiment_classification_aspect = nn.Linear(args.bert_feature_dim, len(sentiment2id))

        self.reverse_1_decoders = nn.ModuleList(
            [Step_1_module(args, self.bert_config) for _ in range(max(1, args.block_num - 1))])
        self.sentiment_classification_opinion = nn.Linear(args.bert_feature_dim, len(validity2id) - 2)
        # self.sentiment_classification_opinion = nn.Linear(args.bert_feature_dim, len(sentiment2id))

    def forward(self, input_bert_features, attention_mask, spans, span_mask, related_spans_tensor, sentence_length):

        spans_embedding, features_mask_tensor = self.span_generator(input_bert_features, attention_mask, spans,
                                                                    span_mask, related_spans_tensor, sentence_length)

        if self.args.related_span_underline:
            # spans_embedding_0 = torch.clone(spans_embedding)
            for related_attention in self.related_attentions:
                related_layer_output, related_intermediate_output = related_attention(spans_embedding,
                                                                                      related_spans_tensor,
                                                                                      spans_embedding)
                spans_embedding = related_layer_output
            # spans_embedding = spans_embedding + spans_embedding_0

        span_embedding_1 = torch.clone(spans_embedding)
        for forward_1_decoder in self.forward_1_decoders:
            forward_layer_output, forward_intermediate_output = forward_1_decoder(span_embedding_1)
            span_embedding_1 = forward_layer_output
        class_logits_aspect = self.sentiment_classification_aspect(span_embedding_1)

        span_embedding_2 = torch.clone(spans_embedding)
        for reverse_1_decoder in self.reverse_1_decoders:
            reverse_layer_output, reverse_intermediate_output = reverse_1_decoder(span_embedding_2)
            span_embedding_2 = reverse_layer_output
        class_logits_opinion = self.sentiment_classification_opinion(span_embedding_2)

        return class_logits_aspect, class_logits_opinion, spans_embedding, span_embedding_1, span_embedding_2, \
               features_mask_tensor

    def span_generator(self, input_bert_features, attention_mask, spans, span_mask, related_spans_tensor,
                       sentence_length):
        bert_feature = self.dropout_output(input_bert_features)
        features_mask_tensor = None
        if self.args.span_generation == "Average" or self.args.span_generation == "Max":
            # 如果使用全部span的bert信息：
            spans_num = spans.shape[1]
            spans_width_start_end = spans[:, :, 0:2].view(spans.size(0), spans_num, -1)
            spans_width_start_end_embedding, spans_width_start_end_mask = batched_span_select(bert_feature,
                                                                                              spans_width_start_end)
            spans_width_start_end_mask = spans_width_start_end_mask.unsqueeze(-1).expand(-1, -1, -1,
                                                                                         self.args.bert_feature_dim)
            spans_width_start_end_embedding = torch.where(spans_width_start_end_mask, spans_width_start_end_embedding,
                                                          torch.tensor(0).type_as(spans_width_start_end_embedding))
            if self.args.span_generation == "Max":
                spans_width_start_end_max = spans_width_start_end_embedding.max(2)
                spans_embedding = spans_width_start_end_max[0]
            else:
                spans_width_start_end_mean = spans_width_start_end_embedding.mean(dim=2, keepdim=True).squeeze(-2)
                spans_embedding = spans_width_start_end_mean
        elif self.args.span_generation == "Start_end":
            # 如果使用span区域大小进行embedding
            spans_start = spans[:, :, 0].view(spans.size(0), -1)
            spans_start_embedding = batched_index_select(bert_feature, spans_start)
            spans_end = spans[:, :, 1].view(spans.size(0), -1)
            spans_end_embedding = batched_index_select(bert_feature, spans_end)

            spans_width = spans[:, :, 2].view(spans.size(0), -1)
            spans_width_embedding = self.step_1_embedding4width(spans_width)
            spans_embedding = torch.cat((spans_start_embedding, spans_width_embedding, spans_end_embedding), dim=-1)  # 预留可修改部分
            # spans_embedding_dict = torch.cat((spans_start_embedding, spans_end_embedding, spans_width_embedding), dim=-1)
            spans_embedding_dict = self.step_1_linear4width(spans_embedding)
            spans_embedding = spans_embedding_dict
        elif self.args.span_generation == "CNN":
            feature_slice, features_mask_tensor = self.feature_slice(bert_feature, attention_mask, span_mask,
                                                                     sentence_length)
            spans_embedding = feature_slice
        elif self.args.span_generation == "ATT":
            spans_width_start_end = spans[:, :, 0:2].view(spans.shape[0], spans.shape[1], -1)
            spans_width_start_end_embedding, spans_width_start_end_mask = batched_span_select(bert_feature,
                                                                                              spans_width_start_end)
            span_sum_embdding = torch.sum(spans_width_start_end_embedding, dim=2).unsqueeze(2)
            for ATT_attention in self.ATT_attentions:
                ATT_layer_output, ATT_intermediate_output = ATT_attention(span_sum_embdding,
                                                                                      spans_width_start_end_mask,
                                                                                      spans_width_start_end_embedding)
                span_sum_embdding = ATT_layer_output
            spans_embedding = span_sum_embdding.squeeze()
        elif self.args.span_generation == "SE_ATT":
            spans_width_start_end = spans[:, :, 0:2].view(spans.shape[0], spans.shape[1], -1)
            spans_width_start_end_embedding, spans_width_start_end_mask = batched_span_select(bert_feature,
                                                                                              spans_width_start_end)
            spans_width_start_end_mask_2 = spans_width_start_end_mask.unsqueeze(-1).expand(-1, -1, -1,
                                                                                         self.args.bert_feature_dim)
            spans_width_start_end_embedding = torch.where(spans_width_start_end_mask_2, spans_width_start_end_embedding,
                                                          torch.tensor(0).type_as(spans_width_start_end_embedding))
            claim_self_att = self.compess_projection(spans_width_start_end_embedding).squeeze()
            claim_self_att = torch.sum(spans_width_start_end_embedding, dim=-1).squeeze()
            claim_rep = masked_softmax(claim_self_att, span_mask, spans_width_start_end_mask).unsqueeze(-1).transpose(2, 3)
            claim_rep = torch.matmul(claim_rep, spans_width_start_end_embedding)
            spans_embedding = claim_rep.squeeze()
        return spans_embedding, features_mask_tensor


class Dim_Four_Block(torch.nn.Module):
    def __init__(self, args, bert_config):
        super(Dim_Four_Block, self).__init__()
        self.args = args
        self.forward_attn = Dim_Four_Attention(bert_config)
        self.intermediate = Intermediate(bert_config)
        self.output = Output(bert_config)
        
    def forward(self, hidden_embedding, masks, encoder_embedding):
        #注意， mask需要和attention中的scores匹配，用来去掉对应的无意义的值
        #对应的score的维度为 (batch_size, num_heads, hidden_dim, encoder_dim)
        masks = (~masks) * -1e9
        attention_masks = masks[:, :, None, None, :]
        cross_attention_output = self.forward_attn(hidden_states=hidden_embedding,
                                                   encoder_hidden_states=encoder_embedding,
                                                   encoder_attention_mask=attention_masks)
        attention_output = cross_attention_output[0]
        attention_result = cross_attention_output[1:]
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output, attention_result


class Pointer_Block(torch.nn.Module):
    def __init__(self, args, bert_config, mask_for_encoder=True):
        super(Pointer_Block, self).__init__()
        self.args = args
        self.forward_attn = Attention(bert_config)
        self.intermediate = Intermediate(bert_config)
        self.output = Output(bert_config)
        self.mask_for_encoder = mask_for_encoder

    def forward(self, hidden_embedding, masks, encoder_embedding):
        #Note that mask needs to match the scores in attention to remove the corresponding meaningless values
        #The dimension of the corresponding score is (batch_size, num_heads, hidden_dim, encoder_dim)
        masks = (~masks) * -1e9
        if masks.dim() == 3:
            attention_masks = masks[:, None, :, :]
        elif masks.dim() == 2:
            if self.mask_for_encoder:
                attention_masks = masks[:, None, None, :]
            else:
                attention_masks = masks[:, None, :, None]
        if self.mask_for_encoder:
            cross_attention_output = self.forward_attn(hidden_states=hidden_embedding,
                                                       encoder_hidden_states=encoder_embedding,
                                                       encoder_attention_mask=attention_masks)
        else:
            cross_attention_output = self.forward_attn(hidden_states=hidden_embedding,
                                                       encoder_hidden_states=encoder_embedding,
                                                       attention_mask=attention_masks)
        attention_output = cross_attention_output[0]
        attention_result = cross_attention_output[1:]
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output, attention_result


class Step_2_forward(torch.nn.Module):
    def __init__(self, args, bert_config):
        super(Step_2_forward, self).__init__()
        self.args = args
        self.bert_config = bert_config
        self.forward_opinion_decoder = nn.ModuleList(
            [Pointer_Block(args, self.bert_config, mask_for_encoder=False) for _ in range(max(1, args.block_num - 1))])
        self.opinion_docoder2class = nn.Linear(args.bert_feature_dim, len(sentiment2id))

    def forward(self, aspect_spans_embedding, aspect_span_mask, spans_aspect_tensor):
        '''aspect---> opinion direction'''
        for opinion_decoder_layer in self.forward_opinion_decoder:
            opinion_layer_output, opinion_attention = opinion_decoder_layer(aspect_spans_embedding, aspect_span_mask, spans_aspect_tensor)
            aspect_spans_embedding = opinion_layer_output
            # WHY ONLY THE LAST ONE???
        opinion_class_logits = self.opinion_docoder2class(aspect_spans_embedding)
        return opinion_class_logits, opinion_attention


class Step_2_reverse(torch.nn.Module):
    def __init__(self, args, bert_config):
        super(Step_2_reverse, self).__init__()
        self.args = args
        self.bert_config = bert_config
        self.reverse_aspect_decoder = nn.ModuleList(
            [Pointer_Block(args, self.bert_config, mask_for_encoder=False) for _ in range(max(1, args.block_num - 1))])
        self.aspect_docoder2class = nn.Linear(args.bert_feature_dim, len(sentiment2id))

    def forward(self, reverse_spans_embedding, reverse_span_mask, all_reverse_opinion_tensor):
        '''opinion---> aspect direction'''
        for reverse_aspect_decoder_layer in self.reverse_aspect_decoder:
            aspect_layer_output, aspect_attention = reverse_aspect_decoder_layer(reverse_spans_embedding, reverse_span_mask, all_reverse_opinion_tensor)
            reverse_spans_embedding = aspect_layer_output
        aspect_class_logits = self.aspect_docoder2class(reverse_spans_embedding)
        return aspect_class_logits, aspect_attention



def Loss(gold_aspect_label, pred_aspect_label, gold_opinion_label, pred_opinion_label, spans_mask_tensor, opinion_span_mask_tensor,
         reverse_gold_opinion_label, reverse_pred_opinion_label, reverse_gold_aspect_label, reverse_pred_aspect_label,
         cnn_spans_mask_tensor, reverse_aspect_span_mask_tensor, spans_embedding, related_spans_tensor, args):
    loss_function = nn.CrossEntropyLoss(reduction='sum')
    if cnn_spans_mask_tensor is not None:
        spans_mask_tensor = cnn_spans_mask_tensor

    # Loss Forward
    aspect_spans_mask_tensor = spans_mask_tensor.view(-1) == 1
    pred_aspect_label_logits = pred_aspect_label.view(-1, pred_aspect_label.shape[-1])
    gold_aspect_effective_label = torch.where(aspect_spans_mask_tensor, gold_aspect_label.view(-1),
                                              torch.tensor(loss_function.ignore_index).type_as(gold_aspect_label))
    aspect_loss = loss_function(pred_aspect_label_logits, gold_aspect_effective_label)

    opinion_span_mask_tensor = opinion_span_mask_tensor.view(-1) == 1
    pred_opinion_label_logits = pred_opinion_label.view(-1, pred_opinion_label.shape[-1])
    gold_opinion_effective_label = torch.where(opinion_span_mask_tensor, gold_opinion_label.view(-1),
                                               torch.tensor(loss_function.ignore_index).type_as(gold_opinion_label))
    opinion_loss = loss_function(pred_opinion_label_logits, gold_opinion_effective_label)
    as_2_op_loss = aspect_loss + opinion_loss

    # Loss Reverse direction
    reverse_opinion_span_mask_tensor = spans_mask_tensor.view(-1) == 1
    reverse_pred_opinion_label_logits = reverse_pred_opinion_label.view(-1, reverse_pred_opinion_label.shape[-1])
    reverse_gold_opinion_effective_label = torch.where(reverse_opinion_span_mask_tensor, reverse_gold_opinion_label.view(-1),
                                              torch.tensor(loss_function.ignore_index).type_as(reverse_gold_opinion_label))
    reverse_opinion_loss = loss_function(reverse_pred_opinion_label_logits, reverse_gold_opinion_effective_label)

    reverse_aspect_span_mask_tensor = reverse_aspect_span_mask_tensor.view(-1) == 1
    reverse_pred_aspect_label_logits = reverse_pred_aspect_label.view(-1, reverse_pred_aspect_label.shape[-1])
    reverse_gold_aspect_effective_label = torch.where(reverse_aspect_span_mask_tensor, reverse_gold_aspect_label.view(-1),
                                               torch.tensor(loss_function.ignore_index).type_as(reverse_gold_aspect_label))
    reverse_aspect_loss = loss_function(reverse_pred_aspect_label_logits, reverse_gold_aspect_effective_label)
    op_2_as_loss = reverse_opinion_loss + reverse_aspect_loss

    if args.kl_loss:
        kl_loss = shape_span_embedding(args, spans_embedding, spans_embedding, related_spans_tensor, spans_mask_tensor)
        # loss = as_2_op_loss + op_2_as_loss + kl_loss
        loss = as_2_op_loss + op_2_as_loss + args.kl_loss_weight * kl_loss
    else:
        loss = as_2_op_loss + op_2_as_loss
        kl_loss = 0
    return loss, args.kl_loss_weight * kl_loss

def shape_span_embedding(args, p, q, pad_mask, span_mask):
    kl_loss = 0
    input_size = p.size()
    assert input_size == q.size()
    for i in range(input_size[0]):
        span_mask_index = torch.nonzero(span_mask[i, :]).squeeze()
        lucky_squence = random.choice(span_mask_index)
        P = p[i, lucky_squence, :]
        mask_index = torch.nonzero(pad_mask[i, lucky_squence, :])
        q_tensor = None
        for idx in mask_index:
            if idx == lucky_squence:
                continue
            if q_tensor is None:
                q_tensor = p[i, idx]
            else:
                q_tensor = torch.cat((q_tensor, p[i, idx]), dim=0)
        if q_tensor is None:
            continue
        expan_P = P.expand_as(q_tensor)
        kl_loss += compute_kl_loss(args, expan_P, q_tensor)
    return kl_loss

def compute_kl_loss(args, p, q, pad_mask=None):
    if args.kl_loss_mode == "KLLoss":
        p_loss = F.kl_div(F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction="none")
        q_loss = F.kl_div(F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction="none")

        if pad_mask is not None:
            p_loss.masked_fill(pad_mask, 0.)
            q_loss.masked_fill(pad_mask, 0.)
        p_loss = p_loss.sum()
        q_loss = q_loss.sum()
        total_loss = math.log(1+5/((p_loss + q_loss) / 2))
    elif args.kl_loss_mode == "JSLoss":
        m = (p+q)/2
        m_loss = 0.5 * F.kl_div(F.log_softmax(p, dim=-1), F.softmax(m, dim=-1), reduction="none") + 0.5 * F.kl_div(
            F.log_softmax(q, dim=-1), F.softmax(m, dim=-1), reduction="none")
        if pad_mask is not None:
            m_loss.masked_fill(pad_mask, 0.)
        m_loss = m_loss.sum()
        # test = -math.log(2*m_loss)-math.log(-2*m_loss+2)
        total_loss = 10*(math.log(1+5/m_loss))
    elif args.kl_loss_mode == "EMLoss":
        test = torch.square(p-q)
        em_loss = torch.sqrt(torch.sum(torch.square(p - q)))
        total_loss = math.log(1+5/(em_loss))
    elif args.kl_loss_mode == "CSLoss":
        test = torch.cosine_similarity(p, q, dim=1)
        cs_loss = torch.sum(torch.cosine_similarity(p, q, dim=1))
        total_loss = math.log(1 + 5 / (cs_loss))
    else:
        total_loss = 0
        print("what's wrong with you?")
    return  total_loss


if __name__ == '__main__':
#     tensor1 = torch.zeros((3,3))
#     tensor2 = torch.nonzero(tensor1, as_tuple=False)
#     tensor1 = tensor1.type_as(tensor2)
    print('666')


Writing model.py


In [9]:
%%writefile gcn.py

import torch.nn as nn
import torch.nn.functional as F
import torch
import spacy
import numpy as np

# nlp = spacy.load("ru_core_news_md")

class GCN(nn.Module):

    def __init__(self, emb_dim=768, num_layers=1, gcn_dropout=0.7):             #此处dropout可以增大
        super(GCN, self).__init__()
        self.layers = num_layers
        self.emb_dim = emb_dim
        self.out_dim = emb_dim
        input_dim = self.emb_dim
        # gcn layer
        self.W = nn.ModuleList([nn.Linear(input_dim, input_dim) for i in range(self.layers)])
        self.gcn_drop = nn.Dropout(gcn_dropout)
        self.relu = nn.ReLU()


    def forward(self, adj, inputs, device):
        # gcn layer

        # adj (batch_size, len, len)
        # inputs (batch_size, len, emb_dim)

#         adj = adj.to_dense()
        if inputs.shape[1] < adj.shape[1]:
            adj = adj[:, :inputs.shape[1], :inputs.shape[1]]
        
        denom = adj.sum(2).unsqueeze(2) + 1                 # batch_size, len, 1
#         mask = (adj.sum(2) + adj.sum(1)).eq(0).unsqueeze(2) # batch_size, len, 1

        for layer in range(self.layers):
            Ax = torch.bmm(adj, inputs)        # batch_size, len, emb_dim
            AxW = self.W[layer](Ax)            # batch_size, len, emb_dim
            AxW = AxW + self.W[layer](inputs)  # self loop
            AxW = AxW.to(device) / denom
            gAxW = self.relu(AxW)              # batch_size, len, emb_dim
            if layer < self.layers - 1:
                inputs = self.gcn_drop(gAxW)
            else:
                inputs = gAxW
        return inputs, None # mask
    

def make_adj_matrix(text, max_len=512):
    doc = nlp(text)
    doc = nlp(" ".join(text))
    adj_matrix = np.eye(max_len)
    for (_, token) in enumerate(doc):
        if token.i >= max_len or token.head.i >= max_len:
            continue
        adj_matrix[token.i][token.head.i] = 1
    return torch.FloatTensor(adj_matrix).to_sparse()

Writing gcn.py


# Run

In [32]:
%%writefile run.py

import os
import argparse
import tqdm
import torch
import torch.nn.functional as F
from transformers import AdamW, BertModel, get_linear_schedule_with_warmup
from data_BIO_loader import DataTterator
from data_BIO_loader import MyDataset
from model import stage_2_features_generation, Step_1, Step_2_forward, Step_2_reverse, Loss
from Metric import Metric
from eval_features import unbatch_data
# from kaggle.working.log import logger
from thop import profile, clever_format
import wandb
from transformers.models.bert.modeling_bert import BertEmbeddings
from gcn import GCN, make_adj_matrix
import numpy as np

import logging
from datetime import datetime

import time
os.environ['CUDA_VISIBLE_DEVICES'] = '5'

# import spacy
# nlp = spacy.load("ru_core_news_md")

sentiment2id = {'none': 0, 'positive': 1, 'negative': 2, 'neutral': 3, 'start': 4}

from datetime import datetime

now = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
logger = logging.getLogger("test")
logger.setLevel(level=logging.INFO)

handler = logging.FileHandler("/kaggle/working/log/"+now+".log", encoding='utf-8')
handler.setLevel(logging.INFO)

console = logging.StreamHandler()
console.setLevel(logging.INFO)

logger.addHandler(handler)
logger.addHandler(console)


from natasha import (
    Segmenter,
    
    NewsEmbedding,
    NewsMorphTagger,
    NewsSyntaxParser,
    
    Doc
)

segmenter = Segmenter()
emb = NewsEmbedding()
morph_tagger = NewsMorphTagger(emb)
syntax_parser = NewsSyntaxParser(emb)


def make_adj_matrix(sent, max_len):
    new_sent = ""
    new_inds = [] # from tokens to poses in text
    new_inds_mapping = {}
    index = -1
    for idx, i in enumerate(sent[1:-1]):
        if i[:2] == '##':
            new_sent += i[2:]
        else:
            new_sent += " "
            new_sent += i
            index += 1
        new_inds.append(index)

    new_inds_mapping[0] = [0]
    for idx, i in enumerate(new_inds):
        if new_inds_mapping.get(i + 1):
            new_inds_mapping[i + 1].append(idx + 1)
        else:
            new_inds_mapping[i + 1] = [idx + 1]

    new_sent = new_sent.strip()
    text = new_sent
    
    splitted_text = text.split(" ")
    
    doc = Doc(text)
    doc.segment(segmenter)
    doc.tag_morph(morph_tagger)
    doc.parse_syntax(syntax_parser)
    
    doc_sents_lens = [0]
    for i in doc.sents:
        doc_sents_lens.append(doc_sents_lens[-1] + len(i.tokens))
    
    cnt = 0
    i = 0
    j = 0
    splitted_mapping = {0:[0]} # from poses from the text to segmented words
    while i < len(doc.tokens) and j < len(splitted_text):
        cur_nat_text = doc.tokens[i].text
        cur_our_text = splitted_text[j]
        if cur_nat_text == cur_our_text:
            splitted_mapping[i + 1] = [j + 1]
            i += 1
            j += 1
        else:
            splitted_mapping[i + 1] = [j + 1]
            if len(cur_nat_text) < len(cur_our_text):
                while cur_nat_text != cur_our_text:
                    i += 1
                    splitted_mapping[i + 1] = [j + 1]

                    cur_nat_text += doc.tokens[i].text
            elif len(cur_nat_text) > len(cur_our_text):
                while cur_nat_text != cur_our_text:
                    j += 1
                    splitted_mapping[i + 1].append(j + 1)
                    cur_our_text += splitted_text[j]
            else:
                raise "???"
            i += 1
            j += 1

#     adj_matrix = np.eye(max_len, max_len)
    adj_matrix = np.zeros((max_len, max_len))

    for i in doc.tokens:
        sent_id, cur_id = [int(j)  for j in i.id.split('_')]
        head_sent_id, head_id = [int(j) for j in i.head_id.split('_')]
        cur_words_ids = []
        for j in splitted_mapping[doc_sents_lens[sent_id - 1] + cur_id]:
            for k in new_inds_mapping[j]:
                cur_words_ids.append(k)
        cur_words_head_ids = []
        for j in splitted_mapping[doc_sents_lens[head_sent_id - 1] + head_id]:
            for k in new_inds_mapping[j]:
                cur_words_head_ids.append(k)
        for i in cur_words_ids:
            for j in cur_words_head_ids:
                adj_matrix[i][j] = 1

    return torch.FloatTensor(adj_matrix).to_sparse()


def eval(gcn_model, bert_model, step_1_model, step_2_forward, step_2_reverse, dataset, args):
    with torch.no_grad():
        gcn_model.eval()
        bert_model.eval()
        step_1_model.eval()
        step_2_forward.eval()
        step_2_reverse.eval()
        '''真实结果'''
        gold_instances = []
        '''前向预测结果'''
        forward_stage1_pred_aspect_result, forward_stage1_pred_aspect_with_sentiment, \
        forward_stage1_pred_aspect_sentiment_logit, forward_stage2_pred_opinion_result, \
        forward_stage2_pred_opinion_sentiment_logit = [],[],[],[],[]

        '''反向预测结果'''
        reverse_stage1_pred_opinion_result, reverse_stage1_pred_opinion_with_sentiment, \
        reverse_stage1_pred_opinion_sentiment_logit, reverse_stage2_pred_aspect_result, \
        reverse_stage2_pred_aspect_sentiment_logit = [], [], [], [], []
        
        tot_loss = 0
        tot_kl_loss = 0

        for j in range(dataset.batch_count):
            tokens_tensor, attention_mask, bert_spans_tensor, spans_mask_tensor, spans_ner_label_tensor, \
            spans_aspect_tensor, spans_opinion_label_tensor, reverse_ner_label_tensor, reverse_opinion_tensor, \
            reverse_aspect_label_tensor, related_spans_tensor, sentence_length = dataset.get_batch(j)

            bert_output = bert_model(input_ids=tokens_tensor, attention_mask=attention_mask)
            
            sentence_adj = []
            for sent in sentence_length:
                sentence_adj.append(make_adj_matrix(sent[0], args.max_seq_length))
            sentence_adj = torch.cat([i.unsqueeze(0) for i in sentence_adj], axis=0).to_dense().to(args.device)
                        
            h_gcn, _ = gcn_model(sentence_adj, bert_output.last_hidden_state, args.device)
            bert_out = bert_output.last_hidden_state + h_gcn # \hat{h}

            aspect_class_logits, opinion_class_logits, spans_embedding, forward_embedding, reverse_embedding, \
                cnn_spans_mask_tensor = step_1_model(
                    bert_out, attention_mask, bert_spans_tensor, spans_mask_tensor,
                    related_spans_tensor, sentence_length)

            '''Batch Update'''
            pred_aspect_logits = torch.argmax(F.softmax(aspect_class_logits, dim=2), dim=2)
            pred_sentiment_ligits = F.softmax(aspect_class_logits, dim=2)
            pred_aspect_logits = torch.where(spans_mask_tensor == 1, pred_aspect_logits,
                                             torch.tensor(0).type_as(pred_aspect_logits))

            reverse_pred_stage1_logits = torch.argmax(F.softmax(opinion_class_logits, dim=2), dim=2)
            reverse_pred_sentiment_ligits = F.softmax(opinion_class_logits, dim=2)
            reverse_pred_stage1_logits = torch.where(spans_mask_tensor == 1, reverse_pred_stage1_logits,
                                             torch.tensor(0).type_as(reverse_pred_stage1_logits))

            '''true result synthesis'''
            gold_instances.append(dataset.get_instances(j))
            
            
            all_span_opinion_tensor = []
            step_2_opinion_class_logits = []
            all_span_mask = []
            all_reverse_aspect_tensor = []
            reverse_aspect_class_logits = []
            reverse_span_mask = []

            '''Bidirectional prediction'''
            if torch.nonzero(pred_aspect_logits, as_tuple=False).shape[0] == 0:
#                 print("zero pred_aspect_logits...")
                forward_stage1_pred_aspect_result.append(torch.full_like(spans_aspect_tensor, -1))
                forward_stage1_pred_aspect_with_sentiment.append(pred_aspect_logits)
                forward_stage1_pred_aspect_sentiment_logit.append(pred_sentiment_ligits)
                forward_stage2_pred_opinion_result.append(torch.full_like(spans_opinion_label_tensor, -1))
                forward_stage2_pred_opinion_sentiment_logit.append(
                    torch.full_like(spans_opinion_label_tensor.unsqueeze(-1).expand(-1, -1, len(sentiment2id)), -1))

            else:
#                 print("non-zero pred_aspect_logits...")
                pred_aspect_spans = torch.chunk(torch.nonzero(pred_aspect_logits, as_tuple=False),
                                                torch.nonzero(pred_aspect_logits, as_tuple=False).shape[0], dim=0)
                pred_span_aspect_tensor = None
                for pred_aspect_span in pred_aspect_spans:
                    batch_num = pred_aspect_span.squeeze()[0]
                    span_aspect_tensor_unspilt_1 = bert_spans_tensor[batch_num, pred_aspect_span.squeeze()[1], :2]
                    span_aspect_tensor_unspilt = torch.tensor(
                        (batch_num, span_aspect_tensor_unspilt_1[0], span_aspect_tensor_unspilt_1[1])).unsqueeze(0)
                    if pred_span_aspect_tensor is None:
                        pred_span_aspect_tensor = span_aspect_tensor_unspilt
                    else:
                        pred_span_aspect_tensor = torch.cat((pred_span_aspect_tensor, span_aspect_tensor_unspilt),dim=0)

                all_span_opinion_tensor, all_span_aspect_tensor, all_bert_embedding, all_attention_mask, \
                    all_spans_embedding, all_span_mask = stage_2_features_generation(
                        bert_out, attention_mask, bert_spans_tensor, spans_mask_tensor,
                        forward_embedding, pred_span_aspect_tensor)
                
#                 all_span_opinion_tensor, all_span_aspect_tensor, all_bert_embedding, all_attention_mask, \
#                 all_spans_embedding, all_span_mask = stage_2_features_generation(bert_out,
#                                                                              attention_mask, bert_spans_tensor,
#                                                                              spans_mask_tensor, forward_embedding,
#                                                                              spans_aspect_tensor,
#                                                                              spans_opinion_label_tensor)


                step_2_opinion_class_logits, opinion_attention = step_2_forward(all_spans_embedding, all_span_mask,
                                                                         all_span_aspect_tensor)

                forward_stage1_pred_aspect_result.append(pred_span_aspect_tensor)
                forward_stage1_pred_aspect_with_sentiment.append(pred_aspect_logits)
                forward_stage1_pred_aspect_sentiment_logit.append(pred_sentiment_ligits)
                forward_stage2_pred_opinion_result.append(torch.argmax(F.softmax(step_2_opinion_class_logits, dim=2), dim=2))
                forward_stage2_pred_opinion_sentiment_logit.append(F.softmax(step_2_opinion_class_logits, dim=2))
            '''Reverse prediction'''
            if torch.nonzero(reverse_pred_stage1_logits, as_tuple=False).shape[0] == 0:
#                 print("zero reverse_pred_stage1_logits...")
                reverse_stage1_pred_opinion_result.append(torch.full_like(reverse_opinion_tensor, -1))
                reverse_stage1_pred_opinion_with_sentiment.append(reverse_pred_stage1_logits)
                reverse_stage1_pred_opinion_sentiment_logit.append(reverse_pred_sentiment_ligits)
                reverse_stage2_pred_aspect_result.append(torch.full_like(reverse_aspect_label_tensor, -1))
                reverse_stage2_pred_aspect_sentiment_logit.append(
                    torch.full_like(reverse_aspect_label_tensor.unsqueeze(-1).expand(-1, -1, len(sentiment2id)), -1))
            else:
#                 print("non-zero reverse_pred_stage1_logits...")
                reverse_pred_opinion_spans = torch.chunk(torch.nonzero(reverse_pred_stage1_logits, as_tuple=False),
                                                torch.nonzero(reverse_pred_stage1_logits, as_tuple=False).shape[0], dim=0)
                reverse_span_opinion_tensor = None
                for reverse_pred_opinion_span in reverse_pred_opinion_spans:
                    batch_num = reverse_pred_opinion_span.squeeze()[0]
                    reverse_opinion_tensor_unspilt = bert_spans_tensor[batch_num, reverse_pred_opinion_span.squeeze()[1], :2]
                    reverse_opinion_tensor_unspilt = torch.tensor(
                        (batch_num, reverse_opinion_tensor_unspilt[0], reverse_opinion_tensor_unspilt[1])).unsqueeze(0)
                    if reverse_span_opinion_tensor is None:
                        reverse_span_opinion_tensor = reverse_opinion_tensor_unspilt
                    else:
                        reverse_span_opinion_tensor = torch.cat((reverse_span_opinion_tensor, reverse_opinion_tensor_unspilt), dim=0)
           
                all_reverse_aspect_tensor, all_reverse_opinion_tensor, reverse_bert_embedding, reverse_attention_mask, \
                reverse_spans_embedding, reverse_span_mask = stage_2_features_generation(
                        bert_out,
                        attention_mask,
                        bert_spans_tensor,
                        spans_mask_tensor,
                        reverse_embedding,
                        reverse_span_opinion_tensor)

                reverse_aspect_class_logits, reverse_aspect_attention = step_2_reverse(reverse_spans_embedding,
                                                                                reverse_span_mask,
                                                                                all_reverse_opinion_tensor)

                reverse_stage1_pred_opinion_result.append(reverse_span_opinion_tensor)
                reverse_stage1_pred_opinion_with_sentiment.append(reverse_pred_stage1_logits)
                reverse_stage1_pred_opinion_sentiment_logit.append(reverse_pred_sentiment_ligits)
                reverse_stage2_pred_aspect_result.append(torch.argmax(F.softmax(reverse_aspect_class_logits, dim=2), dim=2))
                reverse_stage2_pred_aspect_sentiment_logit.append(F.softmax(reverse_aspect_class_logits, dim=2))
            
#             print('val all_span_opinion_tensor', all_span_opinion_tensor)
#                 step_2_opinion_class_logits, \
#                 all_span_mask, \
#                 all_reverse_aspect_tensor, \
#                 reverse_aspect_class_logits, \
#                 reverse_span_mask)
            if not(all_span_opinion_tensor is None or not len(all_span_opinion_tensor) or \
                step_2_opinion_class_logits is None or not len(step_2_opinion_class_logits) or \
                all_span_mask is None or not len(all_span_mask) or \
                all_reverse_aspect_tensor is None or not len(all_reverse_aspect_tensor) or \
                reverse_aspect_class_logits is None or not len(reverse_aspect_class_logits) or \
                reverse_span_mask is None or not len(reverse_span_mask)):
                print("val evaluating loss...")
                loss, kl_loss = Loss(spans_ner_label_tensor, aspect_class_logits, all_span_opinion_tensor, step_2_opinion_class_logits,
                            spans_mask_tensor, all_span_mask, reverse_ner_label_tensor, opinion_class_logits,
                            all_reverse_aspect_tensor, reverse_aspect_class_logits, cnn_spans_mask_tensor, reverse_span_mask,
                            spans_embedding, related_spans_tensor, args)
                print("val tot_loss...{}".format(loss.item()))
                tot_loss += loss.item()
                tot_kl_loss += kl_loss

        gold_instances = [x for i in gold_instances for x in i]
        forward_pred_data = (forward_stage1_pred_aspect_result, forward_stage1_pred_aspect_with_sentiment,
                             forward_stage1_pred_aspect_sentiment_logit, forward_stage2_pred_opinion_result,
                             forward_stage2_pred_opinion_sentiment_logit)
        forward_pred_result = unbatch_data(forward_pred_data)

        reverse_pred_data = (reverse_stage1_pred_opinion_result, reverse_stage1_pred_opinion_with_sentiment,
                             reverse_stage1_pred_opinion_sentiment_logit, reverse_stage2_pred_aspect_result,
                             reverse_stage2_pred_aspect_sentiment_logit)
        reverse_pred_result = unbatch_data(reverse_pred_data)

        metric = Metric(args, forward_pred_result, reverse_pred_result, gold_instances)
        aspect_result, opinion_result, apce_result, pair_result, triplet_result = metric.score_triples()

        
        logger.info(
            'aspect precision: {}\taspect recall: {:.8f}\taspect f1: {:.8f}'.format(aspect_result[0], aspect_result[1], aspect_result[2]))
        logger.info(
            'opinion precision: {}\topinion recall: {:.8f}\topinion f1: {:.8f}'.format(opinion_result[0],
                                                                                        opinion_result[1],
                                                                                        opinion_result[2]))
        logger.info('APCE precision: {}\tAPCE recall: {:.8f}\tAPCE f1: {:.8f}'.format(apce_result[0],
                                                                                apce_result[1], apce_result[2]))
        logger.info('pair precision: {}\tpair recall: {:.8f}\tpair f1: {:.8f}'.format(pair_result[0],
                                                                                          pair_result[1],
                                                                                          pair_result[2]))
        logger.info('triple precision: {}\ttriple recall: {:.8f}\ttriple f1: {:.8f}'.format(triplet_result[0],
                                                                                          triplet_result[1],
                                                                                          triplet_result[2]))

    return aspect_result, opinion_result, apce_result, pair_result, triplet_result, tot_loss


def train(args):
#     wandb_ran = wandb.init(
#         project='aste-SBN',
#         config=args
#     )

    if args.dataset_path == './datasets/BIO_form/':
        train_path = args.dataset_path + args.dataset + "/train.json"
        dev_path = args.dataset_path + args.dataset + "/dev.json"
        test_path = args.dataset_path + args.dataset + "/test.json"
    else:
        train_path = args.dataset_path + args.dataset + "/train_full.txt"
        dev_path = args.dataset_path + args.dataset + "/dev_full.txt"
#         test_path = args.dataset_path + args.dataset + "/test_full_fake_only.txt"
        test_path = '/kaggle/input/dataset-sent-no-fake/test_full_fake_only.txt'

    print('-------------------------------')
    print('Start loading the test set')
    logger.info('Start loading the test set')
    test_datasets = MyDataset(args, test_path, if_train=False)
    testset = DataTterator(test_datasets, args)
    print('The test set is loaded')
    logger.info('The test set is loaded')
    print('-------------------------------')
    
    gcn = GCN(emb_dim=args.bert_feature_dim).to(args.device)
    gcn_param_optimizer = list(gcn.named_parameters())
#     gcn = None

    Bert = BertModel.from_pretrained(args.init_model)
    bert_config = Bert.config

    if args.add_pos_enc:
        print("Change pos_embeddings to 1536 len...")

        # word_emb
        word_emb = Bert.embeddings.word_embeddings.weight.data

        # token_type_emb
        token_type_emb = Bert.embeddings.token_type_embeddings.weight.data

        # pos_enc
        pos_enc = Bert.embeddings.position_embeddings.weight.data
        new_pos_enc = torch.concat((pos_enc, pos_enc * 2, pos_enc * 4), axis=0)
#         new_pos_enc = torch.repeat_interleave(pos_enc, 3, dim=0)

        # new config and embeddings structure
        bert_config.update({'max_position_embeddings': 1536})
        Bert.embeddings = BertEmbeddings(bert_config)

        # return pretrained weights
        Bert.embeddings.word_embeddings.weight.data = word_emb
        Bert.embeddings.token_type_embeddings.weight.data = token_type_emb
        Bert.embeddings.position_embeddings.weight.data = new_pos_enc

        print("Changed successful!")
    
    Bert.to(args.device)
    bert_param_optimizer = list(Bert.named_parameters())

    step_1_model = Step_1(args, bert_config)
    step_1_model.to(args.device)
    step_1_param_optimizer = list(step_1_model.named_parameters())

    step2_forward_model = Step_2_forward(args, bert_config)
    step2_forward_model.to(args.device)
    forward_step2_param_optimizer = list(step2_forward_model.named_parameters())

    step2_reverse_model = Step_2_reverse(args, bert_config)
    step2_reverse_model.to(args.device)
    reverse_step2_param_optimizer = list(step2_reverse_model.named_parameters())

    training_param_optimizer = [
        {'params': [p for n, p in gcn_param_optimizer]},
        {'params': [p for n, p in bert_param_optimizer]},
        {'params': [p for n, p in step_1_param_optimizer], 'lr': args.task_learning_rate},
        {'params': [p for n, p in forward_step2_param_optimizer], 'lr': args.task_learning_rate},
        {'params': [p for n, p in reverse_step2_param_optimizer], 'lr': args.task_learning_rate}]
    optimizer = AdamW(training_param_optimizer, lr=args.learning_rate)

    if args.model_to_upload != None:
        
        model_path = args.model_to_upload
        if args.device == 'cpu':
            state = torch.load(model_path, map_location=torch.device('cpu'))
        else:
            state = torch.load(model_path)
        
        new_state = {}
        for i in state['bert_model'].keys():
            new_state[i[7:]] = state['bert_model'][i]
        Bert.load_state_dict(new_state)
        
        new_state = {}
        for i in state['step_1_model'].keys():
            new_state[i[7:]] = state['step_1_model'][i]
        step_1_model.load_state_dict(new_state)
        
        new_state = {}
        for i in state['step2_forward_model'].keys():
            new_state[i[7:]] = state['step2_forward_model'][i]
        step2_forward_model.load_state_dict(new_state)
        
        new_state = {}
        for i in state['step2_reverse_model'].keys():
            new_state[i[7:]] = state['step2_reverse_model'][i]
        step2_reverse_model.load_state_dict(new_state)
        
        optimizer.load_state_dict(state['optimizer'])
        
        with torch.no_grad():
            Bert.eval()
            step_1_model.eval()
            step2_forward_model.eval()
            step2_reverse_model.eval()


    if args.multi_gpu:
        gcn = torch.nn.DataParallel(gcn)
        Bert = torch.nn.DataParallel(Bert)
        step_1_model = torch.nn.DataParallel(step_1_model)
        step2_forward_model = torch.nn.DataParallel(step2_forward_model)
        step2_reverse_model = torch.nn.DataParallel(step2_reverse_model)

    if args.mode == 'train':
        print('-------------------------------')
        logger.info('Start loading the training and verification set')
        print('Start loading the training and verification set')
        train_datasets = MyDataset(args, train_path, if_train=True)
        trainset = DataTterator(train_datasets, args)
        print("Train features build completed")

        print("Dev features build beginning")
        dev_datasets = MyDataset(args, dev_path, if_train=False)
        devset = DataTterator(dev_datasets, args)
        print('The training set and verification set are loaded')
        logger.info('The training set and verification set are loaded')
        print('-------------------------------')
        if not os.path.exists(args.model_dir):
            os.makedirs(args.model_dir)

        # scheduler
        if args.whether_warm_up:
            training_steps = args.epochs * trainset.batch_count
            warmup_steps = int(training_steps * args.warm_up)
            scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps,
                                                        num_training_steps=training_steps)

        tot_loss = 0
        tot_kl_loss = 0
        best_aspect_f1, best_opinion_f1, best_APCE_f1, best_pairs_f1, best_triple_f1 = 0,0,0,0,0
        best_pairs_f1 = 0
        for i in range(args.epochs):
            logger.info(('Epoch:{}'.format(i)))
            for j in tqdm.trange(trainset.batch_count):
                
                gcn.train()
                Bert.train()
                step_1_model.train()
                step2_forward_model.train()
                step2_reverse_model.train()
                
                if j == 1:
                    start = time.time()
                optimizer.zero_grad()

                tokens_tensor, attention_mask, bert_spans_tensor, spans_mask_tensor, spans_ner_label_tensor, \
                spans_aspect_tensor, spans_opinion_label_tensor, reverse_ner_label_tensor, reverse_opinion_tensor, \
                reverse_aspect_label_tensor, related_spans_tensor, sentence_length = trainset.get_batch(j)
                
                sentence_adj = []
                for sent in sentence_length:
                    sentence_adj.append(make_adj_matrix(sent[0], args.max_seq_length))
                sentence_adj = torch.cat([i.unsqueeze(0) for i in sentence_adj], axis=0).to_dense().to(args.device)
                    
                bert_output = Bert(input_ids=tokens_tensor, attention_mask=attention_mask)
                
                h_gcn, _ = gcn(sentence_adj, bert_output.last_hidden_state, args.device)
                bert_out = bert_output.last_hidden_state + h_gcn # \hat{h}
#                 bert_out = torch.cat((bert_output.last_hidden_state, h_gcn), dim=2) # \hat{h}
                
                aspect_class_logits, opinion_class_logits, spans_embedding, forward_embedding, reverse_embedding, \
                    cnn_spans_mask_tensor = step_1_model(bert_out,
                                                        attention_mask,
                                                        bert_spans_tensor,
                                                        spans_mask_tensor,
                                                        related_spans_tensor,
                                                        sentence_length)

                '''Batch Update'''
                all_span_opinion_tensor, all_span_aspect_tensor, all_bert_embedding, all_attention_mask, \
                all_spans_embedding, all_span_mask = stage_2_features_generation(bert_out,
                                                                             attention_mask, bert_spans_tensor,
                                                                             spans_mask_tensor, forward_embedding,
                                                                             spans_aspect_tensor,
                                                                             spans_opinion_label_tensor)
                all_reverse_aspect_tensor, all_reverse_opinion_tensor, reverse_bert_embedding, reverse_attention_mask, \
                reverse_spans_embedding, reverse_span_mask = stage_2_features_generation(bert_out,
                                                                                     attention_mask, bert_spans_tensor,
                                                                                     spans_mask_tensor, reverse_embedding,
                                                                                     reverse_opinion_tensor,
                                                                                     reverse_aspect_label_tensor)

                step_2_opinion_class_logits, opinion_attention = step2_forward_model(all_spans_embedding, 
                                                                                     all_span_mask, all_span_aspect_tensor)
                step_2_aspect_class_logits, aspect_attention = step2_reverse_model(reverse_spans_embedding,
                    reverse_span_mask, all_reverse_opinion_tensor)
                
#                 print('train all_span_opinion_tensor', all_span_opinion_tensor)

                loss, kl_loss = Loss(spans_ner_label_tensor, aspect_class_logits, all_span_opinion_tensor, step_2_opinion_class_logits,
                            spans_mask_tensor, all_span_mask, reverse_ner_label_tensor, opinion_class_logits,
                            all_reverse_aspect_tensor, step_2_aspect_class_logits, cnn_spans_mask_tensor, reverse_span_mask,
                            spans_embedding, related_spans_tensor, args)
                
                if args.accumulation_steps > 1:
                    loss = loss / args.accumulation_steps
                    loss.backward()
                    if ((j + 1) % args.accumulation_steps) == 0:
                        optimizer.step()
                        if args.whether_warm_up:
                            scheduler.step()
                else:
                    loss.backward()
                    optimizer.step()
                    if args.whether_warm_up:
                        scheduler.step()
                tot_loss += loss.item()
                tot_kl_loss += kl_loss
            
            
            logger.info(('Loss:', tot_loss))
            logger.info(('KL_Loss:', tot_kl_loss))
            

            print('Evaluating, please wait')
            aspect_result, opinion_result, apce_result, pair_result, triplet_result, val_tot_loss = eval(gcn, Bert, step_1_model,
                                                                                           step2_forward_model,
                                                                                           step2_reverse_model,
                                                                                           devset, args)

            wandb.log({
                'Loss':tot_loss,
                'KL_Loss':tot_kl_loss,
                'Val_Loss':val_tot_loss,
                'triple precision':triplet_result[0],
                'triple recall':triplet_result[1],
                'triple f1':triplet_result[2]
            })

            tot_loss = 0
            tot_kl_loss = 0

            print('Evaluating complete')


            if triplet_result[2] > 0.5:
#                 model_path = "/kaggle/working/SBN_models/1904_base_full_model_gcn_" + str(i) +'_'+ str(triplet_result[2]) + '.pt'
#                 state = {
# #                     "gcn_model": gcn.state_dict(),
#                     "bert_model": Bert.state_dict(),
#                     "step_1_model": step_1_model.state_dict(),
#                     "step2_forward_model": step2_forward_model.state_dict(),
#                     "step2_reverse_model": step2_reverse_model.state_dict(),
#                     "optimizer": optimizer.state_dict()
#                 }
#                 torch.save(state, model_path)
#                 logger.info("_________________________________________________________")
#                 logger.info("best model save")
#                 logger.info("_________________________________________________________")

                best_triple_f1 = triplet_result[2]
                best_triple_precision = triplet_result[0]
                best_triple_recall = triplet_result[1]
                best_triple_epoch = i
                
                print("Test results...")
                eval(gcn, Bert, step_1_model, step2_forward_model, step2_reverse_model, testset, args)

    logger.info("Features build completed")
    logger.info("Evaluation on testset:")

    eval(gcn, Bert, step_1_model, step2_forward_model, step2_reverse_model, testset, args)
    wandb.finish()


def test(args):
#     test_path = '/kaggle/input/dataset-sentences-with-fake-start/test_no_right_answers.txt'
    test_path = args.dataset_path

    print('-------------------------------')
    print('Start loading the test set')
    logger.info('Start loading the test set')
    test_datasets = MyDataset(args, test_path, if_train=False)
    testset = DataTterator(test_datasets, args)
    print('The test set is loaded')
    logger.info('The test set is loaded')
    print('-------------------------------')

    print('Start loading model...')

    model_path = args.model_to_upload
    if args.device == 'cpu':
        state = torch.load(model_path, map_location=torch.device('cpu'))
    else:
        state = torch.load(model_path)
        # state = load_with_single_gpu(model_path)

    Bert = BertModel.from_pretrained(args.init_model)
    bert_config = Bert.config
    
    if args.add_pos_enc:
        print("Change pos_embeddings to 1536 len...")

        # word_emb
        word_emb = Bert.embeddings.word_embeddings.weight.data

        # token_type_emb
        token_type_emb = Bert.embeddings.token_type_embeddings.weight.data

        # pos_enc
        pos_enc = Bert.embeddings.position_embeddings.weight.data
        new_pos_enc = torch.concat((pos_enc, pos_enc, pos_enc), axis=0)

        # new config and embeddings structure
        bert_config.update({'max_position_embeddings': 1536})
        Bert.embeddings = BertEmbeddings(bert_config)

        # return pretrained weights
        Bert.embeddings.word_embeddings.weight.data = word_emb
        Bert.embeddings.token_type_embeddings.weight.data = token_type_emb
        Bert.embeddings.position_embeddings.weight.data = new_pos_enc

        print("Changed successful!")
    
    
    Bert.to(args.device)
    bert_param_optimizer = list(Bert.named_parameters())

    step_1_model = Step_1(args, bert_config)
    step_1_model.to(args.device)
    step_1_param_optimizer = list(step_1_model.named_parameters())

    step2_forward_model = Step_2_forward(args, bert_config)
    step2_forward_model.to(args.device)
    forward_step2_param_optimizer = list(step2_forward_model.named_parameters())

    step2_reverse_model = Step_2_reverse(args, bert_config)
    step2_reverse_model.to(args.device)
    reverse_step2_param_optimizer = list(step2_reverse_model.named_parameters())

    training_param_optimizer = [
        {'params': [p for n, p in bert_param_optimizer]},
        {'params': [p for n, p in step_1_param_optimizer], 'lr': args.task_learning_rate},
        {'params': [p for n, p in forward_step2_param_optimizer], 'lr': args.task_learning_rate},
        {'params': [p for n, p in reverse_step2_param_optimizer], 'lr': args.task_learning_rate}]
    optimizer = AdamW(training_param_optimizer, lr=args.learning_rate)

    
    if args.model_to_upload != None:
        model_path = args.model_to_upload
        if args.device == 'cpu':
            state = torch.load(model_path, map_location=torch.device('cpu'))
        else:
            state = torch.load(model_path)
        
        new_state = {}
        for i in state['bert_model'].keys():
            new_state[i[7:]] = state['bert_model'][i]
        Bert.load_state_dict(new_state)
        
        new_state = {}
        for i in state['step_1_model'].keys():
            new_state[i[7:]] = state['step_1_model'][i]
        step_1_model.load_state_dict(new_state)
        
        new_state = {}
        for i in state['step2_forward_model'].keys():
            new_state[i[7:]] = state['step2_forward_model'][i]
        step2_forward_model.load_state_dict(new_state)
        
        new_state = {}
        for i in state['step2_reverse_model'].keys():
            new_state[i[7:]] = state['step2_reverse_model'][i]
        step2_reverse_model.load_state_dict(new_state)
        
        optimizer.load_state_dict(state['optimizer'])
        with torch.no_grad():
            Bert.eval()
            step_1_model.eval()
            step2_forward_model.eval()
            step2_reverse_model.eval()


    if args.multi_gpu:
        Bert = torch.nn.DataParallel(Bert)
        step_1_model = torch.nn.DataParallel(step_1_model)
        step2_forward_model = torch.nn.DataParallel(step2_forward_model)
        step2_reverse_model = torch.nn.DataParallel(step2_reverse_model)

    print("Model loading ended")

    eval(Bert, step_1_model, step2_forward_model, step2_reverse_model, testset, args)



def load_with_single_gpu(model_path):
    state_dict = torch.load(model_path)
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    final_state = {}
    for i in state_dict:
        for k, v in state_dict[i].items():
            name = k[7:]
            new_state_dict[name] = v
        final_state[i] = new_state_dict
        new_state_dict = OrderedDict()
    return  final_state

def main():
    parser = argparse.ArgumentParser(description="Train scrip")
    parser.add_argument('--model_dir', type=str, default="savemodel/", help='model path prefix')
    parser.add_argument('--model_to_upload', type=str, default=None)
    parser.add_argument('--add_pos_enc', default=False)
    parser.add_argument('--device', type=str, default="cuda", help='cuda or cpu')
    parser.add_argument("--init_model", default="pretrained_models/bert-base-uncased", type=str, required=False,help="Initial model.")
    parser.add_argument("--init_vocab", default="pretrained_models/bert-base-uncased", type=str, required=False,help="Initial vocab.")

    parser.add_argument("--bert_feature_dim", default=768, type=int, help="feature dim for bert")
    parser.add_argument("--do_lower_case", default=True, action='store_true',help="Set this flag if you are using an uncased model.")
    parser.add_argument("--max_seq_length", default=100, type=int,help="The maximum total input sequence length after WordPiece tokenization. Sequences longer than this will be truncated, and sequences shorter than this will be padded.")
    parser.add_argument("--drop_out", type=int, default=0.1, help="")
    parser.add_argument("--max_span_length", type=int, default=8, help="")
    parser.add_argument("--embedding_dim4width", type=int, default=200,help="")
    parser.add_argument("--task_learning_rate", type=float, default=1e-4)
    parser.add_argument("--learning_rate", type=float, default=1e-5)
    parser.add_argument("--accumulation_steps", type=int, default=1)
    parser.add_argument("--multi_gpu", default=False)
    parser.add_argument('--epochs', type=int, default=130, help='training epoch number')
    parser.add_argument("--train_batch_size", default=16, type=int, help="batch size for training")
    parser.add_argument("--RANDOM_SEED", type=int, default=2022, help="")
    '''修改了数据格式'''
    parser.add_argument("--dataset_path", default="",
                        help="")
    parser.add_argument("--dataset", default="", type=str,
                        help="specify the dataset")
    parser.add_argument('--mode', type=str, default="test", choices=["train", "test"], help='option: train, test')
    '''对相似Span进行attention'''
    # 分词中仅使用结果的首token
    parser.add_argument("--Only_token_head", default=False)
    # Choose the synthesis method of Span
    parser.add_argument('--span_generation', type=str, default="Max", choices=["Start_end", "Max", "Average", "CNN", "ATT"],
                        help='option: CNN, Max, Start_end, Average, ATT, SE_ATT')
    parser.add_argument('--ATT_SPAN_block_num', type=int, default=1, help="number of block in generating spans")

    # Whether to add a separation loss to the relevant span
    parser.add_argument("--kl_loss", default=True)
    parser.add_argument("--kl_loss_weight", type=int, default=0.5, help="weight of the kl_loss")
    parser.add_argument('--kl_loss_mode', type=str, default="KLLoss", choices=["KLLoss", "JSLoss", "EMLoss, CSLoss"],
                        help='选择分离相似Span的分离函数, KL散度、JS散度、欧氏距离以及余弦相似度')
    # Whether to use the filtering algorithm in the test
    parser.add_argument('--Filter_Strategy',  default=True, help='是否使用筛选算法去除冲突三元组')
    # Deprecated    Related Span attention
    parser.add_argument("--related_span_underline", default=False)
    parser.add_argument("--related_span_block_num", type=int, default=1, help="number of block in related span attention")

    # choose Cross Select the number of ATT blocks in Attention
    parser.add_argument("--block_num", type=int, default=1, help="number of block")
    parser.add_argument("--output_path", default='triples.json')
    # Enter and sort in the order of sentences
    parser.add_argument("--order_input", default=True, help="")
    '''Randomize input span sorting'''
    parser.add_argument("--random_shuffle", type=int, default=0, help="")
    # Verify model complexity
    parser.add_argument("--model_para_test", default=False)
    # Use Warm up to converge quickly
    parser.add_argument('--whether_warm_up', default=False)
    parser.add_argument('--warm_up', type=float, default=0.1)
    args = parser.parse_args()

    for k,v in sorted(vars(args).items()):
        logger.info(str(k) + '=' + str(v))
    if args.mode == 'train':
        train(args)
    else:
        test(args)


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        logger.info("keyboard break")


Overwriting run.py


In [11]:
!mkdir SBN_models

In [12]:
!mkdir log

In [13]:
%%writefile log.py


import logging
from datetime import datetime

now = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
logger = logging.getLogger("test")
logger.setLevel(level=logging.INFO)

handler = logging.FileHandler("log/"+now+".log", encoding='utf-8')
handler.setLevel(logging.INFO)

console = logging.StreamHandler()
console.setLevel(logging.INFO)

logger.addHandler(handler)
logger.addHandler(console)


Writing log.py


In [14]:
import wandb

In [15]:
wandb.login(key='')

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [33]:
!python run.py --dataset_path '/kaggle/input/dataset-sent-no-fake' --init_model 'ai-forever/ruBert-base' \
  --init_vocab 'ai-forever/ruBert-base' --mode 'train' \
  --multi_gpu True \
  --kl_loss False \
  --max_span_length 1 --train_batch_size 24 --max_seq_length 512 \
  --bert_feature_dim 768 --epochs 20

2024-04-30 10:43:11.014054: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-30 10:43:11.014114: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-30 10:43:11.015866: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
ATT_SPAN_block_num=1
Filter_Strategy=True
Only_token_head=False
RANDOM_SEED=2022
accumulation_steps=1
add_pos_enc=False
bert_feature_dim=768
block_num=1
dataset=
dataset_path=/kaggle/input/dataset-sent-no-fake
device=cuda
do_lower_case=True
drop_out=0.1
embedding_dim4width=200
epochs=20
init_model=ai-forever/ruBert-base
init_vocab=ai-forever/ruBert-base
kl_loss=F

In [30]:
# !python run.py --dataset_path '/kaggle/input/datset-full-no-fake' --init_model 'ai-forever/ruBert-base' \
#   --init_vocab 'ai-forever/ruBert-base' --mode 'train' \
#   --add_pos_enc True \
#   --multi_gpu True \
#   --max_span_length 2 --train_batch_size 4 --max_seq_length 1408 \
#   --bert_feature_dim 768 --epochs 20

2024-04-29 12:59:10.526698: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-29 12:59:10.526753: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-29 12:59:10.528416: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
ATT_SPAN_block_num=1
Filter_Strategy=True
Only_token_head=False
RANDOM_SEED=2022
accumulation_steps=1
add_pos_enc=True
bert_feature_dim=768
block_num=1
dataset=
dataset_path=/kaggle/input/datset-full-no-fake
device=cuda
do_lower_case=True
drop_out=0.1
embedding_dim4width=200
epochs=20
init_model=ai-forever/ruBert-base
init_vocab=ai-forever/ruBert-base
kl_loss=Tru

In [46]:
# # %cd /kaggle/working
# from IPython.display import FileLink
# FileLink('SBN_models/fake_start_base_full_model_4_0.6850574712643679.pt')

In [None]:
# sent
# 116 - gcn (correct version) ~ 53
# 117 - gcn - no spans started with punk - random_shuffle ==> ~0.5

# full
# 118 - gcn - no spans started with punk - random_shuffle - max_span_length=3 ==>
# 125 - fixed gcn - no spans started with punk - zeros adj matrix --0.57 at 10