In [None]:
!pip install wandb
!pip install transformers
!pip install tqdm
!pip install natasha
!python3 -m spacy download ru_core_news_md

In [3]:
%%writefile span_tagging.py

# span tagging
def form_raw_table(d, version='3D'):
    raw_table = [ ['' for _ in range(len(d['token']))] for _ in range(len(d['token']))]
    aspect_index = list(set([(x[0][0],x[0][-1]) for x in d['triplets']]))
    opinion_index = list(set([(x[1][0],x[1][-1]) for x in d['triplets']]))
    
    # schema
    candidate_senti_aspect_opinion_same = {(min(t[0][0],t[1][0]), max(t[0][1],t[1][1])): t[2] for t in d['triplets']}
    
    candidate_senti = candidate_senti_aspect_opinion_same
    
    for i in range(len(d['token'])):
        for j in range(i, len(d['token'])):
            
            if version == '3D':
                raw_table[i][j] = 'A-' if (i,j) in aspect_index else 'N-'
                raw_table[i][j] += ('O-' if (i,j) in opinion_index else 'N-')
                raw_table[i][j] += candidate_senti[(i,j)] if (i,j) in candidate_senti else 'N'
            elif version == '2D':
                raw_table[i][j] = 'A-' if (i,j) in aspect_index else ( 'O-' if (i,j) in opinion_index else 'N-')

                raw_table[i][j] += candidate_senti[(i,j)] if (i,j) in candidate_senti else 'N'
            elif version == '1D':
                raw_table[i][j] = 'A' if (i,j) in aspect_index else \
                                ('O' if (i,j) in opinion_index else \
                                ( candidate_senti[(i,j)] if (i,j) in candidate_senti  else \
                                'N')) 
    return raw_table

def form_label_id_map(version='3D'):
    label_list = []
    if version == '3D':
        for ifA in ['N','A']:
            for ifO in ['N','O']:
                for ifP in ['N','NEG','NEU','POS']:
                    label_list.append(ifA + '-' + ifO + '-' + ifP)
    elif version == '2D':
        for ifAO in ['N','O','A']:
                for ifP in ['N','NEG','NEU','POS']:
                    label_list.append(ifAO + '-' + ifP)
    elif version == '1D':
        label_list = ['N','NEG','NEU','POS','O','A']

    label2id = {x:idx for idx, x in enumerate(label_list)}
    id2label = {idx:x for idx, x in enumerate(label_list)}
    return label2id, id2label

def form_sentiment_id_map():
    label_list = ['N','NEG','NEU','POS']
    label2id = {x:idx for idx, x in enumerate(label_list)}
    id2label = {idx:x for idx, x in enumerate(label_list)}
    return label2id, id2label

def map_raw_table_to_id(raw_table, label2id):
    return [ [label2id.get(x,0) for x in y] for y in raw_table]

def map_id_to_raw_table(raw_table_id, id2label):
    return [[id2label[x] for x in y] for y in raw_table_id]


Writing span_tagging.py


In [4]:
%%writefile greedy_inference.py

import torch

# Algorithm 1: Greedy Inference
def loop_version_from_tag_table_to_triplets(tag_table, id2senti, version='3D'):
    
    raw_table_id = torch.tensor(tag_table)
    
    # line 1 to line 4  (get aspect/opinion/sentiment snippet)
    if version == '1D': # {N, NEG, NEU, POS, O, A}
        if_aspect = (raw_table_id == 5) > 0
        if_opinion = (raw_table_id == 4) > 0
        if_triplet = raw_table_id * ((raw_table_id > 0) * (raw_table_id < 4)) 
    else: # 2D: {N,O,A} - {N, NEG, NEU, POS}  #3D: {N,A} - {N,O} - {N, NEG, NEU, POS}
        if_aspect = (raw_table_id & torch.tensor(8)) > 0
        if_opinion = (raw_table_id & torch.tensor(4)) > 0
        if_triplet = (raw_table_id & torch.tensor(3))
    
    m = if_triplet.nonzero()
    senti = if_triplet[m[:,0],m[:,1]].unsqueeze(dim=-1)
    candidate_triplets = torch.cat([m,senti,m.sum(dim=-1,keepdim=True)],dim=-1).tolist()
    candidate_triplets.sort(key = lambda x:(x[-1],x[0]))
    
    
    valid_triplets = []
    
    valid_triplets_set = set([])
    
    
    # line 5 to line 24 (look into every sentiment snippet)
    for r_begin, c_end, p, _ in candidate_triplets:
        
        #####################################################################################################
        # CASE-1: aspect-opinion        
        aspect_candidates = guarantee_list((if_aspect[r_begin, r_begin:(c_end+1)].nonzero().squeeze()+r_begin).tolist()) # line 7
        opinion_candidates = guarantee_list((if_opinion[r_begin:(c_end+1),c_end].nonzero().squeeze()+r_begin).tolist())  # line 8
        
        
        if len(aspect_candidates) and len(opinion_candidates):  # line 9
            select_aspect_c = -1 if (len(aspect_candidates) == 1 or aspect_candidates[-1] != c_end) else -2     # line 10
            select_opinion_r = 0 if (len(opinion_candidates) == 1 or opinion_candidates[0] != r_begin) else 1   # line 11
            
            # line 12
            a_ = [r_begin, aspect_candidates[select_aspect_c]]  
            o_ = [opinion_candidates[select_opinion_r], c_end] 
            s_ = id2senti[p] #id2label[p]
            
            # line 13
            if str((a_,o_,s_)) not in valid_triplets_set:
                valid_triplets.append((a_,o_,s_))
                valid_triplets_set.add(str((a_,o_,s_)))
            
            
        #####################################################################################################    
        # CASE-2: opinion-aspect
        opinion_candidates = guarantee_list((if_opinion[r_begin, r_begin:(c_end+1)].nonzero().squeeze()+r_begin).tolist())   # line 16
        aspect_candidates = guarantee_list((if_aspect[r_begin:(c_end+1),c_end].nonzero().squeeze()+r_begin).tolist())        # line 17

        if len(aspect_candidates) and len(opinion_candidates):  # line 18
            select_opinion_c = -1 if (len(opinion_candidates) == 1 or opinion_candidates[-1] != c_end) else -2 # line 19
            select_aspect_r = 0 if (len(aspect_candidates) == 1 or aspect_candidates[0] != r_begin) else 1     # line 20
            
            # line 21
            o_ = [r_begin, opinion_candidates[select_opinion_c]]
            a_ = [aspect_candidates[select_aspect_r], c_end]
            s_ = id2senti[p] #id2label[p]
            
            # line 22
            if str((a_,o_,s_)) not in valid_triplets_set:
                valid_triplets.append((a_,o_,s_))
                valid_triplets_set.add(str((a_,o_,s_)))
    return {
        'aspects': if_aspect.nonzero().squeeze().tolist(), # for ATE
        'opinions': if_opinion.nonzero().squeeze().tolist(), # for OTE
        'triplets': sorted(valid_triplets, key=lambda x:(x[0][0],x[0][-1],x[1][0],x[1][-1])) # line 25
    }

def guarantee_list(l):
    if type(l) != list:
        l = [l]
    return l

Writing greedy_inference.py


In [12]:
%%writefile ASTE_dataloader.py

import torch
import numpy as np
from collections import Counter
from torch.utils.data import Dataset

from vocab import *
from span_tagging import form_raw_table,map_raw_table_to_id
from tqdm import tqdm

from natasha import (
    Segmenter,
    
    NewsEmbedding,
    NewsMorphTagger,
    NewsSyntaxParser,
    
    Doc
)

segmenter = Segmenter()
emb = NewsEmbedding()
morph_tagger = NewsMorphTagger(emb)
syntax_parser = NewsSyntaxParser(emb)


def make_adj_matrix(bert_tokens, tokenizer, max_len, sep_token):
    sent = []
    for i in bert_tokens:
        sent.append(tokenizer.decode([i]))
        if i == sep_token:
            break
    new_sent = ""
    new_inds = [] # from tokens to poses in text
    new_inds_mapping = {}
    index = -1
    for idx, i in enumerate(sent[1:-1]):
        if i[:2] == '##':
            new_sent += i[2:]
        else:
            new_sent += " "
            new_sent += i
            index += 1
        new_inds.append(index)

    new_inds_mapping[0] = [0]
    for idx, i in enumerate(new_inds):
        if new_inds_mapping.get(i + 1):
            new_inds_mapping[i + 1].append(idx + 1)
        else:
            new_inds_mapping[i + 1] = [idx + 1]

    new_sent = new_sent.strip()
    text = new_sent

    splitted_text = text.split(" ")

    doc = Doc(text)
    doc.segment(segmenter)
    doc.tag_morph(morph_tagger)
    doc.parse_syntax(syntax_parser)
    
    doc_sents_lens = [0]
    for i in doc.sents:
        doc_sents_lens.append(doc_sents_lens[-1] + len(i.tokens))
    
    cnt = 0
    i = 0
    j = 0
    splitted_mapping = {0:[0]} # from poses from the text to segmented words
    while i < len(doc.tokens) and j < len(splitted_text):
        cur_nat_text = doc.tokens[i].text
        cur_our_text = splitted_text[j]
        if cur_nat_text == cur_our_text:
            splitted_mapping[i + 1] = [j + 1]
            i += 1
            j += 1
        else:
            splitted_mapping[i + 1] = [j + 1]
            if len(cur_nat_text) < len(cur_our_text):
                while cur_nat_text != cur_our_text:
                    i += 1
                    splitted_mapping[i + 1] = [j + 1]

                    cur_nat_text += doc.tokens[i].text
            elif len(cur_nat_text) > len(cur_our_text):
                while cur_nat_text != cur_our_text:
                    j += 1
                    splitted_mapping[i + 1].append(j + 1)
                    cur_our_text += splitted_text[j]
            else:
                raise "???"
            i += 1
            j += 1

    adj_matrix = np.eye(max_len, max_len)

    for i in doc.tokens:
        sent_id, cur_id = [int(j)  for j in i.id.split('_')]
        head_sent_id, head_id = [int(j) for j in i.head_id.split('_')]
        cur_words_ids = []
        for j in splitted_mapping[doc_sents_lens[sent_id - 1] + cur_id]:
            for k in new_inds_mapping[j]:
                cur_words_ids.append(k)
        cur_words_head_ids = []
        for j in splitted_mapping[doc_sents_lens[head_sent_id - 1] + head_id]:
            for k in new_inds_mapping[j]:
                cur_words_head_ids.append(k)
        for i in cur_words_ids:
            for j in cur_words_head_ids:
                adj_matrix[i][j] = 1

    return torch.FloatTensor(adj_matrix).to_sparse()


class ASTE_End2End_Dataset(Dataset):
    def __init__(self, file_name, vocab = None, version = '3D', tokenizer = None, max_len = 128, lower=True, is_clean = True):
        super().__init__()
        
        self.max_len = max_len
        self.lower = lower
        self.version = version
        
        if type(file_name) is str:
            with open(file_name,'r',encoding='utf-8') as f:
                lines = f.readlines()
                self.raw_data = [line2dict(l, is_clean = is_clean) for l in lines]
        else:
            self.raw_data = file_name
        
        self.tokenizer = tokenizer
        self.data = self.preprocess(self.raw_data, vocab=vocab, version=version)
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
    def text2bert_id(self, token):
        re_token = []
        word_mapback = []
        word_split_len = []
        for idx, word in enumerate(token):
            temp = self.tokenizer.tokenize(word)
            re_token.extend(temp)
            word_mapback.extend([idx] * len(temp))
            word_split_len.append(len(temp))
        re_id = self.tokenizer.convert_tokens_to_ids(re_token)
        return re_id, word_mapback, word_split_len
    
    def preprocess(self, data, vocab, version):
        
        token_vocab = vocab['token_vocab']
        label2id = vocab['label_vocab']['label2id']
        processed = []
        max_len = self.max_len
        CLS_id = self.tokenizer.convert_tokens_to_ids([self.tokenizer.cls_token])
        SEP_id = self.tokenizer.convert_tokens_to_ids([self.tokenizer.sep_token])
        
        for d in tqdm(data, 'Loading data...'):
            golden_label = map_raw_table_to_id(form_raw_table(d, version=version),label2id) if 'triplets' in d else None
            tok = d['token']
            if self.lower:
                tok = [t.lower() for t in tok]
            
            text_raw_bert_indices, word_mapback, _ = self.text2bert_id(tok)
            text_raw_bert_indices = text_raw_bert_indices[:max_len]
            word_mapback = word_mapback[:max_len]
            
            length = word_mapback[-1] + 1
            if length != len(tok):
                print(tok)
                print(len(tok))
                print(word_mapback)
                print(len(word_mapback))
            assert(length == len(tok))
            bert_length = len(word_mapback)
            
            bert_token = CLS_id + text_raw_bert_indices + SEP_id
            
            tok = tok[:length]
            adj_matrix = make_adj_matrix(bert_token, self.tokenizer, self.max_len, SEP_id)
            tok = [token_vocab.stoi.get(t, token_vocab.unk_index) for t in tok]
            
            temp = {
                'adj_matrix': adj_matrix,
                'token': tok,
                'token_length': length,
                'bert_token': bert_token,
                'bert_length': bert_length,
                'bert_word_mapback': word_mapback,
                'golden_label': golden_label

            }
            processed.append(temp)
        return processed
    
def ASTE_collate_fn(batch):
    batch_size = len(batch)
    
    re_batch = {}
    
    token = get_long_tensor([ batch[i]['token'] for i in range(batch_size)])
    
    adj_matrix = torch.cat([batch[i]['adj_matrix'].unsqueeze(0) for i in range(batch_size)], axis=0)
    token_length = torch.tensor([batch[i]['token_length'] for i in range(batch_size)])
    bert_token = get_long_tensor([batch[i]['bert_token'] for i in range(batch_size)])
    bert_length = torch.tensor([batch[i]['bert_length'] for i in range(batch_size)])
    bert_word_mapback = get_long_tensor([batch[i]['bert_word_mapback'] for i in range(batch_size)])

    golden_label = np.zeros((batch_size, token_length.max(), token_length.max()),dtype=np.int64)
    
    if batch[0]['golden_label'] is not None:
        for i in range(batch_size):
            golden_label[i, :token_length[i], :token_length[i]] = batch[i]['golden_label']

    golden_label = torch.from_numpy(golden_label)
    
    re_batch = {
        'adj_matrix': adj_matrix,
        'token' : token,
        'token_length' : token_length,
        'bert_token' : bert_token,
        'bert_length' : bert_length,
        'bert_word_mapback' : bert_word_mapback,
        'golden_label' : golden_label
    }
    
    return re_batch

def get_long_tensor(tokens_list, max_len=None):
    """ Convert list of list of tokens to a padded LongTensor. """
    batch_size = len(tokens_list)
    token_len = max(len(x) for x in tokens_list) if max_len is None else max_len
    tokens = torch.LongTensor(batch_size, token_len).fill_(0)
    for i, s in enumerate(tokens_list):
        tokens[i, : min(token_len,len(s))] = torch.LongTensor(s)[:token_len]
    return tokens


############################################################################
# data preprocess
def clean_data(l):
    token, triplets = l.strip().split('####')
    temp_t  = list(set([str(t) for t in eval(triplets) ]))
    return token + '####' + str([eval(t) for t in temp_t]) + '\n'

def line2dict(l, is_clean=False):
    if is_clean:
        l = clean_data(l)
    sentence, triplets = l.strip().split('####')
    start_end_triplets = []
    for t in eval(triplets):
        start_end_triplets.append(tuple([[t[0][0],t[0][-1]],[t[1][0],t[1][-1]],t[2]]))
    start_end_triplets.sort(key=lambda x: (x[0][0],x[1][-1])) # sort ?
    return dict(token=sentence.split(' '), triplets=start_end_triplets)


#############################################################################
# vocab
def build_vocab(dataset):
    tokens = []
    
    files = ['train_triplets.txt','dev_triplets.txt','test_triplets.txt']
    for file_name in files:
        file_path = dataset + '/' + file_name
        with open(file_path,'r',encoding='utf-8') as f:
            lines = f.readlines()
        
        for l in lines:
            cur_token = l.strip().split('####')[0].split()
            tokens.extend(cur_token)
    return tokens

def load_vocab(dataset_dir,lower=True):
    tokens = build_vocab(dataset_dir)
    if lower:
        tokens = [w.lower() for w in tokens]
    token_counter = Counter(tokens)
    token_vocab = Vocab(token_counter, specials=["<pad>", "<unk>"])
    vocab = {'token_vocab':token_vocab}
    return vocab

Overwriting ASTE_dataloader.py


In [5]:
%%writefile vocab.py


import pickle

class Vocab(object):
    def __init__(self, counter, specials=["<pad>", "<unk>"]):
        self.pad_index = 0
        self.unk_index = 1
        counter = counter.copy()
        self.itos = list(specials)
        for tok in specials:
            del counter[tok]

        # sort by frequency, then alphabetically
        words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
        words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)

        for word, _ in words_and_frequencies:
            self.itos.append(word)

        # stoi is simply a reverse dict for itos
        self.stoi = {tok: i for i, tok in enumerate(self.itos)}

    def __eq__(self, other):
        if self.stoi != other.stoi:
            return False
        if self.itos != other.itos:
            return False
        return True

    def __len__(self):
        return len(self.itos)

    def extend(self, v):
        words = v.itos
        for w in words:
            if w not in self.stoi:
                self.itos.append(w)
                self.stoi[w] = len(self.itos) - 1
        return self

    @staticmethod
    def load_vocab(vocab_path: str):
        with open(vocab_path, "rb") as f:
            print('Loading vocab from:', vocab_path)
            return pickle.load(f)

    def save_vocab(self, vocab_path):
        with open(vocab_path, "wb") as f:
            print('Saving vocab to:', vocab_path)
            pickle.dump(self, f)

Writing vocab.py


In [6]:
%%writefile V_calc_metrics.py


class Metrics_V:
    def __init__(self):
        pass

    def _drop_duplicates(self, samples):
        return {sample_id: set(aspects) for sample_id, aspects in samples.items()}

    def calculate(self, true, predicted):
        true = self._drop_duplicates(true)
        predicted = self._drop_duplicates(predicted)

        assert len(true) == len(predicted)

        self.tp = 0
        self.fp = 0
        self.fn = 0
        for sample_id in (set(true) | set(predicted)):
            if sample_id not in true:
                raise ValueError(f"Sample id {sample_id} is not found in true aspects")
            if sample_id not in predicted:
                raise ValueError(f"Sample id {sample_id} is not found in predicted aspects")

            true_aspects = true[sample_id]
            pred_aspects = predicted[sample_id]

            current_tp = sum(pred_aspect in true_aspects for pred_aspect in pred_aspects)
            current_fp = len(pred_aspects) - current_tp
            current_fn = sum(true_aspect not in pred_aspects for true_aspect in true_aspects)
            
            self.tp += current_tp
            self.fp += current_fp
            self.fn += current_fn

        # return self.precision()

    @property
    def precision(self) -> float:
        if self.tp + self.fp == 0:
            return 0
        return self.tp / (self.tp + self.fp)
    
    @property
    def recall(self) -> float:
        if self.tp + self.fn == 0:
            return 0
        return self.tp / (self.tp + self.fn)

    @property
    def f1(self) -> float:
        if self.precision + self.recall == 0:
            return 0
        return 2 * self.precision * self.recall / (self.precision + self.recall)


Writing V_calc_metrics.py


In [7]:
%%writefile evaluate.py

import torch
import json
from collections import Counter
from greedy_inference import loop_version_from_tag_table_to_triplets

from V_calc_metrics import Metrics_V

def evaluate_model(model, test_dataset, test_dataloader, id2senti, device='cuda', version = '3D', weight = None, saved_file=None):
    model.eval()
    model.bert.eval()
    total_loss = 0.0
    total_step = 0

    saved_token = [test_dataset.raw_data[idx]['token'] for idx in range(len(test_dataset.raw_data))]
    saved_golds = [test_dataset.raw_data[idx]['triplets'] for idx in range(len(test_dataset.raw_data))]
    
    saved_preds = []
    saved_aspects = []
    saved_opinions = []

    my_metric = Metrics_V()

    true_dict = {}
    pred_dict = {}

    with torch.no_grad():
        for batch in test_dataloader:
            inputs = {k:v.to(device) for k,v in batch.items()}
        
            outputs = model(inputs, weight)

            loss = outputs['loss']
            total_step += 1
            total_loss += loss.item()

            batch_raw_table_id = torch.argmax(outputs['logits'],dim=-1)
            for idx in range(len(batch_raw_table_id)):
                pred_triplets = loop_version_from_tag_table_to_triplets(tag_table = batch_raw_table_id[idx].tolist(), 
                                                            id2senti = id2senti, 
                                                            version=version)
                
                saved_preds.append(pred_triplets['triplets'])
                saved_aspects.append(pred_triplets['aspects'])
                saved_opinions.append(pred_triplets['opinions'])

    true_dict = {}
    pred_dict = {}
    num = 0
    for i_pred, i_gold in zip(saved_preds, saved_golds):
        pred_dict[num] = set()
        for i in i_pred:
            pred_dict[num].add("|".join([str(j) for j in i]))
        true_dict[num] = set()
        for i in i_gold:
            true_dict[num].add("|".join([str(j) for j in i]))
        num += 1
        
    my_metric.calculate(true_dict, pred_dict)
    print("My metrics: precision: {:.6f}\trecall: {:.6f}\tf1: {:.6f}".format(
        my_metric.precision, my_metric.recall, my_metric.f1))
    

    if saved_file is not None:
        with open(saved_file,'w') as f:
            combined = [
                dict(token=token, pred=pred, gold=gold, pred_aspect = pred_aspect, pred_opinion=pred_opinion) \
                    for token,pred,gold,pred_aspect,pred_opinion in zip(saved_token,saved_preds, saved_golds, saved_aspects, saved_opinions)
            ]
            json.dump(combined, f)

    loss = total_loss / total_step
    evaluate_dict = evaluate_predictions(preds = saved_preds, goldens = saved_golds, preds_aspect = saved_aspects, preds_opinion = saved_opinions)
    model.train()
    return loss, evaluate_dict


def evaluate_predictions(preds = None, goldens = None, preds_aspect = None, preds_opinion = None):
    counts = Counter()
    
    one_counts = Counter()
    multi_counts = Counter()
    aspect_counts = Counter()
    opinion_counts = Counter()

    
    ate_counts = Counter()
    ote_counts = Counter()
    
    for pred, gold, pred_aspect,pred_opinion in zip(preds,goldens,preds_aspect,preds_opinion):
        counts = evaluate_sample(pred, gold, counts)
    
        pred_one,pred_new_multi, pred_a_multi, pred_o_multi = get_spereate_triplets(pred)
        one,new_multi, a_multi, o_multi = get_spereate_triplets(gold)
        
        one_counts = evaluate_sample(pred_one, one, one_counts)
        multi_counts = evaluate_sample(pred_new_multi, new_multi, multi_counts)
        aspect_counts = evaluate_sample(pred_a_multi, a_multi, aspect_counts)
        opinion_counts = evaluate_sample(pred_o_multi, o_multi, opinion_counts)
        
        gold_ate = [[m[0],m[1]] for m in list(set([tuple(x[0]) for x in gold]))]
        gold_ote = [[m[0],m[1]] for m in list(set([tuple(x[1]) for x in gold]))]
        
        if len(pred_aspect) > 0 and type(pred_aspect[0]) is int:
            pred_aspect = [pred_aspect]
            
        if len(pred_opinion) > 0 and  type(pred_opinion[0]) is int:
            pred_opinion = [pred_opinion]
        
        ate_counts = evaluate_term(pred=pred_aspect, gold=gold_ate, counts = ate_counts)
        ote_counts = evaluate_term(pred=pred_opinion, gold = gold_ote, counts = ote_counts)
    
    all_scores = output_score_dict(counts)
    one_scores = output_score_dict(one_counts)
    multi_scores = output_score_dict(multi_counts)
    aspect_scores = output_score_dict(aspect_counts)
    opinion_scores = output_score_dict(opinion_counts)
    term_scores = output_score_dict_term(ate_counts, ote_counts)
    
    return all_scores, one_scores, multi_scores, aspect_scores, opinion_scores, term_scores

###############################################################################################
# ASTE (AOPE)
def evaluate_sample(pred, gold, counts = None):
    if counts is None:
        counts = Counter()
    
    correct_aspect = set()
    correct_opinion = set()
    
    # ASPECT.
    aspect_golden = list(set([tuple(x[0]) for x in gold]))
    aspect_predict = list(set([tuple(x[0]) for x in pred]))

    counts['aspect_golden'] += len(aspect_golden)
    counts['aspect_predict'] += len(aspect_predict)
    
    
    for prediction in aspect_predict:
        if any([prediction == actual for actual in aspect_golden]):
            counts['aspect_matched'] += 1
            correct_aspect.add(prediction)

    # OPINION.
    opinion_golden = list(set([tuple(x[1]) for x in gold]))
    opinion_predict = list(set([tuple(x[1]) for x in pred]))
    
    counts['opinion_golden'] += len(opinion_golden)
    counts['opinion_predict'] += len(opinion_predict)
    
    
    for prediction in opinion_predict:
        if any([prediction == actual for actual in opinion_golden]):
            counts['opinion_matched'] += 1
            correct_opinion.add(prediction)

    triplets_golden = [(tuple(x[0]),tuple(x[1]), x[2]) for x in gold]
    triplets_predict = [(tuple(x[0]),tuple(x[1]), x[2]) for x in pred]
    
    counts['triplet_golden'] += len(triplets_golden)
    counts['triplet_predict'] += len(triplets_predict)
    for prediction in triplets_predict:
        if any([prediction[:2] == actual[:2] for actual in triplets_golden]):
            counts['pair_matched'] += 1

        if any([prediction == actual for actual in triplets_golden]):
            counts['triplet_matched'] += 1
                

    # Return the updated counts.
    return counts

def output_score_dict(counts):
    scores_aspect = compute_f1(counts['aspect_predict'], counts['aspect_golden'], counts['aspect_matched'])
    scores_opinion = compute_f1(counts['opinion_predict'], counts['opinion_golden'], counts['opinion_matched'])
    
    scores_pair = compute_f1(counts['triplet_predict'], counts['triplet_golden'], counts['pair_matched'])
    scores_triplet = compute_f1(counts['triplet_predict'], counts['triplet_golden'], counts['triplet_matched'])
    
    return dict(aspect=scores_aspect, opinion=scores_opinion, pair=scores_pair, triplet=scores_triplet)

###############################################################################################
# ATE & OTE
def evaluate_term(pred, gold, counts=None):
    if counts is None:
        counts = Counter()

    counts['golden'] += len(gold)
    counts['predict'] += len(pred)
    
    for prediction in pred:
        if any([prediction == actual for actual in gold]):
            counts['matched'] += 1
    return counts


def output_score_dict_term(aspect_counts, opinion_counts):
    score_ate = compute_f1(aspect_counts['predict'], aspect_counts['golden'], aspect_counts['matched'])
    score_ote = compute_f1(opinion_counts['predict'], opinion_counts['golden'], opinion_counts['matched'])
    return dict(ate=score_ate, ote=score_ote)

###############################################################################################
# for additional experiments
def get_spereate_triplets(triplet):
    one_triplet = []
    new_triplet = []
    a_triplet = []
    o_triplet = []
    for t in triplet:
        if t[0][-1] != t[0][0] or t[1][-1] != t[1][0]:
            new_triplet.append(t)
        else:
            one_triplet.append(t)
        if t[0][-1] != t[0][0]:
            a_triplet.append(t)
        if t[1][-1] != t[1][0]:
            o_triplet.append(t)
    return one_triplet, new_triplet, a_triplet, o_triplet

def compute_f1(predict, golden, matched):
    # F1 score.
    precision = matched / predict if predict > 0 else 0
    recall = matched / golden if golden > 0 else 0
    f1 = (2 * precision * recall / (precision + recall)) if (precision + recall > 0) else 0
    return dict(precision=precision, recall=recall, f1=f1)


##################################################################################################
# print
def print_dict(d, select_k = None):
    if select_k is None:
        select_k = list(d.keys())
    
    print_str = '\t  \tP\t\tR\t\tF\n'
    for k in select_k: 
        append_plus = '*' if k in ['aspect','opinion','triplet'] else ''
        print_str += '{:^8}\t{:.6f}\t{:.6f}\t{:.6f}\n'.format(append_plus + k.upper(),
                                                                 d[k]['precision'], 
                                                                 d[k]['recall'], 
                                                                 d[k]['f1'])
    print(print_str)
    
    
def print_evaluate_dict(evaluate_dict):
    type_s = ['all','one','multi','multi_aspect','multi_opinion', 'term']
    
    for idx,m in enumerate(evaluate_dict):
        print('\n[ ' + type_s[idx], ']')
        if type_s[idx] in ['one','multi','multi_aspect','multi_opinion']:
            select_k = ['triplet']
        elif type_s[idx] in ['all']:
            select_k = ['pair','triplet']
        else:
            select_k = None
        
        print_dict(m, select_k = select_k)




Writing evaluate.py


In [8]:
%%writefile gcn.py

import torch.nn as nn
import torch.nn.functional as F
import torch
import spacy
import numpy as np

nlp = spacy.load("ru_core_news_md")

class GCN(nn.Module):

    def __init__(self, emb_dim=768, num_layers=1, gcn_dropout=0.7):             #此处dropout可以增大
        super(GCN, self).__init__()
        self.layers = num_layers
        self.emb_dim = emb_dim
        self.out_dim = emb_dim
        input_dim = self.emb_dim
        # gcn layer
        self.W = nn.ModuleList([nn.Linear(input_dim, input_dim) for i in range(self.layers)])
        self.gcn_drop = nn.Dropout(gcn_dropout)
#         self.relu = nn.ReLU()
        self.gelu = nn.GELU()


    def forward(self, adj, inputs, device):
        # gcn layer

        # adj (batch_size, len, len)
        # inputs (batch_size, len, emb_dim)

        adj = adj.to_dense()
        if inputs.shape[1] < adj.shape[1]:
            adj = adj[:, :inputs.shape[1], :inputs.shape[1]]
        denom = adj.sum(2).unsqueeze(2) + 1                 # batch_size, len, 1
#         mask = (adj.sum(2) + adj.sum(1)).eq(0).unsqueeze(2) # batch_size, len, 1

        for layer in range(self.layers):
            Ax = torch.bmm(adj, inputs)        # batch_size, len, emb_dim
            AxW = self.W[layer](Ax)            # batch_size, len, emb_dim
            AxW = AxW + self.W[layer](inputs)  # self loop
            AxW = AxW.to(device) / denom
            gAxW = self.gelu(AxW)              # batch_size, len, emb_dim
            if layer < self.layers - 1:
                inputs = self.gcn_drop(gAxW)
            else:
                inputs = gAxW
        return inputs, None # mask


Writing gcn.py


In [9]:
%%writefile model.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel
from transformers.models.bert.modeling_bert import BertEmbeddings
from gcn import GCN


class base_model(nn.Module):
    def __init__(self, pretrained_model_path, hidden_dim, dropout, args, class_n=16, 
                 span_average = False, gcn_num_layers=1, gcn_dropout=0.7):
        super().__init__()
        
        self.device = args.device
        
        # Encoder
        self.bert = BertModel.from_pretrained(pretrained_model_path)
        bert_config = self.bert.config
        if args.add_pos_enc:
            print("Change pos_embeddings to 1536 len...")

            # word_emb
            word_emb = self.bert.embeddings.word_embeddings.weight.data

            # token_type_emb
            token_type_emb = self.bert.embeddings.token_type_embeddings.weight.data

            # pos_enc
            pos_enc = self.bert.embeddings.position_embeddings.weight.data
            new_pos_enc = torch.concat((pos_enc, pos_enc * 2, pos_enc * 4), axis=0)
            # new_pos_enc = torch.repeat_interleave(pos_enc, 3, dim=0)

            # new config and embeddings structure
            bert_config.update({'max_position_embeddings': 1536})
            self.bert.embeddings = BertEmbeddings(bert_config)

            # return pretrained weights
            self.bert.embeddings.word_embeddings.weight.data = word_emb
            self.bert.embeddings.token_type_embeddings.weight.data = token_type_emb
            self.bert.embeddings.position_embeddings.weight.data = new_pos_enc

            print("Changed successful!")

        self.dense = nn.Linear(self.bert.pooler.dense.out_features, hidden_dim)
        self.span_average = span_average

        # Classifier
        self.classifier = nn.Linear(hidden_dim * 3 , class_n)
        
        # dropout
        self.layer_drop = nn.Dropout(dropout)
        
        # GCN
        self.gcn = GCN(num_layers=gcn_num_layers, gcn_dropout=gcn_dropout)
        
        
    def forward(self, inputs, weight=None):
        
        #############################################################################################
        # word representation
        bert_token = inputs['bert_token']
        attention_mask = (bert_token > 0).int()
        bert_word_mapback = inputs['bert_word_mapback']
        token_length = inputs['token_length']
        bert_length = inputs['bert_length']
        
        adj_martixes = inputs['adj_matrix']
                
        bert_output = self.bert(bert_token, attention_mask = attention_mask)
        h_gcn, _ = self.gcn(adj_martixes.to(self.device), bert_output.last_hidden_state, self.device)
        bert_out = bert_output.last_hidden_state + h_gcn # \hat{h}
        
        bert_seq_indi = sequence_mask(bert_length).unsqueeze(dim=-1)
        bert_out = bert_out[:, 1:max(bert_length) + 1, :] * bert_seq_indi.float()
        word_mapback_one_hot = (F.one_hot(bert_word_mapback).float() * bert_seq_indi.float()).transpose(1, 2)
        
        bert_out = torch.bmm(word_mapback_one_hot.float(), self.dense(bert_out))
        wnt = word_mapback_one_hot.sum(dim=-1)
        wnt.masked_fill_(wnt == 0, 1)
        bert_out = bert_out / wnt.unsqueeze(dim=-1)  # h_i
        #############################################################################################
        # span representation
        
        max_seq = bert_out.shape[1]
        
        token_length_mask = sequence_mask(token_length)
        candidate_tag_mask = torch.triu(
            torch.ones(max_seq,max_seq,dtype=torch.int64,device=bert_out.device),
            diagonal=0).unsqueeze(dim=0) * (token_length_mask.unsqueeze(dim=1) * token_length_mask.unsqueeze(dim=-1))
        
        boundary_table_features = torch.cat(
            [bert_out.unsqueeze(dim=2).repeat(1,1,max_seq,1), bert_out.unsqueeze(dim=1).repeat(1,max_seq,1,1)],
            dim=-1) * candidate_tag_mask.unsqueeze(dim=-1)  # h_i ; h_j 
        span_table_features = form_raw_span_features(
            bert_out, candidate_tag_mask, is_average = self.span_average) # sum(h_i,h_{i+1},...,h_{j})
        
        # h_i ; h_j ; sum(h_i,h_{i+1},...,h_{j})
        table_features = torch.cat([boundary_table_features, span_table_features],dim=-1)
       
        #############################################################################################
        # classifier
        logits = self.classifier(self.layer_drop(table_features)) * candidate_tag_mask.unsqueeze(dim=-1)
        
        outputs = {
            'logits':logits
        }
        
        if 'golden_label' in inputs and inputs['golden_label'] is not None:
            loss = calcualte_loss(logits, inputs['golden_label'],candidate_tag_mask, weight = weight)
            outputs['loss'] = loss
        
        return outputs


def sequence_mask(lengths, max_len=None):

    batch_size = lengths.numel()
    max_len = max_len or lengths.max()
    return torch.arange(0, max_len, device=lengths.device).type_as(lengths).unsqueeze(0).expand(
        batch_size, max_len
    ) < (lengths.unsqueeze(1))

def form_raw_span_features(v, candidate_tag_mask, is_average = True):
    new_v = v.unsqueeze(dim=1) * candidate_tag_mask.unsqueeze(dim=-1)
    span_features = torch.matmul(new_v.transpose(1,-1).transpose(2,-1), candidate_tag_mask.unsqueeze(dim=1).float()).transpose(2,1).transpose(2,-1)
    
    if is_average:
        _, max_seq, _ = v.shape
        sub_v = torch.tensor(range(1,max_seq+1), device = v.device).unsqueeze(dim=-1)  - torch.tensor(range(max_seq),device = v.device)
        sub_v  = torch.where(sub_v > 0, sub_v, 1).T
        
        span_features = span_features / sub_v.unsqueeze(dim=0).unsqueeze(dim=-1)
        
    return span_features

def calcualte_loss(logits, golden_label, candidate_tag_mask, weight=None):
    loss_func = nn.CrossEntropyLoss(weight = weight, reduction='none')
    return (loss_func(logits.view(-1,logits.shape[-1]), 
                      golden_label.view(-1)
                      ).view(golden_label.size()) * candidate_tag_mask).sum()
    

Writing model.py


In [10]:
%%writefile run.py

import os
import time
import torch
import random
import argparse
import numpy as np
from tqdm import tqdm
import wandb

from transformers import BertTokenizer
from torch.utils.data import DataLoader

from ASTE_dataloader import ASTE_End2End_Dataset, ASTE_collate_fn,load_vocab
from span_tagging import form_label_id_map, form_sentiment_id_map
from evaluate import evaluate_model,print_evaluate_dict
from gcn import GCN


wandb.login(key='')


def totally_parameters(model):
    n_params = sum([p.nelement() for p in model.parameters()])
    return n_params

def ensure_dir(d, verbose=True):
    if not os.path.exists(d):
        if verbose:
            print("Directory {} do not exist; creating...".format(d))
        os.makedirs(d)

def form_weight_n(n):
    if n  > 6:
        weight = torch.ones(n)
        index_range = torch.tensor(range(n))
        weight = weight + ((index_range & 3) > 0)
    else:
        weight = torch.tensor([1.0,2.0,2.0,2.0,3.0,3.0])
    
    return weight

def train_and_evaluate(model_func, args, save_specific=False):
    print('=========================================================================================================')
    
    set_random_seed(args.seed)
    
    tokenizer = BertTokenizer.from_pretrained(args.pretrained_model, do_lower_case=True)
    dataset_dir = args.dataset_dir + '/' + args.dataset
    saved_dir = args.saved_dir + '/' + args.dataset
    ensure_dir(saved_dir)
     
    vocab = load_vocab(dataset_dir = dataset_dir)

    label2id, id2label = form_label_id_map(args.version)
    senti2id, id2senti = form_sentiment_id_map()
    
    vocab['label_vocab'] = dict(label2id=label2id,id2label=id2label)
    vocab['senti_vocab'] = dict(senti2id=senti2id,id2senti=id2senti)

    class_n = len(label2id)
    args.class_n = class_n
    weight = None
    if args.with_weight is True:
        weight = form_weight_n(class_n).to(args.device)
    print('> label2id:', label2id)
    print('> weight:', args.with_weight, weight)
    print(args)

    print('> Load model...')
    base_model = model_func(pretrained_model_path = args.pretrained_model,
                            hidden_dim = args.hidden_dim,
                            dropout = args.dropout_rate,
                            args = args,
                            class_n = class_n,
                            span_average = args.span_average,
                            gcn_num_layers=3, 
                            gcn_dropout=0.5
                            ).to(args.device)
    
    print('> # parameters', totally_parameters(base_model))
    
    print('> Load dataset...')
    train_dataset = ASTE_End2End_Dataset(file_name = os.path.join(dataset_dir, 'train_triplets.txt'),
                                         version = args.version,
                                        vocab = vocab,
                                        tokenizer = tokenizer,
                                        max_len = args.max_len)
    valid_dataset = ASTE_End2End_Dataset(file_name = os.path.join(dataset_dir, 'dev_triplets.txt'),
                                         version = args.version,
                                        vocab = vocab,
                                        tokenizer = tokenizer,
                                        max_len = args.max_len)
    test_dataset = ASTE_End2End_Dataset(file_name = os.path.join(dataset_dir, 'test_triplets.txt'),
                                        version = args.version,
                                        vocab = vocab,
                                        tokenizer = tokenizer,
                                        max_len = args.max_len)
    
    train_dataloader = DataLoader(train_dataset, batch_size = args.batch_size, collate_fn = ASTE_collate_fn, shuffle = True)
    valid_dataloader = DataLoader(valid_dataset, batch_size = args.batch_size, collate_fn = ASTE_collate_fn, shuffle = False)
    test_dataloader = DataLoader(test_dataset, batch_size = args.batch_size, collate_fn = ASTE_collate_fn, shuffle = False)


    optimizer = get_bert_optimizer(base_model, args)

    triplet_max_f1 = 0.0

    best_model_save_path = saved_dir +  '/' + args.dataset + '_' +  args.version + '_' + str(args.with_weight) +'_best.pkl'
    
    wandb.init(
        project='aste-STAGE',
        config=args
    )
    
    scaler = torch.cuda.amp.GradScaler()
    
    print('> Training...')
    for epoch in range(1, args.num_epoch+1):
        train_loss = 0.
        total_step = 0
        
        epoch_begin = time.time()
        for batch in tqdm(train_dataloader, 'Epoch:{}'.format(epoch)):
            base_model.train()
            base_model.bert.train()
            optimizer.zero_grad()
            
            inputs = {k:v.to(args.device) for k,v in batch.items()}
            outputs = base_model(inputs, weight)
            
            loss = outputs['loss']
            
            total_step += 1
            train_loss += loss.item()
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
#             loss.backward()
#             optimizer.step()
        
        valid_loss, valid_results = evaluate_model(base_model, valid_dataset, valid_dataloader, 
                                                   id2senti = id2senti, 
                                                   device = args.device, 
                                                   version = args.version, 
                                                   weight = weight)
        
        wandb.log({
            'train_loss': train_loss / total_step,
            'val_loss': valid_loss,
            'val_precision': valid_results[0]['triplet']['precision'],
            'val_recall':valid_results[0]['triplet']['recall'],
            'val_f1': valid_results[0]['triplet']['f1'],
        })

        print('\ttrain_loss:{:.4f}\tvalid_loss:{:.4f} [{:.4f}s]'.format(train_loss / total_step, 
                                                                        valid_loss,
                                                                        time.time()-epoch_begin))
                
        print('\ttriplet_precision:{:.4f} \ttriplet_recall:{:.4f} \ttriplet_f1:{:.4f}'.format( 
                                                    valid_results[0]['triplet']['precision'], 
                                                    valid_results[0]['triplet']['recall'], 
                                                    valid_results[0]['triplet']['f1'],
                                                    ))
        # save model based on the best f1 scores
        # if valid_results[0]['triplet']['f1'] > triplet_max_f1:
        #     triplet_max_f1 = valid_results[0]['triplet']['f1']
            
        #     evaluate_model(base_model, test_dataset, test_dataloader, 
        #                     id2senti = id2senti, 
        #                     device = args.device, 
        #                     version = args.version, 
        #                     weight = weight)
        #     torch.save(base_model, best_model_save_path)
        print("Test results...")
        _, test_results = evaluate_model(base_model, test_dataset, test_dataloader, 
                                             id2senti = id2senti, 
                                             device = args.device, 
                                             version = args.version, 
                                             weight = weight,
                                             saved_file= "new_gcn3_2304_run_sent_" + str(epoch) + ".json")
            
    
    # saved_best_model = torch.load(best_model_save_path)
    # if save_specific:
    #     torch.save(saved_best_model, best_model_save_path.replace('_best','_' + str(args.seed) +'_best'))
    
    saved_file = (saved_dir + '/' + args.saved_file) if args.saved_file is not None else None
    
    print('> Testing...')
    # model performance on the test set
    _, test_results = evaluate_model(base_model, test_dataset, test_dataloader, 
                                             id2senti = id2senti, 
                                             device = args.device, 
                                             version = args.version, 
                                             weight = weight,
                                             saved_file= saved_file)
    

    print('------------------------------')
    
    print('Dataset:{}, test_f1:{:.2f} | version:{} lr:{} bert_lr:{} seed:{} dropout:{}'.format(args.dataset,test_results[0]['triplet']['f1'],
                                                                                                 args.version, args.lr, args.bert_lr, 
                                                                                                 args.seed, args.dropout_rate))
    print_evaluate_dict(test_results)

    wandb.finish()

    return test_results




def get_bert_optimizer(model, args):

    no_decay = ['bias', 'LayerNorm.weight']
    diff_part = ['bert.embeddings', 'bert.encoder']
    
    

    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if
                    not any(nd in n for nd in no_decay) and any(nd in n for nd in diff_part)],
            "weight_decay": args.l2,
            "lr": args.bert_lr
        },
        {
            "params": [p for n, p in model.named_parameters() if
                    any(nd in n for nd in no_decay) and any(nd in n for nd in diff_part)],
            "weight_decay": 0.0,
            "lr": args.bert_lr
        },
        {
            "params": [p for n, p in model.named_parameters() if
                    not any(nd in n for nd in no_decay) and not any(nd in n for nd in diff_part)],
            "weight_decay": args.l2,
            "lr": args.lr
        },
        {
            "params": [p for n, p in model.named_parameters() if
                    any(nd in n for nd in no_decay) and not any(nd in n for nd in diff_part)],
            "weight_decay": 0.0,
            "lr": args.lr
        },
    ]
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, eps=args.adam_epsilon)

    return optimizer

def set_random_seed(seed):

    os.environ['PYTHONHASHSEED'] =str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic =True

def get_parameters():
    parser = argparse.ArgumentParser()
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--dataset_dir', type=str,default='./data/ASTE-Data-V2-EMNLP2020')
    parser.add_argument('--saved_dir', type=str, default='saved_models')
    parser.add_argument('--saved_file', type=str, default=None)
    parser.add_argument('--pretrained_model', type=str, default='bert-base-uncased')
    parser.add_argument('--dataset', type=str, default='14lap')
    parser.add_argument('--add_pos_enc', default=False)
    
    parser.add_argument('--version', type=str, default='3D', choices=['3D','2D','1D'])
    
    parser.add_argument('--seed', type=int, default=64)
    
    parser.add_argument('--hidden_dim', type=int, default=200)
    parser.add_argument('--num_epoch', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--max_len', type=int, default=256)
    
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--bert_lr', type=float, default=2e-5)
    parser.add_argument('--l2', type=float, default=0.0)
    parser.add_argument('--dropout_rate', type=float, default=0.5)
    parser.add_argument('--adam_epsilon', default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    
    # loss
    parser.add_argument('--with_weight', default=True, action='store_true')
    parser.add_argument('--span_average', default=False, action='store_true')
    
    args = parser.parse_args()
    
    return args


def run():
    from model import base_model
    args = get_parameters()
#     args.with_weight = True # default true here
        
    train_and_evaluate(base_model, args)
    

if __name__ == '__main__':
    run()

Writing run.py


In [16]:
!python run.py --device 'cuda' \
    --dataset_dir '' --dataset '/kaggle/input/aste-dataset-full' \
    --pretrained_model 'ai-forever/ruBert-base' \
    --version '1D' \
    --max_len 1448 \
    --hidden_dim 32 \
    --num_epoch 20 \
    --batch_size 2 \
    --seed 42 \
    --add_pos_enc True \
    --saved_file 'new_gcn_2704_run_full_try2.json'

[34m[1mwandb[0m: Currently logged in as: [33mlmartinson[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
> label2id: {'N': 0, 'NEG': 1, 'NEU': 2, 'POS': 3, 'O': 4, 'A': 5}
> weight: True tensor([1., 2., 2., 2., 3., 3.], device='cuda:0')
Namespace(device='cuda', dataset_dir='', saved_dir='saved_models', saved_file='new_gcn_2704_run_full_try2.json', pretrained_model='ai-forever/ruBert-base', dataset='/kaggle/input/aste-dataset-full', add_pos_enc='True', version='1D', seed=42, hidden_dim=32, num_epoch=20, batch_size=2, max_len=1448, lr=0.001, bert_lr=2e-05, l2=0.0, dropout_rate=0.5, adam_epsilon=1e-08, with_weight=True, span_average=False, class_n=6)
> Load model...
  return self.fget.__get__(instance, owner)()
Change pos_embeddings to 1536 len...
Changed successful!
> # parameters 180890726
> Load dataset...
Loading data...:   0%|                                 | 0/2285 [00:00<?, ?it/s]2024-

In [18]:
# sent
# 25 precision: 0.519058	recall: 0.569149	f1: 0.542950
# 26 worse 
# 27 with lower_case precision: 0.453846	recall: 0.627660	f1: 0.526786
# 28 same as previos but 4 epochs 0.563338	recall: 0.574468	f1: 0.568849
# 43 diff weights(A 3 O 3) + hd 64 precision: 0.500330	recall: 0.575988	f1: 0.535500
# 44 diff weights(A 3 O 3) + hd 128 precision
# 78 gcn, weights, hd 64 ... 0.52  
# 79 gcn eye, weights, hd 64 ... 0.52  
# 97 gcn updated adj -- so-so 1 epoch is enought
# 98 Gelu + 3 gcn_layers gcn dropout 0.5

# full
# 30 precision: 0.616236	recall: 0.500750	f1: 0.552523
# 40 precision: 0.729138	recall: 0.399550	f1: 0.516223
# 41 diff weights(A 2 O 2) precision: 0.568182	recall: 0.505997	f1: 0.535289
# 42 diff weights(A 3 O 3) precision: 0.566641	recall: 0.557721	f1: 0.562146
# 84 0.592593	recall: 0.527736	f1: 0.558287
# 103 fixed gcn 1D precision: 0.545067	recall: 0.575712	f1: 0.559971


In [None]:
!ls