In [1]:
# !pip install lightning

# Инференс

In [2]:
from bert4rec import Bert4Rec, Bert4RecModel
best_model = Bert4Rec.load_from_checkpoint('./epoch=92-step=1116.ckpt')

In [3]:
from prediction_callbacks import (
    SparkPredictionCallback,
    PandasPredictionCallback,
    TorchPredictionCallback,
    QueryEmbeddingsPredictionCallback,
)
from schema import (
    FeatureHint,
    FeatureInfo,
    FeatureSchema,
    FeatureSource,
    FeatureType,
)
import pandas as pd
import lightning as L
from lightning.pytorch.loggers import CSVLogger
from sequence_tokenizer import SequenceTokenizer
from postprocessors import RemoveSeenItems
from last_n_splitter import LastNSplitter
from sklearn.preprocessing import LabelEncoder
from data import Dataset, get_spark_session
from schema import (
    TensorFeatureSource,
    TensorSchema,
    TensorFeatureInfo)
from torch.utils.data import DataLoader
from dataset import (
    Bert4RecPredictionDataset,
)

MAX_SEQ_LEN = 100
BATCH_SIZE = 512
NUM_WORKERS = 4
# --------------------
spark_session = get_spark_session()
le = LabelEncoder()

interactions = pd.read_csv('../data/events.csv')
interactions[interactions['user_id']==0]

item_features = pd.read_csv('../data/item_features.csv')

user_features = pd.read_csv('../data/user_features.csv')
user_features['gender'] = le.fit_transform(user_features['gender'])

splitter = LastNSplitter(
    N=1,
    divide_column="user_id",
    query_column="user_id",
    strategy="interactions",
)

raw_test_events, raw_test_gt = splitter.split(interactions)
raw_validation_events, raw_validation_gt = splitter.split(raw_test_events)
raw_train_events = raw_validation_events
def prepare_feature_schema(is_ground_truth: bool) -> FeatureSchema:
    base_features = FeatureSchema(
        [
            FeatureInfo(
                column="user_id",
                feature_hint=FeatureHint.QUERY_ID,
                feature_type=FeatureType.CATEGORICAL,
            ),
            FeatureInfo(
                column="item_id",
                feature_hint=FeatureHint.ITEM_ID,
                feature_type=FeatureType.CATEGORICAL,
            ),
        ]
    )
    if is_ground_truth:
        return base_features

    all_features = base_features + FeatureSchema(
        [
            FeatureInfo(
                column="timestamp",
                feature_type=FeatureType.NUMERICAL,
                feature_hint=FeatureHint.TIMESTAMP,
            ),
        ]
    )
    return all_features
train_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=raw_train_events,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)

ITEM_FEATURE_NAME = "item_id_seq"

tensor_schema = TensorSchema(
    TensorFeatureInfo(
        name=ITEM_FEATURE_NAME,
        is_seq=True,
        feature_type=FeatureType.CATEGORICAL,
        feature_sources=[TensorFeatureSource(FeatureSource.INTERACTIONS, train_dataset.feature_schema.item_id_column)],
        feature_hint=FeatureHint.ITEM_ID,
        embedding_dim=300,
    )
)
tokenizer = SequenceTokenizer(tensor_schema, allow_collect_to_master=True)
tokenizer.fit(train_dataset)


# --------------------
validation_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=raw_validation_events,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)
validation_gt = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=True),
    interactions=raw_validation_gt,
    check_consistency=True,
    categorical_encoded=False,
)
test_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=raw_test_events,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)
test_gt = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=True),
    interactions=raw_test_gt,
    check_consistency=True,
    categorical_encoded=False,
)
test_query_ids = test_gt.query_ids
test_query_ids_np = tokenizer.query_id_encoder.transform(test_query_ids)["user_id"].values
sequential_test_dataset = tokenizer.transform(test_dataset).filter_by_query_id(test_query_ids_np)
TOPK = [1, 10, 20, 100]

postprocessors = [RemoveSeenItems(sequential_test_dataset)]

spark_prediction_callback = SparkPredictionCallback(
    spark_session=spark_session,
    top_k=max(TOPK),
    query_column="user_id",
    item_column="item_id",
    rating_column="score",
    postprocessors=postprocessors,
)

pandas_prediction_callback = PandasPredictionCallback(
    top_k=max(TOPK),
    query_column="user_id",
    item_column="item_id",
    rating_column="score",
    postprocessors=postprocessors,
)

torch_prediction_callback = TorchPredictionCallback(
    top_k=max(TOPK),
    postprocessors=postprocessors,
)
prediction_dataloader = DataLoader(
    dataset=Bert4RecPredictionDataset(
        sequential_test_dataset,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

csv_logger = CSVLogger(save_dir=".logs/test", name="Bert4Rec_example")

query_embeddings_callback = QueryEmbeddingsPredictionCallback()
csv_logger = CSVLogger(save_dir=".logs/train", name="Bert4Rec_example")
trainer = L.Trainer(
    callbacks=[
        spark_prediction_callback,
        pandas_prediction_callback,
        torch_prediction_callback,
        query_embeddings_callback,
    ],
    logger=csv_logger,
    inference_mode=True
)


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/10/25 22:23:58 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/10/25 22:23:58 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [4]:
#trainer goes brrrrrr
trainer.predict(best_model, dataloaders=prediction_dataloader, return_predictions=False)

spark_res = spark_prediction_callback.get_result()
pandas_res = pandas_prediction_callback.get_result()
torch_user_ids, torch_item_ids, torch_scores = torch_prediction_callback.get_result()
user_embeddings = query_embeddings_callback.get_result()

/Users/akzholba/Documents/RecSys/RecSysCompetition/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:419: Consider setting `persistent_workers=True` in 'predict_dataloader' to speed up the dataloader worker initialization.


Predicting DataLoader 0: 100%|██████████| 12/12 [01:24<00:00,  0.14it/s]


24/10/25 22:25:50 WARN SQLConf: The SQL config 'spark.sql.execution.arrow.enabled' has been deprecated in Spark v3.0 and may be removed in the future. Use 'spark.sql.execution.arrow.pyspark.enabled' instead of it.
24/10/25 22:25:50 WARN SQLConf: The SQL config 'spark.sql.execution.arrow.enabled' has been deprecated in Spark v3.0 and may be removed in the future. Use 'spark.sql.execution.arrow.pyspark.enabled' instead of it.
24/10/25 22:25:50 WARN SQLConf: The SQL config 'spark.sql.execution.arrow.enabled' has been deprecated in Spark v3.0 and may be removed in the future. Use 'spark.sql.execution.arrow.pyspark.enabled' instead of it.


In [5]:
recommendations = tokenizer.query_and_item_id_encoder.inverse_transform(spark_res)
recommendations.show()

                                                                                

+-----------------+-------+-------+
|            score|user_id|item_id|
+-----------------+-------+-------+
|9.312718391418457|      0|   3421|
|8.847206115722656|      0|   1422|
|8.750907897949219|      0|    434|
|8.747825622558594|      0|    708|
| 8.62032413482666|      0|   3025|
|8.587139129638672|      0|   2543|
|8.581725120544434|      0|   1461|
|8.497233390808105|      0|   2003|
| 8.42972469329834|      0|   1332|
|8.354842185974121|      0|   1287|
|8.291471481323242|      0|    566|
|8.272330284118652|      0|    827|
|8.268226623535156|      0|   2138|
|8.233556747436523|      0|   3460|
|8.218438148498535|      0|   2108|
|8.189332008361816|      0|    980|
| 8.14272403717041|      0|   2251|
|8.138208389282227|      0|   1250|
|7.928656101226807|      0|   1951|
|7.888411045074463|      0|   1128|
+-----------------+-------+-------+
only showing top 20 rows



In [6]:
# Преобразуем Spark DataFrame в Pandas DataFrame
pandas_df2 = recommendations.toPandas()

# Сохраняем результаты в CSV-файл
pandas_df2.to_csv('recommendations.csv', index=False)

def get_top_n(user_item_ratings, model_name, n=100):
    '''Функция возвращает топ-n фильмов для каждого пользователя'''

    # Сортируем данные по убыванию предсказанной оценки
    top_n = user_item_ratings.sort_values(model_name, ascending=False)

    # Оставляем только первые n строк для каждого пользователя
    top_n = top_n.groupby('user_id').head(n).reset_index(drop=True)

    return top_n
top_10_films = get_top_n(pandas_df2, 'score', n=10)[['user_id', 'item_id']]

# Экспортируем результат в CSV-файл
top_10_films.to_csv('top_10_films.csv', index=False)
def format_for_submission(df):
    # Группируем строки по user_id и соединяем item_id через пробел
    submission = (
        df
        .groupby('user_id')['item_id']
        .apply(lambda x: ' '.join(x.astype(str)))
        .reset_index()
    )

    return submission
submission = format_for_submission(top_10_films)
submission.to_csv('submission.csv', index=False)

  PyArrow >= 4.0.0 must be installed; however, it was not found.
Attempting non-optimization as 'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to true.
  warn(msg)
                                                                                

# Метрика

In [8]:
from utils_bert.recall_at_k import *
submission_file_path = 'submission.csv'
real_interactions_file_path = '../data/dataset_to_recall.csv'

dataset_for_recall = pd.read_csv(real_interactions_file_path)
submission_df = pd.read_csv(submission_file_path)

submission_df['item_id'] = submission_df['item_id'].apply(lambda x: x.split())
dataset_for_recall['last_10_interactions'] = dataset_for_recall['last_10_interactions'].apply(lambda x: x.split())

submission_df['y_real']  = dataset_for_recall['last_10_interactions']

result = recall_at_k_overall(submission_df, actual_col='y_real', predicted_col='item_id')
print(f"Recall@10 = {result:.4f}")

Recall@10 = 0.0208
