In [1]:
import random
from typing import *
import json
from itertools import chain

from tqdm import tqdm
import numpy as np
import pandas as pd

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset, IterableDataset
from transformers import AutoTokenizer, AdamW, AutoModel, get_linear_schedule_with_warmup

In [2]:
import sys
sys.path.append('..')

from constant import rel2id, punct_lst, weak_signals, weak_labels
from utils import to_cuda

# Config

In [3]:
class CFG:
    data_file = '../aug/diag_codt/diag_train.conll'
    # data_file = '../data_testset/1to800_1108.json'
    plm = 'hfl/chinese-electra-180g-base-discriminator'
    random_seed = 42
    num_epochs = 1
    batch_size = 32
    plm_lr = 2e-5
    head_lr = 1e-4
    weight_decay = 0.1
    dropout = 0.1
    grad_clip = 1
    scheduler = 'linear'
    warmup_ratio = 0.1
    num_early_stop = 3
    max_length = 128
    hidden_size = 400
    num_labels = 35
    print_every = 100
    save_every = 200
    eval_every = 1e9
    cuda = True
    fp16 = True
    debug = False

# Seed

In [4]:
def seed_everything(seed=CFG.random_seed):
    np.random.seed(seed%(2**32-1))
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic =True
    torch.backends.cudnn.benchmark = False

seed_everything()

# Data

In [5]:
inter_signals = {
                # '但': 36, '但是': 36, 
                '情况':36, '为什么': 36,
                '吗': 37, '?': 37, '什么': 37,
                # '呢': 37, '吧': 37,
                '请': 39, '麻烦':39, '希望': 39, '让': 39, '咨询': 39}     

In [6]:
origin4change = [rel2id[x] for x in ['root', 'dfsubj', 'sasubj']]
# origin4change.extend([i for i in range(21, 35)])

signal_dct = {}
for i, signals in enumerate(weak_signals):
    for s in signals:
        signal_dct[s] = weak_labels[i]
print(signal_dct, len(signal_dct.keys()))
signal_dct.update(inter_signals)
print(signal_dct, len(signal_dct.keys()))

{'说': 21, '表示': 21, '看到': 21, '显示': 21, '知道': 21, '认为': 21, '希望': 21, '指出': 21, '如果': 25, '假如': 25, '的话': 25, '若': 25, '如': 25, '要是': 25, '倘若': 25, '因为': 23, '所以': 23, '导致': 23, '因此': 23, '造成': 23, '由于': 23, '因而': 23, '但是': 26, '可是': 26, '但': 26, '竟': 26, '却': 26, '不过': 26, '居然': 26, '而是': 26, '反而': 26, '以及': 31, '也': 31, '并': 31, '并且': 31, '又': 31, '或者': 31, '对于': 22, '自从': 22, '上次': 22, '明天': 34, '晚上': 34, '到时候': 34, '再': 34, '然后': 34, '接下来': 34, '最后': 34, '随后': 34, '为了': 28, '使': 28, '为 的': 28, '为 了': 28, '通过': 32, '必须': 32, '点击': 32, '对 吗': 33, '是 吗': 33, '对 吧': 33, '是 吧': 33, '对 ?': 33, '别 的': 24, '另外': 24, '解释': 30, '比如': 30, '例如': 30, '是 这样': 30, '理想': 29, '真 棒': 29, '太 棒': 29, '真差': 29, '太 差': 29, '不 行': 29, '扯皮': 29, '这么 麻烦': 29} 74
{'说': 21, '表示': 21, '看到': 21, '显示': 21, '知道': 21, '认为': 21, '希望': 39, '指出': 21, '如果': 25, '假如': 25, '的话': 25, '若': 25, '如': 25, '要是': 25, '倘若': 25, '因为': 23, '所以': 23, '导致': 23, '因此': 23, '造成': 23, '由于': 23, '因而': 23, '但是': 26, '可是': 26, '但': 26, '

In [7]:
class Dependency():
    def __init__(self, idx, word, head, rel):
        self.id = idx
        self.word = word
        self.head = head
        self.rel = rel

    def __str__(self):
        # example:  1	上海	_	NR	NR	_	2	nn	_	_
        values = [str(self.idx), self.word, "_", "_", "_", "_", str(self.head), self.rel, "_", "_"]
        return '\t'.join(values)

    def __repr__(self):
        return f"({self.word}, {self.head}, {self.rel})"

In [8]:
sample_lst, sample = [], []
with open(CFG.data_file, 'r', encoding='utf-8') as f:
    for line in f.readlines():
        toks = line.strip().split('\t')
        if len(toks) == 1:
            sample_lst.append(sample)
            sample = []
            continue
        
        sample.append(Dependency(int(toks[0]), toks[1], toks[6], toks[7]))

In [9]:
# sample_lst = []
# with open(CFG.data_file, 'r', encoding='utf-8') as f:
#     data = json.load(f)
# for d in data:
#     rel_dct = {}
#     for tripple in d['relationship']:
#         head, rel, tail = tripple
#         head_uttr_idx, head_word_idx = [int(x) for x in head.split('-')]
#         tail_uttr_idx, tail_word_idx = [int(x) for x in tail.split('-')]
#         if head_uttr_idx != tail_uttr_idx:
#             continue

#         if not rel_dct.get(head_uttr_idx, None):
#             rel_dct[head_uttr_idx] = {tail_word_idx: [head_word_idx, rel]}
#         else:
#             rel_dct[head_uttr_idx][tail_word_idx] = [head_word_idx, rel]
            
#     for item in d['dialog']:
#         turn = item['turn']
#         utterance = item['utterance']
#         # dep_lst:List[Dependency] = [Dependency(0, '[root]', -1, '_')]
#         sample:List[Dependency] = []

#         for word_idx, word in enumerate(utterance.split(' ')):
#             head_word_idx, rel = rel_dct[turn].get(word_idx + 1, [word_idx, 'adjct'])  # some word annoted missed, padded with last word and 'adjct'
#             sample.append(Dependency(word_idx + 1, word, head_word_idx, rel))  # start from 1

#         sample_lst.append(sample)

In [10]:
labels = []
edu_lst = []
for sample in tqdm(sample_lst):
    edu, label = [], []
    for i, dep in enumerate(sample[1:]):
        word = dep.word
        rel = dep.rel
        
        edu.append(dep)
        
        if word in signal_dct.keys():
            label.append(word)
            if signal_dct[f'{word}'] not in [21, 29, 33] and random.random() < 0.7:
                edu.pop(-1)
        elif f'{sample[i-1].word} {word}' in signal_dct.keys():
            popped = False
            if sample[i-1].word in signal_dct.keys():
                label.pop(-1)
                if sample[i-1].word != edu[-1].word:  # last word have dropped
                    edu.pop(-1)
                    popped = True
            label.append(f'{sample[i-1].word}{word}')
            if ~popped and signal_dct[f'{sample[i-1].word} {word}'] not in [21, 29, 33] and random.random() < 0.7:  # drop
                edu.pop(-1)
                if sample[i-1].word in signal_dct.keys() and sample[i-1].word == edu[-1].word:  # not pop before
                    edu.pop(-1)
        elif random.random() < 0.2: # drop
            edu.pop(-1)
        
        # 'edu'
        if word in punct_lst:
            # if len(label) == 0 and random.random() < 0.2:  # drop
            #     edu, label = [], []
            #     continue
            edu_lst.append(edu)
            labels.append(label)                                
            edu, label = [], []
    
    if len(sample) > 1:
        edu.append(sample[-1])
        if sample[-1].word in signal_dct.keys():
            label.append(word)
    
    # if len(label) == 0 and random.random() < 0.2:
    #     continue
        
    edu_lst.append(edu)
    labels.append(label)

100%|██████████| 236590/236590 [00:04<00:00, 53804.46it/s]


In [11]:
pos_cnt, neg_cnt = 0, 0
for label in labels:
    if len(label) != 0:
        pos_cnt += 1
    else:
        neg_cnt += 1
        
print(pos_cnt, neg_cnt)

115136 267022


In [12]:
tokenizer = AutoTokenizer.from_pretrained(CFG.plm)
print(len(tokenizer))

num_added_toks = tokenizer.add_tokens(['[root]', '[qst]', '[ans]', '[none]'], special_tokens=True)

tokenizer.eos_token = '[EOS]'
print(f'EOS: {tokenizer.eos_token_id}')

tokenizer.root_token = '[root]'
tokenizer.root_token_ids = tokenizer('[root]')['input_ids'][1]
print(f"add token: {tokenizer.root_token} {tokenizer.root_token_ids}")

tokenizer.qst_token = '[qst]'
tokenizer.qst_token_ids = tokenizer('[qst]')['input_ids'][1]
print(f"add token: {tokenizer.qst_token} {tokenizer.qst_token_ids}")

tokenizer.ans_token = '[ans]'
tokenizer.ans_token_ids = tokenizer('[ans]')['input_ids'][1]
print(f"add token: {tokenizer.ans_token} {tokenizer.ans_token_ids}")

tokenizer.none_token = '[none]'
tokenizer.none_token_ids = tokenizer('[none]')['input_ids'][1]
print(f"add token: {tokenizer.none_token} {tokenizer.none_token_ids}")

# tokenizer.signal_token = '[signal]'
# tokenizer.signal_token_ids = tokenizer('[signal]')['input_ids'][1]
# print(f"add token: {tokenizer.signal_token} {tokenizer.signal_token_ids}")

print(len(tokenizer))
CFG.tokenizer = tokenizer

21128
EOS: 100
add token: [root] 21128
add token: [qst] 21129
add token: [ans] 21130
add token: [none] 21131
21132


In [13]:
label2id = {}
for l in signal_dct.keys():
    label_word = ''.join(l.split(' '))
    label2id[label_word] = CFG.tokenizer(label_word)['input_ids'][1:-1]
print(label2id)

{'说': [6432], '表示': [6134, 4850], '看到': [4692, 1168], '显示': [3227, 4850], '知道': [4761, 6887], '认为': [6371, 711], '希望': [2361, 3307], '指出': [2900, 1139], '如果': [1963, 3362], '假如': [969, 1963], '的话': [4638, 6413], '若': [5735], '如': [1963], '要是': [6206, 3221], '倘若': [951, 5735], '因为': [1728, 711], '所以': [2792, 809], '导致': [2193, 5636], '因此': [1728, 3634], '造成': [6863, 2768], '由于': [4507, 754], '因而': [1728, 5445], '但是': [852, 3221], '可是': [1377, 3221], '但': [852], '竟': [4994], '却': [1316], '不过': [679, 6814], '居然': [2233, 4197], '而是': [5445, 3221], '反而': [1353, 5445], '以及': [809, 1350], '也': [738], '并': [2400], '并且': [2400, 684], '又': [1348], '或者': [2772, 5442], '对于': [2190, 754], '自从': [5632, 794], '上次': [677, 3613], '明天': [3209, 1921], '晚上': [3241, 677], '到时候': [1168, 3198, 952], '再': [1086], '然后': [4197, 1400], '接下来': [2970, 678, 3341], '最后': [3297, 1400], '随后': [7390, 1400], '为了': [711, 749], '使': [886], '为的': [711, 4638], '通过': [6858, 6814], '必须': [2553, 7557], '点击': [4157, 1140], '对吗'

In [14]:
# class EduDataset(Dataset):
#     def __init__(self, cfg, train, prompt, edu_lst, labels):
#         self.train = train
#         self.cfg = cfg
#         self.prompt = prompt
#         self.inputs, self.heads, self.rels, self.labels, self.masks = self.read_data(edu_lst, labels)
        
#     def read_data(self, edu_lst, labels):
#         inputs, offsets = [], []
#         tags, heads, rels, masks = [], [], [], []
        
#         ans_word_len = 3
#         answer_words = []
        
#         for deps, label in tqdm(zip(edu_lst, labels)):
#             seq_len = len(deps)

#             word_lst = [] 
#             rel_attr = {'input_ids':torch.Tensor(), 'token_type_ids':torch.Tensor(), 'attention_mask':torch.Tensor()}
#             head_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)  # same as root index is 0, constrainting by mask 
#             rel_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)
#             mask_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)
#             ans_tokens = torch.zeros((ans_word_len, len(self.cfg.tokenizer)))
#             for i, dep in enumerate(deps):
#                 if i == seq_len or i + 1== self.cfg.max_length:
#                     break

#                 word_lst.append(dep.word)

#                 if dep.head in ['_', '-1'] or int(dep.head) + 1 >= self.cfg.max_length:
#                     head_tokens[i+1] = 0
#                     mask_tokens[i+1] = 0
#                 else:
#                     head_tokens[i+1] = int(dep.head)
#                     mask_tokens[i+1] = 1

#                 if self.train:
#                     rel_tokens[i+1] = rel2id[dep.rel]
#                 else:
#                     rel_tokens[i+1] = rel2id.get(dep.rel, rel2id['adjct'])

#             word_lst = self.prompt.split(' ') + ['[SEP]'] + word_lst
            
#             tokenized = self.cfg.tokenizer.encode_plus(word_lst, 
#                                               padding='max_length', 
#                                               truncation=True,
#                                               max_length=self.cfg.max_length, 
#                                               return_offsets_mapping=False, 
#                                               return_tensors='pt',
#                                               is_split_into_words=True)
#             inputs.append({"input_ids": tokenized['input_ids'][0],
#                           "token_type_ids": tokenized['token_type_ids'][0],
#                            "attention_mask": tokenized['attention_mask'][0]
#                           })
            
#             if len(label) == 0:
#                 ans_tokens[0:ans_word_len, self.cfg.tokenizer.none_token_ids] = 1
#             else: 
#                 # label_id = self.cfg.tokenizer(l)['input_ids'][1:-1]
#                 label_id = label2id[label[0]]  # first label
#                 if len(label_id) > ans_word_len:
#                     label_id = label_id[:ans_word_len]
#                 ans_tokens[torch.arange(len(label_id)), label_id] = 1
#                 if len(label_id) < ans_word_len:
#                     ans_tokens[torch.arange(len(label_id), ans_word_len), self.cfg.tokenizer.eos_token_id] = 1  # padding
            
#             answer_words.append(ans_tokens)

#             heads.append(head_tokens)
#             rels.append(rel_tokens)
#             masks.append(mask_tokens)
                    
#         return inputs, heads, rels, answer_words, masks

#     def __getitem__(self, idx):
#         return self.inputs[idx], self.heads[idx], self.rels[idx], self.labels[idx], self.masks[idx]
    
#     def __len__(self):
#         return len(self.rels)

In [15]:
class EduDataset(IterableDataset):
    def __init__(self, cfg, train, prompt, edu_lst, labels):
        self.train = train
        self.cfg = cfg
        self.prompt = prompt
        self.edu_lst = edu_lst
        self.labels = labels
        
    def read_data(self, edu_lst, labels):
        inputs, offsets = [], []
        tags, heads, rels, masks = [], [], [], []
        
        ans_word_len = 3
        answer_words = []
        
        for deps, label in zip(edu_lst, labels):
            seq_len = len(deps)

            word_lst = [] 
            rel_attr = {'input_ids':torch.Tensor(), 'token_type_ids':torch.Tensor(), 'attention_mask':torch.Tensor()}
            head_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)  # same as root index is 0, constrainting by mask 
            rel_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)
            mask_tokens = np.zeros(self.cfg.max_length, dtype=np.int64)
            ans_tokens = torch.zeros((ans_word_len, len(self.cfg.tokenizer)))
            for i, dep in enumerate(deps):
                if i == seq_len or i + 1== self.cfg.max_length:
                    break

                word_lst.append(dep.word)

                if dep.head in ['_', '-1'] or int(dep.head) + 1 >= self.cfg.max_length:
                    head_tokens[i+1] = 0
                    mask_tokens[i+1] = 0
                else:
                    head_tokens[i+1] = int(dep.head)
                    mask_tokens[i+1] = 1

                if self.train:
                    rel_tokens[i+1] = rel2id[dep.rel]
                else:
                    rel_tokens[i+1] = rel2id.get(dep.rel, rel2id['adjct'])

            word_lst = self.prompt.split(' ') + ['[SEP]'] + word_lst
            
            tokenized = self.cfg.tokenizer.encode_plus(word_lst, 
                                              padding='max_length', 
                                              truncation=True,
                                              max_length=self.cfg.max_length, 
                                              return_offsets_mapping=False, 
                                              return_tensors='pt',
                                              is_split_into_words=True)
            tokenized = {"input_ids": tokenized['input_ids'][0],
                          "token_type_ids": tokenized['token_type_ids'][0],
                           "attention_mask": tokenized['attention_mask'][0]
                          }
            
            if len(label) == 0:
                ans_tokens[0:ans_word_len, self.cfg.tokenizer.none_token_ids] = 1
            else: 
                # label_id = self.cfg.tokenizer(l)['input_ids'][1:-1]
                label_id = label2id[label[0]]  # first label
                if len(label_id) > ans_word_len:
                    label_id = label_id[:ans_word_len]
                ans_tokens[torch.arange(len(label_id)), label_id] = 1
                if len(label_id) < ans_word_len:
                    ans_tokens[torch.arange(len(label_id), ans_word_len), self.cfg.tokenizer.eos_token_id] = 1  # padding
                
                for l in enumerate(label):
                    if l not in inter_signals.keys():
                        continue
                    label_id = label2id[label[0]]  # first label
                    if len(label_id) > ans_word_len:
                        label_id = label_id[:ans_word_len]
                    ans_tokens[torch.arange(len(label_id)), label_id] = 1
                    if len(label_id) < ans_word_len:
                        ans_tokens[torch.arange(len(label_id), ans_word_len), self.cfg.tokenizer.eos_token_id] = 1  # padding    
                        
                    if l != label[0]:  # two different label
                        ans_tokens[torch.arange(len(label_id), ans_word_len), :] /= 2
                    break
            
#             answer_words.append(ans_tokens)

#             heads.append(head_tokens)
#             rels.append(rel_tokens)
#             masks.append(mask_tokens)
                    
            yield tokenized, head_tokens, rel_tokens, ans_tokens, mask_tokens

    def __iter__(self):
        return self.read_data(self.edu_lst, self.labels)

In [16]:
prompt = '能 表现 篇章 语义 信号 的 词 是 ： [MASK] [MASK] [MASK]'
prompt_tokenized = CFG.tokenizer.encode_plus(prompt.split(' '), is_split_into_words=True, return_tensors='pt')
print(prompt_tokenized['input_ids'])

masked_idx = [x[1].item() for x in (prompt_tokenized['input_ids'] == CFG.tokenizer.mask_token_id).nonzero()]
CFG.masked_idx = masked_idx
print(masked_idx)

tensor([[ 101, 5543, 6134, 4385, 5063, 4995, 6427,  721,  928, 1384, 4638, 6404,
         3221, 8038,  103,  103,  103,  102]])
[14, 15, 16]


In [17]:
if CFG.debug:
    dataset = EduDataset(CFG, True, prompt, edu_lst[:3000], labels[:3000])
else:
    dataset = EduDataset(CFG, True, prompt, edu_lst, labels)

In [18]:
# logit = torch.randn(3, len(CFG.tokenizer))
# gt = dataset[0][-2]

# print(logit.shape)
# print(gt.shape)

# F.cross_entropy(logit, gt)

# Model

In [19]:
class PromptModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
    
        self.encoder = AutoModel.from_pretrained(cfg.plm)
        self.encoder.resize_token_embeddings(len(cfg.tokenizer))
        
        # for param in self.encoder.parameters():
        #     param.requires_grad = False

        self.mlm_head = nn.Linear(self.encoder.config.hidden_size, len(cfg.tokenizer))

        self.dropout = nn.Dropout(cfg.dropout)

    def forward(self, inputs, labels=None):
        feats, *_ = self.encoder(**inputs, return_dict=False)   # batch_size, seq_len (tokenized), plm_hidden_size
        feats = self.dropout(feats)

        masked_feats = feats[:, self.cfg.masked_idx, :]   # batch_size, masked_len, plm_hidden_size

        logit = self.mlm_head(masked_feats)  # batch_size, masked_len, vocab_size

        if labels is not None: # labels: [batch_size, masked_len, vocab_size]
            loss = F.cross_entropy(logit.view(-1, labels.size()[-1]), labels.view(-1, labels.size()[-1]), reduction='none')
            loss = loss.view(labels.size()[0], labels.size()[1])
            loss = loss.sum(1).mean()
            return logit, loss
        return logit

In [20]:
model = PromptModel(CFG)

Some weights of the model checkpoint at hfl/chinese-electra-180g-base-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense.bias', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# Trainer

In [21]:
class MyTrainer():
    def __init__(self, 
                 model,
                 trainset_size,
                 config: Dict) -> None:

        plm_params = [p for n,p in model.named_parameters() if 'encoder' in n]
        head_params = [p for n,p in model.named_parameters() if 'encoder' not in n]
        self.optim = AdamW([{'params': plm_params, 'lr':config.plm_lr}, 
                            {'params': head_params, 'lr':config.head_lr}], 
                            lr=config.plm_lr,
                            weight_decay=config.weight_decay
                          )
        
        training_step = int(config.num_epochs * (trainset_size / config.batch_size))
        warmup_step = int(config.warmup_ratio * training_step)  
        self.optim_schedule = get_linear_schedule_with_warmup(optimizer=self.optim, 
                                                              num_warmup_steps=warmup_step, 
                                                              num_training_steps=training_step)
        self.scaler = torch.cuda.amp.GradScaler(enabled=config.fp16)

        self.config = config

    def train(self, 
              model: nn.Module, 
              train_iter: DataLoader, 
              val_iter: DataLoader):
        model.train()
        if self.config.cuda and torch.cuda.is_available():
            model.cuda()
            pass
        
        best_res = [1e4, 0, 0]
        early_stop_cnt = 0
        best_state_dict = None
        step = 0
        for epoch in tqdm(range(self.config.num_epochs)):
            for batch in train_iter:
                inputs, heads, rels, labels, masks = batch
                
                if self.config.cuda and torch.cuda.is_available():
                    inputs_cuda = {}
                    for key, value in inputs.items():
                        inputs_cuda[key] = value.cuda()
                    inputs = inputs_cuda
                    
                    heads, rels, labels, masks = to_cuda(data=(heads, rels, labels, masks))
                
                logit, loss = model(inputs, labels)
                
                self.optim.zero_grad()
                if self.config.cuda and self.config.fp16:
                    self.scaler.scale(loss).backward()
                    self.scaler.unscale_(self.optim)
                else:
                    loss.backward()

                nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, model.parameters()), max_norm=self.config.grad_clip)

                if self.config.fp16:
                    self.scaler.step(self.optim)
                    self.scaler.update()
                else:
                    self.optim.step()
                self.optim_schedule.step()

                # metrics = self.metrics_fn(arc_logits, rel_logits, heads, rels, masks)

                if (step) % self.config.print_every == 0:
                    print(f"--epoch {epoch}, step {step}, loss {loss}")
                    # print(f"  {metrics}")
                    
                if (step + 1) % self.config.save_every == 0:
                    print(f'-------saving---------')
                    torch.save(model.state_dict(), f"../results/prompt_model_diag.pt")
                    # print(f"  {metrics}")

                if val_iter and (step) % self.config.eval_every == 0:
                    avg_loss = self.eval(model, val_iter)
                    res = [avg_loss]
                    if loss < best_res[0]:  # las
                        best_res = res
                        best_state_dict = model.state_dict()
                        torch.save(best_state_dict, f"../results/prompt_model_diag.pt")
                        early_stop_cnt = 0
                    else:
                        early_stop_cnt += 1
                    
                    print("--Best Evaluation: ")
                    print("-Loss: {}\n".format(*best_res))
                    # back to train mode
                    model.train()
                
                if early_stop_cnt >= self.config.num_early_stop:
                    print("--early stopping, training finished.")
                    return best_res, best_state_dict

                step += 1
        
        print("--training finished.")
        
        if best_state_dict is not None:
            return best_res, best_state_dict
        
        return best_res, model.state_dict()

    # eval func
    def eval(self, model: nn.Module, eval_iter: DataLoader, save_file: str = "", save_title: str = ""):
        model.eval()

        head_whole, rel_whole, mask_whole = torch.Tensor(), torch.Tensor(), torch.Tensor()
        logit_whole = torch.Tensor()
        avg_loss = 0.0
        for step, batch in enumerate(eval_iter):
            inputs, heads, rels, labels, masks = batch

            if self.config.cuda and torch.cuda.is_available():
                inputs_cuda = {}
                for key, value in inputs.items():
                    inputs_cuda[key] = value.cuda()
                inputs = inputs_cuda

                heads, rels, labels, masks = to_cuda(data=(heads, rels, labels, masks))
            
            with torch.no_grad():
                logit, loss = model(inputs, labels)

            logit_whole = torch.cat([logit_whole, logit.cpu()], dim=0)

            head_whole, rel_whole = torch.cat([head_whole, heads.cpu()], dim=0), torch.cat([rel_whole, rels.cpu()], dim=0)
            mask_whole = torch.cat([mask_whole, masks.cpu()], dim=0)

            avg_loss += loss.item() * len(heads)  # times the batch size of data


        avg_loss /= len(eval_iter.dataset)  # type: ignore

        print("--Evaluation:")
        print("Avg Loss: {}  \n".format(avg_loss))

        if save_file != "":
            results = [save_title, avg_loss, uas, las]  # type: ignore
            results = [str(x) for x in results]
            with open(save_file, "a+") as f:
                f.write(",".join(results) + "\n")  # type: ignore

        return avg_loss
    
    def save_results(self, save_file, save_title, results):
        saves = [save_title] + results
        saves = [str(x) for x in saves]
        with open(save_file, "a+") as f:
            f.write(",".join(saves) + "\n")  # type: ignore

# Training

In [22]:
# full_ids = list(range(len(dataset)))
# random.shuffle(full_ids)

# ratio = 0.0
# trainset_size = int(len(dataset) * (1 - ratio))
# train_ids = full_ids[:trainset_size]
# val_ids = full_ids[trainset_size:]

# train_dataset = Subset(dataset, train_ids)
# val_dataset = Subset(dataset, val_ids)

In [23]:
train_iter = DataLoader(dataset, batch_size=CFG.batch_size)
val_iter = DataLoader(dataset, batch_size=CFG.batch_size)

In [24]:
trainer = MyTrainer(model=model, trainset_size=len(edu_lst), config=CFG)
best_res, best_state_dict = trainer.train(model=model, train_iter=train_iter, val_iter=None)

  0%|          | 0/1 [00:00<?, ?it/s]

--epoch 0, step 0, loss 30.07476043701172
--epoch 0, step 100, loss 17.39851188659668
-------saving---------
--epoch 0, step 200, loss 5.753630638122559
--epoch 0, step 300, loss 2.9933180809020996
-------saving---------
--epoch 0, step 400, loss 1.7638427019119263
--epoch 0, step 500, loss 2.4559969902038574
-------saving---------
--epoch 0, step 600, loss 1.8246821165084839
--epoch 0, step 700, loss 1.4699232578277588
-------saving---------
--epoch 0, step 800, loss 1.3855056762695312
--epoch 0, step 900, loss 1.147589087486267
-------saving---------
--epoch 0, step 1000, loss 1.1596451997756958
--epoch 0, step 1100, loss 1.016486406326294
-------saving---------
--epoch 0, step 1200, loss 1.1765527725219727
--epoch 0, step 1300, loss 0.6073367595672607
-------saving---------
--epoch 0, step 1400, loss 0.5795507431030273
--epoch 0, step 1500, loss 1.0257056951522827
-------saving---------
--epoch 0, step 1600, loss 0.6995720267295837
--epoch 0, step 1700, loss 0.5993667840957642
-----

100%|██████████| 1/1 [35:27<00:00, 2127.36s/it]

--training finished.





In [25]:
trainer.save_results(save_file='../results/res.txt', save_title='', results=best_res)
torch.save(best_state_dict, f"../results/prompt_model_diag.pt")

In [None]:
import os
os.system('shutdown')