In [1]:

import torch
from torch.utils.data import Dataset as TorchDataset
from data.interaction import Interaction
from torch.nn.utils import rnn as rnn_utils
import os
import pandas as pd
from root import DATASET_DIR, absolute, ROOT_DIR
from enum import Enum
from typing import Union, Dict, List
import numpy as np
import copy

class FeatType(Enum):
    Token = (0, "单个离散特征序列")
    TokenSeq = (1, "多个离散特征序列")
    Float = (2, "单个连续特征序列")
    FloatSeq = (3, "多个连续特征序列")
    
    @classmethod
    def from_code(cls, code:Union[str, int]):
        if isinstance(code, str): code = int(code)
        for feat_type in FeatType:
            if feat_type.value[0] == code:
                return feat_type
        return None

class RecboleDataset(TorchDataset):
    def __init__(self, config) -> None:
        super().__init__()
        self.config = config
        
        self._get_preset()
        self._get_field_from_config()
        self._load_data(DATASET_DIR)
        self._data_processing()
        
        
    def _get_field_from_config(self):
        """初始化数据集的通用字段"""
        self.dataset_name = self.config['dataset']
        self.uid_field = self.config["USER_ID_FIELD"]
        self.iid_field = self.config["ITEM_ID_FIELD"]
        self.label_field = self.config["LABEL_FIELD"]
        
        self.split_ratio = self.config['split_ratio']
    
    def _get_preset(self):
        self.field2type:Dict[str, FeatType] = {}
        self.field2num:Dict[str, int] = {}
        
    def _data_processing(self):
        self.feat_name_list = self._build_feat_name_list()
        
    def _load_data(self, data_dir):
        """加载数据集"""
        dataset_dir = os.path.join(data_dir, self.dataset_name)
        self._load_item_feat(dataset_dir, self.dataset_name)
        self._load_user_feat(dataset_dir, self.dataset_name)
        self._load_inter_feat(dataset_dir, self.dataset_name)
        
        
    def _load_feat(self, feat_dir, feat_name):
        path = os.path.join(feat_dir, feat_name)
        if not os.path.exists(path): raise FileNotFoundError(f"{path} not found")
        df = pd.read_csv(
            path, 
            header=0
        )
        new_columns = []
        for col in df.columns:
            name, dtype = col.split(":")
            dtype = FeatType.from_code(dtype)
            if dtype is None: raise ValueError(f"feat type {dtype} not found")
            new_columns.append(name)    
            self.field2type[name] = dtype
        df.columns = new_columns
        
        return df
    
    @property
    def uid_num(self):
        return self._count_unique(self.uid_field)
    
    @property
    def iid_num(self):
        return self._count_unique(self.iid_field)
    
    def _count_unique(self, feat_name):
        return len(self.inter_feat[feat_name].unique())

        
    def _load_inter_feat(self, feat_dir, feat_prefix):
        feat_name = f"{feat_prefix}.inter"
        self.inter_feat = self._load_feat(feat_dir, feat_name)
        
    
    def _load_user_feat(self, feat_dir, feat_prefix):
        feat_name = f"{feat_prefix}.user"
        self.user_feat = self._load_feat(feat_dir, feat_name)
    
    def _load_item_feat(self, feat_dir, feat_prefix):
        feat_name = f"{feat_prefix}.item"
        self.item_feat = self._load_feat(feat_dir, feat_name)
        
    def _build_feat_name_list(self):
        feat_name_list = [
            feat_name
            for feat_name in ["inter_feat", "user_feat", "item_feat"]
            if getattr(self, feat_name, None) is not None
        ]
        return feat_name_list
    
    def _change_feat_format(self):
        for feat_name in self.feat_name_list:
            feat = getattr(self, feat_name)
            setattr(self, feat_name, self._dataframe_to_interaction(feat))
    
    def build(self):
        self._change_feat_format()
        dataset = self.split_by_ratio_without_eval(self.split_ratio)
        return dataset
        
    def split_by_ratio_without_eval(self, split_ratio:float):
        """分割训练集和测试集"""
        assert 0 < split_ratio < 1
        total_cnt = self.__len__()
        split_ids = self._calcu_split_ids(total_cnt, [split_ratio, 1 - split_ratio])
        next_index = [
            range(start, end)
            for start, end in zip([0] + split_ids, split_ids + [total_cnt])
        ]
        next_df = [self.inter_feat[index] for index in next_index]
        next_ds = [self.copy(_) for _ in next_df]
        return next_ds
        
        
    def copy(self, new_inter_feat):
        """Given a new interaction feature, return a new :class:`Dataset` object,
        whose interaction feature is updated with ``new_inter_feat``, and all the other attributes the same.

        Args:
            new_inter_feat (Interaction): The new interaction feature need to be updated.

        Returns:
            :class:`~Dataset`: the new :class:`~Dataset` object, whose interaction feature has been updated.
        """
        nxt = copy.copy(self)
        nxt.inter_feat = new_inter_feat
        return nxt
        
        

    def _calcu_split_ids(self, tot, ratios):
        """Given split ratios, and total number, calculate the number of each part after splitting.

        Other than the first one, each part is rounded down.

        Args:
            tot (int): Total number.
            ratios (list): List of split ratios. No need to be normalized.

        Returns:
            list: Number of each part after splitting.
        """
        cnt = [int(ratios[i] * tot) for i in range(len(ratios))]
        cnt[0] = tot - sum(cnt[1:])
        for i in range(1, len(ratios)):
            if cnt[0] <= 1:
                break
            if 0 < ratios[-i] * tot < 1:
                cnt[-i] += 1
                cnt[0] -= 1
        split_ids = np.cumsum(cnt)[:-1]
        return list(split_ids)
        
    def _dataframe_to_interaction(self, data:pd.DataFrame):
        data_for_tensor = {}
        for col_name in data:
            assert isinstance(col_name, str)
            value = data[col_name].values
            ftype = self.field2type[col_name]
            if ftype == FeatType.Token:
                data_for_tensor[col_name] = torch.LongTensor(value)
            elif ftype == FeatType.Float:
                data_for_tensor[col_name] = torch.FloatTensor(value)
            else:
                raise NotImplementedError(f"feat type {ftype} not implemented")
        return Interaction(data_for_tensor)
    
    def __len__(self):
        return len(self.inter_feat)
    
    def __getitem__(self, index, join=True):
        df = self.inter_feat[index]
        return self.join(df) if join else df

    def join(self, df):
        """Given interaction feature, join user/item feature into it.

        Args:
            df (Interaction): Interaction feature to be joint.

        Returns:
            Interaction: Interaction feature after joining operation.
        """
        if self.user_feat is not None and self.uid_field in df:
            df.update(self.user_feat[df[self.uid_field]])
        if self.item_feat is not None and self.iid_field in df:
            df.update(self.item_feat[df[self.iid_field]])
        return df


    
class GeneralDataset(RecboleDataset):
    def __init__(self, config):
        super().__init__(config)



config_test = {
    "split_rate": 0.2,
    "dataset": "wsdream-rt",
    "USER_ID_FIELD": "user_id",
    "ITEM_ID_FIELD": "item_id",
    "LABEL_FIELD": "rt"
}
        


In [2]:
# 把原始的wsdream数据转成原子形式
# https://recbole.io/cn/atomic_files.html

import pandas as pd
from root import ORIGINAL_DATASET_DIR, absolute, DATASET_DIR
import os
from enum import Enum
import numpy as np


class WSDreamDataType(Enum):
    TP_ONLY = (1, "wsdream-tp")
    RT_ONLY = (2, "wsdream-rt")
    TP_AND_RT = (3, "wsdream-all")

    @classmethod
    def from_code(cls, code:int):
        for wsdream_type in WSDreamDataType:
            if wsdream_type.value[0] == code:
                return wsdream_type
        return None

class BasicDataConvert:

    def load_user_data(self):
        raise NotImplementedError
    
    def loda_item_data(self):
        raise NotImplementedError
    
    def load_inter_data(self):
        raise NotImplementedError
    
    def fit(self):
        raise NotImplementedError
    
ALL_USER_FIELD = ["[User ID]", "[IP Address]", "[Country]", "[IP No.]", "[AS]", "[Latitude]", "[Longitude]"]
ALL_ITEM_FIELD = ["[Service ID]","[WSDL Address]","[Service Provider]","[IP Address]","[Country]","[IP No.]","[AS]","[Latitude]","[Longitude]"]
    
class WSDreamDataConvert(BasicDataConvert):
    
    def __init__(self, wsdream_type:WSDreamDataType) -> None:
        super().__init__()
        
        self.origin_user_field = ["[User ID]", "[Country]", "[AS]"]
        self.user_field = ["user_id", "country", "AS"]
        
        self.origin_item_field = ["[Service ID]", "[Country]", "[AS]"]
        self.item_field = ["item_id", "country", "AS",]
        
        self.inter_field = ["user_id", "item_id"]
        if wsdream_type == WSDreamDataType.RT_ONLY: self.inter_field.append("rt")
        elif wsdream_type == WSDreamDataType.TP_ONLY: self.inter_field.append("tp")
        else: self.inter_field.extend(["rt", "tp"])
            
        self.upath = os.path.join(ORIGINAL_DATASET_DIR, "userlist.txt")
        self.ipath = os.path.join(ORIGINAL_DATASET_DIR, "wslist.txt")
        self.rt_inter = os.path.join(ORIGINAL_DATASET_DIR, "rtMatrix.txt")
        self.tp_inter = os.path.join(ORIGINAL_DATASET_DIR, "tpMatrix.txt")
        
        self.wstype = wsdream_type
        self.output_dir = os.path.join(DATASET_DIR, wsdream_type.value[1])
        
        self.dataset_name = wsdream_type.value[1]
        
        self._load_data()
        
    def _load_data(self):
        self.user_data = self.load_user_data()
        self.item_data = self.loda_item_data()
        self.inter_data = self.load_inter_data()
        for name in ["[Country]", "[AS]"]:
            self._deal_categorical_feat(name)
        
    def _deal_categorical_feat(self, name:str):
        if self.item_data is None and self.user_data is None: self._load_data()
        feat_kinds = []
        if name in self.user_data: feat_kinds.extend(self.user_data[name].unique().tolist())
        if name in self.item_data: feat_kinds.extend(self.item_data[name].unique().tolist())
        feat_kinds = list(set(feat_kinds))
        map_ = {
            feat:idx for idx, feat in enumerate(feat_kinds)
        }
        if name in self.user_data: self.user_data.replace({name:map_}, inplace=True)
        if name in self.item_data: self.item_data.replace({name:map_}, inplace=True)

        
    def _feat_type_wrap(self, type_:str):
        feat_types = []
        if type_ == "user":
            feat_types = [0, 0, 0]
            return list(map(lambda x,y:f'{x}:{y}', self.user_field, feat_types))
        elif type_ == "item":
            feat_types = [0, 0, 0]
            return list(map(lambda x,y:f'{x}:{y}', self.item_field, feat_types))
        else:
            if self.wstype == WSDreamDataType.RT_ONLY or self.wstype == WSDreamDataType.TP_ONLY:
                feat_types = [0, 0, 2]
            else:
                feat_types = [0, 0, 2, 2]
            return list(map(lambda x,y:f'{x}:{y}', self.inter_field, feat_types))
    

    def load_inter_data(self) -> pd.DataFrame:
        rt_path, tp_path = None, None
        if self.wstype == WSDreamDataType.RT_ONLY: rt_path = self.rt_inter
        elif self.wstype == WSDreamDataType.TP_ONLY: tp_path = self.tp_inter
        else: rt_path, tp_path = self.rt_inter, self.tp_inter
        if rt_path and tp_path:
            rt_data = np.loadtxt(rt_path, dtype=np.float64)
            tp_data = np.loadtxt(tp_path, dtype=np.float64)
            rows, cols = np.nonzero(rt_data)
            inter_data = pd.DataFrame({self.inter_field[0]:rows, self.inter_field[1]:cols, self.inter_field[2]:rt_data[rows, cols], self.inter_field[3]: tp_data[rows, cols]})
        else:
            path = self.rt_inter if rt_path else self.tp_inter
            inter_data = np.loadtxt(path, dtype=np.float64)
            rows, cols = np.nonzero(inter_data)
            inter_data = pd.DataFrame({self.inter_field[0]:rows, self.inter_field[1]:cols, self.inter_field[2]:inter_data[rows, cols]})
        return inter_data

    def load_user_data(self) -> pd.DataFrame:
        return pd.read_csv(self.upath, sep="\t", header=0)[self.origin_user_field]
    
    def loda_item_data(self):
        return pd.read_csv(self.ipath, sep="\t", header=0)[self.origin_item_field]
    
    def _convert(self, type_:str):
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)
        if type_ == "user":
            data = self.user_data
        elif type_ == "item":
            data = self.item_data
        else:
            data = self.inter_data
        data.columns = self._feat_type_wrap(type_)
        data.to_csv(os.path.join(self.output_dir, f"{self.dataset_name}.{type_}"), index=False)
        
    def fit(self):
        for type_ in ["user", "item", "inter"]:
            self._convert(type_)
        
    
wc = WSDreamDataConvert(wsdream_type=WSDreamDataType.RT_ONLY)
# wc.load_user_data()
wc.fit()

# data = pd.read_csv(os.path.join(DATASET_DIR, "rt.inter"), sep=",", header=0)


In [3]:


from torch import optim
from tqdm import tqdm
from torch.amp.autocast_mode import autocast
from torch.cuda.amp.grad_scaler import GradScaler
import time
from recbole.utils import early_stopping, dict2str, ensure_dir, get_gpu_usage, get_local_time, set_color

class AbstractTrainer(object):
    r"""Trainer Class is used to manage the training and evaluation processes of recommender system models.
    AbstractTrainer is an abstract class in which the fit() and evaluate() method should be implemented according
    to different training and evaluation strategies.
    """

    def __init__(self, config, model):
        self.config = config
        self.model = model

    def fit(self, train_data):
        r"""Train the model based on the train data."""
        raise NotImplementedError("Method [next] should be implemented.")

    def evaluate(self, eval_data):
        r"""Evaluate the model based on the eval data."""

        raise NotImplementedError("Method [next] should be implemented.")



class Trainer(AbstractTrainer):
    r"""The basic Trainer for basic training and evaluation strategies in recommender systems. This class defines common
    functions for training and evaluation processes of most recommender system models, including fit(), evaluate(),
    resume_checkpoint() and some other features helpful for model training and evaluation.

    Generally speaking, this class can serve most recommender system models, If the training process of the model is to
    simply optimize a single loss without involving any complex training strategies, such as adversarial learning,
    pre-training and so on.

    Initializing the Trainer needs two parameters: `config` and `model`. `config` records the parameters information
    for controlling training and evaluation, such as `learning_rate`, `epochs`, `eval_step` and so on.
    `model` is the instantiated object of a Model Class.

    """

    def __init__(self, config, model):
        super(Trainer, self).__init__(config, model)
        
        self.learner = config["learner"]
        self.learning_rate = config["learning_rate"]
        self.epochs = config["epochs"]
        self.eval_step:int = min(config["eval_step"], self.epochs)
        self.stopping_step = config["stopping_step"]
        self.valid_metric_bigger = config["valid_metric_bigger"] # 是不是越大越好
        self.test_batch_size = config["eval_batch_size"]
        self.gpu_available = torch.cuda.is_available() and config["use_gpu"]
        self.device = config["device"]
        self.checkpoint_dir = config["checkpoint_dir"]
        self.enable_amp = config["enable_amp"]
        self.enable_scaler = torch.cuda.is_available() and config["enable_scaler"]
        self.enable_amp = config["enable_amp"]
        ensure_dir(self.checkpoint_dir)
        saved_model_file = "{}-{}.pth".format(self.config["model"], get_local_time())
        self.saved_model_file = os.path.join(self.checkpoint_dir, saved_model_file)
        self.weight_decay = config["weight_decay"]

        self.start_epoch = 0
        self.cur_step = 0
        self.best_valid_score = -np.inf if self.valid_metric_bigger else np.inf
        self.best_valid_result = None
        self.train_loss_dict = dict()
        self.optimizer = self._build_optimizer()
        self.eval_type = config["eval_type"]
        self.item_tensor = None
        self.tot_item_num = None

    def _build_optimizer(self, **kwargs):
        r"""Init the Optimizer

        Args:
            params (torch.nn.Parameter, optional): The parameters to be optimized.
                Defaults to ``self.model.parameters()``.
            learner (str, optional): The name of used optimizer. Defaults to ``self.learner``.
            learning_rate (float, optional): Learning rate. Defaults to ``self.learning_rate``.
            weight_decay (float, optional): The L2 regularization weight. Defaults to ``self.weight_decay``.

        Returns:
            torch.optim: the optimizer
        """
        params = kwargs.pop("params", self.model.parameters())
        learner = kwargs.pop("learner", self.learner)
        learning_rate = kwargs.pop("learning_rate", self.learning_rate)
        weight_decay = kwargs.pop("weight_decay", self.weight_decay)

        if (
            self.config["reg_weight"]
            and weight_decay
            and weight_decay * self.config["reg_weight"] > 0
        ):
            ...
            # self.logger.warning(
            #     "The parameters [weight_decay] and [reg_weight] are specified simultaneously, "
            #     "which may lead to double regularization."
            # )

        if learner.lower() == "adam":
            optimizer = optim.Adam(params, lr=learning_rate, weight_decay=weight_decay)
        elif learner.lower() == "sgd":
            optimizer = optim.SGD(params, lr=learning_rate, weight_decay=weight_decay)
        elif learner.lower() == "adagrad":
            optimizer = optim.Adagrad(
                params, lr=learning_rate, weight_decay=weight_decay
            )
        elif learner.lower() == "rmsprop":
            optimizer = optim.RMSprop(
                params, lr=learning_rate, weight_decay=weight_decay
            )
        elif learner.lower() == "sparse_adam":
            optimizer = optim.SparseAdam(params, lr=learning_rate)
            if weight_decay > 0:
                ...
                # self.logger.warning(
                #     "Sparse Adam cannot argument received argument [{weight_decay}]"
                # )
        else:
            # self.logger.warning(
            #     "Received unrecognized optimizer, set default Adam optimizer"
            # )
            optimizer = optim.Adam(params, lr=learning_rate)
        return optimizer

    def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=False):
        r"""Train the model in an epoch

        Args:
            train_data (DataLoader): The train data.
            epoch_idx (int): The current epoch id.
            loss_func (function): The loss function of :attr:`model`. If it is ``None``, the loss function will be
                :attr:`self.model.calculate_loss`. Defaults to ``None``.
            show_progress (bool): Show the progress of training epoch. Defaults to ``False``.

        Returns:
            float/tuple: The sum of loss returned by all batches in this epoch. If the loss in each batch contains
            multiple parts and the model return these multiple parts loss instead of the sum of loss, it will return a
            tuple which includes the sum of loss in each part.
        """
        self.model.train()
        loss_func = loss_func or self.model.calculate_loss
        total_loss = None
        iter_data = (
            tqdm(
                train_data,
                total=len(train_data),
                ncols=100,
                desc=set_color(f"Train {epoch_idx:>5}", "pink"),
            )
            if show_progress
            else train_data
        )

        scaler = GradScaler(enabled=self.enable_scaler)
        for batch_idx, interaction in enumerate(iter_data):
            interaction = interaction.to(self.device)
            self.optimizer.zero_grad()
            # 自动混合精度
            with autocast(device_type=self.device.type, enabled=self.enable_amp):
                losses = loss_func(interaction)

            if isinstance(losses, tuple):
                loss = sum(losses)
                loss_tuple = tuple(per_loss.item() for per_loss in losses)
                total_loss = (
                    loss_tuple
                    if total_loss is None
                    else tuple(map(sum, zip(total_loss, loss_tuple)))
                )
            else:
                loss = losses
                total_loss = (
                    losses.item() if total_loss is None else total_loss + losses.item()
                )
            self._check_nan(loss)
            scaler.scale(loss).backward()
            scaler.step(self.optimizer)
            scaler.update()
            # if self.gpu_available and show_progress:
            #     iter_data.set_postfix_str(
            #         set_color("GPU RAM: " + get_gpu_usage(self.device), "yellow")
            #     )
        return total_loss

    def _save_checkpoint(self, epoch, verbose=True, **kwargs):
        r"""Store the model parameters information and training information.

        Args:
            epoch (int): the current epoch id

        """
        if not self.config["single_spec"] and self.config["local_rank"] != 0:
            return
        saved_model_file = kwargs.pop("saved_model_file", self.saved_model_file)
        state = {
            "config": self.config,
            "epoch": epoch,
            "cur_step": self.cur_step,
            "best_valid_score": self.best_valid_score,
            "state_dict": self.model.state_dict(),
            "other_parameter": self.model.other_parameter(),
            "optimizer": self.optimizer.state_dict(),
        }
        torch.save(state, saved_model_file, pickle_protocol=4)
        if verbose:
            ...
            # self.logger.info(
            #     set_color("Saving current", "blue") + f": {saved_model_file}"
            # )

    def resume_checkpoint(self, resume_file):
        r"""Load the model parameters information and training information.

        Args:
            resume_file (file): the checkpoint file

        """
        resume_file = str(resume_file)
        self.saved_model_file = resume_file
        checkpoint = torch.load(resume_file, map_location=self.device)
        self.start_epoch = checkpoint["epoch"] + 1
        self.cur_step = checkpoint["cur_step"]
        self.best_valid_score = checkpoint["best_valid_score"]

        # load architecture params from checkpoint
        if checkpoint["config"]["model"].lower() != self.config["model"].lower():
            ...
            # self.logger.warning(
            #     "Architecture configuration given in config file is different from that of checkpoint. "
            #     "This may yield an exception while state_dict is being loaded."
            # )
        self.model.load_state_dict(checkpoint["state_dict"])
        self.model.load_other_parameter(checkpoint.get("other_parameter"))

        # load optimizer state from checkpoint only when optimizer type is not changed
        self.optimizer.load_state_dict(checkpoint["optimizer"])
        message_output = "Checkpoint loaded. Resume training from epoch {}".format(
            self.start_epoch
        )
        # self.logger.info(message_output)

    def _check_nan(self, loss):
        if torch.isnan(loss):
            raise ValueError("Training loss is nan")

    def _generate_train_loss_output(self, epoch_idx, s_time, e_time, losses):
        des = self.config["loss_decimal_place"] or 4
        train_loss_output = (
            set_color("epoch %d training", "green")
            + " ["
            + set_color("time", "blue")
            + ": %.2fs, "
        ) % (epoch_idx, e_time - s_time)
        if isinstance(losses, tuple):
            des = set_color("train_loss%d", "blue") + ": %." + str(des) + "f"
            train_loss_output += ", ".join(
                des % (idx + 1, loss) for idx, loss in enumerate(losses)
            )
        else:
            des = "%." + str(des) + "f"
            train_loss_output += set_color("train loss", "blue") + ": " + des % losses
        return train_loss_output + "]"

    # def _add_train_loss_to_tensorboard(self, epoch_idx, losses, tag="Loss/Train"):
    #     if isinstance(losses, tuple):
    #         for idx, loss in enumerate(losses):
    #             self.tensorboard.add_scalar(tag + str(idx), loss, epoch_idx)
    #     else:
    #         self.tensorboard.add_scalar(tag, losses, epoch_idx)

    # def _add_hparam_to_tensorboard(self, best_valid_result):
    #     # base hparam
    #     hparam_dict = {
    #         "learner": self.config["learner"],
    #         "learning_rate": self.config["learning_rate"],
    #         "train_batch_size": self.config["train_batch_size"],
    #     }
    #     # unrecorded parameter
    #     unrecorded_parameter = {
    #         parameter
    #         for parameters in self.config.parameters.values()
    #         for parameter in parameters
    #     }.union({"model", "dataset", "config_files", "device"})
    #     # other model-specific hparam
    #     hparam_dict.update(
    #         {
    #             para: val
    #             for para, val in self.config.final_config_dict.items()
    #             if para not in unrecorded_parameter
    #         }
    #     )
    #     for k in hparam_dict:
    #         if hparam_dict[k] is not None and not isinstance(
    #             hparam_dict[k], (bool, str, float, int)
    #         ):
    #             hparam_dict[k] = str(hparam_dict[k])

    #     self.tensorboard.add_hparams(
    #         hparam_dict, {"hparam/best_valid_result": best_valid_result}
    #     )

    def fit(
        self,
        train_data,
        valid_data=None,
        verbose=True,
        saved=True,
        show_progress=False,
        callback_fn=None,
    ):
        r"""Train the model based on the train data and the valid data.

        Args:
            train_data (DataLoader): the train data
            valid_data (DataLoader, optional): the valid data, default: None.
                                               If it's None, the early_stopping is invalid.
            verbose (bool, optional): whether to write training and evaluation information to logger, default: True
            saved (bool, optional): whether to save the model parameters, default: True
            show_progress (bool): Show the progress of training epoch and evaluate epoch. Defaults to ``False``.
            callback_fn (callable): Optional callback function executed at end of epoch.
                                    Includes (epoch_idx, valid_score) input arguments.

        Returns:
             (float, dict): best valid score and best valid result. If valid_data is None, it returns (-1, None)
        """
        if saved and self.start_epoch >= self.epochs:
            self._save_checkpoint(-1, verbose=verbose)
        
        valid_step = 0

        for epoch_idx in range(self.start_epoch, self.epochs):
            # train
            training_start_time = time.time()
            train_loss = self._train_epoch(
                train_data, epoch_idx, show_progress=show_progress
            )
            self.train_loss_dict[epoch_idx] = (
                sum(train_loss) if isinstance(train_loss, tuple) else train_loss
            )
            training_end_time = time.time()
            train_loss_output = self._generate_train_loss_output(
                epoch_idx, training_start_time, training_end_time, train_loss
            )
            if verbose:
                # self.logger.info(train_loss_output)
                ...
            # TODO
            # self._add_train_loss_to_tensorboard(epoch_idx, train_loss)
            # self.wandblogger.log_metrics(
            #     {"epoch": epoch_idx, "train_loss": train_loss, "train_step": epoch_idx},
            #     head="train",
            # )

            # eval
            if self.eval_step <= 0 or not valid_data:
                if saved:
                    self._save_checkpoint(epoch_idx, verbose=verbose)
                continue
            # TODO
            # if (epoch_idx + 1) % self.eval_step == 0:
            #     valid_start_time = time.time()
            #     valid_score, valid_result = self._valid_epoch(
            #         valid_data, show_progress=show_progress
            #     )

            #     (
            #         self.best_valid_score,
            #         self.cur_step,
            #         stop_flag,
            #         update_flag,
            #     ) = early_stopping(
            #         valid_score,
            #         self.best_valid_score,
            #         self.cur_step,
            #         max_step=self.stopping_step,
            #         bigger=self.valid_metric_bigger,
            #     )
            #     valid_end_time = time.time()
            #     valid_score_output = (
            #         set_color("epoch %d evaluating", "green")
            #         + " ["
            #         + set_color("time", "blue")
            #         + ": %.2fs, "
            #         + set_color("valid_score", "blue")
            #         + ": %f]"
            #     ) % (epoch_idx, valid_end_time - valid_start_time, valid_score)
            #     valid_result_output = (
            #         set_color("valid result", "blue") + ": \n" + dict2str(valid_result)
            #     )
            #     if verbose:
            #         # self.logger.info(valid_score_output)
            #         # self.logger.info(valid_result_output)
            #         ...
            #     # self.tensorboard.add_scalar("Vaild_score", valid_score, epoch_idx)
            #     # self.wandblogger.log_metrics(
            #     #     {**valid_result, "valid_step": valid_step}, head="valid"
            #     # )

            #     if update_flag:
            #         if saved:
            #             self._save_checkpoint(epoch_idx, verbose=verbose)
            #         self.best_valid_result = valid_result

            #     if callback_fn:
            #         callback_fn(epoch_idx, valid_score)

            #     if stop_flag:
            #         stop_output = "Finished training, best eval result in epoch %d" % (
            #             epoch_idx - self.cur_step * self.eval_step
            #         )
            #         if verbose:
            #             # self.logger.info(stop_output)
            #             ...
            #         break

            #     valid_step += 1

        # self._add_hparam_to_tensorboard(self.best_valid_score)
        return self.best_valid_score, self.best_valid_result


    def _spilt_predict(self, interaction, batch_size):
        spilt_interaction = dict()
        for key, tensor in interaction.interaction.items():
            spilt_interaction[key] = tensor.split(self.test_batch_size, dim=0)
        num_block = (batch_size + self.test_batch_size - 1) // self.test_batch_size
        result_list = []
        for i in range(num_block):
            current_interaction = dict()
            for key, spilt_tensor in spilt_interaction.items():
                current_interaction[key] = spilt_tensor[i]
            result = self.model.predict(
                Interaction(current_interaction).to(self.device)
            )
            if len(result.shape) == 0:
                result = result.unsqueeze(0)
            result_list.append(result)
        return torch.cat(result_list, dim=0)

In [4]:
from data.dataloader import GeneralDataLoader
from models.neumf import NeuMF

def data_reparation(config, dataset:RecboleDataset):
    built_dataset = dataset.build()
    train_dataset, test_dataset = built_dataset
    train_data = GeneralDataLoader(train_dataset, config)
    test_data = GeneralDataLoader(test_dataset, config)
    return train_data, test_data

from config.configuration import Config
import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
        "--dataset", "-d", type=str, default="wsdream-rt", help="name of datasets"
    )

args, _ = parser.parse_known_args()

config = Config(model="NeuMF", dataset=args.dataset)
    

dataset = GeneralDataset(config)
train_data, test_data = dataset.build()
model = NeuMF(config, dataset)
trainer = Trainer(config, model)
trainer.fit(train_data, test_data, saved=False, show_progress=True)

[1;35mTrain     0[0m:   3%|█                                         | 10508/394935 [00:31<19:01, 336.80it/s][0m


KeyboardInterrupt: 