In [1]:
from collections import Counter
from typing import *
import random
import json
import logging
import datetime
from itertools import chain

In [2]:
from tqdm import tqdm
import numpy as np
from sklearn.model_selection import KFold

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torch.cuda.amp import autocast, GradScaler

import transformers
from transformers import AutoConfig, AutoModel, AutoTokenizer

In [3]:
transformers.logging.set_verbosity_error() # only report errors.

In [4]:
import sys
sys.path.append('..')

from trainer import BasicTrainer
from model.base_par import DepParser
from utils import arc_rel_loss, uas_las, to_cuda

2022-12-14 12:54:01.095502: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0


## Config

In [5]:
class CFG:
    data_file = '../data/train_50.json'
    codt_train_file = '../codt/train/BC-Train-full.conll'
    codt_dev_file = '../codt/dev/BC-Dev.conll'
    # codt_train_file = '../aug/codt/codt_train_fixed.conll'
    # codt_dev_file = '../aug/codt/codt_dev.conll'
    test_file = '../data/test.json'
    plm = 'hfl/chinese-electra-180g-base-discriminator'
    shots = [50]
    num_epochs = 15
    batch_size = 32
    plm_lr = 2e-5
    head_lr = 1e-4
    weight_decay = 0.01
    dropout = 0.2
    grad_clip = 2
    scheduler = 'linear'
    warmup_ratio = 0.1
    num_early_stop = 3
    max_length = 160
    num_labels = 35
    hidden_size = 400
    print_every_ratio = 0.5
    cuda = True
    fp16 = True
    eval_strategy = 'epoch'
    mode = 'inference'

In [6]:
# the seed closest to the average result
shot2seed = {5:40, 10:44, 20:41, 40:44, 50:42}

In [7]:
logger = logging.getLogger("logger")
logger.setLevel(logging.INFO)
print(logger)
print(type(logger))

if CFG.mode == 'training':
    fh = logging.FileHandler(filename=f"../results/few_shot/with_codt_sampled/training.log",mode='w')
else:
    fh = logging.FileHandler(filename=f"../results/few_shot/with_codt_sampled/inference_inter.log",mode='w')
logger.addHandler(fh)

time_now = datetime.datetime.now().isoformat()
print(time_now)
logger.info(f'=-=-=-=-=-=-=-=-={time_now}=-=-=-=-=-=-=-=-=-=')

<Logger logger (INFO)>
<class 'logging.Logger'>
2022-12-14T12:54:02.840797


## Seed and Device

In [8]:
def seed_everything(seed):
    np.random.seed(seed%(2**32-1))
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic =True
    torch.backends.cudnn.benchmark = False

In [9]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(f'Using device: {device}')

Using device: cuda


## Data

In [10]:
rel_dct = {
    'root': '根节点',
    'sasubj-obj': '同主同宾',
    'sasubj': '同主语',
    'dfsubj': '不同主语',
    'subj': '主语',
    'subj-in': '内部主语',
    'obj': '宾语',
    'pred': '谓语',
    'att': '定语',
    'adv': '状语',
    'cmp': '补语',
    'coo': '并列',
    'pobj': '介宾',
    'iobj': '间宾',
    'de': '的',
    'adjct': '附加',
    'app': '称呼',
    'exp': '解释',
    'punc': '标点',
    'frag': '片段',
    'repet': '重复',
    # rst
    'attr': '归属',
    'bckg': '背景',
    'cause': '因果',
    'comp': '比较',
    'cond': '状况',
    'cont': '对比',
    'elbr': '阐述',
    'enbm': '目的',
    'eval': '评价',
    'expl': '解释-例证',
    'joint': '联合',
    'manner': '方式',
    'rstm': '重申',
    'temp': '时序',
    # 'tp-chg': '主题变更',
    # 'prob-sol': '问题-解决',
    # 'qst-ans': '疑问-回答',
    # 'stm-rsp': '陈述-回应',
    # 'req-proc': '需求-处理',
}

In [11]:
rel2id = {key:idx for idx, key in enumerate(rel_dct.keys())}
print(rel2id)

{'root': 0, 'sasubj-obj': 1, 'sasubj': 2, 'dfsubj': 3, 'subj': 4, 'subj-in': 5, 'obj': 6, 'pred': 7, 'att': 8, 'adv': 9, 'cmp': 10, 'coo': 11, 'pobj': 12, 'iobj': 13, 'de': 14, 'adjct': 15, 'app': 16, 'exp': 17, 'punc': 18, 'frag': 19, 'repet': 20, 'attr': 21, 'bckg': 22, 'cause': 23, 'comp': 24, 'cond': 25, 'cont': 26, 'elbr': 27, 'enbm': 28, 'eval': 29, 'expl': 30, 'joint': 31, 'manner': 32, 'rstm': 33, 'temp': 34}


In [12]:
tokenizer = AutoTokenizer.from_pretrained(CFG.plm)
print(len(tokenizer))

CFG.tokenizer = tokenizer

21128


In [13]:
class Dependency():
    def __init__(self, idx, word, head, rel):
        self.id = idx
        self.word = word
        self.head = int(head)
        self.rel = rel

    def __str__(self):
        # example:  1	上海	_	NR	NR	_	2	nn	_	_
        values = [str(self.idx), self.word, "_", "_", "_", "_", str(self.head), self.rel, "_", "_"]
        return '\t'.join(values)

    def __repr__(self):
        return f"({self.word}, {self.head}, {self.rel})"

In [14]:
def load_annoted(data_file, data_ids):
    with open(data_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
        
    sample_lst:List[List[Dependency]] = []
    
    for i, d in enumerate(data):
        if i not in data_ids:
            continue
        rel_dct = {}
        for tripple in d['relationship']:
            head, rel, tail = tripple
            head_uttr_idx, head_word_idx = [int(x) for x in head.split('-')]
            tail_uttr_idx, tail_word_idx = [int(x) for x in tail.split('-')]
            if head_uttr_idx != tail_uttr_idx:
                continue
            
            if not rel_dct.get(head_uttr_idx, None):
                rel_dct[head_uttr_idx] = {tail_word_idx: [head_word_idx, rel]}
            else:
                rel_dct[head_uttr_idx][tail_word_idx] = [head_word_idx, rel]
            
        for item in d['dialog']:
            turn = item['turn']
            utterance = item['utterance']
            # dep_lst:List[Dependency] = [Dependency(0, '[root]', -1, '_')]
            dep_lst:List[Dependency] = []
            
            for word_idx, word in enumerate(utterance.split(' ')):
                head_word_idx, rel = rel_dct[turn].get(word_idx + 1, [word_idx, 'adjct'])  # some word annoted missed, padded with last word and 'adjct'
                dep_lst.append(Dependency(word_idx + 1, word, head_word_idx, rel))  # start from 1
            
            sample_lst.append(dep_lst)
        
    return sample_lst

In [15]:
def load_codt(data_file: str):
    # id, form, tag, head, rel
#     sentence:List[Dependency] = [Dependency('0', '<root>', '_', '0', '_')]
    sentence:List[Dependency] = []

    with open(data_file, 'r', encoding='utf-8') as f:
        # data example: 1	上海	_	NR	NR	_	2	nn	_	_
        for line in f.readlines():
            toks = line.split()
            if len(toks) == 0 and len(sentence) != 0:
                yield sentence
#                 sentence = [Dependency('0', '<root>', '_', '0', '_')]
                sentence = []
            elif len(toks) == 10:
                if toks[8] != '_':
                    dep = Dependency(toks[0], toks[1], toks[8], toks[9])
                else:
                    dep = Dependency(toks[0], toks[1], toks[6], toks[7])
                sentence.append(dep)

In [16]:
class DialogDataset(Dataset):
    def __init__(self, cfg, data_file, data_ids, train, with_codt_sampled=True):
        self.cfg = cfg
        self.data_file = data_file
        self.train = train
        self.with_codt_sampled = with_codt_sampled
        
        self.inputs, self.offsets, self.heads, self.rels, self.masks = self.read_data(data_ids)
        
    def read_data(self, data_ids):
        inputs, offsets = [], []
        tags, heads, rels, masks = [], [], [], []
        
        codt_file = self.cfg.codt_train_file if self.train else self.cfg.codt_dev_file
        # diag_train = '../aug/diag_weakcodt_sampled_new/diag_train_merged.conll'
        # diag_dev = '../aug/diag_weakcodt_sampled_new/diag_dev_sampled_new.conll'
        diag_train = '../aug/diag_codt_sampled_new/diag_train_sampled.conll'
        diag_dev = '../aug/diag_codt_sampled_new/diag_dev_sampled.conll'
        diag_file = diag_train if self.train else diag_dev
        
        c_lst = []
        for i in range(2):   
            c_lst.append(load_annoted(self.data_file, data_ids))
        c = chain(*c_lst)
        if self.with_codt_sampled:
            # data_iter = chain(c, load_codt(codt_file), c, load_codt(diag_file), c)
            data_iter = chain(c, load_codt(codt_file))
        else:
            data_iter = load_annoted(self.data_file, data_ids)
        for deps in tqdm(data_iter):
            # another sentence
            seq_len = len(deps)

            word_lst = [] 
            head_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)
            rel_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)
            mask_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)
            for i, dep in enumerate(deps):
                if i == seq_len or i + 1== self.cfg.max_length:
                    break

                word_lst.append(dep.word)

                if dep.head != -1 and dep.rel in rel2id.keys() and dep.head + 1 < self.cfg.max_length:
                    head_tokens[i+1] = dep.head
                    mask_tokens[i+1] = 1
                    rel_tokens[i+1] = rel2id[dep.rel]

            tokenized = tokenizer.encode_plus(word_lst, 
                                              padding='max_length', 
                                              truncation=True,
                                              max_length=self.cfg.max_length, 
                                              return_offsets_mapping=True, 
                                              return_tensors='pt',
                                              is_split_into_words=True)
            inputs.append({"input_ids": tokenized['input_ids'][0],
                           "token_type_ids": tokenized['token_type_ids'][0],
                           "attention_mask": tokenized['attention_mask'][0]
                          })

            sentence_word_idx = []
            for idx, (start, end) in enumerate(tokenized.offset_mapping[0][1:]):
                if start == 0 and end != 0:
                    sentence_word_idx.append(idx)

            if len(sentence_word_idx) < self.cfg.max_length - 1:
                sentence_word_idx.extend([0]* (self.cfg.max_length - 1 - len(sentence_word_idx)))
            offsets.append(torch.as_tensor(sentence_word_idx))

            heads.append(head_tokens)
            rels.append(rel_tokens)
            masks.append(mask_tokens)
                    
        return inputs, offsets, heads, rels, masks

    def __getitem__(self, idx):
        return self.inputs[idx], self.offsets[idx], self.heads[idx], self.rels[idx], self.masks[idx]
    
    def __len__(self):
        return len(self.rels)

# Inference

In [17]:
def load_codt_signal(data_file: str, return_two=False):
    sentence:List[Dependency] = []

    with open(data_file, 'r', encoding='utf-8') as f:
        # data example: 1	上海	_	NR	NR	_	2	nn	_	_
        for line in f.readlines():
            toks = line.split()
            if len(toks) == 0 and len(sentence) != 0:
                yield sentence
                sentence = []
            elif len(toks) == 10:                
                if return_two:
                    sentence.append([int(toks[2]), int(toks[3])])
                else:
                    sentence.append(int(toks[2]))

In [18]:
def postproc(arc_logits, rel_logits, masks_whole):
    from constant import rel2id, punct_lst, weak_signals, weak_labels
    
    rel_preds = rel_logits.argmax(-1)
    head_preds = arc_logits.argmax(-1)

    origin4change = [rel2id[x] for x in ['root', 'dfsubj', 'sasubj']]
    # origin4change.extend([i for i in range(21, 35)])

    max_len = CFG.max_length

    signals_new_whole = torch.Tensor()
    heads_new_whole, rels_new_whole = torch.Tensor(), torch.Tensor()
    for sample_idx, (deps, pred_signals) in tqdm(enumerate(zip(load_annoted(CFG.test_file, data_ids=list(range(800))), load_codt_signal('../mlm_based/diag_test.conll')))):
        seq_len = len(deps)
        if seq_len == 0:
            continue

        signals = torch.full(size=(max_len,), fill_value=rel2id['elbr']).int()
        heads, rels = torch.full(size=(max_len,), fill_value=-2).int(), torch.zeros(max_len).int()
        split, splits, signal, word_lst  = 1, [1], rel2id['elbr'], ['root']
        for i, dep in enumerate(deps[:-1]):
            if i + 2 >= max_len:
                break

            word = dep.word
            word_lst.append(word)

            try:
                signal = pred_signals[i]
            except IndexError:
                signal = pred_signals[len(pred_signals) - 1]

            if word in punct_lst and deps[i+1].word not in punct_lst:
                if i + 2 - split > 2:  # set 2 to the min length of edu
                    signals[split:i+2] = signal
                    # signal = None
                split = i + 2
                splits.append(split)

        splits.append(len(deps))

        # add the last data
        if i + 1 < max_len:
            signal = pred_signals[-1]
            word_lst.append(word)

        heads = head_preds[sample_idx]
        heads.masked_fill_(mask=~masks_whole[sample_idx].bool(), value=-2)

        rels = rel_preds[sample_idx]
        rels.masked_fill_(mask=~masks_whole[sample_idx].bool(), value=-2)

        cnt, attr, = -1, False
        for idx, head in enumerate(heads[1:]):
            if head == -2:
                break
            if head == -1:
                continue

            if len(splits) > 2 and idx + 1 >= splits[cnt+1]:
                cnt += 1

            if ((len(splits) > 2 and (head < splits[cnt] or head >= splits[cnt+1])) or idx - head > 0) and rels[idx + 1] in origin4change:  # cross 'edu'

                rels[idx+1] = signals[idx+1]

                if rels[idx + 1] in [rel2id['cond']]:  # reverse
                    tmp_heads = heads.clone()
                    tmp_heads[:splits[cnt+1]] = 0
                    head_idx = [idx + 1]
                    tail_idx = (tmp_heads == idx + 1).nonzero()  # find tail index
                    if len(tail_idx) == 0:  # ring or fail
                        # unchange
                        tail_idx = [idx + 1]
                        head_idx = (heads == idx + 1).nonzero() if head_idx == tail_idx else head_idx
                    elif len(head_idx) != 0:
                        heads[tail_idx[0]] = 0
                        heads[head_idx[0]] = tail_idx[0]

                # special cases
                if word_lst[idx+1] == '好' and word_lst[idx] in ['你', '您']:  # reverse
                    tmp_heads = heads.clone()
                    tmp_heads[:splits[cnt+1]] = 0
                    tail_idx = (tmp_heads == idx + 1).nonzero()  # find tail index
                    if len(tail_idx) != 0:  
                        heads[tail_idx[0]] = 0
                        heads[idx + 1] = tail_idx[0]
                        rels[idx + 1] = rel2id['elbr']

            if not attr and rels[idx + 1] in [rel2id['obj']] and signals[idx+1] == rel2id['attr']:
                rels[idx+1] = signals[idx+1]
                attr = True

        rels.masked_fill_(heads == 0, 0)  # root
        heads[0] = 0
        heads[1:].masked_fill_(heads[1:] == -2, 0)

        heads_new_whole = torch.cat([heads_new_whole, heads.unsqueeze(0)])
        rels_new_whole = torch.cat([rels_new_whole, rels.unsqueeze(0)])
        signals_new_whole = torch.cat([signals_new_whole, signals.unsqueeze(0)])

    return heads_new_whole, rels_new_whole

In [19]:
def infer_fn(model, eval_iter):
    arc_logits, rel_logits = torch.Tensor(), torch.Tensor()
    heads_whole, rels_whole, masks_whole = torch.Tensor(), torch.Tensor(), torch.Tensor()
    for batch in eval_iter:
        inputs, offsets, heads, rels, masks = batch

        inputs_cuda = {}
        for key, value in inputs.items():
            inputs_cuda[key] = value.cuda()
        inputs = inputs_cuda

        offsets, heads, rels, masks = to_cuda(data=(offsets, heads, rels, masks))

        with torch.no_grad():
            model.eval()
            arc_logit, rel_logit = model(inputs, offsets, heads, rels, masks, evaluate=True)

        arc_logit[:, torch.arange(arc_logit.size()[1]), torch.arange(arc_logit.size()[2])] = -1e4

        arc_logits = torch.cat([arc_logits, arc_logit.cpu()])
        rel_logits = torch.cat([rel_logits, rel_logit.cpu()])

        heads_whole = torch.cat([heads_whole, heads.cpu()])
        rels_whole = torch.cat([rels_whole, rels.cpu()])
        masks_whole = torch.cat([masks_whole, masks.cpu()])

    # rel_preds = rel_logits.argmax(-1)
    # head_pred = arc_logits.argmax(-1)
    head_preds, rel_preds = postproc(arc_logits, rel_logits, masks_whole)

    return arc_logits, rel_logits, head_preds, rel_preds

In [20]:
test_dataset = DialogDataset(CFG, CFG.test_file, list(range(800)), train=False, with_codt_sampled=False)
test_iter = DataLoader(test_dataset, batch_size=CFG.batch_size * 4)

100%|██████████| 20086/20086 [00:36<00:00, 551.07it/s]


# Inter

In [21]:
def load_inter(data_file):
    with open(data_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
        
    signal_iter = load_codt_signal('../mlm_based/diag_test.conll', return_two=True)

    sample_lst:List[List[Dependency]] = []
    # for d, pred_signals in tqdm(zip(data, load_codt_signal('../prompt_based/diag_test.conll', idx=3))):
    for d in data:
        rel_dct = {}
        for tripple in d['relationship']:
            head, rel, tail = tripple
            head_uttr_idx, head_word_idx = [int(x) for x in head.split('-')]
            tail_uttr_idx, tail_word_idx = [int(x) for x in tail.split('-')]
            
            if rel == 'root' and head_uttr_idx != 0: # ignore root
                continue
                 
            if not rel_dct.get(tail_uttr_idx, None):
                rel_dct[tail_uttr_idx] = {tail_word_idx: [head, rel]}
            else:
                rel_dct[tail_uttr_idx][tail_word_idx] = [head, rel]
                
        sent_lens_accum = [1]
        for i, item in enumerate(d['dialog']):
            utterance = item['utterance']
            sent_lens_accum.append(sent_lens_accum[i] + len(utterance.split(' ')) + 1)
        sent_lens_accum[0] = 0
        
        dep_lst:List[Dependency] = []
        role_lst:List[str] = []
        weak_signal = []
        for item in d['dialog']:
            turn = item['turn']
            utterance = item['utterance']

            pred_signals = next(signal_iter)

            role = '[ans]' if item['speaker'] == 'A' else '[qst]'
            dep_lst.append(Dependency(sent_lens_accum[turn], role, -1, '_'))  
            
            tmp_signal = []
            for word_idx, word in enumerate(utterance.split(' ')):
                tail2head = rel_dct.get(turn, {1: [f'{turn}-{word_idx}', 'adjct']})
                head, rel = tail2head.get(word_idx + 1, [f'{turn}-{word_idx}', 'adjct'])  # some word annoted missed, padded with last word and 'adjct'
                head_uttr_idx, head_word_idx = [int(x) for x in head.split('-')]
                
                # only parse cross-utterance
                if turn != head_uttr_idx:
                    dep_lst.append(Dependency(sent_lens_accum[turn] + word_idx + 1, word, sent_lens_accum[head_uttr_idx] + head_word_idx, rel))  # add with accumulated length
                else:
                    dep_lst.append(Dependency(sent_lens_accum[turn] + word_idx + 1, word, -1, '_')) 

                try:
                    signal1, signal2 = pred_signals[i]
                except IndexError:
                    signal1, signal2 = pred_signals[len(pred_signals) - 1]
                
                tmp_signal = [signal1, signal2]
                
                # if word in weak_signal_dct.keys():
                #     tmp_signal.append(weak_signal_dct[word])

            if len(tmp_signal) != 0:
                # weak_signal.append(tmp_signal[-1])  # choose the last
                weak_signal.append(tmp_signal)  # choose the last
            else:
                weak_signal.append(-1)
            role_lst.append(item['speaker'])        
        sample_lst.append([dep_lst, role_lst, weak_signal])
        
    return sample_lst

In [22]:
class InterDataset(Dataset):
    def __init__(self, cfg):
        self.cfg = cfg
        self.inputs, self.offsets, self.heads, self.rels, self.masks, self.speakers, self.signs = self.read_data()
        
    def read_data(self):
        inputs, offsets = [], []
        tags, heads, rels, masks, speakers, signs = [], [], [], [], [], []
                
        for idx, (deps, roles, sign) in enumerate(load_inter(self.cfg.data_file)):
            seq_len = len(deps)
            signs.append(sign)

            word_lst = [] 
            head_tokens = np.zeros(1024, dtype=np.int64)  # same as root index is 0, constrainting by mask 
            rel_tokens = np.zeros(1024, dtype=np.int64)
            mask_tokens = np.zeros(1024, dtype=np.int64)
            for i, dep in enumerate(deps):
                if i == seq_len or i + 1== 1024:
                    break

                word_lst.append(dep.word)
                
                if int(dep.head) == -1 or int(dep.head) + 1 >= 1024:
                    head_tokens[i+1] = 0
                    mask_tokens[i+1] = 0
                else:
                    head_tokens[i+1] = int(dep.head)
                    mask_tokens[i+1] = 1
#                     head_tokens[i] = dep.head if dep.head != '_' else 0
                rel_tokens[i+1] = rel2id.get(dep.rel, 0)

            tokenized = tokenizer.encode_plus(word_lst, 
                                              padding='max_length', 
                                              truncation=True,
                                              max_length=1024, 
                                              return_offsets_mapping=True, 
                                              return_tensors='pt',
                                              is_split_into_words=True)
            inputs.append({"input_ids": tokenized['input_ids'][0],
                          "token_type_ids": tokenized['token_type_ids'][0],
                           "attention_mask": tokenized['attention_mask'][0]
                          })

            sentence_word_idx = []
            for idx, (start, end) in enumerate(tokenized.offset_mapping[0][1:]):
                if start == 0 and end != 0:
                    sentence_word_idx.append(idx)

            if len(sentence_word_idx) < 1024 - 1:
                sentence_word_idx.extend([0]* (1024 - 1 - len(sentence_word_idx)))
            offsets.append(torch.as_tensor(sentence_word_idx))

            heads.append(head_tokens)
            rels.append(rel_tokens)
            masks.append(mask_tokens)
            speakers.append(roles)
                    
        return inputs, offsets, heads, rels, masks, speakers, signs

    def __getitem__(self, idx):
        return self.inputs[idx], self.offsets[idx], self.heads[idx], self.rels[idx], self.masks[idx], self.speakers[idx], self.signs[idx]
    
    def __len__(self):
        return len(self.rels)

In [23]:
def get_root(rel_preds, masks_whole):
    root_ids = []
    # for rel_pred, mask in zip(rel_preds, masks_whole):
    for rel_pred, mask in zip(rel_preds, masks_whole):
        try:
            root_idx = (((rel_pred == 0) * mask) != 0).nonzero()[0].item()
        except IndexError: # no root
            root_idx = 2
        root_ids.append(root_idx)

    return root_ids

In [27]:
def eval_inter(inter_dataloader, masks_whole, root_ids):
    cnt = 0

    inter_heads_whole, inter_rels_whole, inter_masks_whole = torch.Tensor(), torch.Tensor(), torch.Tensor()
    inter_heads_preds, inter_rels_preds = torch.Tensor(), torch.Tensor()
    for batch in inter_dataloader:
        inputs, offsets, heads, rels, masks, speakers, signs = batch
        inter_head_preds = torch.zeros_like(heads, dtype=int)
        inter_rel_preds = torch.zeros_like(rels, dtype=int)

        inter_heads_whole = torch.cat([inter_heads_whole, heads])
        inter_rels_whole = torch.cat([inter_rels_whole, rels])
        inter_masks_whole = torch.cat([inter_masks_whole, masks])

        accum = 1
        for i, speakr in enumerate(speakers[1:]):
            seq_len = masks_whole[cnt].sum().item() + 1

            if speakr == speakers[i]:
                rel = signs[i][0]
            else:
                rel = signs[i][1]
            
            head_idx = int(root_ids[cnt] + accum) if i > 0 else root_ids[cnt]
            tail_idx = int(root_ids[cnt+1] + accum + seq_len)
            
            inter_head_preds[0][tail_idx] = head_idx
            inter_rel_preds[0][tail_idx] = rel

            cnt += 1
            accum += seq_len

        cnt += 1

        inter_heads_preds = torch.cat([inter_heads_preds, inter_head_preds])
        inter_rels_preds = torch.cat([inter_rels_preds, inter_rel_preds])

    arc_logits_correct = (inter_heads_preds == inter_heads_whole).long() * inter_masks_whole
    rel_logits_correct = (inter_rels_preds == inter_rels_whole).long() * arc_logits_correct

    logger.info(rel_logits_correct.sum() / inter_masks_whole.long().sum())
    logger.info(arc_logits_correct.sum() / inter_masks_whole.long().sum())

    return arc_logits_correct, rel_logits_correct

In [25]:
inter_dataset = InterDataset(CFG)
inter_dataloader = DataLoader(inter_dataset, batch_size=1)

In [None]:
for shot in CFG.shots:
    logger.info(f'----------------Shot: {shot}----------------')
    seed = shot2seed[shot]
    seed_everything(seed)
    
    model = DepParser(CFG)
    model.load_state_dict(torch.load(f'../results/few_shot/with_codt_sampled/model_shot_{shot}.bin'))
    model = model.cuda()
        
    arc_logits, rel_logits = torch.Tensor(), torch.Tensor()
    heads_whole, rels_whole, masks_whole = torch.Tensor(), torch.Tensor(), torch.Tensor()
    for batch in test_iter:
        inputs, offsets, heads, rels, masks = batch

        inputs_cuda = {}
        for key, value in inputs.items():
            inputs_cuda[key] = value.cuda()
        inputs = inputs_cuda

        offsets, heads, rels, masks = to_cuda(data=(offsets, heads, rels, masks))

        with torch.no_grad():
            model.eval()
            arc_logit, rel_logit = model(inputs, offsets, heads, rels, masks, evaluate=True)

        arc_logit[:, torch.arange(arc_logit.size()[1]), torch.arange(arc_logit.size()[2])] = -1e4

        arc_logits = torch.cat([arc_logits, arc_logit.cpu()])
        rel_logits = torch.cat([rel_logits, rel_logit.cpu()])

        heads_whole = torch.cat([heads_whole, heads.cpu()])
        rels_whole = torch.cat([rels_whole, rels.cpu()])
        masks_whole = torch.cat([masks_whole, masks.cpu()])

    # rel_preds = rel_logits.argmax(-1)
    # head_pred = arc_logits.argmax(-1)
    head_preds, rel_preds = postproc(arc_logits, rel_logits, masks_whole)

    root_ids = get_root(rel_preds, masks_whole)

    arc_correct_inter, rel_correct_inter = eval_inter(inter_dataloader, masks_whole, root_ids)

In [31]:
arc_correct_inter, rel_correct_inter = eval_inter(inter_dataloader, masks_whole, root_ids)

tensor([[ 16],
        [ 22],
        [ 33],
        [ 38],
        [ 47],
        [ 51],
        [ 53],
        [ 59],
        [ 64],
        [ 69],
        [ 73],
        [100],
        [105],
        [117],
        [134],
        [146],
        [161]])
tensor([[ 20],
        [ 28],
        [ 37],
        [ 55],
        [ 64],
        [ 75],
        [ 89],
        [ 97],
        [103],
        [116],
        [133],
        [139],
        [147],
        [155],
        [172],
        [177],
        [186]])
tensor([[ 14],
        [ 25],
        [ 38],
        [ 42],
        [ 48],
        [ 54],
        [ 63],
        [ 69],
        [ 74],
        [ 88],
        [104],
        [108],
        [116]])
tensor([[  7],
        [ 13],
        [ 20],
        [ 29],
        [ 33],
        [ 82],
        [ 89],
        [ 97],
        [102],
        [116],
        [121],
        [127]])
tensor([[  9],
        [ 15],
        [ 20],
        [ 23],
        [ 30],
        [ 38],
        [ 65],
      