In [1]:
# Импортируем необходимые модули и библиотеки
import os
from catboost import CatBoostClassifier, Pool, CatBoost
import pandas as pd
from sqlalchemy import create_engine, Column, Integer, String
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.ext.declarative import declarative_base
from typing import List
from fastapi import FastAPI, Depends 
from datetime import datetime
from pydantic import BaseModel

'''
ФУНКЦИИ ПО ЗАГРУЗКЕ МОДЕЛЕЙ
'''
# Проверка если код выполняется в лмс, или локально
def get_model_path(path: str) -> str:
    """Просьба не менять этот код"""
    if os.environ.get("IS_LMS") == "1":  # проверяем где выполняется код в лмс, или локально. Немного магии
        MODEL_PATH = 'catboost_model_1.cbm'
    else:
        MODEL_PATH = path
    return MODEL_PATH

class CatBoostWrapper(CatBoost):
    def predict_proba(self, X):
        return self.predict(X, prediction_type='Probability')

def load_models():
    model_path = get_model_path("c:\\Users\\Alex\\Desktop\\Repos\\Start_ML\\Final_project\\fast_api\\catboost_model_1.cbm")
    model = CatBoostWrapper()
    model.load_model(model_path)
    return model


'''
Получение данных из базы данных
'''

# Определяем функцию для получения данных из базы данных PostgreSQL
def batch_load_sql(query: str) -> pd.DataFrame:
    CHUNKSIZE = 200000
    engine = create_engine(
        "postgresql://robot-startml-ro:pheiph0hahj1Vaif@"
        "postgres.lab.karpov.courses:6432/startml"
    )
    conn = engine.connect().execution_options(stream_results=True)
    chunks = []
    for chunk_dataframe in pd.read_sql(query, conn, chunksize=CHUNKSIZE):
        chunks.append(chunk_dataframe)
    conn.close()
    return pd.concat(chunks, ignore_index=True)

def load_features() -> pd.DataFrame:
    query = "a-efimik_features_lesson_22_4"
    return batch_load_sql(query)

# Определяем переменные для подключения к базе данных
SQLALCHEMY_DATABASE_URL = "postgresql://robot-startml-ro:pheiph0hahj1Vaif@postgres.lab.karpov.courses:6432/startml"
engine = create_engine(SQLALCHEMY_DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()

# Определяем класс Post для работы с таблицей базы данных post
class Post(Base):
    __tablename__ = 'post'
    id = Column(Integer, primary_key=True)
    text = Column(String)
    topic = Column(String)

class PostGet(BaseModel):
    id: int
    text: str
    topic: str

    class Config:
        orm_mode = True

# Определяем функцию для получения сессии базы данных
def get_db():
    with SessionLocal() as db:
        return db





  Base = declarative_base()


In [2]:
model = load_models()

In [3]:
features = load_features()

### Here we prepare the data for the model

In [11]:
def caching_predictions(feed_data, model):
    # Get a list of unique users
    unique_users = feed_data['user_id'].unique()

    # Create a group ID based on the 'user_id'
    group_id_dict = {user_id: idx for idx, user_id in enumerate(unique_users)}
    feed_data['group_id'] = feed_data['user_id'].map(group_id_dict)

    feed_data_sorted = feed_data.sort_values(by='group_id')
    data_pool = Pool(feed_data_sorted.drop(columns=['user_id']), group_id=feed_data_sorted['group_id'])
    y_pred_proba = model.predict_proba(data_pool)[:, 1]
    feed_data['pred_proba'] = y_pred_proba
    top_5_posts = feed_data.groupby('user_id').apply(lambda x: x.nlargest(5, 'pred_proba')['post_id'])
    top_5_posts_dict = top_5_posts.reset_index().groupby('user_id')['post_id'].apply(list).to_dict()
    
    return top_5_posts_dict


top_5_posts_dict = caching_predictions(features, model)

In [12]:
top_5_posts_dict

{200: [1630, 1881, 1864, 5593, 6661],
 201: [1336, 821, 3371, 5514, 4150],
 202: [5208, 2418, 7014, 4730, 4231],
 203: [4541, 1795, 2062, 5490, 5423],
 204: [7002, 368, 5246, 2390, 6056],
 205: [1158, 3712, 2052, 226, 3099],
 206: [6042, 3608, 4022, 6030, 5419],
 207: [2823, 5541, 3062, 427, 5538],
 208: [5342, 681, 3775, 1803, 1166],
 209: [1560, 1347, 7053, 1718, 2611],
 210: [2376, 6307, 1150, 2993, 1907],
 211: [7069, 693, 1280, 4632, 4932],
 212: [5375, 2881, 5321, 4691, 7092],
 213: [5842, 6655, 2401, 1416, 1398],
 214: [1760, 2004, 7198, 371, 1746],
 215: [3449, 1086, 5073, 7116, 1869],
 216: [2529, 4065, 3875, 5327, 1700],
 217: [180, 2544, 4510, 567, 1131],
 218: [4444, 1677, 1167, 6971, 5576],
 219: [1230, 2980, 1257, 6694, 4708],
 220: [6379, 2370, 7064, 1849, 1677],
 221: [1792, 2037, 5221, 6329, 542],
 222: [6138, 4858, 1269, 767, 4752],
 223: [505, 4956, 4890, 3136, 1421],
 224: [1629, 3035, 1566, 1956, 3266],
 225: [5792, 2889, 4794, 3849, 1080],
 226: [6943, 286, 1269, 

In [13]:
top_5_posts_dict[200]

[1630, 1881, 1864, 5593, 6661]