In [1]:
import os

os.environ["http_proxy"] = "http://localhost:7890"
os.environ["https_proxy"] = "http://localhost:7890"

In [2]:
import logging
import os
from dataclasses import dataclass
from pathlib import Path

import pandas as pd
from datasets import disable_progress_bar, load_dataset, load_from_disk
from datasets.dataset_dict import DatasetDict
from recbole.data import create_dataset, data_preparation

logger = logging.getLogger(__name__)

disable_progress_bar()


TASK_ATTRS = {
    # Amazon-beauty
    "beauty": {
        "load_args": ("McAuley-Lab/Amazon-Reviews-2023", "0core_rating_only_All_Beauty"),
        "metric_keys": ("accuracy", "recall"),
    },
}


@dataclass
class DataConfig:
    task_name: str
    datasets_path: str
    preprocessed_datasets_path: str
    train_batch_size: int
    valid_batch_size: int
    test_batch_size: int
    num_proc: int
    force_preprocess: bool
    MAX_ITEM_LIST_LENGTH: int
    load_col: dict
    train_neg_sample_args: dict
    train_net_sample_args: None


class DataModule:
    """DataModule class
    ```
    data_module = DataModule(
        config.data,
    )
    # preprocess datasets
    data_module.run_preprocess(tokenizer=tokenizer)
    # preprocess external dataset (distilled data)
    data_module.preprocess_dataset(tokenizer=tokenizer, dataset=dataset)
    ```
    """

    def __init__(self, config):
        self.config = config

        # load raw dataset
        self.dataset_attr = TASK_ATTRS[self.config.task_name]
        self.datasets: DatasetDict = self.get_dataset()
        logger.info(f"Datasets: {self.datasets}")

        # preprocessed_dataset
        self.run_preprocess()

        # generate dataloader
        self.train_loader = None
        self.valid_loader = None
        self.get_dataloader()

    def get_dataset(self):
        """load raw datasets from source"""
        if os.path.exists(self.config.datasets_path):
            datasets = load_from_disk(self.config.datasets_path)
        else:
            assert self.config.task_name in TASK_ATTRS
            datasets = load_dataset(*self.dataset_attr["load_args"])

            # assert datasets.keys() >= {"train", "valid"}

            os.makedirs(os.path.dirname(self.config.datasets_path), exist_ok=True)
            datasets.save_to_disk(self.config.datasets_path)

        return datasets

    def run_preprocess(self):
        """datasets preprocessing"""

        if (
            os.path.exists(self.config.preprocessed_datasets_path)
            and not self.config.force_preprocess
        ):
            logger.info(
                "Load preprocessed datasets from `{}`".format(
                    self.config.preprocessed_datasets_path
                )
            )
            self.preprocessed_datasets = load_from_disk(
                self.config.preprocessed_datasets_path
            )
            return

        self.preprocessed_datasets = self.preprocess_dataset(dataset=self.datasets)

        logger.info(
            f"Save preprocessed datasets to `{self.config.preprocessed_datasets_path}`"
        )
        os.makedirs(
            os.path.dirname(self.config.preprocessed_datasets_path), exist_ok=True
        )
        self.preprocessed_datasets.to_csv(
            os.path.join(
                self.config.preprocessed_datasets_path,
                f"{self.config.task_name}.inter",
            ),
            sep="\t",
            index=False,
        )

    def preprocess_dataset(self, dataset):
        dataset_df = pd.DataFrame(dataset)
        dataset_df.columns = ["uid", "iid", "rating", "timestamp"]

        ### Filter users and items with less than 5 interactions ###
        filtered_review_df = dataset_df.groupby("iid").filter(lambda x: len(x) >= 3)
        filtered_review_df = (
            filtered_review_df.groupby("uid")
            .filter(lambda x: len(x) >= 5)
            .groupby("uid")
            .apply(
                lambda x: x.sort_values(by=["timestamp"], ascending=[True]),
                include_groups=True,
            )
            .reset_index(drop=True)
        )

        ### ID map ###
        unique_uids = filtered_review_df["uid"].unique()
        unique_iids = filtered_review_df["iid"].unique()
        uid_map = {old_id: new_id for new_id, old_id in enumerate(unique_uids)}
        iid_map = {old_id: new_id for new_id, old_id in enumerate(unique_iids)}
        mapped_review_df = filtered_review_df.copy()
        mapped_review_df["uid"] = mapped_review_df["uid"].map(uid_map)
        mapped_review_df["iid"] = mapped_review_df["iid"].map(iid_map)
        mapped_review_df.columns = [
            "user_id:token",
            "item_id:token",
            "rating:float",
            "timestamp:float",
        ]

        return mapped_review_df

    def get_dataloader(self):
        # dataset filtering
        dataset = create_dataset(self.config)
        # dataset splitting
        self.train_loader, self.valid_loader, _ = data_preparation(self.config, dataset)

    def train_loader(self):
        return self.train_loader

    def valid_loader(self):
        return self.valid_loader


  from .autonotebook import tqdm as notebook_tqdm
