In [1]:
import os
from collections import Counter
from typing import *
import random
import json

In [2]:
from tqdm import tqdm
import numpy as np
from sklearn.model_selection import KFold

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torch.cuda.amp import autocast, GradScaler

from transformers import AutoConfig, AutoModel, AutoTokenizer, AdamW, get_constant_schedule_with_warmup, get_linear_schedule_with_warmup

In [3]:
from utils import to_cuda

# Seed

In [4]:
def seed_everything(seed=42):
    np.random.seed(seed%(2**32-1))
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic =True
    torch.backends.cudnn.benchmark = False

seed_everything()

# Data Preparation

In [5]:
def prepare_data():
    file = 'data/train_proc.json'
    with open(file, 'r', encoding='utf-8') as f:
        data = json.load(f)   
        
    punct_lst = ["，", ".", "。", "!", "?", "~", "...", "......"]
    single_max_len = 128
    concat_max_len = 64
    long_len = 12
    long_keep_prob = 0.6
    concat_prob = 0.5
    punct_probs = {
        "": 0.25,
        "，": 0.55,
        "。": 0.05,
        "!": 0.05,
        "?": 0.05,
        "~": 0.05,
    }
    special_tokens = {
        "Q": "[qst]",
        "A": "[ans]",
    }
    
    sample_lst = []
    for d in data:
        concat = {'Q': {
            'word_lst': [], 'split_ids': []
        }, 
                  'A': {
            'word_lst': [], 'split_ids': []
        }}

        for turn, dialog in enumerate(d.get('dialog')):
            speaker, utterance = dialog.get('speaker'), dialog.get('utterance')
            word_lst = utterance.split(' ')
            if len(word_lst) >= single_max_len:
                continue    
        
            if len(word_lst) >= long_len:
                if random.random() < long_keep_prob:
                    split_lst = []
                    for idx, word in enumerate(word_lst[:-1]):
                        if word_lst[idx+1] not in punct_lst and random.random() < punct_probs.get(word, 0.0):
                            split_lst.append(idx + 1)
                    if len(split_lst) == 0 or split_lst[-1] != len(word_lst):
                        split_lst.append(len(word_lst))
                    sample_lst.append({'word_lst': [special_tokens[speaker]] + word_lst, 'split_ids': split_lst})
                if len(concat[speaker]['word_lst']) != 0:
                    sample = {'word_lst': [special_tokens[speaker]] + concat[speaker]['word_lst'], 'split_ids': concat[speaker]['split_ids']}
                    sample_lst.append(sample)
                    concat[speaker] = {'word_lst': [], 'split_ids': []}
                continue

            if word_lst[-1] not in punct_lst:
                punct = random.choices(list(punct_probs.keys()), weights=punct_probs.values(), k=1)
                if punct != ['']:
                    word_lst += punct
            
            all_punct = True
            for word in word_lst:
                if word not in punct_lst:
                    all_punct = False
                    break

            if all_punct:
                # pop last split index
                if len(concat[speaker]['split_ids']) != 0:
                    concat[speaker]['split_ids'].pop()
                else:
                    continue

            concat[speaker]['word_lst'].extend(word_lst)
            concat[speaker]['split_ids'].append(len(concat[speaker]['word_lst']))
            
            # add sample
            if len(concat[speaker]['word_lst']) >= concat_max_len - 1 or random.random() > concat_prob:
                if concat[speaker]['word_lst'][-1] == '，':  # '，' is seldom appear at the end of a sentence
                    concat[speaker]['word_lst'] = concat[speaker]['word_lst'][:-1]
                    concat[speaker]['split_ids'][-1] -= 1
                sample = {'word_lst': [special_tokens[speaker]] + concat[speaker]['word_lst'], 'split_ids': concat[speaker]['split_ids']}
                sample_lst.append(sample)
                concat[speaker] = {'word_lst': [], 'split_ids': []}  
                     
    print(len(sample_lst))
    
    ret = []
    for sample in sample_lst:
        if len(sample['word_lst']) <= concat_max_len:
            ret.append(sample)
    
    sample_lst = ret
    del ret
    print(len(sample_lst))
    
    with open('data/seg_with_train.json', 'w', encoding='utf-8') as f:
        save_str = json.dumps(sample_lst, ensure_ascii=False, indent=4, separators=(',', ': '))
        f.write(save_str)

In [6]:
prepare_data()

127218
126800


## Config

In [7]:
class CFG:
    data_file = 'data/seg_with_train.json'
    # data_file = 'data_testset/1to500_1012.json'
    plm = 'hfl/chinese-electra-180g-large-discriminator'
    num_folds = 5
    trn_folds = [0]  # only one fold, as splitting train/val randomly
    # trn_folds = [0, 1, 2, 3, 4]
    random_seed = 42
    num_epochs = 3
    batch_size = 48
    max_length = 128
    # batch_size = 32
    # max_length = 160
    num_labels = 2
    lr = 2e-5
    weight_decay = 0.01
    dropout = 0.2
    grad_clip = 1
    scheduler = 'linear'
    warmup_ratio = 0.1
    num_early_stop = 5
    hidden_size = 400
    print_every = 500  
    eval_every = 1000
    # print_every = 200  
    # eval_every = 300
    cuda = True
    fp16 = True
    debug = False

## Data

In [8]:
with open(CFG.data_file, 'r', encoding='utf-8') as f:
    data = json.load(f)

# with open(CFG.data_file, 'r', encoding='utf-8') as f:
#     data = json.load(f)
    
# data = data[:500]  # 500 data have been annotated

In [9]:
tokenizer = AutoTokenizer.from_pretrained(CFG.plm)
print(len(tokenizer))
 
num_added_toks = tokenizer.add_tokens(['[root]', '[qst]', '[ans]'], special_tokens=True)
tokenizer.root_token = '[root]'
tokenizer.root_token_ids = tokenizer('[root]')['input_ids'][1]
print(f"add token: {tokenizer.root_token} {tokenizer.root_token_ids}")

tokenizer.qst_token = '[qst]'
tokenizer.qst_token_ids = tokenizer('[qst]')['input_ids'][1]
print(f"add token: {tokenizer.qst_token} {tokenizer.qst_token_ids}")

tokenizer.ans_token = '[ans]'
tokenizer.ans_token_ids = tokenizer('[ans]')['input_ids'][1]
print(f"add token: {tokenizer.ans_token} {tokenizer.ans_token_ids}")
print(len(tokenizer))

CFG.tokenizer = tokenizer

21128
add token: [root] 21128
add token: [qst] 21129
add token: [ans] 21130
21131


In [10]:
class EduDataset(Dataset):
    def __init__(self, cfg, data):
        self.cfg = cfg
        self.data = data
        self.inputs, self.offsets, self.tags = self.load_data()
    
    def load_data(self):
        inputs, offsets, tags = [], [], []
            
        for sample in tqdm(self.data):
            word_lst, split_ids = sample['word_lst'], sample['split_ids']
            tokenized = self.cfg.tokenizer.encode_plus(word_lst, 
                                                       padding='max_length', 
                                                       truncation=True,
                                                       max_length=self.cfg.max_length + 2,  # reserved for cls and sep
                                                       return_offsets_mapping=True, 
                                                       return_tensors='pt',
                                                       is_split_into_words=True)
            
            inputs.append({"input_ids": tokenized['input_ids'][0],
                           "token_type_ids": tokenized['token_type_ids'][0],
                           "attention_mask": tokenized['attention_mask'][0]
                          })
            
            sentence_word_idx = []
            for idx, (start, end) in enumerate(tokenized.offset_mapping[0][1:]):
                if start == 0 and end != 0:
                    sentence_word_idx.append(idx)
            if len(sentence_word_idx) < self.cfg.max_length - 1:
                sentence_word_idx.extend([0]* (self.cfg.max_length - len(sentence_word_idx)))
            offsets.append(torch.as_tensor(sentence_word_idx))
            
            # ignore cls for convenience
            tag = torch.full(size=(CFG.max_length, ), fill_value=-1, dtype=torch.long)
            tag[0:split_ids[-1]+1] = 0
            tag[split_ids[:-1]] = 1
            
            tags.append(tag)  
            
        return inputs, offsets, tags
    
    def __getitem__(self, idx):
        return self.inputs[idx], self.offsets[idx], self.tags[idx]
    
    def __len__(self):
        return len(self.inputs)

In [11]:
tokenized = tokenizer.encode_plus(data[0]['word_lst'], 
                                   padding='max_length', 
                                   truncation=True,
                                   max_length=CFG.max_length + 2,  # reserved for cls and sep
                                   return_offsets_mapping=True, 
                                   return_tensors='pt',
                                   is_split_into_words=True)

In [12]:
tokenized['input_ids']

tensor([[  101, 21130,   779,   779,  8024,  3300,   784,   720,  7309,  7579,
          2769,  1377,   809,  2376,  2644,  1905,  4415,  2772,  6237,  1104,
          4638,  1450,   136,   102,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,  

In [13]:
# class Dependency():
#     def __init__(self, idx, word, head, rel):
#         self.id = idx
#         self.word = word
#         self.tag = '_'
#         self.head = head
#         self.rel = rel

#     def __str__(self):
#         # example:  1	上海	_	NR	NR	_	2	nn	_	_
#         values = [str(self.idx), self.word, "_", self.tag, "_", "_", str(self.head), self.rel, "_", "_"]
#         return '\t'.join(values)

#     def __repr__(self):
#         return f"({self.word}, {self.tag}, {self.head}, {self.rel})"

# rst_dct = {
#     'attr': '归属',
#     'bckg': '背景',
#     'cause': '因果',
#     'comp': '比较',
#     'cond': '状况',
#     'cont': '对比',
#     'elbr': '阐述',
#     'enbm': '目的',
#     'eval': '评价',
#     'expl': '解释-例证',
#     'joint': '联合',
#     'manner': '方式',
#     'rstm': '重申',
#     'temp': '时序',
#     'tp-chg': '主题变更',
#     'prob-sol': '问题-解决',
#     'qst-ans': '疑问-回答',
#     'stm-rsp': '陈述-回应',
#     'req-proc': '需求-处理',
# }

# rst_lst = [x for x in rst_dct.keys()]
# print(rst_lst)

# def data_gen(data):
#     special_tokens = {
#         "Q": "[qst]",
#         "A": "[ans]",
#     }    

#     for i, d in enumerate(data):
#         rel_dct = {}
#         for tripple in d['relationship']:
#             head, rel, tail = tripple
#             head_uttr_idx, head_word_idx = [int(x) for x in head.split('-')]
#             tail_uttr_idx, tail_word_idx = [int(x) for x in tail.split('-')]
#             if head_uttr_idx != tail_uttr_idx:
#                 continue

#             if not rel_dct.get(head_uttr_idx, None):
#                 rel_dct[head_uttr_idx] = {tail_word_idx: [head_word_idx, rel]}
#             else:
#                 rel_dct[head_uttr_idx][tail_word_idx] = [head_word_idx, rel]


#         for item in d['dialog']:
#             turn = item['turn']
#             speaker = item['speaker']
#             utterance = item['utterance']
#             dep_lst:List[Dependency] = []

#             edus = [[-1, -1]]  # 0 for root
#             inner_rels, inter_rels = [], []
#             for word_idx, word in enumerate(utterance.split(' ')):
#                 head_word_idx, rel = rel_dct[turn].get(word_idx + 1, [word_idx, 'adjct'])  # some word annoted missed, padded with last word and 'adjct'

#                 if rel in rst_lst:
#                     inter_rels.append(rel)
#                     edus.append([word_idx + 1, word_idx + 1])
#                     continue

#                 expand, include = False, False
#                 for edu in edus:
#                     if edu[0] <= head_word_idx <= edu[1]:
#                         edu[0] = min(edu[0], word_idx + 1)
#                         edu[1] = max(edu[1], word_idx + 1)
#                         expand = True
#                         break
#                     elif edu[0] <= word_idx + 1 <= edu[1] and head_word_idx != 0:
#                         edu[0] = min(edu[0], head_word_idx)
#                         edu[1] = max(edu[1], head_word_idx)
#                         include = True
#                         break

#                 if not expand and not include:
#                     if head_word_idx == 0:  # ignore root
#                         edus.append([word_idx + 1, word_idx + 1])
#                     else:
#                         # max with 1 to ignore root index
#                         edus.append([max(min(word_idx + 1, head_word_idx), 1), max(word_idx + 1, head_word_idx)])
#                 inner_rels.append(rel)

#                 dep_lst.append(Dependency(word_idx + 1, word, head_word_idx, rel))  # start from 1

#             finals = [edus[1]]
#             if len(edus) >= 2:
#                 for edu in edus[2:]:
#                     if finals[-1][0] <= edu[0] <= edu[1] <= finals[-1][1]:
#                         continue
#                     finals.append(edu)

#             split_ids = [x[1] for x in finals]

#             yield [special_tokens[speaker]] + utterance.split(' '), split_ids

In [14]:
# class EduDataset(Dataset):
#     def __init__(self, cfg, data):
#         self.cfg = cfg
#         self.data = data
#         self.inputs, self.offsets, self.tags = self.load_data()
    
#     def load_data(self):
#         inputs, offsets, tags = [], [], []
            
#         for word_lst, split_ids in tqdm(data_gen(self.data)):
#             tokenized = self.cfg.tokenizer.encode_plus(word_lst, 
#                                                        padding='max_length', 
#                                                        truncation=True,
#                                                        max_length=self.cfg.max_length + 2,  # reserved for cls and sep
#                                                        return_offsets_mapping=True, 
#                                                        return_tensors='pt',
#                                                        is_split_into_words=True)
            
#             inputs.append({"input_ids": tokenized['input_ids'][0],
#                            "token_type_ids": tokenized['token_type_ids'][0],
#                            "attention_mask": tokenized['attention_mask'][0]
#                           })
            
#             sentence_word_idx = []
#             for idx, (start, end) in enumerate(tokenized.offset_mapping[0][1:]):
#                 if start == 0 and end != 0:
#                     sentence_word_idx.append(idx)
#             if len(sentence_word_idx) < self.cfg.max_length - 1:
#                 sentence_word_idx.extend([0]* (self.cfg.max_length - len(sentence_word_idx)))
#             offsets.append(torch.as_tensor(sentence_word_idx))
            
#             # ignore cls for convenience
#             tag = torch.full(size=(CFG.max_length, ), fill_value=-1, dtype=torch.long)
#             tag[0:split_ids[-1]] = 0
#             tag[split_ids[:-1]] = 1
            
#             tags.append(tag)  
            
#         return inputs, offsets, tags
    
#     def __getitem__(self, idx):
#         return self.inputs[idx], self.offsets[idx], self.tags[idx]
    
#     def __len__(self):
#         return len(self.inputs)

In [15]:
if CFG.debug:
    data = data[:3000]
    CFG.print_every = 20
    CFG.eval_every = 40

In [16]:
dataset = EduDataset(CFG, data)

100%|██████████| 126800/126800 [04:19<00:00, 488.97it/s]


In [17]:
class SegModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        
        self.cfg = cfg
        
        self.encoder = AutoModel.from_pretrained(cfg.plm)
        self.encoder.resize_token_embeddings(len(cfg.tokenizer))
        
        # self.gru = nn.GRU(self.encoder.config.hidden_size, self.encoder.config.hidden_size//2, bidirectional=True, num_layers=1, batch_first=True)
        
        self.mlp = nn.Linear(self.encoder.config.hidden_size, cfg.num_labels)
        self.dropout = nn.Dropout(cfg.dropout)
        
    def feat(self, inputs, offsets):
        length = torch.sum(inputs["attention_mask"], dim=-1) - 2
        
        feats, *_ = self.encoder(**inputs, return_dict=False)   # batch_size, seq_len (tokenized), plm_hidden_size
           
        # remove [CLS] [SEP]
        cls_feat = feats[:, :1]
        char_feat = torch.narrow(feats, 1, 1, feats.size(1) - 2)
        return cls_feat, char_feat, length
        
    def forward(self, inputs, offsets, tags=None):
        cls_feat, char_feat, char_len = self.feat(inputs, offsets)
        
        word_idx = offsets.unsqueeze(-1).expand(-1, -1, char_feat.shape[-1])  # expand to the size of char feat
        word_feat = torch.gather(char_feat, dim=1, index=word_idx)  # embeddings of first char in each word
        
        # feats, _ = self.gru(self.dropout(word_feat))
        
        feats = self.dropout(word_feat)
        logits = self.mlp(feats)
        
        if tags is not None:
            loss = nn.CrossEntropyLoss(reduction='mean', ignore_index=-1)(logits.view(-1, logits.size(-1)), tags.long().view(-1))
            return torch.softmax(logits, dim=-1), loss

        return torch.softmax(logits, dim=-1)

In [18]:
# model = SegModel(CFG)

In [19]:
# data_iter = DataLoader(dataset)

In [20]:
# for i, batch in enumerate(data_iter):
#     print(batch[0])
#     print(tokenizer.decode(batch[0]['input_ids'][0]))
#     print(batch[2])
    
#     if i == 1:
#         break

In [21]:
def metrics_fn(predictions, labels):
    predictions = predictions.view(-1, predictions.size(-1))
    labels = labels.view(-1)
    
    pred_detected = (predictions[:, 1] > 0.5)
    label_detected = (labels == 1)
    co = pred_detected * (pred_detected == label_detected)
    
    if sum(pred_detected) == 0:
        precision = 0
    else:
        precision = sum(co) / sum(pred_detected)
    recall = sum(co) / sum(label_detected)
    f1 = (2 * precision * recall) / (precision + recall)

    return {"f1": f1.item()}

In [22]:
# metrics_fn(predictions=logits, labels=tags)

In [23]:
# for inputs, offsets, tags in data_iter:
#     logits, loss = model(inputs, offsets, tags)
    
#     metrics = metrics_fn(predictions=logits, labels=tags)
#     print(metrics)

In [24]:
class MyTrainer():
    def __init__(self, 
                 optim,
                 lr_scheduler,
                 trainset_size,
                 metrics_fn: Callable, 
                 config: Dict) -> None:

        self.optim = optim
        self.optim_schedule = lr_scheduler
        
        self.metrics_fn = metrics_fn
        
        self.scaler = torch.cuda.amp.GradScaler(enabled=config.fp16)

        self.config = config

    def train(self, 
              model: nn.Module, 
              train_iter: DataLoader, 
              val_iter: DataLoader):
        model.train()
        if self.config.cuda and torch.cuda.is_available():
            model.cuda()
            pass
        
        best_res = [0, 0, 0]
        early_stop_cnt = 0
        best_state_dict = None
        step = 0
        for epoch in tqdm(range(self.config.num_epochs)):
            for batch in train_iter:
                inputs, offsets, tags = batch
                
                if self.config.cuda and torch.cuda.is_available():
                    inputs_cuda = {}
                    for key,value in inputs.items():
                        inputs_cuda[key] = value.cuda()
#                         inputs_cuda[key] = value
                    inputs = inputs_cuda
                    offsets, tags = to_cuda(data=(offsets, tags))
                
                logits, loss = model(inputs, offsets, tags)
                
                self.optim.zero_grad()
                if self.config.cuda and self.config.fp16:
                    self.scaler.scale(loss).backward()
                    self.scaler.unscale_(self.optim)
                else:
                    loss.backward()

                nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, model.parameters()), max_norm=self.config.grad_clip)

                if self.config.fp16:
                    self.scaler.step(self.optim)
                    self.scaler.update()
                else:
                    self.optim.step()
                self.optim_schedule.step()

                metrics = self.metrics_fn(predictions=logits, labels=tags)

                if (step) % self.config.print_every == 0:
                    print(f"--epoch {epoch}, step {step}, loss {loss}")
                    print(f"  {metrics}")

                if val_iter and (step + 1) % self.config.eval_every == 0:
                    avg_loss, metrics = self.eval(model, val_iter)
                    res = [avg_loss, metrics['f1']]
                    if metrics['f1'] > best_res[1]:  # f1
                        best_res = res
                        best_state_dict = model.state_dict()
                        early_stop_cnt = 0
                    else:
                        early_stop_cnt += 1
                    print("--Best Evaluation: ")
                    print("-Loss: {}  F1: {} \n".format(*best_res))
                    # back to train mode
                    model.train()
                
                if early_stop_cnt >= self.config.num_early_stop:
                    print("--early stopping, training finished.")
                    return best_res, best_state_dict

                step += 1
        print("--training finished.")
        return best_res, best_state_dict

    # eval func
    def eval(self, model: nn.Module, eval_iter: DataLoader, save_file: str = "", save_title: str = ""):
        model.eval()

        avg_loss, step = 0.0, 0
        logits_whole, tags_whole = torch.Tensor(), torch.Tensor()
        for step, batch in enumerate(eval_iter):
            inputs, offsets, tags = batch
                
            if self.config.cuda and torch.cuda.is_available():
                inputs_cuda = {}
                for key,value in inputs.items():
                    inputs_cuda[key] = value.cuda()
                inputs = inputs_cuda
                offsets, tags = to_cuda(data=(offsets, tags))
            
            with torch.no_grad():
                logits, loss = model(inputs, offsets, tags)
                
            logits_whole = torch.cat([logits_whole, logits.cpu()], dim=0)
            tags_whole = torch.cat([tags_whole, tags.cpu()], dim=0)

            avg_loss += loss * len(tags)  # times the batch size of data

        metrics = self.metrics_fn(predictions=logits_whole, labels=tags_whole)
        
        avg_loss /= len(eval_iter.dataset)
        print("--Evaluation:")
        print("-Loss: {}  F1: {} \n".format(avg_loss, metrics['f1']))

        if save_file != "":
            results = [save_title, avg_loss.item(), metrics['f1']]  # type: ignore
            results = [str(x) for x in results]
            with open(save_file, "a+") as f:
                f.write(",".join(results) + "\n")  # type: ignore

        return avg_loss.item(), metrics   # type: ignore
    
    def save_results(self, save_file, save_title, results):
        saves = [save_title] + results
        saves = [str(x) for x in saves]
        with open(save_file, "a+") as f:
            f.write(",".join(saves) + "\n")  # type: ignore

## Training

In [25]:
kfold = KFold(n_splits=CFG.num_folds, shuffle=True)

In [26]:
for fold, (train_ids, val_ids) in enumerate(kfold.split(dataset)):
    print(f'FOLD {fold}')
    print(f'{len(train_ids)}/{len(val_ids)}')
    print('--------------------------------')
    
    if fold not in CFG.trn_folds:
        break

    if CFG.cuda and torch.cuda.is_available:
        torch.cuda.empty_cache()
 
    random.shuffle(train_ids)
    random.shuffle(val_ids)

    tr_dataset = Subset(dataset, train_ids)
    va_dataset = Subset(dataset, val_ids)
    
    tr_iter = DataLoader(tr_dataset, batch_size=CFG.batch_size)
    va_iter = DataLoader(va_dataset, batch_size=CFG.batch_size)
    
    model = SegModel(CFG)
    
    optim = AdamW(model.parameters(), 
                      lr=CFG.lr,
                      weight_decay=CFG.weight_decay
                      )

    training_step = int(CFG.num_epochs * (len(train_ids) / CFG.batch_size))
    warmup_step = int(CFG.warmup_ratio * training_step)  
    lr_scheduler = get_linear_schedule_with_warmup(optimizer=optim, 
                                                   num_warmup_steps=warmup_step, 
                                                   num_training_steps=training_step)

    trainer = MyTrainer(optim=optim, 
                        lr_scheduler=lr_scheduler,
                        trainset_size=len(train_ids), 
                        metrics_fn=metrics_fn, 
                        config=CFG)
    
    best_res, best_state_dict = trainer.train(model=model, train_iter=tr_iter, val_iter=va_iter)
    print(best_res)
    with open("/root/autodl-tmp/diag_dep/edu_seg/res.txt", 'a+') as f:
        f.write(f'{fold}\t {str(best_res)}\n')
    
    torch.save(best_state_dict, f"/root/autodl-tmp/diag_dep/edu_seg/{fold}/model.bin")

FOLD 0
101440/25360
--------------------------------


Some weights of the model checkpoint at hfl/chinese-electra-180g-large-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  0%|          | 0/3 [00:00<?, ?it/s]

--epoch 0, step 0, loss 0.8980361819267273
  {'f1': 0.010603728704154491}
--epoch 0, step 500, loss 0.05390365794301033
  {'f1': 0.8311688303947449}
--Evaluation:
-Loss: 0.04614398628473282  F1: 0.7758190631866455 

--Best Evaluation: 
-Loss: 0.04614398628473282  F1: 0.7758190631866455 

--epoch 0, step 1000, loss 0.05576056241989136
  {'f1': 0.8253968954086304}
--epoch 0, step 1500, loss 0.052981190383434296
  {'f1': 0.7241379618644714}
--Evaluation:
-Loss: 0.04305392503738403  F1: 0.7697447538375854 

--Best Evaluation: 
-Loss: 0.04614398628473282  F1: 0.7758190631866455 

--epoch 0, step 2000, loss 0.03807033598423004
  {'f1': 0.8405797481536865}


 33%|███▎      | 1/3 [36:28<1:12:56, 2188.35s/it]

--epoch 1, step 2500, loss 0.033386681228876114
  {'f1': 0.8399999737739563}
--Evaluation:
-Loss: 0.04917750880122185  F1: 0.7673178911209106 

--Best Evaluation: 
-Loss: 0.04614398628473282  F1: 0.7758190631866455 

--epoch 1, step 3000, loss 0.05689133703708649
  {'f1': 0.8108108043670654}
--epoch 1, step 3500, loss 0.030760589987039566
  {'f1': 0.8888888955116272}
--Evaluation:
-Loss: 0.04266449436545372  F1: 0.7940782904624939 

--Best Evaluation: 
-Loss: 0.04266449436545372  F1: 0.7940782904624939 

--epoch 1, step 4000, loss 0.04126041382551193
  {'f1': 0.8070175051689148}


 67%|██████▋   | 2/3 [1:12:13<36:02, 2162.96s/it]

--epoch 2, step 4500, loss 0.04014455899596214
  {'f1': 0.7894737124443054}
--Evaluation:
-Loss: 0.04350907728075981  F1: 0.7914131879806519 

--Best Evaluation: 
-Loss: 0.04266449436545372  F1: 0.7940782904624939 

--epoch 2, step 5000, loss 0.0398637130856514
  {'f1': 0.7999999523162842}
--epoch 2, step 5500, loss 0.03099857084453106
  {'f1': 0.8125}
--Evaluation:
-Loss: 0.04599112644791603  F1: 0.7894276976585388 

--Best Evaluation: 
-Loss: 0.04266449436545372  F1: 0.7940782904624939 

--epoch 2, step 6000, loss 0.03410981595516205
  {'f1': 0.8571429252624512}


100%|██████████| 3/3 [1:47:31<00:00, 2150.53s/it]


--training finished.
[0.04266449436545372, 0.7940782904624939]
FOLD 1
101440/25360
--------------------------------


In [None]:
import os
os.system('shutdown')