In [1]:
from collections import Counter
from typing import *
import random
import json
import logging
import datetime

In [2]:
from tqdm import tqdm
import numpy as np
from sklearn.model_selection import KFold

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torch.cuda.amp import autocast, GradScaler

import transformers
from transformers import AutoConfig, AutoModel, AutoTokenizer

In [3]:
transformers.logging.set_verbosity_error() # only report errors.

In [4]:
import sys
sys.path.append('..')

from trainer import BasicTrainer
from model.base_par import DepParser
from utils import arc_rel_loss, uas_las, to_cuda

## Config

In [5]:
class CFG:
    data_file = '../data/train_50.json'
    test_file = '../data/test.json'
    plm = 'hfl/chinese-electra-180g-base-discriminator'
    random_seeds = [40, 41, 42, 43, 44]
    shot = 40
    num_epochs = 30
    batch_size = 32
    plm_lr = 2e-5
    head_lr = 1e-4
    weight_decay = 0.01
    dropout = 0.2
    grad_clip = 1
    scheduler = 'linear'
    warmup_ratio = 0.1
    num_early_stop = 5
    max_length = 160
    num_labels = 35
    hidden_size = 400
    print_every_ratio = 0.5
    # eval_every = 100
    cuda = True
    fp16 = True
    eval_strategy = 'epoch'
    mode = 'training'

In [6]:
logger = logging.getLogger("logger")
logger.setLevel(logging.INFO)
print(logger)
print(type(logger))

if CFG.mode == 'training':
    fh = logging.FileHandler(filename=f"../results/few_shot/{CFG.shot}-shot/res.log",mode='w')
else:
    fh = logging.FileHandler(filename=f"../results/few_shot/res.log",mode='w')
logger.addHandler(fh)

time_now = datetime.datetime.now().isoformat()
print(time_now)
logger.info(f'=-=-=-=-=-=-=-=-={time_now}=-=-=-=-=-=-=-=-=-=')

=-=-=-=-=-=-=-=-=2022-11-20T14:31:27.380720=-=-=-=-=-=-=-=-=-=


<Logger logger (INFO)>
<class 'logging.Logger'>
2022-11-20T14:31:27.380720


## Seed and Device

In [7]:
def seed_everything(seed):
    np.random.seed(seed%(2**32-1))
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic =True
    torch.backends.cudnn.benchmark = False

# seed_everything()

In [8]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(f'Using device: {device}')

Using device: cuda


## Data

In [9]:
rel_dct = {
    'root': '根节点',
    'sasubj-obj': '同主同宾',
    'sasubj': '同主语',
    'dfsubj': '不同主语',
    'subj': '主语',
    'subj-in': '内部主语',
    'obj': '宾语',
    'pred': '谓语',
    'att': '定语',
    'adv': '状语',
    'cmp': '补语',
    'coo': '并列',
    'pobj': '介宾',
    'iobj': '间宾',
    'de': '的',
    'adjct': '附加',
    'app': '称呼',
    'exp': '解释',
    'punc': '标点',
    'frag': '片段',
    'repet': '重复',
    # rst
    'attr': '归属',
    'bckg': '背景',
    'cause': '因果',
    'comp': '比较',
    'cond': '状况',
    'cont': '对比',
    'elbr': '阐述',
    'enbm': '目的',
    'eval': '评价',
    'expl': '解释-例证',
    'joint': '联合',
    'manner': '方式',
    'rstm': '重申',
    'temp': '时序',
    # 'tp-chg': '主题变更',
    # 'prob-sol': '问题-解决',
    # 'qst-ans': '疑问-回答',
    # 'stm-rsp': '陈述-回应',
    # 'req-proc': '需求-处理',
}

In [10]:
rel2id = {key:idx for idx, key in enumerate(rel_dct.keys())}
print(rel2id)

{'root': 0, 'sasubj-obj': 1, 'sasubj': 2, 'dfsubj': 3, 'subj': 4, 'subj-in': 5, 'obj': 6, 'pred': 7, 'att': 8, 'adv': 9, 'cmp': 10, 'coo': 11, 'pobj': 12, 'iobj': 13, 'de': 14, 'adjct': 15, 'app': 16, 'exp': 17, 'punc': 18, 'frag': 19, 'repet': 20, 'attr': 21, 'bckg': 22, 'cause': 23, 'comp': 24, 'cond': 25, 'cont': 26, 'elbr': 27, 'enbm': 28, 'eval': 29, 'expl': 30, 'joint': 31, 'manner': 32, 'rstm': 33, 'temp': 34}


In [11]:
tokenizer = AutoTokenizer.from_pretrained(CFG.plm)
print(len(tokenizer))
 
# num_added_toks = tokenizer.add_tokens(['[root]', '[qst]', '[ans]'], special_tokens=True)
# tokenizer.root_token = '[root]'
# tokenizer.root_token_ids = tokenizer('[root]')['input_ids'][1]
# print(f"add token: {tokenizer.root_token} {tokenizer.root_token_ids}")

# tokenizer.qst_token = '[qst]'
# tokenizer.qst_token_ids = tokenizer('[qst]')['input_ids'][1]
# print(f"add token: {tokenizer.qst_token} {tokenizer.qst_token_ids}")

# tokenizer.ans_token = '[ans]'
# tokenizer.ans_token_ids = tokenizer('[ans]')['input_ids'][1]
# print(f"add token: {tokenizer.ans_token} {tokenizer.ans_token_ids}")
# print(len(tokenizer))

CFG.tokenizer = tokenizer

21128


In [12]:
class Dependency():
    def __init__(self, idx, word, head, rel):
        self.id = idx
        self.word = word
        self.head = head
        self.rel = rel

    def __str__(self):
        # example:  1	上海	_	NR	NR	_	2	nn	_	_
        values = [str(self.idx), self.word, "_", "_", "_", "_", str(self.head), self.rel, "_", "_"]
        return '\t'.join(values)

    def __repr__(self):
        return f"({self.word}, {self.head}, {self.rel})"

In [13]:
def load_annoted(data_file, data_ids):
    with open(data_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
        
    sample_lst:List[List[Dependency]] = []
    
    for i, d in enumerate(data):
        if i not in data_ids:
            continue
        rel_dct = {}
        for tripple in d['relationship']:
            head, rel, tail = tripple
            head_uttr_idx, head_word_idx = [int(x) for x in head.split('-')]
            tail_uttr_idx, tail_word_idx = [int(x) for x in tail.split('-')]
            if head_uttr_idx != tail_uttr_idx:
                continue
            
            if not rel_dct.get(head_uttr_idx, None):
                rel_dct[head_uttr_idx] = {tail_word_idx: [head_word_idx, rel]}
            else:
                rel_dct[head_uttr_idx][tail_word_idx] = [head_word_idx, rel]
            
        for item in d['dialog']:
            turn = item['turn']
            utterance = item['utterance']
            # dep_lst:List[Dependency] = [Dependency(0, '[root]', -1, '_')]
            dep_lst:List[Dependency] = []
            
            for word_idx, word in enumerate(utterance.split(' ')):
                head_word_idx, rel = rel_dct[turn].get(word_idx + 1, [word_idx, 'adjct'])  # some word annoted missed, padded with last word and 'adjct'
                dep_lst.append(Dependency(word_idx + 1, word, head_word_idx, rel))  # start from 1
            
            sample_lst.append(dep_lst)
        
    return sample_lst

In [14]:
class DialogDataset(Dataset):
    def __init__(self, cfg, data_file, data_ids):
        self.cfg = cfg
        self.data_file = data_file
        self.inputs, self.offsets, self.heads, self.rels, self.masks = self.read_data(data_ids)
        
    def read_data(self, data_ids):
        inputs, offsets = [], []
        tags, heads, rels, masks = [], [], [], []
        
        for deps in tqdm(load_annoted(self.data_file, data_ids)):
            # another sentence
            seq_len = len(deps)

            word_lst = [] 
#                 head_tokens = np.ones(self.cfg.max_length, dtype=np.int64)*(-1)  # root index is 0, thus using -1 for padding 
            head_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)  # same as root index is 0, constrainting by mask 
            rel_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)
            mask_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)
            for i, dep in enumerate(deps):
                if i == seq_len or i + 1== self.cfg.max_length:
                    break

                word_lst.append(dep.word)

                if dep.head == -1 or dep.head + 1 >= self.cfg.max_length:
                    head_tokens[i+1] = 0
                    mask_tokens[i+1] = 0
                else:
                    head_tokens[i+1] = int(dep.head)
                    mask_tokens[i+1] = 1
#                     head_tokens[i] = dep.head if dep.head != '_' else 0
                rel_tokens[i+1] = rel2id.get(dep.rel, 0)

            tokenized = tokenizer.encode_plus(word_lst, 
                                              padding='max_length', 
                                              truncation=True,
                                              max_length=self.cfg.max_length, 
                                              return_offsets_mapping=True, 
                                              return_tensors='pt',
                                              is_split_into_words=True)
            inputs.append({"input_ids": tokenized['input_ids'][0],
                          "token_type_ids": tokenized['token_type_ids'][0],
                           "attention_mask": tokenized['attention_mask'][0]
                          })

#                 sentence_word_idx = np.zeros(self.cfg.max_length, dtype=np.int64)
            sentence_word_idx = []
            for idx, (start, end) in enumerate(tokenized.offset_mapping[0][1:]):
                if start == 0 and end != 0:
                    sentence_word_idx.append(idx)
#                         sentence_word_idx[idx] = idx
            if len(sentence_word_idx) < self.cfg.max_length - 1:
                sentence_word_idx.extend([0]* (self.cfg.max_length - 1 - len(sentence_word_idx)))
            offsets.append(torch.as_tensor(sentence_word_idx))
#                 offsets.append(sentence_word_idx)

            heads.append(head_tokens)
            rels.append(rel_tokens)
            masks.append(mask_tokens)
                    
        return inputs, offsets, heads, rels, masks

    def __getitem__(self, idx):
        return self.inputs[idx], self.offsets[idx], self.heads[idx], self.rels[idx], self.masks[idx]
    
    def __len__(self):
        return len(self.rels)

In [15]:
# dataset = DialogDataset(CFG)
# total_ids = list(range(len(dataset)))

## Tranining

In [16]:
if CFG.mode == 'training':    
    total_ids = list(range(50))
    for seed in CFG.random_seeds:   
        print(f'\nSEED {seed}')
        print('--------------------------------')
        logger.info(f'\n=========SEED {seed}===========')
        logger.info(f'-------------------------------')

        if CFG.cuda and torch.cuda.is_available:
            torch.cuda.empty_cache()

        seed_everything(seed=seed)

        random.shuffle(total_ids)

        train_ids = total_ids[0:CFG.shot]
        val_ids = total_ids[CFG.shot:2*CFG.shot]

        tr_dataset = DialogDataset(CFG, CFG.data_file, train_ids)
        va_dataset = DialogDataset(CFG, CFG.data_file, val_ids)

        print(f'---Data Size Train/Val: {len(tr_dataset)} / {len(va_dataset)}')
        logger.info(f'---Data Size Train/Val: {len(tr_dataset)} / {len(va_dataset)}')

        tr_iter = DataLoader(tr_dataset, batch_size=CFG.batch_size)
        va_iter = DataLoader(va_dataset, batch_size=CFG.batch_size * 2)

        model = DepParser(CFG)
        print('Loading Model....')
        trainer = BasicTrainer(model=model, 
                               trainset_size=len(tr_dataset), 
                               loss_fn=arc_rel_loss, 
                               metrics_fn=uas_las, 
                               logger=logger, 
                               config=CFG)

        best_res, best_state_dict = trainer.train(model=model, train_iter=tr_iter, val_iter=va_iter)
        print(best_res)
        with open(f"../results/few_shot/{CFG.shot}-shot/res.txt", 'a+') as f:
            f.write(f'{seed}\t {str(best_res)}\n')

        torch.save(best_state_dict, f'../results/few_shot/{CFG.shot}-shot/model_{seed}.bin')

        model = None

    logger.info('\n')


-------------------------------



SEED 40
--------------------------------


100%|██████████| 953/953 [00:02<00:00, 399.44it/s]
100%|██████████| 214/214 [00:00<00:00, 406.57it/s]
---Data Size Train/Val: 953 / 214


---Data Size Train/Val: 953 / 214
Loading Model....


  0%|          | 0/30 [00:00<?, ?it/s]--epoch 0, step 0, loss 9.30202579498291
  {'UAS': 0.00909090909090909, 'LAS': 0.0}
  3%|▎         | 1/30 [00:05<02:33,  5.28s/it]--Evaluation:
Avg Loss: 5.1037017697485805  UAS: 0.30405405405405406  LAS: 0.2536855036855037 

--Best Evaluation: 
-loss: 5.1037017697485805  UAS: 0.30405405405405406  LAS: 0.2536855036855037 

  7%|▋         | 2/30 [00:10<02:34,  5.52s/it]--Evaluation:
Avg Loss: 4.0209948103004525  UAS: 0.4656019656019656  LAS: 0.4189189189189189 

--Best Evaluation: 
-loss: 4.0209948103004525  UAS: 0.4656019656019656  LAS: 0.4189189189189189 

 10%|█         | 3/30 [00:16<02:30,  5.57s/it]--Evaluation:
Avg Loss: 3.171771303515568  UAS: 0.5730958230958231  LAS: 0.5386977886977887 

--Best Evaluation: 
-loss: 3.171771303515568  UAS: 0.5730958230958231  LAS: 0.5386977886977887 

 13%|█▎        | 4/30 [00:22<02:27,  5.65s/it]--Evaluation:
Avg Loss: 2.6634005965473495  UAS: 0.6369778869778869  LAS: 0.6124078624078624 

--Best Evaluation: 


[1.9751488605392313, 0.855036855036855, 0.8243243243243243]



-------------------------------



SEED 41
--------------------------------


100%|██████████| 943/943 [00:02<00:00, 410.09it/s]
100%|██████████| 224/224 [00:00<00:00, 416.91it/s]
---Data Size Train/Val: 943 / 224


---Data Size Train/Val: 943 / 224
Loading Model....


  0%|          | 0/30 [00:00<?, ?it/s]--epoch 0, step 0, loss 9.354349136352539
  {'UAS': 0.038834951456310676, 'LAS': 0.0}
  3%|▎         | 1/30 [00:05<02:28,  5.14s/it]--Evaluation:
Avg Loss: 5.093712670462472  UAS: 0.2903225806451613  LAS: 0.2446743761412051 

--Best Evaluation: 
-loss: 5.093712670462472  UAS: 0.2903225806451613  LAS: 0.2446743761412051 

  7%|▋         | 2/30 [00:10<02:30,  5.37s/it]--Evaluation:
Avg Loss: 4.129108224596296  UAS: 0.46439440048691416  LAS: 0.43578819233110166 

--Best Evaluation: 
-loss: 4.129108224596296  UAS: 0.46439440048691416  LAS: 0.43578819233110166 

 10%|█         | 3/30 [00:16<02:28,  5.50s/it]--Evaluation:
Avg Loss: 3.3646328108651296  UAS: 0.5757760194765672  LAS: 0.5502130249543518 

--Best Evaluation: 
-loss: 3.3646328108651296  UAS: 0.5757760194765672  LAS: 0.5502130249543518 

 13%|█▎        | 4/30 [00:21<02:23,  5.53s/it]--Evaluation:
Avg Loss: 2.878542559487479  UAS: 0.6506390748630554  LAS: 0.6244674376141205 

--Best Evaluation: 

[2.0707460130964006, 0.8575776019476568, 0.8356664637857577]



-------------------------------



SEED 42
--------------------------------


100%|██████████| 974/974 [00:02<00:00, 340.56it/s]
100%|██████████| 193/193 [00:00<00:00, 319.69it/s]
---Data Size Train/Val: 974 / 193


---Data Size Train/Val: 974 / 193
Loading Model....


  0%|          | 0/30 [00:00<?, ?it/s]--epoch 0, step 0, loss 9.155004501342773
  {'UAS': 0.04090909090909091, 'LAS': 0.0}
  3%|▎         | 1/30 [00:05<02:36,  5.41s/it]--Evaluation:
Avg Loss: 5.107088515177909  UAS: 0.2763480392156863  LAS: 0.22426470588235295 

--Best Evaluation: 
-loss: 5.107088515177909  UAS: 0.2763480392156863  LAS: 0.22426470588235295 

  7%|▋         | 2/30 [00:11<02:43,  5.85s/it]--Evaluation:
Avg Loss: 3.968304460530454  UAS: 0.4577205882352941  LAS: 0.41482843137254904 

--Best Evaluation: 
-loss: 3.968304460530454  UAS: 0.4577205882352941  LAS: 0.41482843137254904 

 10%|█         | 3/30 [00:17<02:42,  6.01s/it]--Evaluation:
Avg Loss: 3.3450794565090862  UAS: 0.5465686274509803  LAS: 0.5165441176470589 

--Best Evaluation: 
-loss: 3.3450794565090862  UAS: 0.5465686274509803  LAS: 0.5165441176470589 

 13%|█▎        | 4/30 [00:23<02:37,  6.04s/it]--Evaluation:
Avg Loss: 2.846413416952049  UAS: 0.6354166666666666  LAS: 0.6041666666666666 

--Best Evaluation: 


[2.2865072155637916, 0.8247549019607843, 0.7984068627450981]



-------------------------------



SEED 43
--------------------------------


100%|██████████| 888/888 [00:02<00:00, 401.07it/s]
100%|██████████| 279/279 [00:00<00:00, 392.85it/s]
---Data Size Train/Val: 888 / 279


---Data Size Train/Val: 888 / 279
Loading Model....


  0%|          | 0/30 [00:00<?, ?it/s]--epoch 0, step 0, loss 9.676490783691406
  {'UAS': 0.031818181818181815, 'LAS': 0.0}
  3%|▎         | 1/30 [00:04<02:23,  4.94s/it]--Evaluation:
Avg Loss: 5.1593841258770246  UAS: 0.25709584533113944  LAS: 0.19786096256684493 

--Best Evaluation: 
-loss: 5.1593841258770246  UAS: 0.25709584533113944  LAS: 0.19786096256684493 

  7%|▋         | 2/30 [00:11<02:40,  5.74s/it]--Evaluation:
Avg Loss: 4.270335546103857  UAS: 0.41341011929247223  LAS: 0.37433155080213903 

--Best Evaluation: 
-loss: 4.270335546103857  UAS: 0.41341011929247223  LAS: 0.37433155080213903 

 10%|█         | 3/30 [00:16<02:30,  5.59s/it]--Evaluation:
Avg Loss: 3.673500110598875  UAS: 0.5018510900863842  LAS: 0.47593582887700536 

--Best Evaluation: 
-loss: 3.673500110598875  UAS: 0.5018510900863842  LAS: 0.47593582887700536 

 13%|█▎        | 4/30 [00:22<02:24,  5.55s/it]--Evaluation:
Avg Loss: 3.1447854238599007  UAS: 0.5713698066639243  LAS: 0.5508021390374331 

--Best Evalu

[2.5934444277089983, 0.8013163307280954, 0.7725215960510078]



-------------------------------



SEED 44
--------------------------------


100%|██████████| 980/980 [00:02<00:00, 402.82it/s]
100%|██████████| 187/187 [00:00<00:00, 416.97it/s]
---Data Size Train/Val: 980 / 187


---Data Size Train/Val: 980 / 187
Loading Model....


  0%|          | 0/30 [00:00<?, ?it/s]--epoch 0, step 0, loss 9.720638275146484
  {'UAS': 0.045454545454545456, 'LAS': 0.0}
  3%|▎         | 1/30 [00:05<02:36,  5.39s/it]--Evaluation:
Avg Loss: 4.984934832323043  UAS: 0.29212752114508783  LAS: 0.24593363695510737 

--Best Evaluation: 
-loss: 4.984934832323043  UAS: 0.29212752114508783  LAS: 0.24593363695510737 

  7%|▋         | 2/30 [00:11<02:44,  5.89s/it]--Evaluation:
Avg Loss: 4.1576597473838115  UAS: 0.450227716330514  LAS: 0.4085881587508133 

--Best Evaluation: 
-loss: 4.1576597473838115  UAS: 0.450227716330514  LAS: 0.4085881587508133 

 10%|█         | 3/30 [00:17<02:39,  5.89s/it]--Evaluation:
Avg Loss: 3.502768646587025  UAS: 0.5413142485361093  LAS: 0.5068314899154196 

--Best Evaluation: 
-loss: 3.502768646587025  UAS: 0.5413142485361093  LAS: 0.5068314899154196 

 13%|█▎        | 4/30 [00:23<02:33,  5.89s/it]--Evaluation:
Avg Loss: 2.85401858365472  UAS: 0.6187378009108653  LAS: 0.5797007156798959 

--Best Evaluation: 
-l

[2.248283480577928, 0.8347430058555628, 0.8054651919323357]






# Inference

In [17]:
test_dataset = DialogDataset(CFG, CFG.test_file, list(range(800)))
test_iter = DataLoader(test_dataset, batch_size=CFG.batch_size * 8)

100%|██████████| 20086/20086 [00:58<00:00, 343.85it/s]


In [25]:
def evaluate(model, eval_iter):
    arc_logits, rel_logits = torch.Tensor(), torch.Tensor()
    heads_whole, rels_whole, masks_whole = torch.Tensor(), torch.Tensor(), torch.Tensor()
    for batch in eval_iter:
        inputs, offsets, heads, rels, masks = batch

        inputs_cuda = {}
        for key, value in inputs.items():
            inputs_cuda[key] = value.cuda()
        inputs = inputs_cuda

        offsets, heads, rels, masks = to_cuda(data=(offsets, heads, rels, masks))

        with torch.no_grad():
            model.eval()
            arc_logit, rel_logit = model(inputs, offsets, heads, rels, masks, evaluate=True)

        arc_logit[:, torch.arange(arc_logit.size()[1]), torch.arange(arc_logit.size()[2])] = -1e4

        arc_logits = torch.cat([arc_logits, arc_logit.cpu()])
        rel_logits = torch.cat([rel_logits, rel_logit.cpu()])

        heads_whole = torch.cat([heads_whole, heads.cpu()])
        rels_whole = torch.cat([rels_whole, rels.cpu()])
        masks_whole = torch.cat([masks_whole, masks.cpu()])

    rel_preds = rel_logits.argmax(-1)
    head_pred = arc_logits.argmax(-1)

    arc_logits_correct = (head_pred == heads_whole).long() * masks_whole * (rels_whole < 21).long()
    rel_logits_correct = (rel_preds == rels_whole).long() * arc_logits_correct 
    uas_syntax = (arc_logits_correct.sum() / (masks_whole * (rels_whole < 21).long()).sum()).item()
    las_syntax = (rel_logits_correct.sum() / (masks_whole * (rels_whole < 21).long()).sum()).item()
    logger.info(f'Syntax UAS: {uas_syntax}; Syntax LAS: {las_syntax}')

    arc_logits_correct = (head_pred == heads_whole).long() * masks_whole * (rels_whole >= 21).long()
    rel_logits_correct = (rel_preds == rels_whole).long() * arc_logits_correct
    uas_discourse = (arc_logits_correct.sum() / (rels_whole >= 21).long().sum()).item()
    las_discourse = (rel_logits_correct.sum() / (rels_whole >= 21).long().sum()).item()
    logger.info(f'Discourse UAS: {uas_discourse}; Discourse LAS: {las_discourse}')
    logger.info('---------------------------------------------------')
    
    return uas_syntax, las_syntax, uas_discourse, las_discourse

In [19]:
if CFG.mode == 'training':
    model = DepParser(CFG)
    uas_syntaxs, las_syntaxs = [], []
    uas_discourses, las_discourses = [], []
    for seed in CFG.random_seeds:
        model.load_state_dict(torch.load(f'../results/few_shot/{CFG.shot}-shot/model_{seed}.bin'))
        model = model.cuda()

        uas_syntax, las_syntax, uas_discourse, las_discourse = evaluate(model, test_iter)

        uas_syntaxs.append(uas_syntax)
        las_syntaxs.append(las_syntax)
        uas_discourses.append(uas_discourse)
        las_discourses.append(las_discourse)

    avg_uas_syntax, avg_las_syntax = np.mean(uas_syntaxs), np.mean(las_syntaxs)
    avg_uas_discourse, avg_las_discourse = np.mean(uas_discourses), np.mean(las_discourses)
    std_uas_syntax, std_las_syntax = np.std(uas_syntaxs), np.std(las_syntaxs)
    std_uas_discourse, std_las_discourse = np.std(uas_discourses), np.std(las_discourses)
    logger.info('\n----------------Result----------------')
    logger.info(f'Avg Syntax UAS: {avg_uas_syntax:.4f}; Avg Syntax LAS: {avg_las_syntax:.4f}')
    logger.info(f'Std Syntax UAS: {std_uas_syntax:.4f}; Std Syntax LAS: {std_las_syntax:.4f}')
    logger.info(f'Avg Discourse UAS: {avg_uas_discourse:.4f}; Avg Discourse LAS: {avg_las_discourse:.4f}')
    logger.info(f'Std Discourse UAS: {std_uas_discourse:.4f}; Std Discourse LAS: {std_las_discourse:.4f}\n')

    logger.info('=================End=================')
    logger.info(datetime.datetime.now().isoformat())
    logger.info('=====================================')

Syntax UAS: 0.8391957879066467; Syntax LAS: 0.8125379085540771
Discourse UAS: 0.5790943503379822; Discourse LAS: 0.42469578981399536
---------------------------------------------------
Syntax UAS: 0.8376250863075256; Syntax LAS: 0.8095029592514038
Discourse UAS: 0.5614402294158936; Discourse LAS: 0.40973469614982605
---------------------------------------------------
Syntax UAS: 0.8378002643585205; Syntax LAS: 0.8097970485687256
Discourse UAS: 0.5839816331863403; Discourse LAS: 0.42250150442123413
---------------------------------------------------
Syntax UAS: 0.8274938464164734; Syntax LAS: 0.7999787330627441
Discourse UAS: 0.5542589426040649; Discourse LAS: 0.40564531087875366
---------------------------------------------------
Syntax UAS: 0.8366989493370056; Syntax LAS: 0.8082138895988464
Discourse UAS: 0.5652304291725159; Discourse LAS: 0.4136245846748352
---------------------------------------------------

----------------Result----------------
Avg Syntax UAS: 0.8358; Avg Syntax L

In [26]:
if CFG.mode == 'inference':
    model = DepParser(CFG)

    for shot in [5, 10, 20, 40]:
        logger.info(f'----------------Shot: {shot}----------------')
        uas_syntaxs, las_syntaxs = [], []
        uas_discourses, las_discourses = [], []
        for seed in tqdm(CFG.random_seeds):
            model.load_state_dict(torch.load(f'../results/few_shot/{shot}-shot/model_{seed}.bin'))
            model = model.cuda()
                
            uas_syntax, las_syntax, uas_discourse, las_discourse = evaluate(model, test_iter)

            uas_syntaxs.append(uas_syntax)
            las_syntaxs.append(las_syntax)
            uas_discourses.append(uas_discourse)
            las_discourses.append(las_discourse)

        avg_uas_syntax, avg_las_syntax = np.mean(uas_syntaxs), np.mean(las_syntaxs)
        avg_uas_discourse, avg_las_discourse = np.mean(uas_discourses), np.mean(las_discourses)
        std_uas_syntax, std_las_syntax = np.std(uas_syntaxs), np.std(las_syntaxs)
        std_uas_discourse, std_las_discourse = np.std(uas_discourses), np.std(las_discourses)
        logger.info('\n----------------Result----------------')
        logger.info(f'Avg Syntax UAS: {avg_uas_syntax:.4f}; Avg Syntax LAS: {avg_las_syntax:.4f}')
        logger.info(f'Std Syntax UAS: {std_uas_syntax:.4f}; Std Syntax LAS: {std_las_syntax:.4f}')
        logger.info(f'Avg Discourse UAS: {avg_uas_discourse:.4f}; Avg Discourse LAS: {avg_las_discourse:.4f}')
        logger.info(f'Std Discourse UAS: {std_uas_discourse:.4f}; Std Discourse LAS: {std_las_discourse:.4f}\n')

    logger.info('=================End=================')
    logger.info(datetime.datetime.now().isoformat())
    logger.info('=====================================')

----------------Shot: 5----------------
  0%|          | 0/5 [00:00<?, ?it/s]
  0%|          | 0/79 [00:00<?, ?it/s][A
  1%|▏         | 1/79 [00:00<00:33,  2.35it/s][A
  3%|▎         | 2/79 [00:00<00:35,  2.20it/s][A
  4%|▍         | 3/79 [00:01<00:36,  2.08it/s][A
  5%|▌         | 4/79 [00:01<00:34,  2.19it/s][A
  6%|▋         | 5/79 [00:02<00:36,  2.00it/s][A
  8%|▊         | 6/79 [00:02<00:36,  2.00it/s][A
  9%|▉         | 7/79 [00:03<00:35,  2.01it/s][A
 10%|█         | 8/79 [00:03<00:35,  1.97it/s][A
 11%|█▏        | 9/79 [00:04<00:37,  1.85it/s][A
 13%|█▎        | 10/79 [00:05<00:36,  1.90it/s][A
 14%|█▍        | 11/79 [00:05<00:37,  1.82it/s][A
 15%|█▌        | 12/79 [00:06<00:36,  1.86it/s][A
 16%|█▋        | 13/79 [00:06<00:36,  1.80it/s][A
 18%|█▊        | 14/79 [00:07<00:35,  1.85it/s][A
 19%|█▉        | 15/79 [00:07<00:36,  1.74it/s][A
 20%|██        | 16/79 [00:08<00:36,  1.71it/s][A
 22%|██▏       | 17/79 [00:09<00:36,  1.70it/s][A
 23%|██▎       | 18/79

In [21]:
# import os
# os.system("shutdown")