In [2]:
# Импортируем необходимые модули и библиотеки
import os
from catboost import CatBoostClassifier, Pool, CatBoost
import pandas as pd
from sqlalchemy import create_engine, Column, Integer, String
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.ext.declarative import declarative_base
from typing import List
from fastapi import FastAPI, Depends 
from datetime import datetime
from pydantic import BaseModel

'''
ФУНКЦИИ ПО ЗАГРУЗКЕ МОДЕЛЕЙ
'''
# Проверка если код выполняется в лмс, или локально
def get_model_path(path: str) -> str:
    """Просьба не менять этот код"""
    if os.environ.get("IS_LMS") == "1":  # проверяем где выполняется код в лмс, или локально. Немного магии
        MODEL_PATH = 'catboost_model.cbm'
    else:
        MODEL_PATH = path
    return MODEL_PATH

class CatBoostWrapper(CatBoost):
    def predict_proba(self, X):
        return self.predict(X, prediction_type='Probability')

# Загрузка модели
def load_models():
    model_path = get_model_path("catboost_model.cbm")
    model = CatBoostWrapper()
    model.load_model(model_path)
    return model


'''
Получение данных из базы данных
'''

# Определяем функцию для получения данных из базы данных PostgreSQL
def batch_load_sql(query: str) -> pd.DataFrame:
    CHUNKSIZE = 200000
    engine = create_engine(
        "postgresql://robot-startml-ro:pheiph0hahj1Vaif@"
        "postgres.lab.karpov.courses:6432/startml"
    )
    conn = engine.connect().execution_options(stream_results=True)
    chunks = []
    for chunk_dataframe in pd.read_sql(query, conn, chunksize=CHUNKSIZE):
        chunks.append(chunk_dataframe)
    conn.close()
    return pd.concat(chunks, ignore_index=True)

def load_features() -> pd.DataFrame:
    query = "a-efimik_features_lesson_22_500MB"
    return batch_load_sql(query)

# Определяем переменные для подключения к базе данных
SQLALCHEMY_DATABASE_URL = "postgresql://robot-startml-ro:pheiph0hahj1Vaif@postgres.lab.karpov.courses:6432/startml"
engine = create_engine(SQLALCHEMY_DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()

# Определяем класс Post для работы с таблицей базы данных post
class Post(Base):
    __tablename__ = 'post'
    id = Column(Integer, primary_key=True)
    text = Column(String)
    topic = Column(String)

class PostGet(BaseModel):
    id: int
    text: str
    topic: str

    class Config:
        orm_mode = True

# Определяем функцию для получения сессии базы данных
def get_db():
    with SessionLocal() as db:
        return db





  Base = declarative_base()


In [3]:
model = load_models()

In [4]:
features = load_features()

In [5]:
features

Unnamed: 0,user_id,post_id_x,gender,age,country,city,exp_group,os,source,topic,...,component_2,component_3,component_4,component_5,component_6,component_7,component_8,component_9,component_10,post_id_y
0,43893,4645,0,34,7,1953,2,0,0,3,...,0.042752,-0.055726,-0.000477,-0.023751,0.047259,0.020530,-0.028873,-0.018328,-0.028018,4645
1,43893,5918,0,34,7,1953,2,0,0,3,...,0.091917,-0.009930,0.009773,-0.032072,-0.019178,0.065090,-0.039035,-0.008044,-0.005215,5918
2,43896,3619,1,18,7,3578,1,0,0,1,...,-0.029965,-0.084669,0.234959,-0.062218,-0.007342,0.007093,-0.050568,-0.022288,0.005569,3619
3,43896,182,1,18,7,3578,1,0,0,0,...,-0.148778,-0.040344,-0.000720,-0.013706,0.033206,-0.074778,0.019613,-0.013318,0.010663,182
4,43896,760,1,18,7,3578,1,0,0,2,...,-0.031280,0.038805,-0.002746,-0.002253,-0.005034,0.082962,-0.097861,-0.072218,0.048612,760
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1537873,43893,1468,0,34,7,1953,2,0,0,5,...,-0.107158,0.160502,0.040028,-0.087193,-0.114284,-0.108797,-0.124626,0.083079,0.031612,1468
1537874,43893,2152,0,34,7,1953,2,0,0,6,...,-0.049721,-0.011634,0.002557,-0.079772,0.123424,0.031392,0.036403,-0.018418,0.007575,2152
1537875,43893,1130,0,34,7,1953,2,0,0,4,...,-0.276711,-0.167969,-0.006701,0.152747,-0.158673,0.073473,0.109715,0.050098,-0.079709,1130
1537876,43893,2598,0,34,7,1953,2,0,0,1,...,-0.036207,-0.045127,0.187084,-0.070025,-0.000820,0.033582,-0.054700,-0.023181,0.001254,2598


In [13]:
def prepare_data_for_prediction(user_id: int, features: pd.DataFrame) -> Pool:
    # Получение фичей по пользователю
    X = features[features['user_id'] == user_id]

    # Получение уникальных user_id
    unique_user_ids = X['user_id'].unique()
    
    # Создание словаря с соответствием user_id и group_id
    group_id_dict = {user_id: idx for idx, user_id in enumerate(unique_user_ids)}
    
    # Добавление group_id в датафрейм
    X['group_id'] = X['user_id'].map(group_id_dict)
    
    # Сортировка по group_id
    X = X.sort_values(by='group_id')

    # Создание Pool
    data_pool = Pool(X.drop(columns=['user_id']), group_id=X['group_id'])

    return data_pool

# Prepare the data for prediction
data_pool = prepare_data_for_prediction(user_id = 43893, features = features)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  X['group_id'] = X['user_id'].map(group_id_dict)


In [11]:
# Make predictions using the trained model and data_pool
predictions = model.predict_proba(data_pool)[:, 1] # [:, 1] - вероятность класса 1

CatBoostError: C:/Program Files (x86)/Go Agent/pipelines/BuildMaster/catboost.git/catboost/libs/data/model_dataset_compatibility.cpp:81: At position 47 should be feature with name post_id_y (found group_id).

In [10]:
def get_recommendations(user_id: int = 43893, limit: int = 5):
    from sqlalchemy.orm import Session

    # Create a new session
    db = SessionLocal()

    

    # Add the predictions to the user_features dataframe
    user_features = features[features['user_id'] == user_id]
    user_features['proba'] = predictions

    # Select limit posts with the highest probabilities
    preds_id = user_features.sort_values(by='proba', ascending=False)['post_id_x'][:limit].values.tolist()
    preds = []

    # For each selected post, query the database to get its text and topic
    for i in preds_id:
        post = db.query(Post).filter(Post.id == i).one_or_none()
        preds.append(PostGet(id=post.id, text=post.text, topic=post.topic))
    
    # Close the session
    db.close()

    # Return the list of PostGet objects containing the text and topic of the selected articles
    return preds

# Test the function
print(get_recommendations(1, 5))


CatBoostError: C:/Program Files (x86)/Go Agent/pipelines/BuildMaster/catboost.git/catboost/libs/data/model_dataset_compatibility.cpp:81: At position 47 should be feature with name post_id_y (found group_id).