In [1]:

import argparse
import copy
import os
from enum import Enum
from typing import Dict, List, Union

import numpy as np
import pandas as pd
import torch
from config.configuration import Config
from data.dataloader import GeneralTrainerDataLoader
from data.dataset import GeneralDataset
from data.interaction import Interaction
from data.utils import data_reparation
from models import NeuMF
from root import DATASET_DIR, ROOT_DIR, absolute
from torch.nn.utils import rnn as rnn_utils
from torch.utils.data import Dataset as TorchDataset
from trainer import Trainer
from utils.logger import init_logger
import hashlib

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from root import RESOURCE_DIR

class Template:
    
    def __init__(self, content:Dict[str, str]) -> None:
        self.content = content
        
    def _fit(self, content:Dict[str, str]):
        raise NotImplementedError

    def __str__(self) -> str:
        return self._fit(self.content)

class BasicTempalte(Template):
    
    def __init__(self, content:Dict[str, str]) -> None:
        super().__init__(content)
        self.output = self._fit(content)
    
    def __str__(self) -> str:
        return self.output

    def _fit(self, content: Dict[str, str]):
        content_list = []
        for key, val in content.items():
            content_list.append(self._template(key, val))
        return ",".join(content_list)
            
    def _template(self, subject, predicate) :
        return f"The {subject} is {predicate}"

from root import ORIGINAL_DATASET_DIR

from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings

embedding_models = {
    "il":(HuggingFaceInstructEmbeddings, "hkunlp/instructor-large"),
    "e5":(HuggingFaceInstructEmbeddings, "intfloat/e5-large-v2"),
    "ixl":(HuggingFaceInstructEmbeddings, "hkunlp/instructor-xl"),
}

class EmbeddingType(Enum):
    USER = "user"
    ITEM = "item"
    
class EmbeddingModel(Enum):
    INSTRUCTOR_LARGE = "il"
    INSTRUCTOR_E5 = "e5"
    INSTRUCTOR_XL = "ixl"
    
class TemplateType(Enum):
    BASIC = "basic"

class EmbeddingHelper:
    
    def __init__(self) -> None:
        
        self.upath = os.path.join(ORIGINAL_DATASET_DIR, "userlist.txt")
        self.ipath = os.path.join(ORIGINAL_DATASET_DIR, "wslist.txt")
        self._load_user_and_item()   
            
    @property
    def _user_info_header(self):
        return ["user_id", "ip_address", "counrty", "ip_number", "AS", "latitude", "longitude"]
    
    @property
    def _item_info_header(self):
        return ["service_id", "wsdl_address", "provider", "ip_address", "country", "ip_number", "AS", "latitude", "longitude"]
    
    def _load_user_and_item(self):
        self.user_info = pd.read_csv(self.upath, sep="\t", header=0, names=self._user_info_header)
        self.item_info = pd.read_csv(self.ipath, sep="\t", header=0, names=self._item_info_header)
        
    def info2template(self, type_:EmbeddingType, template_type:TemplateType)->List[str]:
        if type_ == EmbeddingType.USER:
            info = self.user_info
        else:
            info = self.item_info
        if template_type == TemplateType.BASIC:
            template_func = BasicTempalte
        else:
            raise NotImplementedError
        res = []
        for row_dict in info.to_dict(orient="records"):
            template = template_func(row_dict)  # type: ignore
            res.append(str(template))
        return res
    
    @property
    def embedding_path(self):
        if not os.path.join(RESOURCE_DIR, "embedding"):
            os.makedirs(os.path.join(RESOURCE_DIR, "embedding"))
        return os.path.join(RESOURCE_DIR, "embedding")
            
    def get_models(self, type_:EmbeddingModel) -> Union[HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings]:
        model, model_name = embedding_models[type_.value]
        return model(model_name = model_name)
    
    def save_embedding(self, embed_data, embed_name):
        if not os.path.join(self.embedding_path, embed_name):
            np.save(os.path.join(self.embedding_path, embed_name), embed_data)
        
    def load_embedding(self, embed_name):
        if not os.path.join(self.embedding_path, embed_name):
            raise FileNotFoundError
        return np.load(os.path.join(self.embedding_path, embed_name))
    
    def fit(self, type_:EmbeddingType, template_type: TemplateType, model_type:EmbeddingModel, auto_save = True):
        combined_string = f"{type_.value}_{template_type.value}_{model_type.value}"
        file_name = hashlib.md5(combined_string.encode()).hexdigest()[:6]
        try:
            return self.load_embedding(file_name)
        except FileNotFoundError:
            pass
        model = self.get_models(model_type)
        embeddings = model.embed_documents(self.info2template(type_, template_type))
        if auto_save:
            self.save_embedding(embeddings, file_name)
        return embeddings
    
eh = EmbeddingHelper()
eh.fit(EmbeddingType.USER, TemplateType.BASIC, EmbeddingModel.INSTRUCTOR_LARGE)
        

load INSTRUCTOR_Transformer
max_seq_length  512


KeyboardInterrupt: 

### TODO
- 添加验证集
- NGCF
- LightGCN
- 所有代码整体过一遍
- tensorboard
- 训练参数整理
- checkpoint逻辑
- 日志 ✅

In [None]:

parser = argparse.ArgumentParser()
parser.add_argument(
    "--dataset", "-d", type=str, default="wsdream-rt", help="name of datasets"
)

args, _ = parser.parse_known_args()

config = Config(model="NeuMF", dataset=args.dataset)


init_logger(config)

dataset = GeneralDataset(config)
train_data, test_data = data_reparation(config, dataset)
model = NeuMF(config, dataset)
trainer = Trainer(config, model)

trainer.fit(train_data, test_data, saved=False, show_progress=True)
