# Library

In [1]:
# Suppress warnings
import warnings
warnings.filterwarnings("ignore")
import os
from os.path import join
from tqdm import tqdm
from collections import defaultdict as dd
from bs4 import BeautifulSoup
from fuzzywuzzy import fuzz
import numpy as np
import torch
from transformers import AutoTokenizer
from transformers import BertForSequenceClassification, AutoModelForSequenceClassification
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
from transformers.optimization import AdamW
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from tqdm import trange
from sklearn.metrics import classification_report, precision_recall_fscore_support, average_precision_score
import logging
import utils
import settings
from torch.cuda.amp import autocast, GradScaler

import gc
from types import SimpleNamespace

from awp import AWP
from PST_model import * 

2024-06-10 22:30:22,195 Note: NumExpr detected 16 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2024-06-10 22:30:22,196 NumExpr defaulting to 8 threads.


# Config

In [2]:
# 创建一个字典来存储配置参数
config_dict = {
    "GRADIENT_ACCUMULATION_STEPS": 1,
    "NUM_TRAIN_EPOCHS": 20,
    "LEARNING_RATE": 2e-5,
    "WARMUP_PROPORTION": 0.1,
    "MAX_GRAD_NORM": 1000,
    'differential_learning_rate': 5e-4,
    'differential_learning_rate_layers': 'head',
    # "eps": 1e-6,
    # "betas": (0.9, 0.999),
    "weight_decay": 0.01,
    "seed": 42,
    "MAX_SEQ_LENGTH": 512,
    'model_name': 'scibert',  # roberta-large, deberta-base, distilbert
    'clean_flag': True,
    'train_mask': False,
    'dist': 200,
    'optim_strategy': 'normal',  # normal, llrd
    'PATIENCE': 3,
    'Debug': False,
    'scheduler': 'cosine',  # linear, cosine
    'amp': True,
    'pool': 'ConcatPool', # WLP, GeM, Mean, ConcatPool, AP
    'loss_function': 'CrossEntropy',
    # 'model_class': 'PST_model',
    'BATCH_SIZE': 8,
    'awp_enable': True,
    'awp_start_epoch': 2,
    'layer_start': 9,
    'reinit_n_layers': 0,
}

# 将字典转换为 SimpleNamespace 对象
cfg = SimpleNamespace(**config_dict)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cfg.device = device



# 日志
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger(__name__)


# Function

In [3]:
def clean_memory():
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
def prepare_bert_input(cfg):
    x_train = []
    y_train = []
    x_valid = []
    y_valid = []

    data_dir = join(settings.DATA_TRACE_DIR, "PST")  # 文件路径
    papers = utils.load_json(data_dir, "paper_source_trace_train_ans.json")
    n_papers = len(papers)  # 训练集 788样本

    papers = sorted(papers, key=lambda x: x["_id"])  # 按id字段排序
    n_train = int(n_papers * 2 / 3)
    # n_valid = n_papers - n_train
    # 划分，525训练，263验证
    papers_train = papers[:n_train]
    papers_valid = papers[n_train:]

    # 获取每个paper的id
    pids_train = {p["_id"] for p in papers_train}
    pids_valid = {p["_id"] for p in papers_valid}

    in_dir = join(data_dir, "paper-xml")  # .xml原文文件夹路径
    files = []
    for f in os.listdir(in_dir):
        if f.endswith(".xml"):
            files.append(f)  # 获取所有原文
    # 把所有样本的源论文题目找出来
    pid_to_source_titles = dd(list)
    for paper in tqdm(papers):
        pid = paper["_id"]
        for ref in paper["refs_trace"]:
            pid_to_source_titles[pid].append(ref["title"].lower())

    for cur_pid in tqdm(pids_train | pids_valid):
        f = open(join(in_dir, cur_pid + ".xml"), encoding='utf-8')
        xml = f.read()
        bs = BeautifulSoup(xml, "xml")  # 用BS解析xml内容

        source_titles = pid_to_source_titles[cur_pid]
        if len(source_titles) == 0:
            continue

        references = bs.find_all("biblStruct")  # 找到所有完整的参考文献引用
        bid_to_title = {}
        n_refs = 0
        for ref in references:
            if "xml:id" not in ref.attrs:  # 没有id属性则跳过，所以实际数量可能会少于len(references)
                continue
            bid = ref.attrs["xml:id"]  # bib ID
            if ref.analytic is None:
                continue
            if ref.analytic.title is None:
                continue
            bid_to_title[bid] = ref.analytic.title.text.lower()  # 把ID和标题加入到字典中
            b_idx = int(bid[1:]) + 1
            if b_idx > n_refs:
                n_refs = b_idx  # 更新参考文献数量
        
        flag = False

        cur_pos_bib = set()

        for bid in bid_to_title:
            cur_ref_title = bid_to_title[bid]  # 获取当前参考文献题目
            for label_title in source_titles:
                if fuzz.ratio(cur_ref_title, label_title) >= 80:
                    flag = True
                    cur_pos_bib.add(bid)  # 使用fuzz.ratio()判断当前参考文献题目和refs_trace题目的相似度，大于80则加入到cur_pos_bib

        # 获取不匹配的参考文献集合，negative
        cur_neg_bib = set(bid_to_title.keys()) - cur_pos_bib
        # 找不到匹配的文献则跳过当前循环
        if not flag:
            continue
    
        if len(cur_pos_bib) == 0 or len(cur_neg_bib) == 0:
            continue
    
        bib_to_contexts = utils.find_bib_context(xml, clean_flag=cfg.clean_flag, dist=cfg.dist)  # 获取上下文信息

        n_pos = len(cur_pos_bib)
        n_neg = n_pos * 10
        # 正负样本比例为1：10
        cur_neg_bib_sample = np.random.choice(list(cur_neg_bib), n_neg, replace=True)

        if cur_pid in pids_train:
            cur_x = x_train
            cur_y = y_train
        elif cur_pid in pids_valid:
            cur_x = x_valid
            cur_y = y_valid
        else:
            continue
            # raise Exception("cur_pid not in train/valid/test")
        
        for bib in cur_pos_bib:
            cur_context = " ".join(bib_to_contexts[bib])
            cur_x.append(cur_context)
            cur_y.append(1)
    
        for bib in cur_neg_bib_sample:
            cur_context = " ".join(bib_to_contexts[bib])
            cur_x.append(cur_context)
            cur_y.append(0)
            
    print("len(x_train)", len(x_train), "len(x_valid)", len(x_valid))
    with open(join(data_dir, "bib_context_train.txt"), "w", encoding="utf-8") as f:
        for line in x_train:
            f.write(line + "\n")
    
    with open(join(data_dir, "bib_context_valid.txt"), "w", encoding="utf-8") as f:
        for line in x_valid:
            f.write(line + "\n")
    
    with open(join(data_dir, "bib_context_train_label.txt"), "w", encoding="utf-8") as f:
        for line in y_train:
            f.write(str(line) + "\n")
    
    with open(join(data_dir, "bib_context_valid_label.txt"), "w", encoding="utf-8") as f:
        for line in y_valid:
            f.write(str(line) + "\n")


In [4]:
class BertInputItem(object):
    """An item with all the necessary attributes for finetuning BERT."""

    def __init__(self, text, input_ids, input_mask, segment_ids, label_id):
        self.text = text
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_id = label_id


def random_masking(text, tokenizer, max_length=512, mask_ratio=0.15):
    # 对文本进行分词
    tokens = tokenizer.tokenize(text)
    # 确保分词后的序列长度小于最大序列长度
    tokens = tokens[:(max_length - 2)]  # 减2是因为要加上[CLS]和[SEP]标记

    # 计算需要掩码的Token数量
    num_tokens = len(tokens)
    num_to_mask = int(mask_ratio * num_tokens)

    # 随机选择Token进行掩码
    indices_to_mask = np.random.choice(range(num_tokens), num_to_mask, replace=False)

    # 将选中的Token替换为[MASK]
    for i in sorted(indices_to_mask, reverse=True):
        tokens[i] = tokenizer.mask_token

    # 添加BERT的序列标记
    tokens = ['[CLS]'] + tokens + ['[SEP]']

    # 将Token转换为ID
    input_ids = tokenizer.convert_tokens_to_ids(tokens)

    return input_ids


def convert_examples_to_inputs(example_texts, example_labels, max_seq_length, tokenizer, mask, verbose=0):
    """Loads a data file into a list of `InputBatch`s."""
    
    input_items = []
    examples = zip(example_texts, example_labels)
    # print(f'Mask Flag:{mask}')
    for (ex_index, (text, label)) in enumerate(examples):

        # Create a list of token ids
        # 分词和token化
        # 随机掩码
        if mask:
            ratio = 0.15
            input_ids = random_masking(text, tokenizer, max_seq_length, mask_ratio=ratio)
            # print(f'Input have been masked with ratio {ratio}')
        else:
            input_ids = tokenizer.encode(f"[CLS] {text} [SEP]")
            if len(input_ids) > max_seq_length:
                input_ids = input_ids[:max_seq_length]
        # All our tokens are in the first input segment (id 0).
        segment_ids = [0] * len(input_ids)  # 只有一个序列

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1] * len(input_ids)

        # Zero-pad up to the sequence length.
        # 确保序列长度一致
        padding = [0] * (max_seq_length - len(input_ids))
        input_ids += padding
        input_mask += padding
        segment_ids += padding

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length

        label_id = label

        input_items.append(
            BertInputItem(text=text,
                          input_ids=input_ids,
                          input_mask=input_mask,
                          segment_ids=segment_ids,
                          label_id=label_id))
        
    return input_items


def get_data_loader(features, max_seq_length, batch_size, shuffle=True): 

    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
    data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)

    dataloader = DataLoader(data, shuffle=shuffle, batch_size=batch_size)
    return dataloader

In [5]:
def load_data():
    train_texts = []
    dev_texts = []
    train_labels = []
    dev_labels = []
    data_year_dir = join(settings.DATA_TRACE_DIR, "PST")
    print("data_year_dir", data_year_dir)
    with open(join(data_year_dir, "bib_context_train.txt"), "r", encoding="utf-8") as f:
        for line in f:
            train_texts.append(line.strip())
    with open(join(data_year_dir, "bib_context_valid.txt"), "r", encoding="utf-8") as f:
        for line in f:
            dev_texts.append(line.strip())
    with open(join(data_year_dir, "bib_context_train_label.txt"), "r", encoding="utf-8") as f:
        for line in f:
            train_labels.append(int(line.strip()))
    with open(join(data_year_dir, "bib_context_valid_label.txt"), "r", encoding="utf-8") as f:
        for line in f:
            dev_labels.append(int(line.strip()))
    
    return train_texts, dev_texts, train_labels, dev_labels

In [6]:
def get_scheduler(cfg, optimizer):
    # 选择scheduler
    if cfg.scheduler == "linear":
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=cfg.num_warmup_steps,
            num_training_steps=cfg.num_train_steps
        )
    elif cfg.scheduler == "cosine":
        scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=cfg.num_warmup_steps,
            num_training_steps=cfg.num_train_steps
        )
    return scheduler

In [7]:
def get_optimizer_grouped_parameters(cfg, model):

    model_type = "model"
    learning_rate = 1.0e-5
    weight_decay = 0.01
    layerwise_learning_rate_decay = 0.9 # 0.8 train2
    
    no_decay = ["bias", "LayerNorm.weight"]
    # initialize lr for task specific layer
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if "head" in n or "pooling" in n],
            "weight_decay": 0.0,
            "lr": learning_rate,
        },
    ]
    
    # initialize lrs for every layer
    num_layers = model.config.num_hidden_layers
    layers = [getattr(model, model_type).embeddings] + list(getattr(model, model_type).encoder.layer)
    layers.reverse()
    
    lr = learning_rate
    for layer in layers:
        lr *= layerwise_learning_rate_decay
        optimizer_grouped_parameters += [
            {
                "params": [p for n, p in layer.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": weight_decay,
                "lr": lr,
            },
            {
                "params": [p for n, p in layer.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
                "lr": lr,
            },
        ]
        
    adam_epsilon = 1e-6 # 5e-6 train.sh
    optimizer = torch.optim.AdamW(
        optimizer_grouped_parameters,
        lr=learning_rate,
        eps=adam_epsilon,
    )
        
    return optimizer

In [8]:
def get_optimizer(cfg, model): # best
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    differential_layers = cfg.differential_learning_rate_layers

    optimizer = torch.optim.AdamW(
        [
            {
                "params": [
                    param
                    for name, param in model.named_parameters()
                    if (not any(layer in name for layer in differential_layers))
                    and (not any(nd in name for nd in no_decay))
                ],
                "lr": cfg.LEARNING_RATE,
                "weight_decay": cfg.weight_decay,
                # "weight_decay": 0.1,
            },
            {
                "params": [
                    param
                    for name, param in model.named_parameters()
                    if (not any(layer in name for layer in differential_layers))
                    and (any(nd in name for nd in no_decay))
                ],
                "lr": cfg.LEARNING_RATE,
                "weight_decay": 0,
            },
            {
                "params": [
                    param
                    for name, param in model.named_parameters()
                    if (any(layer in name for layer in differential_layers))
                    and (not any(nd in name for nd in no_decay))
                ],
                "lr": cfg.differential_learning_rate,
                "weight_decay": cfg.weight_decay,
                # "weight_decay": 0.01,
            },
            {
                "params": [
                    param
                    for name, param in model.named_parameters()
                    if (any(layer in name for layer in differential_layers))
                    and (any(nd in name for nd in no_decay))
                ],
                "lr": cfg.differential_learning_rate,
                "weight_decay": 0,
            },
        ],
        lr=cfg.LEARNING_RATE,
        # weight_decay=cfg.optimizer.weight_decay,
        eps=1e-6,
        betas=(0.9, 0.999)
    )

    return optimizer

In [9]:
# 打印模型的参数量
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

# Evaluate

In [10]:
def evaluate(model, dataloader, device, criterion):
    losses = utils.AverageMeter()
    model.eval()
    # eval_loss = 0
    # nb_eval_steps = 0
    predicted_labels, correct_labels = [], []
    # output_prob = []
    for step, batch in enumerate(tqdm(dataloader, desc="Evaluation iteration", position=0)):
        batch = tuple(t.to(device) for t in batch)

        inputs = {}
        inputs['input_ids'] = batch[0]
        inputs['attention_mask'] = batch[1]
        inputs['segment_ids'] = batch[2]
        inputs['target'] = batch[3]
        num_samples = inputs['target'].shape[0]

        with torch.no_grad():
            output_eval = model(inputs)
            loss = output_eval['loss']
            logits = output_eval['logits']
        losses.update(loss.item(), num_samples)
        if step % 100 == 0:
            tqdm.write(f"Step {step}")
        outputs = np.argmax(logits.to('cpu'), axis=1)
        label_ids = inputs['target'].to('cpu').numpy()
        # prob = [float(utils.sigmoid(x)) for x in logits[:, 1].to('cpu').detach().numpy()]
        # output_prob += prob
        predicted_labels += list(outputs)
        correct_labels += list(label_ids)
    del inputs
    clean_memory()
    correct_labels = np.array(correct_labels)
    predicted_labels = np.array(predicted_labels)
    # output_prob = np.array(output_prob)

    return losses.avg, correct_labels, predicted_labels

# Train

In [11]:
def train_loop(cfg, fold):
    print(f"model_name: {cfg.model_name}")  # 模型名称

    # ====================================================
    # 加载数据
    # ====================================================
    train_texts, dev_texts, train_labels, dev_labels = load_data()
    class_weight = len(train_labels) / (2 * np.bincount(train_labels))
    class_weight = torch.Tensor(class_weight).to(device)
    print("Class weight:", class_weight)  # 计算每个类别的频次，平衡数据集中的类别
    cfg.class_weight = class_weight

    # ====================================================
    # Debug
    # ====================================================
    if cfg.Debug:
        train_texts, dev_texts, train_labels, dev_labels = train_texts[:100], dev_texts[:100], train_labels[:100], dev_labels[:100]
        # cfg.NUM_TRAIN_EPOCHS = 1
    print("Train size:", len(train_texts))  # 7645
    print("Dev size:", len(dev_texts))  # 4037

    # ====================================================
    # 加载模型
    # ====================================================
    if cfg.model_name == "deberta-base":
        model_path = './bert_models/deberta_v3_base'
        
    elif cfg.model_name == "scibert":
        model_path = './bert_models/scibert_scivocab_uncased'
        
    elif cfg.model_name == 'roberta-base': 

        model_path = './bert_models/dsp_roberta_base_dapt_cs_tapt_sciie_3219'
    else:
        raise NotImplementedError

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = Net(cfg, model_path=model_path)
    model.to(cfg.device)
    num_params = count_parameters(model)
    print(f"模型的总参数量: {num_params}")
    # ====================================================
    # Data_loader
    # ====================================================
    train_features = convert_examples_to_inputs(train_texts, train_labels, cfg.MAX_SEQ_LENGTH, tokenizer, mask=cfg.train_mask)
    dev_features = convert_examples_to_inputs(dev_texts, dev_labels, cfg.MAX_SEQ_LENGTH, tokenizer, mask=False)
    train_dataloader = get_data_loader(train_features, cfg.MAX_SEQ_LENGTH, cfg.BATCH_SIZE, shuffle=True)
    dev_dataloader = get_data_loader(dev_features, cfg.MAX_SEQ_LENGTH, cfg.BATCH_SIZE, shuffle=False)

    # ====================================================
    # 优化器设置
    # ====================================================
    if cfg.optim_strategy == 'llrd':
        print("Learning Rate Decay")
        optimizer = get_optimizer_grouped_parameters(cfg, model)
    else:
        optimizer = get_optimizer(cfg, model)
    
    num_train_steps = int(len(train_dataloader.dataset) / cfg.BATCH_SIZE / cfg.GRADIENT_ACCUMULATION_STEPS * cfg.NUM_TRAIN_EPOCHS)
    num_warmup_steps = int(cfg.WARMUP_PROPORTION * num_train_steps)
    cfg.num_train_steps = num_train_steps
    cfg.num_warmup_steps = num_warmup_steps
    scheduler = get_scheduler(cfg, optimizer)

    # ====================================================
    # 其他设置
    # ====================================================
    scaler = GradScaler()

    if cfg.loss_function == 'CrossEntropy':
        criterion = torch.nn.CrossEntropyLoss(weight=cfg.class_weight)
    else:
        criterion = FocalLoss(weight=cfg.class_weight)
        
    OUTPUT_DIR = join(settings.OUT_DIR, "kddcup", cfg.model_name, 'num_fold=0')
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    MODEL_FILE_NAME = "pytorch_model.bin"
    loss_history = []
    no_improvement = 0

    awp = AWP(model,
              optimizer,
              adv_lr=0.0001,
              adv_eps=0.001,
              start_epoch=2,
              scaler=scaler
             )

    for e in trange(int(cfg.NUM_TRAIN_EPOCHS), desc="Epoch"):
        model.train()

        losses = utils.AverageMeter()
        # ====================================================
        # Trian
        # ====================================================
        for step, batch in enumerate(tqdm(train_dataloader, desc="Training iteration", position=0)):
            batch = tuple(t.to(device) for t in batch)
            inputs = {}
            inputs['input_ids'] = batch[0]
            inputs['attention_mask'] = batch[1]
            inputs['segment_ids'] = batch[2]
            inputs['target'] = batch[3]

            num_samples = inputs['target'].shape[0]

            with autocast():
                output_dict = model(inputs)
                loss = output_dict["loss"]
            if cfg.awp_enable and e >= cfg.awp_start_epoch:
                awp.attack_backward(inputs)

            if cfg.GRADIENT_ACCUMULATION_STEPS > 1:
                loss = loss / cfg.GRADIENT_ACCUMULATION_STEPS
            if step % 100 == 0:
                tqdm.write(f"Step {step}")
                
            losses.update(loss.item(), num_samples)
            scaler.scale(loss).backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.MAX_GRAD_NORM)

            if (step + 1) % cfg.GRADIENT_ACCUMULATION_STEPS == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()

        # ====================================================
        # Valid
        # ====================================================
        avg_loss = losses.avg
        avg_dev_loss, true, prediction = evaluate(model, dev_dataloader, cfg.device, criterion)
        print(
              f'Train_Loss: {avg_loss:.4f}\n'
              f'Dev_Loss: {avg_dev_loss:.4f}\n'
              f'Grad: {grad_norm:.4f}\n'
              f'LR: {scheduler.get_lr()[0]:.8f}\n'
              f'Loss_History: {loss_history}'
        )
        # ====================================================
        # 保存模型、早停
        # ====================================================
        if len(loss_history) == 0 or avg_dev_loss < min(loss_history):
            no_improvement = 0
            model_to_save = model.module if hasattr(model, 'module') else model
            output_model_file = os.path.join(OUTPUT_DIR, MODEL_FILE_NAME)
            torch.save(model_to_save.state_dict(), output_model_file)
            print("Best model saved.")
        else:
            no_improvement += 1

        if no_improvement >= cfg.PATIENCE:
            print("No improvement on development set. Finish training.")
            break
        del inputs, losses, output_dict, batch
        clean_memory()

        loss_history.append(avg_dev_loss)

    # 定义要保存的文件名
    file_name = f'{OUTPUT_DIR}/Loss_history.txt'
    # 打开文件用于写入
    with open(file_name, "w") as file:
        # 遍历列表中的每个元素
        for loss in loss_history:
            # 将数字转换为字符串，并写入文件，每个损失值占一行
            file.write(str(loss) + "\n")  # \n 表示换行
    print(f"Loss history has been saved to {file_name}")

# Generate Valid Submission

In [12]:
def gen_kddcup_valid_submission_bert(cfg):
    print("model name", cfg.model_name)
    print(f'The length of context text:{cfg.dist}')
    data_dir = join(settings.DATA_TRACE_DIR, "PST")
    papers = utils.load_json(data_dir, "paper_source_trace_valid_wo_ans.json")  # 读取test paper id

    if cfg.model_name == "deberta-base":
        model_path = './bert_models/deberta_v3_base'

    elif cfg.model_name == "scibert":
        model_path = './bert_models/scibert_scivocab_uncased'

    elif cfg.model_name == 'roberta-base':
        model_path = './bert_models/dsp_roberta_base_dapt_cs_tapt_sciie_3219'
    else:
        raise NotImplementedError
        
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    sub_example_dict = utils.load_json(data_dir, "submission_example_valid.json")  # 提交模板template
    print("device", cfg.device)
    cfg.tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = Net(cfg, model_path=model_path)
    OUTPUT_DIR = join(settings.OUT_DIR, "kddcup", cfg.model_name, 'num_fold=0')
    model.load_state_dict(torch.load(join(OUTPUT_DIR, "pytorch_model.bin")))
    model.to(cfg.device)
    model.eval()
    xml_dir = join(data_dir, "paper-xml")
    sub_dict = {}

    for paper in tqdm(papers):
        cur_pid = paper["_id"]
        file = join(xml_dir, cur_pid + ".xml")
        f = open(file, encoding='utf-8')
        xml = f.read()
        bs = BeautifulSoup(xml, "xml")
        f.close()

        references = bs.find_all("biblStruct")
        bid_to_title = {}
        n_refs = 0
        for ref in references:
            if "xml:id" not in ref.attrs:
                continue
            bid = ref.attrs["xml:id"]
            if ref.analytic is None:
                continue
            if ref.analytic.title is None:
                continue
            bid_to_title[bid] = ref.analytic.title.text.lower()  # 标题
            b_idx = int(bid[1:]) + 1
            if b_idx > n_refs:
                n_refs = b_idx

        bib_to_contexts = utils.find_bib_context(xml, clean_flag=cfg.clean_flag, dist=cfg.dist)
        bib_sorted = ["b" + str(ii) for ii in range(n_refs)]  # 按顺序的bid
        
        y_score = [0] * n_refs

        assert len(sub_example_dict[cur_pid]) == n_refs
        # continue

        contexts_sorted = [" ".join(bib_to_contexts[bib]) for bib in bib_sorted]

        test_features = convert_examples_to_inputs(contexts_sorted, y_score, cfg.MAX_SEQ_LENGTH, cfg.tokenizer, mask=False)
        test_dataloader = get_data_loader(test_features, cfg.MAX_SEQ_LENGTH, 64, shuffle=False)

        predicted_scores = []

        with torch.inference_mode(mode=True):
            for step, batch in enumerate(test_dataloader):
                batch = tuple(t.to(device) for t in batch)
                # input_ids, input_mask, segment_ids, label_ids = batch
                inputs = {}
                inputs['input_ids'] = batch[0]
                inputs['attention_mask'] = batch[1]
                inputs['segment_ids'] = batch[2]
                inputs['target'] = batch[3]

                with autocast():  # 放在循环内部
                    output_dict = model(inputs)
                    # tmp_eval_loss = r[0]
                    logits = output_dict['logits']

                cur_pred_scores = logits[:, 1].to('cpu').numpy()
                predicted_scores.extend(cur_pred_scores)
                del inputs
                clean_memory()
        for ii in range(len(predicted_scores)):
            bib_idx = int(bib_sorted[ii][1:])
            # print("bib_idx", bib_idx)
            y_score[bib_idx] = float(utils.sigmoid(predicted_scores[ii]))
        
        sub_dict[cur_pid] = y_score
    
    utils.dump_json(sub_dict, OUTPUT_DIR, f"valid_submission_{cfg.model_name}.json")

# Run

In [17]:
clean_memory()
fold = 0 
utils.seed_everything(cfg.seed)

train_loop(cfg, fold=0)

# gen_kddcup_valid_submission_bert(cfg)

# gen_kddcup_test_submission_bert(cfg)