# Library

In [1]:
# Suppress warnings
import warnings
warnings.filterwarnings("ignore")
import os
from os.path import join
from tqdm import tqdm
from collections import defaultdict as dd
from bs4 import BeautifulSoup
from fuzzywuzzy import fuzz
import numpy as np
import torch
from transformers import AutoTokenizer
from transformers import BertForSequenceClassification, AutoModelForSequenceClassification
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
from transformers.optimization import AdamW
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from tqdm import trange
from sklearn.metrics import classification_report, precision_recall_fscore_support, average_precision_score
import logging

import utils
import settings
from torch.cuda.amp import autocast, GradScaler

import gc
from types import SimpleNamespace

from awp import AWP
from PST_model import * 


# Config

In [2]:
# 创建一个字典来存储配置参数
config_dict = {
    "GRADIENT_ACCUMULATION_STEPS": 1,
    "NUM_TRAIN_EPOCHS": 20,
    "LEARNING_RATE": 2e-5,
    "WARMUP_PROPORTION": 0.1,
    "MAX_GRAD_NORM": 1000,
    'differential_learning_rate': 5e-4,
    'differential_learning_rate_layers': 'head',
    "weight_decay": 0.01,
    "seed": 42,
    "MAX_SEQ_LENGTH": 512,
    'model_name': 'scibert',  # roberta-base, deberta-base, scibert
    'clean_flag': True,
    'train_mask': False,
    'dist': 200,
    'optim_strategy': 'normal',  # normal, llrd
    'PATIENCE': 3,
    'Debug': False,
    'scheduler': 'cosine',  # linear, cosine
    'amp': True,
    'pool': 'ConcatPool', # WLP, GeM, Mean, ConcatPool, AP
    'loss_function': 'CrossEntropy',
    'BATCH_SIZE': 8,
    'awp_enable': True,
    'awp_start_epoch': 2,
    'layer_start': 9,
    'reinit_n_layers': 0,
}

# 将字典转换为 SimpleNamespace 对象
cfg = SimpleNamespace(**config_dict)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cfg.device = device



# 日志
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger(__name__)


# Function

In [3]:
def clean_memory():
    gc.collect()
    torch.cuda.empty_cache()
    
class BertInputItem(object):
    """An item with all the necessary attributes for finetuning BERT."""

    def __init__(self, text, input_ids, input_mask, segment_ids, label_id):
        self.text = text
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_id = label_id


def random_masking(text, tokenizer, max_length=512, mask_ratio=0.15):
    # 对文本进行分词
    tokens = tokenizer.tokenize(text)
    # 确保分词后的序列长度小于最大序列长度
    tokens = tokens[:(max_length - 2)]  # 减2是因为要加上[CLS]和[SEP]标记

    # 计算需要掩码的Token数量
    num_tokens = len(tokens)
    num_to_mask = int(mask_ratio * num_tokens)

    # 随机选择Token进行掩码
    indices_to_mask = np.random.choice(range(num_tokens), num_to_mask, replace=False)

    # 将选中的Token替换为[MASK]
    for i in sorted(indices_to_mask, reverse=True):
        tokens[i] = tokenizer.mask_token

    # 添加BERT的序列标记
    tokens = ['[CLS]'] + tokens + ['[SEP]']

    # 将Token转换为ID
    input_ids = tokenizer.convert_tokens_to_ids(tokens)

    return input_ids


def convert_examples_to_inputs(example_texts, example_labels, max_seq_length, tokenizer, mask, verbose=0):
    """Loads a data file into a list of `InputBatch`s."""
    
    input_items = []
    examples = zip(example_texts, example_labels)
    # print(f'Mask Flag:{mask}')
    for (ex_index, (text, label)) in enumerate(examples):

        # Create a list of token ids
        # 分词和token化
        # 随机掩码
        if mask:
            ratio = 0.15
            input_ids = random_masking(text, tokenizer, max_seq_length, mask_ratio=ratio)
            # print(f'Input have been masked with ratio {ratio}')
        else:
            input_ids = tokenizer.encode(f"[CLS] {text} [SEP]")
            if len(input_ids) > max_seq_length:
                input_ids = input_ids[:max_seq_length]
        # All our tokens are in the first input segment (id 0).
        segment_ids = [0] * len(input_ids)  # 只有一个序列

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1] * len(input_ids)

        # Zero-pad up to the sequence length.
        # 确保序列长度一致
        padding = [0] * (max_seq_length - len(input_ids))
        input_ids += padding
        input_mask += padding
        segment_ids += padding

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length

        label_id = label

        input_items.append(
            BertInputItem(text=text,
                          input_ids=input_ids,
                          input_mask=input_mask,
                          segment_ids=segment_ids,
                          label_id=label_id))
        
    return input_items


def get_data_loader(features, max_seq_length, batch_size, shuffle=True): 

    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
    data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)

    dataloader = DataLoader(data, shuffle=shuffle, batch_size=batch_size)
    return dataloader

# Generate Test Submission

In [4]:
def gen_kddcup_test_submission_bert(cfg):
    print("model name", cfg.model_name)
    print(f'The length of context text:{cfg.dist}')
    cfg.class_weight = torch.Tensor([0.5500, 5.5000]).to(device)
    data_dir = join(settings.DATA_TRACE_DIR, "PST")
    papers = utils.load_json(data_dir, "paper_source_trace_test_wo_ans.json")  # 读取test paper id

    if cfg.model_name == "deberta-base":
        model_path = './bert_models/deberta_v3_base'

    elif cfg.model_name == "scibert":
        model_path = './bert_models/scibert_scivocab_uncased'

    elif cfg.model_name == 'roberta-base':
        model_path = './bert_models/dsp_roberta_base_dapt_cs_tapt_sciie_3219'
    else:
        raise NotImplementedError
        
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    sub_example_dict = utils.load_json(data_dir, "submission_example_test.json")  # 提交模板template
    print("device", cfg.device)
    cfg.tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = Net(cfg, model_path=model_path)
    OUTPUT_DIR = join(settings.OUT_DIR, "kddcup", cfg.model_name, 'num_fold=0')
    model.load_state_dict(torch.load(join(OUTPUT_DIR, "pytorch_model.bin")))
    model.to(cfg.device)
    model.eval()
    xml_dir = join(data_dir, "paper-xml")
    sub_dict = {}

    for paper in tqdm(papers):
        cur_pid = paper["_id"]
        file = join(xml_dir, cur_pid + ".xml")
        f = open(file, encoding='utf-8')
        xml = f.read()
        bs = BeautifulSoup(xml, "xml")
        f.close()

        references = bs.find_all("biblStruct")
        bid_to_title = {}
        n_refs = 0
        for ref in references:
            if "xml:id" not in ref.attrs:
                continue
            bid = ref.attrs["xml:id"]
            if ref.analytic is None:
                continue
            if ref.analytic.title is None:
                continue
            bid_to_title[bid] = ref.analytic.title.text.lower()  # 标题
            b_idx = int(bid[1:]) + 1
            if b_idx > n_refs:
                n_refs = b_idx

        bib_to_contexts = utils.find_bib_context(xml, clean_flag=cfg.clean_flag, dist=cfg.dist)
        bib_sorted = ["b" + str(ii) for ii in range(n_refs)]  # 按顺序的bid
        
        y_score = [0] * n_refs

        assert len(sub_example_dict[cur_pid]) == n_refs
        # continue

        contexts_sorted = [" ".join(bib_to_contexts[bib]) for bib in bib_sorted]

        test_features = convert_examples_to_inputs(contexts_sorted, y_score, cfg.MAX_SEQ_LENGTH, cfg.tokenizer, mask=False)
        test_dataloader = get_data_loader(test_features, cfg.MAX_SEQ_LENGTH, 64, shuffle=False)

        predicted_scores = []

        with torch.inference_mode(mode=True):
            for step, batch in enumerate(test_dataloader):
                batch = tuple(t.to(device) for t in batch)
                # input_ids, input_mask, segment_ids, label_ids = batch
                inputs = {}
                inputs['input_ids'] = batch[0]
                inputs['attention_mask'] = batch[1]
                inputs['segment_ids'] = batch[2]
                inputs['target'] = batch[3]

                with autocast():  # 放在循环内部
                    output_dict = model(inputs)
                    # tmp_eval_loss = r[0]
                    logits = output_dict['logits']

                cur_pred_scores = logits[:, 1].to('cpu').numpy()
                predicted_scores.extend(cur_pred_scores)
                del inputs
                clean_memory()
        for ii in range(len(predicted_scores)):
            bib_idx = int(bib_sorted[ii][1:])
            # print("bib_idx", bib_idx)
            y_score[bib_idx] = float(utils.sigmoid(predicted_scores[ii]))
        
        sub_dict[cur_pid] = y_score
    utils.dump_json(sub_dict, OUTPUT_DIR, f"test_submission_{cfg.model_name}.json")

In [5]:
clean_memory()
fold = 0 
utils.seed_everything(cfg.seed)

gen_kddcup_test_submission_bert(cfg)

2024-06-12 15:24:12,422 loading paper_source_trace_test_wo_ans.json ...
2024-06-12 15:24:12,431 paper_source_trace_test_wo_ans.json loaded
2024-06-12 15:24:12,462 loading submission_example_test.json ...
2024-06-12 15:24:12,464 submission_example_test.json loaded


Global seed set to 42
model name scibert
The length of context text:200
device cuda
BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(31090, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-

  0%|                                                                                          | 0/394 [00:01<?, ?it/s]


KeyboardInterrupt: 