Skip to content
109 changes: 77 additions & 32 deletions rectools/models/nn/sasrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,73 @@ def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> D
return {"x": torch.LongTensor(x)}


class SASRecTransformerLayer(nn.Module):
"""
Exactly SASRec author's transformer block architecture but with pytorch Multi-Head Attention realisation.

Parameters
----------
n_factors : int
Latent embeddings size.
n_heads : int
Number of attention heads.
dropout_rate : float
Probability of a hidden unit to be zeroed.
"""

def __init__(
self,
n_factors: int,
n_heads: int,
dropout_rate: float,
):
super().__init__()
# important: original architecture had another version of MHA
self.multi_head_attn = torch.nn.MultiheadAttention(n_factors, n_heads, dropout_rate, batch_first=True)
self.q_layer_norm = nn.LayerNorm(n_factors)
self.ff_layer_norm = nn.LayerNorm(n_factors)
self.feed_forward = PointWiseFeedForward(n_factors, n_factors, dropout_rate, torch.nn.ReLU())
self.dropout = torch.nn.Dropout(dropout_rate)

def forward(
self,
seqs: torch.Tensor,
attn_mask: tp.Optional[torch.Tensor],
key_padding_mask: tp.Optional[torch.Tensor],
) -> torch.Tensor:
"""
Forward pass through transformer block.

Parameters
----------
seqs : torch.Tensor
User sequences of item embeddings.
attn_mask : torch.Tensor, optional
Optional mask to use in forward pass of multi-head attention as `attn_mask`.
key_padding_mask : torch.Tensor, optional
Optional mask to use in forward pass of multi-head attention as `key_padding_mask`.


Returns
-------
torch.Tensor
User sequences passed through transformer layers.
"""
q = self.q_layer_norm(seqs)
mha_output, _ = self.multi_head_attn(
q, seqs, seqs, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False
)
seqs = q + mha_output
ff_input = self.ff_layer_norm(seqs)
seqs = self.feed_forward(ff_input)
seqs = self.dropout(seqs)
seqs += ff_input
return seqs


class SASRecTransformerLayers(TransformerLayersBase):
"""
Exactly SASRec author's transformer blocks architecture but with pytorch Multi-Head Attention realisation.
SASRec transformer blocks.

Parameters
----------
Expand All @@ -137,15 +201,16 @@ def __init__(
):
super().__init__()
self.n_blocks = n_blocks
self.multi_head_attn = nn.ModuleList(
[torch.nn.MultiheadAttention(n_factors, n_heads, dropout_rate, batch_first=True) for _ in range(n_blocks)]
) # important: original architecture had another version of MHA
self.q_layer_norm = nn.ModuleList([nn.LayerNorm(n_factors) for _ in range(n_blocks)])
self.ff_layer_norm = nn.ModuleList([nn.LayerNorm(n_factors) for _ in range(n_blocks)])
self.feed_forward = nn.ModuleList(
[PointWiseFeedForward(n_factors, n_factors, dropout_rate, torch.nn.ReLU()) for _ in range(n_blocks)]
self.transformer_blocks = nn.ModuleList(
[
SASRecTransformerLayer(
n_factors,
n_heads,
dropout_rate,
)
for _ in range(self.n_blocks)
]
)
self.dropout = nn.ModuleList([torch.nn.Dropout(dropout_rate) for _ in range(n_blocks)])
self.last_layernorm = torch.nn.LayerNorm(n_factors, eps=1e-8)

def forward(
Expand Down Expand Up @@ -175,21 +240,11 @@ def forward(
torch.Tensor
User sequences passed through transformer layers.
"""
seqs *= timeline_mask # [batch_size, session_max_len, n_factors]
for i in range(self.n_blocks):
q = self.q_layer_norm[i](seqs)
mha_output, _ = self.multi_head_attn[i](
q, seqs, seqs, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False
)
seqs = q + mha_output
ff_input = self.ff_layer_norm[i](seqs)
seqs = self.feed_forward[i](ff_input)
seqs = self.dropout[i](seqs)
seqs += ff_input
seqs *= timeline_mask

seqs *= timeline_mask # [batch_size, session_max_len, n_factors]
seqs = self.transformer_blocks[i](seqs, attn_mask, key_padding_mask)
seqs *= timeline_mask
seqs = self.last_layernorm(seqs)

return seqs


Expand Down Expand Up @@ -374,13 +429,3 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
get_val_mask_func=get_val_mask_func,
get_trainer_func=get_trainer_func,
)

def _init_data_preparator(self) -> None:
self.data_preparator = self.data_preparator_type(
session_max_len=self.session_max_len,
n_negatives=self.n_negatives if self.loss != "softmax" else None,
batch_size=self.batch_size,
dataloader_num_workers=self.dataloader_num_workers,
train_min_user_interactions=self.train_min_user_interactions,
get_val_mask_func=self.get_val_mask_func,
)
96 changes: 18 additions & 78 deletions rectools/models/nn/transformer_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,8 @@

import torch

from rectools.dataset.dataset import Dataset, DatasetSchema

from .item_net import (
CatFeaturesItemNet,
IdEmbeddingsItemNet,
ItemNetBase,
ItemNetConstructorBase,
SumOfEmbeddingsConstructor,
)
from .transformer_net_blocks import (
LearnableInversePositionalEncoding,
PositionalEncodingBase,
PreLNTransformerLayers,
TransformerLayersBase,
)
from .item_net import ItemNetBase
from .transformer_net_blocks import PositionalEncodingBase, TransformerLayersBase


class TransformerTorchBackbone(torch.nn.Module):
Expand All @@ -39,89 +26,42 @@ class TransformerTorchBackbone(torch.nn.Module):

Parameters
----------
n_blocks : int
Number of transformer blocks.
n_factors : int
Latent embeddings size.
n_heads : int
Number of attention heads.
session_max_len : int
Maximum length of user sequence.
dropout_rate : float
Probability of a hidden unit to be zeroed.
use_pos_emb : bool, default True
If ``True``, learnable positional encoding will be added to session item embeddings.
item_model : ItemNetBase
Network for item embeddings.
pos_encoding_layer : PositionalEncodingBase
Positional encoding layer.
transformer_layers : TransformerLayersBase
Transformer layers.
use_causal_attn : bool, default True
If ``True``, causal mask is used in multi-head self-attention.
transformer_layers_type : type(TransformerLayersBase), default `PreLNTransformerLayers`
Type of transformer layers architecture.
item_net_type : type(ItemNetBase), default `IdEmbeddingsItemNet`
Type of network returning item embeddings.
pos_encoding_type : type(PositionalEncodingBase), default `LearnableInversePositionalEncoding`
Type of positional encoding.
use_key_padding_mask : bool, default False
If ``True``, key padding mask is used in multi-head self-attention.
"""

def __init__(
self,
n_blocks: int,
n_factors: int,
n_heads: int,
session_max_len: int,
dropout_rate: float,
use_pos_emb: bool = True,
item_model: ItemNetBase,
pos_encoding_layer: PositionalEncodingBase,
transformer_layers: TransformerLayersBase,
use_causal_attn: bool = True,
use_key_padding_mask: bool = False,
transformer_layers_type: tp.Type[TransformerLayersBase] = PreLNTransformerLayers,
item_net_block_types: tp.Sequence[tp.Type[ItemNetBase]] = (IdEmbeddingsItemNet, CatFeaturesItemNet),
item_net_constructor_type: tp.Type[ItemNetConstructorBase] = SumOfEmbeddingsConstructor,
pos_encoding_type: tp.Type[PositionalEncodingBase] = LearnableInversePositionalEncoding,
) -> None:
super().__init__()

self.item_model: ItemNetConstructorBase
self.pos_encoding = pos_encoding_type(use_pos_emb, session_max_len, n_factors)
self.item_model = item_model
self.pos_encoding_layer = pos_encoding_layer
self.emb_dropout = torch.nn.Dropout(dropout_rate)
self.transformer_layers = transformer_layers_type(
n_blocks=n_blocks,
n_factors=n_factors,
n_heads=n_heads,
dropout_rate=dropout_rate,
)
self.transformer_layers = transformer_layers
self.use_causal_attn = use_causal_attn
self.use_key_padding_mask = use_key_padding_mask
self.n_factors = n_factors
self.dropout_rate = dropout_rate
self.n_heads = n_heads

self.item_net_block_types = item_net_block_types
self.item_net_constructor_type = item_net_constructor_type

def construct_item_net(self, dataset: Dataset) -> None:
"""
Construct network for item embeddings from dataset.

Parameters
----------
dataset : Dataset
RecTools dataset with user-item interactions.
"""
self.item_model = self.item_net_constructor_type.from_dataset(
dataset, self.n_factors, self.dropout_rate, self.item_net_block_types
)

def construct_item_net_from_dataset_schema(self, dataset_schema: DatasetSchema) -> None:
"""
Construct network for item embeddings from dataset schema.

Parameters
----------
dataset_schema : DatasetSchema
RecTools schema with dataset statistics.
"""
self.item_model = self.item_net_constructor_type.from_dataset_schema(
dataset_schema, self.n_factors, self.dropout_rate, self.item_net_block_types
)

@staticmethod
def _convert_mask_to_float(mask: torch.Tensor, query: torch.Tensor) -> torch.Tensor:
return torch.zeros_like(mask, dtype=query.dtype).masked_fill_(mask, float("-inf"))
Expand Down Expand Up @@ -169,7 +109,7 @@ def _merge_masks(
res = (
merged_mask.view(batch_size, 1, seq_len, seq_len)
.expand(-1, self.n_heads, -1, -1)
.view(-1, seq_len, seq_len)
.reshape(-1, seq_len, seq_len)
) # [batch_size * n_heads, session_max_len, session_max_len]
torch.diagonal(res, dim1=1, dim2=2).zero_()
return res
Expand Down Expand Up @@ -199,7 +139,7 @@ def encode_sessions(self, sessions: torch.Tensor, item_embs: torch.Tensor) -> to
timeline_mask = (sessions != 0).unsqueeze(-1) # [batch_size, session_max_len, 1]

seqs = item_embs[sessions] # [batch_size, session_max_len, n_factors]
seqs = self.pos_encoding(seqs)
seqs = self.pos_encoding_layer(seqs)
seqs = self.emb_dropout(seqs)

if self.use_causal_attn:
Expand Down
51 changes: 38 additions & 13 deletions rectools/models/nn/transformer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,14 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
self.fit_trainer: tp.Optional[Trainer] = None

def _init_data_preparator(self) -> None:
raise NotImplementedError()
self.data_preparator = self.data_preparator_type(
session_max_len=self.session_max_len,
n_negatives=self.n_negatives if self.loss != "softmax" else None,
batch_size=self.batch_size,
dataloader_num_workers=self.dataloader_num_workers,
train_min_user_interactions=self.train_min_user_interactions,
get_val_mask_func=self.get_val_mask_func,
)

def _init_trainer(self) -> None:
if self.get_trainer_func is None:
Expand All @@ -294,20 +301,38 @@ def _init_trainer(self) -> None:
else:
self._trainer = self.get_trainer_func()

def _init_torch_model(self) -> TransformerTorchBackbone:
return TransformerTorchBackbone(
def _construct_item_net(self, dataset: Dataset) -> ItemNetBase:
return self.item_net_constructor_type.from_dataset(
dataset, self.n_factors, self.dropout_rate, self.item_net_block_types
)

def _construct_item_net_from_dataset_schema(self, dataset_schema: DatasetSchema) -> ItemNetBase:
return self.item_net_constructor_type.from_dataset_schema(
dataset_schema, self.n_factors, self.dropout_rate, self.item_net_block_types
)

def _init_pos_encoding_layer(self) -> PositionalEncodingBase:
return self.pos_encoding_type(self.use_pos_emb, self.session_max_len, self.n_factors)

def _init_transformer_layers(self) -> TransformerLayersBase:
return self.transformer_layers_type(
n_blocks=self.n_blocks,
n_factors=self.n_factors,
n_heads=self.n_heads,
session_max_len=self.session_max_len,
dropout_rate=self.dropout_rate,
use_pos_emb=self.use_pos_emb,
)

def _init_torch_model(self, item_model: ItemNetBase) -> TransformerTorchBackbone:
pos_encoding_layer = self._init_pos_encoding_layer()
transformer_layers = self._init_transformer_layers()
return TransformerTorchBackbone(
n_heads=self.n_heads,
dropout_rate=self.dropout_rate,
item_model=item_model,
pos_encoding_layer=pos_encoding_layer,
transformer_layers=transformer_layers,
use_causal_attn=self.use_causal_attn,
use_key_padding_mask=self.use_key_padding_mask,
transformer_layers_type=self.transformer_layers_type,
item_net_block_types=self.item_net_block_types,
pos_encoding_type=self.pos_encoding_type,
item_net_constructor_type=self.item_net_constructor_type,
)

def _init_lightning_model(
Expand Down Expand Up @@ -339,8 +364,8 @@ def _fit(
train_dataloader = self.data_preparator.get_dataloader_train()
val_dataloader = self.data_preparator.get_dataloader_val()

torch_model = self._init_torch_model()
torch_model.construct_item_net(self.data_preparator.train_dataset)
item_model = self._construct_item_net(self.data_preparator.train_dataset)
torch_model = self._init_torch_model(item_model)

dataset_schema = self.data_preparator.train_dataset.get_schema()
item_external_ids = self.data_preparator.train_dataset.item_id_map.external_ids
Expand Down Expand Up @@ -440,8 +465,8 @@ def _model_from_checkpoint(cls, checkpoint: tp.Dict[str, tp.Any]) -> tpe.Self:
loaded.data_preparator._init_extra_token_ids() # pylint: disable=protected-access

# Init and update torch model and lightning model
torch_model = loaded._init_torch_model()
torch_model.construct_item_net_from_dataset_schema(dataset_schema)
item_model = loaded._construct_item_net_from_dataset_schema(dataset_schema)
torch_model = loaded._init_torch_model(item_model)
loaded._init_lightning_model(
torch_model=torch_model,
dataset_schema=dataset_schema,
Expand Down
Loading