In [None]:
"""
Тестируем работоспособность модели с базой SQL 
Модель с эмбедингами нейросети  
"""

In [1]:
from fastapi import  FastAPI, Depends, HTTPException
from sqlalchemy.orm import Session
#from schema import PostGet
from typing import List
from json import dumps
import pandas as pd
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from datetime import datetime
from url import conn_uri

In [2]:
"""
Создадим классы для валидации выходных данных 
"""

from pydantic import BaseModel
class PostGet(BaseModel):
    id: int
    text: str
    topic: str

    class Config:
        orm_mode = True

In [3]:
SQLALCHEMY_DATABASE_URL = url

engine = create_engine(SQLALCHEMY_DATABASE_URL)

SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


In [4]:
import os
from catboost import CatBoostClassifier

def get_model_path(path: str) -> str:
    """
    Функция проверяет где работает модель - локально или на удаленной машине
    """
    if os.environ.get("IS_LMS") == "1":  # проверяем где выполняется код 
        MODEL_PATH = '/workdir/user_input/model'
    else:
        MODEL_PATH = path
    return MODEL_PATH

def load_models():
    """
    Загружаем модель
    """
    model_path = get_model_path("model_Cat_2000it_(2ml)_with_data_nn") #поменять путь в готовом проекте

    model = CatBoostClassifier()  # parameters not required.
    model.load_model(model_path, format='cbm') # пример как можно загружать модели

    return model


In [5]:
def batch_load_sql(query: str) -> pd.DataFrame:
    CHUNKSIZE = 200000
    engine = create_engine(SQLALCHEMY_DATABASE_URL)
    
    conn = engine.connect().execution_options(stream_results=True)
    chunks = []
    for chunk_dataframe in pd.read_sql(query, conn, chunksize=CHUNKSIZE):
        chunks.append(chunk_dataframe)
    conn.close()
    return pd.concat(chunks, ignore_index=True)



def load_features_user() -> pd.DataFrame:
    user = batch_load_sql('SELECT * FROM "shestov_user_lesson_22_v2"')
    return user


def load_features_post_new() -> pd.DataFrame:
    post_sort = batch_load_sql("""SELECT * from "shestov_post_lesson_22_v1.1"
     """)
    return post_sort



In [6]:
model_cat = load_models()


In [7]:
user_transformd_ = load_features_user().drop(['index'],axis=1)

In [8]:
post_sort = load_features_post_new().drop(['index'],axis=1)

In [9]:
app = FastAPI()#

In [10]:
#@app.get("/post/recommendations/", response_model=List[PostGet])
def recommended_posts(id: int, time: datetime, limit: int = 5) -> List[PostGet]:
    """
    Функция принимает на вход id  пользователя, время реакции пользователья на пост и лимит рекомендаций
    Возвращает пользователю рекомендации в количестве limit
    """
    result_list = []
    df_test = pd.merge(
        user_transformd_[user_transformd_['user_id']==id], 
        post_sort,how='cross').drop(['user_id'], axis =1).set_index(['post_id'])
    df_test['hour'] = pd.to_datetime(time).hour
    df_test['month'] = pd.to_datetime(time).month
    df_test['dayofweek'] = pd.to_datetime(time).dayofweek
    predict_pr = model_cat.predict_proba(df_test.drop(['text','topic'],axis=1))
    result = pd.DataFrame(predict_pr, index = df_test.index).drop([0], axis=1).sort_values(by =1 ,ascending = False)[:limit]
   
    for i in range(limit):
        id_ = int(result.index[-1+i])
        result_list.append(            
            
                {"id": id_ ,
                "text": str(df_test.loc[id_]['text']),
                "topic":str(df_test.loc[id_]['topic']) }
            
        )

    return result_list

In [11]:
rec = recommended_posts(id = 204, time = '01.01.2001', limit= 5)

In [12]:
rec

[{'id': 2882,
  'text': '#Yes to #Lockdown2 if #Covid19 gets out of control in #NewSouthWales, she said @AusChamberCEO.\n\nYou &amp; @thepmo are… https://t.co/xNj66t6sfH',
  'topic': 'covid'},
 {'id': 3830,
  'text': 'World tops 25 million #COVID19 cases;\nclosing in on 850,000 deaths.\nhttps://t.co/ojo2CypY9R',
  'topic': 'covid'},
 {'id': 4008,
  'text': 'COVID-19 CASES IN PH BREACH 143,000-MARK\n\nDOH reports 4,444 confirmed cases of #COVID19. As of 4 p.m., Aug. 12, 202… https://t.co/YzNc0D6fZk',
  'topic': 'covid'},
 {'id': 3939,
  'text': 'INTERNATIONAL and state #travelrestrictions due to #COVID19 are likely to redirect the 181 million domestic outboun… https://t.co/IohWbS9SV8',
  'topic': 'covid'},
 {'id': 3368,
  'text': '#Covid19 Positive cases reach 421 996. Covid-19 cases continue to rise rapidly across the country, as the so-called… https://t.co/95n2DEFU0a',
  'topic': 'covid'}]