Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions rectools/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from rectools.utils.config import BaseConfig

from .features import AbsentIdError, DenseFeatures, Features, SparseFeatureName, SparseFeatures
from .identifiers import ExternalId, IdMap
from .identifiers import IdMap
from .interactions import Interactions

AnyFeatureName = tp.Union[str, SparseFeatureName]
Expand Down Expand Up @@ -65,24 +65,24 @@ class FeaturesSchema(BaseConfig):
class IdMapSchema(BaseConfig):
"""IdMap schema."""

external_ids: tp.List[ExternalId]
size: int
dtype: str


class EntitySchema(BaseConfig):
"""Entity schema."""

n_hot: int
id_map: IdMapSchema
features: tp.Optional[FeaturesSchema] = None
id_map: tp.Optional[IdMapSchema] = None


class DatasetSchema(BaseConfig):
"""Dataset schema."""

n_interactions: int
items: EntitySchema
users: EntitySchema
items: EntitySchema


@attr.s(slots=True, frozen=True)
Expand Down Expand Up @@ -134,22 +134,20 @@ def _get_feature_schema(features: Features) -> FeaturesSchema:

@staticmethod
def _get_id_map_schema(id_map: IdMap) -> IdMapSchema:
return IdMapSchema(external_ids=id_map.external_ids.tolist(), dtype=id_map.external_dtype.str)
return IdMapSchema(size=id_map.size, dtype=id_map.external_dtype.str)

def get_schema(self, add_user_id_map: bool = False, add_item_id_map: bool = False) -> DatasetSchemaDict:
def get_schema(self) -> DatasetSchemaDict:
"""Get dataset schema in a dict form that contains all the information about the dataset and its statistics."""
user_schema = EntitySchema(n_hot=self.n_hot_users)
if self.user_features is not None:
user_schema.features = self._get_feature_schema(self.user_features)
if add_user_id_map:
user_schema.id_map = self._get_id_map_schema(self.user_id_map)

item_schema = EntitySchema(n_hot=self.n_hot_items)
if self.item_features is not None:
item_schema.features = self._get_feature_schema(self.item_features)
if add_item_id_map:
item_schema.id_map = self._get_id_map_schema(self.item_id_map)

user_schema = EntitySchema(
n_hot=self.n_hot_users,
id_map=self._get_id_map_schema(self.user_id_map),
features=self._get_feature_schema(self.user_features) if self.user_features is not None else None,
)
item_schema = EntitySchema(
n_hot=self.n_hot_items,
id_map=self._get_id_map_schema(self.item_id_map),
features=self._get_feature_schema(self.item_features) if self.item_features is not None else None,
)
schema = DatasetSchema(
n_interactions=self.interactions.df.shape[0],
users=user_schema,
Expand Down
26 changes: 20 additions & 6 deletions rectools/models/nn/transformer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def forward(
# #### -------------- Lightning Model -------------- #### #


class TransformerLightningModuleBase(LightningModule):
class TransformerLightningModuleBase(LightningModule): # pylint: disable=too-many-instance-attributes
"""
Base class for transfofmers lightning module. To change train procedure inherit
from this class and pass your custom LightningModule to your model parameters.
Expand Down Expand Up @@ -279,6 +279,7 @@ def __init__(
torch_model: TransformerTorchBackbone,
model_config: tp.Dict[str, tp.Any],
dataset_schema: DatasetSchemaDict,
item_external_ids: ExternalIds,
data_preparator: TransformerDataPreparatorBase,
lr: float,
gbce_t: float,
Expand All @@ -292,6 +293,7 @@ def __init__(
self.torch_model = torch_model
self.model_config = model_config
self.dataset_schema = dataset_schema
self.item_external_ids = item_external_ids
self.lr = lr
self.loss = loss
self.adam_betas = adam_betas
Expand Down Expand Up @@ -755,11 +757,13 @@ def _init_lightning_model(
self,
torch_model: TransformerTorchBackbone,
dataset_schema: DatasetSchemaDict,
item_external_ids: ExternalIds,
model_config: tp.Dict[str, tp.Any],
) -> None:
self.lightning_model = self.lightning_module_type(
torch_model=torch_model,
dataset_schema=dataset_schema,
item_external_ids=item_external_ids,
model_config=model_config,
data_preparator=self.data_preparator,
lr=self.lr,
Expand All @@ -781,9 +785,15 @@ def _fit(
torch_model = self._init_torch_model()
torch_model.construct_item_net(self.data_preparator.train_dataset)

dataset_schema = self.data_preparator.train_dataset.get_schema(add_item_id_map=True)
dataset_schema = self.data_preparator.train_dataset.get_schema()
item_external_ids = self.data_preparator.train_dataset.item_id_map.external_ids
model_config = self.get_config()
self._init_lightning_model(torch_model, dataset_schema, model_config)
self._init_lightning_model(
torch_model=torch_model,
dataset_schema=dataset_schema,
item_external_ids=item_external_ids,
model_config=model_config,
)

self.fit_trainer = deepcopy(self._trainer)
self.fit_trainer.fit(self.lightning_model, train_dataloader, val_dataloader)
Expand Down Expand Up @@ -916,15 +926,19 @@ def _model_from_checkpoint(cls, checkpoint: tp.Dict[str, tp.Any]) -> tpe.Self:
dataset_schema = DatasetSchema.model_validate(dataset_schema)

# Update data preparator
id_map_schema = dataset_schema.items.id_map
item_external_ids = np.array(id_map_schema.external_ids, dtype=id_map_schema.dtype)
item_external_ids = checkpoint["hyper_parameters"]["item_external_ids"]
loaded.data_preparator.item_id_map = IdMap(item_external_ids)
loaded.data_preparator._init_extra_token_ids() # pylint: disable=protected-access

# Init and update torch model and lightning model
torch_model = loaded._init_torch_model()
torch_model.construct_item_net_from_dataset_schema(dataset_schema)
loaded._init_lightning_model(torch_model, dataset_schema, model_config)
loaded._init_lightning_model(
torch_model=torch_model,
dataset_schema=dataset_schema,
item_external_ids=item_external_ids,
model_config=model_config,
)
loaded.lightning_model.load_state_dict(checkpoint["state_dict"])

return loaded
Expand Down
14 changes: 7 additions & 7 deletions tests/dataset/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,15 @@ def setup_method(self) -> None:
"users": {
"n_hot": 3,
"id_map": {
"external_ids": ["u1", "u2", "u3"],
"size": 3,
"dtype": "|O",
},
"features": None,
},
"items": {
"n_hot": 3,
"id_map": {
"external_ids": ["i1", "i2", "i5"],
"size": 3,
"dtype": "|O",
},
"features": None,
Expand Down Expand Up @@ -105,15 +105,15 @@ def test_construct_with_extra_cols(self) -> None:
expected = self.expected_interactions
expected.df["extra_col"] = self.interactions_df["extra_col"]
assert_interactions_set_equal(actual, expected)
actual_schema = dataset.get_schema(add_item_id_map=True, add_user_id_map=True)
actual_schema = dataset.get_schema()
assert actual_schema == self.expected_schema

def test_construct_without_features(self) -> None:
dataset = Dataset.construct(self.interactions_df)
self.assert_dataset_equal_to_expected(dataset, None, None)
assert dataset.n_hot_users == 3
assert dataset.n_hot_items == 3
actual_schema = dataset.get_schema(add_item_id_map=True, add_user_id_map=True)
actual_schema = dataset.get_schema()
assert actual_schema == self.expected_schema

@pytest.mark.parametrize("user_id_col", ("id", Columns.User))
Expand Down Expand Up @@ -162,7 +162,7 @@ def test_construct_with_features(self, user_id_col: str, item_id_col: str) -> No
"users": {
"n_hot": 3,
"id_map": {
"external_ids": ["u1", "u2", "u3"],
"size": 3,
"dtype": "|O",
},
"features": {
Expand All @@ -175,7 +175,7 @@ def test_construct_with_features(self, user_id_col: str, item_id_col: str) -> No
"items": {
"n_hot": 3,
"id_map": {
"external_ids": ["i1", "i2", "i5"],
"size": 3,
"dtype": "|O",
},
"features": {
Expand All @@ -186,7 +186,7 @@ def test_construct_with_features(self, user_id_col: str, item_id_col: str) -> No
},
},
}
actual_schema = dataset.get_schema(add_item_id_map=True, add_user_id_map=True)
actual_schema = dataset.get_schema()
assert actual_schema == expected_schema

@pytest.mark.parametrize("user_id_col", ("id", Columns.User))
Expand Down