In [None]:
import copy
import os
from pathlib import Path
from typing import Dict, Union, Any, List, Mapping, Union
import json
from collections import OrderedDict, defaultdict
import torch
import logging
import logging.config
import numpy as np
from torch.utils.data import Dataset
from transformers import AutoTokenizer
import re
import string
import random
import pandas as pd
from datasets import load_dataset
from nltk.corpus import stopwords as stop_words
from torchtext.data import get_tokenizer
from torch.utils.data import DataLoader
from torch.backends import cudnn
from scipy.stats import entropy
import heapq
from gensim.corpora import Dictionary
from gensim.models import CoherenceModel
from tqdm import tqdm
import torch.distributed
from abc import abstractmethod
from numpy import inf
from logger import TensorboardWriter
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, f1_score

In [None]:
def read_json(file: Union[str, os.PathLike]):
    """
    Read json from file
    :param file: the path to the json file
    :return: ordered dictionary content
    """
    file = Path(file)
    with file.open("rt") as handle:
        return json.load(handle, object_hook=OrderedDict)


def write_json(content: Dict, file: Union[str, os.PathLike]):
    """
    Write content to a json file
    :param content: the content dictionary
    :param file: the path to save json file
    """
    file = Path(file)
    with file.open("wt") as handle:
        json.dump(content, handle, indent=4, sort_keys=False)


def write_to_file(file: Union[str, os.PathLike], text: Union[str, list]):
    with open(file, "w", encoding="utf-8") as w:
        if isinstance(text, str):
            w.write(text)
        elif isinstance(text, list):
            w.write("\n".join(text))


def prepare_device(n_gpu_use):
    """
    setup GPU device if available. get gpu device indices which are used for DataParallel
    """
    n_gpu = torch.cuda.device_count()
    if n_gpu_use > 0 and n_gpu == 0:
        print("Warning: There\'s no GPU available on this machine,"
              "training will be performed on CPU.")
        n_gpu_use = 0
    if n_gpu_use > n_gpu:
        print(f"Warning: The number of GPU\'s configured to use is {n_gpu_use}, but only {n_gpu} are "
              "available on this machine.")
        n_gpu_use = n_gpu
    device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
    list_ids = list(range(n_gpu_use))
    return device, list_ids


def del_index_column(df):
    return df.loc[:, ~df.columns.str.contains("^Unnamed")]


def get_project_root(**kwargs):
    project_name = kwargs.pop("project_name", "nc_base") # nc_base
    file_parts = Path(os.getcwd()).parts # D:\AI\Graduation_Project\model\BATM\experiment\runner
    # abs_path = Path(f"{os.sep}".join(file_parts[:file_parts.index(project_name) + 1]))
    abs_path = Path(f"{os.sep}".join(file_parts))
    return os.path.relpath(abs_path, os.getcwd())

In [None]:

class Configuration:
    """
    This is the base class for all configuration class. Deal with the common hyper-parameters to all models'
    configuration, and include the methods for loading/saving configurations.
    For each sub configuration, a variable named 'type' is defined to indicate which class it belongs to

    """

    def __init__(self, **kwargs):
        # parameters in general
        self.n_gpu = kwargs.pop("n_gpu", 1)  # default using gpu for training
        self.embedding_type = kwargs.pop("embedding_type", "glove")
        self.max_length = kwargs.pop("max_length", 100)
        self.loss = kwargs.pop("loss", "cross_entropy")
        self.metrics = kwargs.pop("metrics", ["accuracy", "macro_f"])
        self.save_model = kwargs.pop("save_model", False)
        self.resume = kwargs.pop("resume", None)
        # setup default relative project path
        self.project_name = kwargs.pop("project_name", "nc_base")

        self.project_root = kwargs.pop("project_root", get_project_root(project_name=self.project_name))
        self.data_root = os.path.join(self.project_root, "dataset")
        self.save_dir = kwargs.pop("save_dir", os.path.join(self.project_root, "saved"))
        self.seed = kwargs.pop("seed", 42)
        self.sub_configs = ["arch_config", "data_config", "trainer_config", "optimizer_config", "scheduler_config"]

        # parameters for architecture by default
        self.arch_config = {
            "type": "Baseline", "dropout_rate": 0.2, "embedding_type": self.embedding_type,
            "max_length": self.max_length,
        }
        self.arch_config.update(kwargs.pop("arch_config", {}))

        # parameters for loading data
        self.data_config = {
            "type": "NewsDataLoader", "batch_size": 32, "num_workers": 1, "name": "MIND15/keep_all",
            "max_length": self.max_length, "data_root": self.data_root, "embedding_type": self.embedding_type
        }
        self.data_config.update(kwargs.pop("data_config", {}))
        # identifier of experiment, default is identified by dataset name and architecture type.
        self.run_name = kwargs.pop("run_name", f"{self.data_config['name']}/{self.arch_config['type']}")

        # parameters for optimizer
        self.optimizer_config = {"type": "Adam", "lr": 1e-3, "weight_decay": 0}
        self.optimizer_config.update(kwargs.pop("optimizer_config", {}))

        # parameters for scheduler
        self.scheduler_config = {"type": "StepLR", "step_size": 50, "gamma": 0.1}
        self.scheduler_config.update(kwargs.pop("scheduler_config", {}))

        # parameters for trainer
        self.trainer_config = {
            "epochs": 3, "early_stop": 3, "monitor": "max val_accuracy", "verbosity": 2, "tensorboard": False
        }
        self.trainer_config.update(kwargs.pop("trainer_config", {}))

        # Additional attributes without default values
        for key, value in kwargs.items():
            try:
                setattr(self, key, value)
            except AttributeError as err:
                raise err
    # 实现了get函数
    def get(self, key, default=None):
        if hasattr(self, key):
            return getattr(self, key)
        else:
            for sub in self.sub_configs:
                sub_config = getattr(self, sub)
                if key in sub_config:
                    return sub_config[key]
            return default

    def set(self, key, value):
        if hasattr(self, key):
            setattr(self, key, value)
        for sub in self.sub_configs:
            sub_config = getattr(self, sub)
            if key in sub_config:
                sub_config[key] = value

    def update(self, config_dict: Dict[str, Any]):
        """
        Updates attributes of this class with attributes from ``config_dict``.

        Args:
            config_dict (:obj:`Dict[str, Any]`): Dictionary of attributes that should be updated for this class.
        """
        for key, value in config_dict.items():
            if isinstance(value, dict):
                self.update_sub_config(key, **value)
            else:
                setattr(self, key, value)

    def update_sub_config(self, sub_name: str, **kwargs):
        """
        update corresponding sub configure dictionary
        :param sub_name: the name of sub-configuration, such as arch_config
        """
        getattr(self, sub_name).update(kwargs)

    def save_config(self, save_dir: Union[str, os.PathLike], config_name: str = "config.json"):
        """
        Save configuration with the saved directory with corresponding configuration name in a json file
        :param config_name: default is config.json, should be a json filename
        :param save_dir: the directory to save the configuration
        """
        if os.path.isfile(save_dir):
            raise AssertionError(f"Provided path ({save_dir}) should be a directory, not a file")
        os.makedirs(save_dir, exist_ok=True)
        config_file = Path(save_dir) / config_name
        write_json(copy.deepcopy(self.__dict__), config_file)

    @classmethod
    def from_json_file(cls, json_file: Union[str, os.PathLike]):
        """
        load configuration from a json file
        :param json_file: the path to the json file
        :return: a configuration object
        """
        return cls(**read_json(json_file))

In [None]:
def setup_logging(save_dir):
    """
    Setup logging configuration
    """
    log_config = {
        "version": 1,
        "disable_existing_loggers": False,
        "formatters": {
            "simple": {"format": "%(message)s"},
            "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"}
        },
        "handlers": {
            "console": {
                "class": "logging.StreamHandler",
                "level": "DEBUG",
                "formatter": "simple",
                "stream": "ext://sys.stdout"
            },
            "info_file_handler": {
                "class": "logging.handlers.RotatingFileHandler",
                "level": "INFO",
                "formatter": "datetime",
                "filename": "info.log",
                "maxBytes": 10485760,
                "backupCount": 20, "encoding": "utf8"
            }
        },
        "root": {
            "level": "INFO",
            "handlers": [
                "console",
                "info_file_handler"
            ]
        }
    }
    # modify logging paths based on run config
    for _, handler in log_config['handlers'].items():
        if 'filename' in handler:
            handler['filename'] = str(save_dir / handler['filename'])

    logging.config.dictConfig(log_config)

In [None]:
default_configs = {
    "PretrainedBaseline": {
        "n_layers": 1,
    },
    "TextCNNClassifyModel": {
        "num_filters": 100, "filter_sizes": (2, )
    },
    "NRMSNewsEncoderModel": {
        "variant_name": "base"
    },
    "GRUAttClassifierModel": {
        "variant_name": "gru_att"
    },
    "BiAttentionClassifyModel": {
        "head_num": None, "head_dim": 20, "entropy_constraint": False, "alpha": 0.01, "n_layers": 1, "variant_name": "base",
    },
    "TopicExtractorClassifyModel": {
        "head_num": None, "head_dim": 20, "entropy_constraint": False, "alpha": 0.01, "n_layers": 1
    },
    "FastformerClassifyModel": {
        "embedding_dim": 300, "n_layers": 2, "hidden_act": "gelu", "head_num": 15, "type_vocab_size": 2,
        "vocab_size": 100000, "layer_norm_eps": 1e-12, "initializer_range": 0.02, "pooler_type": "weightpooler",
        "enable_fp16": "False"
    }
}


def arch_default_config(arch_type: str):
    default_config = {"type": arch_type}
    default_config.update(default_configs[arch_type])
    return default_config

In [None]:
class ConfigParser:
    def __init__(self, config: Configuration, modification: dict = None):
        """
        class to parse configuration json file. Handles hyper-parameters for training, initializations of modules,
        checkpoint saving and logging module.
        :param config: Dict containing configurations, hyper-parameters for training. Normal saved in configs directory.
        :param modification: Dict keychain:value, specifying position values to be replaced from config dict.
        Timestamp is being used as default
        """
        # load config file and apply modification.
        self.config = config
        if modification:
            self.config.update(modification)

        # set save_dir where training model and log will be saved.
        save_dir = Path(self.config.save_dir)
        run_name = self.config.run_name

        self._save_dir = save_dir / "models" / run_name
        self._log_dir = save_dir / "log" / run_name
        # configure logging module
        self.log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
        # make directory for saving checkpoints and log
        self._save_dir.mkdir(parents=True, exist_ok=True)
        self._log_dir.mkdir(parents=True, exist_ok=True)
        setup_logging(self.log_dir)
        # save updated config file to the checkpoint directory
        self.config.save_config(self._save_dir)

    @property
    def save_dir(self):
        return self._save_dir

    @property
    def log_dir(self):
        return self._log_dir

# 返回Configuration对象
    @classmethod
    def from_args(cls, args, options: list = None):
        """
        Initialize this class from some cli arguments. Used in train, test.
        """
        for opt in options:
            default = opt.default if hasattr(opt, "default") else None
            args.add_argument(*opt.flags, default=default, type=opt.type)
        if not isinstance(args, tuple):
            args = args.parse_args()
        # parse custom cli options into dictionary
        modification = defaultdict()
        if hasattr(args, "arch_type") and args.arch_type is not None:
            modification["arch_config"] = arch_default_config(args.arch_type)  # setup default arch params
        for opt in options:
            name = opt.flags[-1].replace("--", "")  # acquire param name
            if opt.target:
                if opt.target not in modification:
                    modification[opt.target] = {}
                if getattr(args, name):
                    modification[opt.target][name] = getattr(args, name)  # setup custom params values
            else:
                if getattr(args, name):
                    modification[name] = getattr(args, name)  # setup custom params values
        if hasattr(args, "resume") and args.resume is not None:
            config_file = Path(args.resume).parent / "config.json"
            config = Configuration.from_json_file(config_file)
        else:
            # 到最后 config都会转变为Configuration对象
            config = Configuration(**modification)
        return cls(config)

    def init_obj(self, module_config: str, module: object, *args, **kwargs):
        """
        Finds a function handle with the name given as 'type' in config, and returns the
        instance initialized with corresponding arguments given.

        `object = config.init_obj('trainer_config', module, a, b=1)`
        is equivalent to
        `object = module.module_name(a, b=1)`
        """
        module_args = copy.deepcopy(getattr(self.config, module_config))
        module_args.update(kwargs)
        module_name = module_args.pop("type")
        # getattr()返回module对象里面的module_name,也可以返回名为 module的类的对象
        return getattr(module, module_name)(*args, **module_args)

    def get_logger(self, name, verbosity=2):
        msg_verbosity = f"verbosity option{verbosity} is invalid. Valid options are {self.log_levels.keys()}."
        assert verbosity in self.log_levels, msg_verbosity
        logger = logging.getLogger(name)
        logger.setLevel(self.log_levels[verbosity])
        return logger

    def __getitem__(self, name):
        """Access items like ordinary dict."""
        return getattr(self.config, name)

In [None]:
def text2index(text, word_dict, method="keep", ignore=True):
    return word2index(word_dict, tokenize(text, method), ignore)

'''
这个 clean_text 函数的作用是清洗文本，它主要做的是移除文本中的标点符号以及数字。让我们来具体看一下这个函数是如何工作的：
string.punctuation + "0123456789"：此处创建了一个字符串规则（rule），这个字符串包含了所有的标点符号以及数字0到9。
re.sub(rf'([^{rule}a-zA-Z ])', r" ", text)：这个是 Python 的正则表达式 re.sub() 方法，它用于替换字符串中的匹配项。这里它将所有不在 rule 中，也不是英文字母和空格的字符替换为一个空格。rf'([^{rule}a-zA-Z ])' 表示匹配所有不在 rule、不是英文小写字母、不是英文大写字母、也不是空格的字符。
'''
def clean_text(text):
    rule = string.punctuation + "0123456789"
    return re.sub(rf'([^{rule}a-zA-Z ])', r" ", text)


def aggressive_process(text):
    stopwords = set(stop_words.words("english"))
    text = text.lower().translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation)))
    text = text.translate(str.maketrans("0123456789", ' ' * len("0123456789")))
    text = [w for w in text.split() if len(w) > 0 and w not in stopwords]
    return text


def tokenize(text, method="keep_all"):
    tokens = []
    text = clean_text(text)
    rule = string.punctuation + "0123456789"
    tokenizer = get_tokenizer('basic_english')
    if method == "keep_all":
        tokens = tokenizer(re.sub(rf'([{rule}])', r" \1 ", text.lower()))
    elif method == "aggressive":
        tokens = aggressive_process(text)
    elif method == "alphabet_only":
        tokens = tokenizer(re.sub(rf'([{rule}])', r" ", text.lower()))
    return tokens


def word2index(word_dict, sent, ignore=True):
    word_index = []
    for word in sent:
        if ignore:
            index = word_dict[word] if word in word_dict else 0
        else:
            if word not in word_dict:
                word_dict[word] = len(word_dict)
            index = word_dict[word]
        word_index.append(index)
    return word_index

In [None]:
class BaseDataset(Dataset):
    def __init__(self, texts, labels, label_dict, max_length, word_dict, process_method="keep_all"):
        super().__init__()
        self.texts, self.labels, self.label_dict, self.max_length = texts, labels, label_dict, max_length
        self.word_dict = word_dict
        self.process_method = process_method

        if self.label_dict is None and labels is not None:
            self.label_dict = dict(zip(sorted(set(labels)), range(len(set(labels)))))

    def __getitem__(self, i):
        data = text2index(self.texts[i], self.word_dict, self.process_method, True)[:self.max_length]
        data.extend([0 for _ in range(max(0, self.max_length - len(data)))])
        data = torch.tensor(data, dtype=torch.long)
        label = torch.tensor(self.label_dict.get(self.labels[i], -1), dtype=torch.long).squeeze(0)
        mask = torch.tensor(np.where(data == 0, 0, 1), dtype=torch.long)
        return {"data": data, "label": label, "mask": mask}

    def __len__(self):
        return len(self.labels)


class BaseDatasetBert(Dataset):
    def __init__(self, texts: List[str], labels: List[str] = None, label_dict: Mapping[str, int] = None,
                 max_length: int = 512, embedding_type: str = 'distilbert-base-uncased', is_local=False):

        self.texts = texts
        self.labels = labels
        self.label_dict = label_dict
        self.max_length = max_length

        if self.label_dict is None and labels is not None:
            self.label_dict = dict(zip(sorted(set(labels)), range(len(set(labels)))))

        if is_local:
            model_root = "D:\\AI\\model\\"
        else:
            model_root = ''
        # 这里下载 tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_root + embedding_type)
        logging.getLogger("transformers.tokenization_utils").setLevel(logging.FATAL)

        # self.sep_vid = self.tokenizer.sep_token_id
        # self.cls_vid = self.tokenizer.cls_token_id
        if embedding_type == "transfo-xl-wt103":
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.pad_vid = self.tokenizer.pad_token_id
        else:
            self.pad_vid = self.tokenizer.pad_token_id

    def __len__(self):

        return len(self.texts)

    def __getitem__(self, index) -> Mapping[str, torch.Tensor]:

        x = self.texts[index]
        x_encoded = self.tokenizer.encode(x, add_special_tokens=True, max_length=self.max_length, truncation=True,
                                          return_tensors="pt").squeeze(0)

        # 这里是得到等长的embedding，会进行填充，获得处理过的X和mask
        true_seq_length = x_encoded.size(0)
        pad_size = self.max_length - true_seq_length
        pad_ids = torch.Tensor([self.pad_vid] * pad_size).long()
        x_tensor = torch.cat((x_encoded, pad_ids))

        mask = torch.ones_like(x_encoded, dtype=torch.int8)
        mask_pad = torch.zeros_like(pad_ids, dtype=torch.int8)
        mask = torch.cat((mask, mask_pad))

        output_dict = {"data": x_tensor, 'mask': mask}

        if self.labels is not None:
            y = self.labels[index]
            y_encoded = torch.Tensor([self.label_dict.get(y, -1)]).long().squeeze(0)
            output_dict["label"] = y_encoded

        return output_dict

In [None]:
def clean_df(data_df):
    # 这行代码的作用是从数据表 data_df 中删除那些在"标题" (title) 和 "正文" (body) 字段中同时为空的行。参数 inplace=True 表示这一操作直接就地对 data_df 进行修改。
    data_df.dropna(subset=["title", "body"], inplace=True, how="all")
    # 这行代码的作用是将 data_df 中的空值（NA 或 NaN）替换为字符串 "empty"。参数 inplace=True 表示这一操作直接在 data_df 上进行修改。
    data_df.fillna("empty", inplace=True)
    # 这行代码使用一个匿名函数（lambda 函数）去处理 data_df 中的 title 列。函数 clean_text(s) 应该是一个对字符串进行清洗的函数，即对每一篇文章的标题进行清洗。
    data_df["title"] = data_df.title.apply(lambda s: clean_text(s))
    data_df["body"] = data_df.body.apply(lambda s: clean_text(s))
    return data_df

'''
这段代码是对Pandas DataFrame的一组操作，主要用于创建一个随机筛选的验证集。以下是对每行代码的解析：
df是一个Pandas的DataFrame对象，你可以将它视为一个二维的数据表格。
indices = df.index.values：这行代码获取df的索引，也就是行号，并赋值给indices变量。例如，如果df有10行，indices就是一个包含0到9的数组。
random.Random(42).shuffle(indices)：这行代码使用shuffle方法对indices数组进行随机排序。注意这里的42是随机数生成器的种子，保证了每次运行这段代码，得到的随机排序都是一样的。
split_len = round(split * len(df))：这行代码计算验证集的长度。split是一个介于0和1之间的浮点数，代表验证集在所有数据中的占比。len(df)则是df的行数。所以split * len(df)就是我们期望的验证集大小，然后通过round函数进行四舍五入。
df.loc[indices[:split_len], "split"] = "valid"：这行代码实际执行了数据集的分割。它选取了df中索引在乱序indices数组前split_len部分的行，也就是随机选取的split_len数量的数据，并在这些行下新增了一列名为"split"的属性，标记为"valid"。
'''
# 这个函数就是划分test,train,valid数据集的
def split_df(df, split=0.1, split_test=False):
    indices = df.index.values
    random.Random(42).shuffle(indices)
    split_len = round(split * len(df))
    df.loc[indices[:split_len], "split"] = "valid"
    if split_test:
        df.loc[indices[split_len:split_len*2], "split"] = "test"
        df.loc[indices[split_len*2:], "split"] = "train"
    else:
        df.loc[indices[split_len:], "split"] = "train"
    return df


def load_set_by_type(dataset, set_type: str) -> pd.DataFrame:
    df = {k: [] for k in ["data", "category"]}
    for text, label in zip(dataset[set_type]["text"], dataset[set_type]["label"]):
        for c, v in zip(["data", "category"], [text, label]):
            df[c].append(v)
    df["split"] = set_type
    return pd.DataFrame(df)


def load_dataset_df(dataset_name, data_path):

    if dataset_name in ["MIND15", "News26"]:
        # 这里是直接读取本地的数据，完蛋
        df = clean_df(pd.read_csv(data_path, encoding="utf-8"))
        df["data"] = df.title + "\n" + df.body
    elif dataset_name in ["ag_news", "yelp_review_full", "imdb"]:
        # load corresponding dataset from datasets library，使用 NewsDataLoader 类里面的 load_dataset
        dataset = load_dataset(dataset_name)
        train_set, test_set = split_df(load_set_by_type(dataset, "train")), load_set_by_type(dataset, "test")
        df = train_set.append(test_set)
    else:
        raise ValueError("dataset name should be in one of MIND15, IMDB, News26, and ag_news...")
    labels = df["category"].values.tolist()
    label_dict = dict(zip(sorted(set(labels)), range(len(set(labels)))))
    return df, label_dict


def load_word_dict(data_root, dataset_name, process_method, **kwargs):
    embed_method = kwargs.get("embed_method", "use_all")
    wd_path = Path(data_root) / "utils" / "word_dict" / f"{dataset_name}_{process_method}_{embed_method}.json"
    if os.path.exists(wd_path):
        word_dict = read_json(wd_path)
    else:
        word_dict = {}
        data_path = kwargs.get("data_path", Path(data_root) / "data" / f"{dataset_name}.csv")
        df = kwargs.get("df", load_dataset_df(dataset_name, data_path)[0])
        df.data.apply(lambda s: text2index(s, word_dict, process_method, False))
        os.makedirs(wd_path.parent, exist_ok=True)
        write_json(word_dict, wd_path)
    return word_dict


def load_glove_embedding(glove_path=None):
    if not glove_path:
        # glove_path = "D:\\AI\\Graduation_Project\\model\\BATM\\dataset\\glove\\glove.840B.300d.txt"
        # 这里要使用的是相对路径
        glove_path = '../../dataset/glove/glove.840B.300d.txt'
    glove = pd.read_csv(glove_path, sep=" ", quoting=3, header=None, index_col=0)
    return {key: val.values for key, val in glove.T.items()}

'''
这个函数其实就是加载glove embedding, 并使用glove的embedding将词典word_dict中出现过的单词重新赋值给new_wd,就是使用跟 glove embedding一样的词表索引
'''
def load_embeddings(data_root, dataset_name, process_method, word_dict, glove_path=None, embed_method="use_all"):
    embed_path = Path(data_root) / "utils" / "embed_dict" / f"{dataset_name}_{process_method}_{embed_method}.npy"
    wd_path = Path(data_root) / "utils" / "word_dict" / f"{dataset_name}_{process_method}_{embed_method}.json"
    if os.path.exists(embed_path):
        embeddings = np.load(embed_path.__str__())
        word_dict = read_json(wd_path)
    else:
        new_wd = {"[UNK]": 0}
        # 这里加载已经处理好的embedding_dict
        embedding_dict = load_glove_embedding(glove_path)
        embeddings, exclude_words = [np.zeros(300)], []
        for i, w in enumerate(word_dict.keys()):
            if w in embedding_dict:
                embeddings.append(embedding_dict[w])
                new_wd[w] = len(new_wd)
            else:
                exclude_words.append(w)
        if embed_method == "use_all":
            mean, std = np.mean(embeddings), np.std(embeddings)
            # append random embedding
            for i, w in enumerate(exclude_words):
                new_wd[w] = len(new_wd)
                embeddings.append(np.random.normal(loc=mean, scale=std, size=300))
        os.makedirs(embed_path.parent, exist_ok=True)
        np.save(embed_path.__str__(), np.array(embeddings))
        word_dict = new_wd
        write_json(word_dict, wd_path)
    return np.array(embeddings), word_dict

In [None]:
class NewsDataLoader:
    def load_dataset(self, df):
        pretrained_models = ["distilbert-base-uncased", "bert-base-uncased", "xlnet-base-cased", "roberta-base",
                             "longformer-base-4096", "transfo-xl-wt103"]
        if self.embedding_type in pretrained_models:
            # df["data"] = df.title + "\n" + df.body
            # 这里是根据 embedding_type 得到对应模型的dataset,dataset中含有tokenizer最为关键
            dataset = BaseDatasetBert(texts=df["data"].values.tolist(), labels=df["category"].values.tolist(),
                                      label_dict=self.label_dict, max_length=self.max_length,
                                      embedding_type=self.embedding_type)
            if self.embedding_type == "transfo-xl-wt103":
                # 根据给定的预训练模型类型(embedding_type)生成相应的分词器(tokenizer)并获取其词汇表中每个符号的索引。
                self.word_dict = dataset.tokenizer.sym2idx
            else:
                self.word_dict = dataset.tokenizer.vocab
        elif self.embedding_type in ["glove", "init"]:
            # if we use glove embedding, then we ignore the unknown words
            dataset = BaseDatasetBert(df["data"].values.tolist(), df["category"].values.tolist(), self.label_dict,
                                  self.max_length, self.word_dict, self.method)
        else:
            raise ValueError(f"Embedding type should be one of {','.join(pretrained_models)} or glove and init")
        return dataset

    def __init__(self, batch_size=32, shuffle=True, num_workers=1, max_length=128, name="MIND15/keep", **kwargs):
        self.set_name, self.method = name.split("/")[0], name.split("/")[1]
        print("self.set_name, self.method: ", self.set_name, self.method) #  News26 keep_all
        # kwargs.get("embedding_type", "glove") 尝试从kwargs中获取"embedding_type"的值。如果字典中有"embedding_type"这个键，那么就返回其对应的值；如果字典中没有"embedding_type"这个键，那么方法就会返回默认值"glove"。
        self.max_length, self.embedding_type = max_length, kwargs.get("embedding_type", "glove")
        # self.data_root = kwargs.get("data_root", "../../dataset")
        self.data_root = "../../dataset"
        print("self.data_root: ", self.data_root) # self.data_root:  .\dataset
        data_path = Path(self.data_root) / "data" / f"{self.set_name}.csv"
        print("data_path: ", data_path) # data_path:  dataset\data\News26.csv
        # 加载数据
        df, self.label_dict = load_dataset_df(self.set_name, data_path)
        train_set, valid_set, test_set = df["split"] == "train", df["split"] == "valid", df["split"] == "test"
        if self.embedding_type in ["glove", "init"]:
            # setup word dictionary for glove or init embedding
            self.word_dict = load_word_dict(self.data_root, self.set_name, self.method, df=df)
        if self.embedding_type == "glove":
            # 这里加载 glove的embedding表示
            self.embeds, self.word_dict = load_embeddings(self.data_root, self.set_name, self.method, self.word_dict,
                                                          embed_method=kwargs.get("embed_method", "use_all"))
        self.init_params = {'batch_size': batch_size, 'shuffle': shuffle, 'num_workers': num_workers}
        # initialize train loader
        self.train_loader = DataLoader(self.load_dataset(df[train_set]), **self.init_params)
        # initialize validation loader
        self.valid_loader = DataLoader(self.load_dataset(df[valid_set]), **self.init_params)
        # initialize test loader
        self.test_loader = DataLoader(self.load_dataset(df[test_set]), **self.init_params)

In [None]:
def get_topic_dist(trainer, word_seq):
    topic_dist = np.zeros((trainer.model.head_num, len(word_seq)))
    with torch.no_grad():
        bs = 512
        num = bs * (len(word_seq) // bs)
        word_feat = np.array(word_seq[:num]).reshape(-1, bs).tolist() + [word_seq[num:]]
        for words in word_feat:
            input_feat = {"data": torch.tensor(words).unsqueeze(0), "mask": torch.ones(len(words)).unsqueeze(0)}
            input_feat = trainer.load_batch_data(input_feat)
            _, topic_weight = trainer.best_model.extract_topic(input_feat)  # (B, H, N)
            topic_dist[:, words] = topic_weight.squeeze().cpu().data
        return topic_dist


def get_coherence(topics, texts, method):
    dictionary = Dictionary(texts)
    return CoherenceModel(topics=topics, texts=texts, dictionary=dictionary, coherence=method, topn=25)


def get_topic_list(matrix, top_n, reverse_dict):
    top_index = [heapq.nlargest(top_n, range(len(vec)), vec.take) for vec in matrix]
    topic_list = [[reverse_dict[i] for i in index] for index in top_index]
    return topic_list


def evaluate_topic(topic_list, data_loader):
    texts = [tokenize(s, data_loader.method) for s in data_loader.test_loader.dataset.texts]
    npmi = get_coherence(topic_list, texts, "c_npmi").get_coherence_per_topic()
    c_v = get_coherence(topic_list, texts, "c_v").get_coherence_per_topic()
    return npmi, c_v


def save_topic_info(path, weights, reverse_dict, data_loader, top_n=25):
    topic_list = get_topic_list(weights, top_n, reverse_dict)
    npmi, c_v = evaluate_topic(topic_list, data_loader)
    os.makedirs(path, exist_ok=True)
    write_to_file(os.path.join(path, "topic_list.txt"), [" ".join(topics) for topics in topic_list])
    topic_result = {"NPMI": np.mean(npmi), "CV": np.mean(c_v)}
    write_to_file(os.path.join(path, f"cv_coherence_{topic_result['CV']}.txt"), [str(s) for s in np.round(c_v, 4)])
    write_to_file(os.path.join(path, f"npmi_coherence_{topic_result['NPMI']}.txt"), [str(s) for s in np.round(npmi, 4)])
    return topic_result


In [None]:
def nll_loss(output, target):
    return F.nll_loss(output, target)


def cross_entropy(output, target):
    return F.cross_entropy(output, target)


def categorical_loss(output, target, epsilon=1e-12):
    """
    Computes cross entropy between target (encoded as one-hot vectors) and output.
    Input: output (N, k) ndarray
           target (N, k) ndarray
    Returns: scalar
    """
    output, target = output.float(), target.float()
    output = torch.clamp(output, epsilon, 1. - epsilon)
    return -torch.sum(target * torch.log(output + 1e-9)) / output.shape[0]

loss_dict = {"nll_loss": nll_loss, "cross_entropy": cross_entropy, "categorical_loss": categorical_loss}

In [None]:
class MetricTracker:
    def __init__(self, *keys, writer=None):
        self.writer = writer
        self._data = pd.DataFrame(index=keys, columns=["total", "counts", "average"])
        self.reset()

    def reset(self):
        for col in self._data.columns:
            self._data[col].values[:] = 0

    def update(self, key, value, n=1):
        if self.writer is not None:
            self.writer.add_scalar(key, value)
        self._data.total[key] += value * n
        self._data.counts[key] += n
        self._data.average[key] = round(self._data.total[key] / self._data.counts[key], 6)

    def avg(self, key):
        return self._data.average[key]

    def result(self):
        return dict(self._data.average)


def accuracy(output, target):
    with torch.no_grad():
        pred = torch.argmax(output, dim=1)
        assert pred.shape[0] == len(target)
        return torch.sum(pred == target).item() / len(target)


def macro_f(output, target):
    with torch.no_grad():
        pred = torch.argmax(output, dim=1)
        score = f1_score(target.cpu(), pred.cpu(), average="macro")
        return score

metric_dict = {"accuracy": accuracy, "macro_f": macro_f}

In [None]:
class BaseTrainer:
    """
    Base class for all trainers
    """
    def __init__(self, model, config):
        self.config = config.config
        # 设置 epochs等
        cfg_trainer = config["trainer_config"]
        self.logger = config.get_logger("trainer", cfg_trainer["verbosity"])
        # prepare for (multi-device) GPU training
        # 使用gpu训练
        self.device, device_ids = prepare_device(config["n_gpu"])
        self.model = model.to(self.device)
        if len(device_ids) > 1:
            self.model = torch.nn.DataParallel(self.model, device_ids=device_ids)
        # set up model parameters
        self.best_model = model
        # get function handles of loss and metrics
        # self.criterion = getattr(module_loss, config["loss"])
        self.criterion = loss_dict.pop(config["loss"])
        # 评价函数  ["accuracy", "macro_f"]
        # self.metric_ftns = [getattr(module_metric, met) for met in config["metrics"]]
        self.metric_ftns = [metric_dict(met) for met in config["metrics"]]
        # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
        trainable_params = filter(lambda p: p.requires_grad, model.parameters())
        self.optimizer = config.init_obj("optimizer_config", torch.optim, trainable_params)
        self.lr_scheduler = config.init_obj("scheduler_config", torch.optim.lr_scheduler, self.optimizer)
        # set up trainer parameters
        self.epochs = cfg_trainer["epochs"]
        self.save_model = config["save_model"]
        self.monitor = cfg_trainer.get("monitor", "off")
        self.last_best_path = None
        self.not_improved_count = 0

        # configuration to monitor model performance and save best
        if self.monitor == "off":
            self.mnt_mode = "off"
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ["min", "max"]

            self.mnt_best = inf if self.mnt_mode == "min" else -inf
            self.early_stop = cfg_trainer.get("early_stop", inf)
            if self.early_stop <= 0:
                self.early_stop = inf

        self.start_epoch = 1
        self.checkpoint_dir = config.save_dir

        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer["tensorboard"])

        if config["resume"] is not None:
            self._resume_checkpoint(config["resume"])

    @abstractmethod
    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Current epoch number
        """
        raise NotImplementedError

    def _log_info(self, log):
        # print logged information to the screen
        for key, value in log.items():
            self.logger.info("    {:15s}: {}".format(str(key), value))

    def save_log(self, log, **kwargs):
        log["seed"] = self.config["seed"]
        arch_config = self.config["arch_config"]
        default_config = arch_default_config(arch_config.get("type"))
        for key in arch_config.keys():
            if default_config.get(key, None) != arch_config.get(key):
                log[key] = arch_config.get(key)
        log["run_id"] = self.config["run_name"]
        saved_path = kwargs.get("saved_path", Path(self.checkpoint_dir) / "model_best.csv")
        log_df = pd.DataFrame(log, index=[0])
        if os.path.exists(saved_path):
            log_df = log_df.append(pd.read_csv(saved_path, float_precision="round_trip"), ignore_index=True)
        log_df = log_df.loc[:, ~log_df.columns.str.contains("^Unnamed")]
        log_df.to_csv(saved_path)

    def _monitor(self, log, epoch):
        # evaluate model performance according to configured metric, save best checkpoint as model_best with score
        if self.mnt_mode != "off":
            try:
                # check whether model performance improved or not, according to specified metric(mnt_metric)
                improved = (self.mnt_mode == "min" and log[self.mnt_metric] <= self.mnt_best) or \
                           (self.mnt_mode == "max" and log[self.mnt_metric] >= self.mnt_best)
            except KeyError:
                err_msg = f"Warning:Metric {self.mnt_metric} is not found.Model performance monitoring is disabled."
                self.logger.warning(err_msg)
                self.mnt_mode = "off"
                improved = False
            log["split"] = "valid"
            self.save_log(log)

            if improved:
                self.mnt_best = log[self.mnt_metric]
                self.not_improved_count = 0
                self.best_model = copy.deepcopy(self.model)
                if self.save_model:
                    self._save_checkpoint(epoch, log[self.mnt_metric])
            else:
                self.not_improved_count += 1

    def train(self):
        """
        Full training logic


        """
        for epoch in range(self.start_epoch, self.epochs + 1):
            result = self._train_epoch(epoch)

            # save logged information into log dict
            log = {"epoch": epoch}
            log.update(result)
            self._log_info(log)
            self._monitor(log, epoch)
            if self.not_improved_count > self.early_stop:
                self.logger.info(f"Validation performance did not improve for {self.early_stop} epochs. "
                                 "Training stops.")
                break

    def _save_checkpoint(self, epoch, score=0.0):
        """
        Saving checkpoints
        :param epoch: current epoch number
        :param score: current score of monitor metric
        """
        arch = type(self.model).__name__
        state = {
            "arch": arch,
            "epoch": epoch,
            "state_dict": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "monitor_best": self.mnt_best,
            "config": self.config
        }
        best_path = str(self.checkpoint_dir / f"{round(score, 4)}_model_best-epoch{epoch}.pth")
        if self.last_best_path:
            if os.path.exists(self.last_best_path):
                os.remove(self.last_best_path)
        torch.save(state, best_path)
        self.logger.info(f"Saving current best: {best_path}")
        self.last_best_path = best_path

    def _resume_checkpoint(self, resume_path):
        """
        Resume from saved checkpoints

        :param resume_path: Checkpoint path to be resumed
        """
        resume_path = str(resume_path)
        self.logger.info(f"Loading checkpoint: {resume_path} ...")
        checkpoint = torch.load(resume_path)
        self.start_epoch = checkpoint["epoch"] + 1
        self.mnt_best = checkpoint["monitor_best"]

        # load architecture params from checkpoint.
        if checkpoint["config"]["arch_config"] != self.config["arch_config"]:
            self.logger.warning("Warning: Architecture configuration given in config file is different from that of "
                                "checkpoint. This may yield an exception while state_dict is being loaded.")
        # if torch.distributed.is_initialized():
        #     self.model.load_state_dict(checkpoint["state_dict"])
        # else:
        #     self.model.load_state_dict(checkpoint["state_dict"])
        self.model.load_state_dict(checkpoint["state_dict"])
        # load optimizer state from checkpoint only when optimizer type is not changed.
        if checkpoint["config"]["optimizer_config"] != self.config["optimizer_config"]:
            self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. "
                                "Optimizer parameters not being resumed.")
        else:
            self.optimizer.load_state_dict(checkpoint["optimizer"])

        self.logger.info(f"Checkpoint loaded. Resume training from epoch {self.start_epoch}")


In [None]:
class NCTrainer(BaseTrainer):
    """
    Trainer class
    """
    def __init__(self, model, config, data_loader, **kwargs):
        super().__init__(model, config)
        self.config = config
        self.data_loader = data_loader.train_loader
        # Configuration类实现了 get() 函数,这里是获取 arch_config 属性
        arch_config = self.config["arch_config"]
        self.entropy_constraint = arch_config.get("entropy_constraint", False)
        self.calculate_entropy = arch_config.get("calculate_entropy", self.entropy_constraint)
        # 训练步长
        self.alpha = arch_config.get("alpha", 0.001)
        self.len_epoch = len(self.data_loader)
        self.valid_loader = data_loader.valid_loader
        self.do_validation = self.valid_loader is not None
        self.log_step = int(np.sqrt(self.data_loader.batch_size))
        metrics = ["loss"] + [m.__name__ for m in self.metric_ftns]
        if self.calculate_entropy:
            metrics.extend(["doc_entropy"])
        self.train_metrics = MetricTracker(*metrics, writer=self.writer)
        self.valid_metrics = MetricTracker(*metrics, writer=self.writer)

    # 将数据放入GPU中
    def load_batch_data(self, batch_dict):
        """
        load batch data to default device
        """
        return {k: v.to(self.device) for k, v in batch_dict.items()}

# 训练模型并返回真实结果和loss
    def run_model(self, batch_dict, model=None):
        """
        run model with the batch data
        :param batch_dict: the dictionary of data with format like {"data": Tensor(), "label": Tensor()}
        :param model: by default we use the self model
        :return: the output of running, label used for evaluation, and loss item
        """
        # 将数据放入GPU中
        batch_dict = self.load_batch_data(batch_dict)
        # 训练模型
        output = model(batch_dict) if model is not None else self.model(batch_dict)
        loss = self.criterion(output[0], batch_dict["label"])
        out_dict = {"label": batch_dict["label"], "loss": loss, "predict": output[0]}
        # 使用商约束
        if self.entropy_constraint:
            loss += self.alpha * output[2]
        if self.calculate_entropy:
            out_dict.update({"attention_weight": output[1], "entropy": output[2]})
        return out_dict

    # 更新评价函数  ["accuracy", "macro_f"]
    def update_metrics(self, metrics, out_dict):
        n = len(out_dict["label"])
        metrics.update("loss", out_dict["loss"].item(), n=n)  # update metrix
        if self.calculate_entropy:
            metrics.update("doc_entropy", out_dict["entropy"].item() / n, n=n)
        for met in self.metric_ftns:  # run metric functions
            metrics.update(met.__name__, met(out_dict["predict"], out_dict["label"]), n=n)

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch
        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        # 这里进入训练模式
        self.model.train()
        # self.train_metrics.reset(): 这一行代码调用了一个名为 reset 的方法，该方法可能是用于重置训练过程中的度量（metrics）或统计信息的。这样可以确保每个训练周期（epoch）开始时，度量的状态是干净的，而不会受到之前周期的影响
        self.train_metrics.reset()
        # tqdm 是一个 Python 库，用于在命令行界面中显示进度条，以提供对代码执行进度的实时可视化反馈。它的名称取自阿拉伯语中的“taqaddum”（进展）。
        bar = tqdm(enumerate(self.data_loader), total=len(self.data_loader))
        for batch_idx, batch_dict in bar:
            self.optimizer.zero_grad()  # setup gradient to zero

            out_dict = self.run_model(batch_dict, self.model)  # run model
            out_dict["loss"].backward()  # backpropagation
            self.optimizer.step()  # gradient descent
            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx, "train")
            self.update_metrics(self.train_metrics, out_dict)
            if batch_idx % self.log_step == 0:  # set bar
                bar.set_description(f"Train Epoch: {epoch} Loss: {out_dict['loss'].item()}")
            if batch_idx == self.len_epoch:
                break
        log = self.train_metrics.result()
        if self.do_validation:
            log.update(self.evaluate(self.valid_loader, self.model, epoch))  # update validation log

        if self.lr_scheduler is not None:
            # 是否调整 lr
            self.lr_scheduler.step()
        return log

    def evaluate(self, loader, model, epoch=0, prefix="val"):
        model.eval()
        self.valid_metrics.reset()
        with torch.no_grad():
            for batch_idx, batch_dict in tqdm(enumerate(loader), total=len(loader)):
                out_dict = self.run_model(batch_dict, model)
                self.writer.set_step((epoch - 1) * len(loader) + batch_idx, "evaluate")
                self.update_metrics(self.valid_metrics, out_dict)
        for name, p in model.named_parameters():  # add histogram of model parameters to the tensorboard
            self.writer.add_histogram(name, p, bins='auto')
        return {f"{prefix}_{k}": v for k, v in self.valid_metrics.result().items()}  # return log with prefix


In [None]:
def init_default_model(config_parser: ConfigParser, data_loader: NewsDataLoader):
    # build a default model architecture
    # word_dict获得词表中每个词的索引列表
    model_params = {"num_classes": len(data_loader.label_dict), "word_dict": data_loader.word_dict}
    # 如果 object 对象具有名为 name 的属性或方法，则 hasattr() 函数返回 True。如果 object 对象不具有名为 name 的属性或方法，则返回 False
    if hasattr(data_loader, "embeds"):
        # 可以看 NewsDataLoader的代码，如果有这个属性说明为glove嵌入
        model_params.update({"embeds": data_loader.embeds})
        # arch_config的定义在 /home/zhouyonglin/work/model/BATM/experiment/config/configuration.py
        # 获取相应的模型
    model = config_parser.init_obj("arch_config", module_arch, **model_params)
    return model


def init_data_loader(config_parser: ConfigParser):
    # setup data_loader instances
    # 从 data_config(/home/zhouyonglin/work/model/BATM/experiment/config/configuration.py)中取出type,这里的type是指类NewsDataLoader，返回的就是这个类
    data_loader = config_parser.init_obj("data_config", module_data)
    return data_loader


def run(config_parser: ConfigParser, data_loader: NewsDataLoader):
    cudnn.benchmark = False
    cudnn.deterministic = True
    logger = config_parser.get_logger("train")
    # 得到base模型
    model = init_default_model(config_parser, data_loader)
    logger.info(model)
    trainer = NCTrainer(model, config_parser, data_loader)
    # 训练模型
    # train(),起源于base_trainer的抽象方法 train(),里面调用 _train_epoch抽象方法，这里实际调用的是 nc_trainer中的_train_epoch方法
    trainer.train()
    return trainer


def test(trainer: NCTrainer, data_loader: NewsDataLoader):
    log = {}
    # run validation
    log.update(trainer.evaluate(data_loader.valid_loader, trainer.best_model, prefix="val"))
    # run test
    log.update(trainer.evaluate(data_loader.test_loader, trainer.best_model, prefix="test"))
    return log


def topic_evaluation(trainer: NCTrainer, data_loader: NewsDataLoader, path: Union[str, os.PathLike]):
    # statistic topic distribution of Topic Attention network
    reverse_dict = {v: k for k, v in data_loader.word_dict.items()}
    topic_dist = get_topic_dist(trainer, list(data_loader.word_dict.values()))
    topic_result = save_topic_info(path, topic_dist, reverse_dict, data_loader)
    topic_result.update({"token_entropy": np.mean(entropy(topic_dist, axis=1))})
    return topic_result