In [1]:
import numpy as np
import pandas as pd
# import gymnasium as gym
import os

In [2]:
print(os.getcwd())
os.chdir('../../')
print(os.getcwd())

c:\Users\User\Desktop\music_rec_system\code\environment
c:\Users\User\Desktop\music_rec_system


In [3]:
from config import DATA_CLEANED_PATH, DATA_PREP_PATH, DATA_RAW_PATH

In [4]:
import pandas as pd
import numpy as np
from tqdm import tqdm
# Настройки
OUTPUT_PATH = 'data/prep/sessions_tied.parquet'
TIED_WINDOW = 2  # Окно событий для проверки (±2)

def clean_tied_events(row):
    events = np.array(row['event_type'])
    item_ids = np.array(row['item_ids'])
    played_ratios = np.array(row['played_ratio_pct']) if row['played_ratio_pct'] is not None else np.full(len(events), None)
    
    # Маска для сохранения: listen или tied non-listen (с совпадением item_id)
    keep_mask = np.zeros(len(events), dtype=bool)
    for i in range(len(events)):
        if events[i] == 'listen':
            keep_mask[i] = True
            continue
        # Проверяем окно на listen с ТЕМ ЖЕ item_id
        current_id = item_ids[i]
        window_start = max(0, i - TIED_WINDOW)
        window_end = min(len(events), i + TIED_WINDOW + 1)
        has_tied_listen = any(
            events[j] == 'listen' and item_ids[j] == current_id
            for j in range(window_start, window_end)
        )
        if has_tied_listen:
            keep_mask[i] = True
    
    # Если маска пустая — skip сессию
    if not np.any(keep_mask):
        return None
    
    # Фильтруем
    filtered = {
        'uid': row['uid'],
        'session_idx': row['session_idx'],
        'session_length': int(np.sum(keep_mask)),
        'item_ids': item_ids[keep_mask].tolist(),
        'played_ratio_pct': played_ratios[keep_mask].tolist() if played_ratios is not None else None,
        'event_type': events[keep_mask].tolist(),
    }
    return filtered

# Основной цикл
print("Загружаем сессии...")
sessions_df = pd.read_parquet(DATA_CLEANED_PATH)

cleaned_sessions = []
for _, row in tqdm(sessions_df.iterrows(), total=len(sessions_df), desc="Очистка сессий"):
    cleaned = clean_tied_events(row)
    if cleaned:
        cleaned_sessions.append(cleaned)

cleaned_df = pd.DataFrame(cleaned_sessions)
cleaned_df.to_parquet(OUTPUT_PATH, index=False, compression='snappy')

print(f"Готово! Новый файл: {OUTPUT_PATH}")
print(f"Сессий до: {len(sessions_df)}, после: {len(cleaned_df)}")
print(f"Средняя длина: {cleaned_df['session_length'].mean():.2f}")

Загружаем сессии...


Очистка сессий: 100%|██████████| 1662367/1662367 [02:18<00:00, 11991.44it/s]


Готово! Новый файл: data/prep/sessions_tied.parquet
Сессий до: 1662367, после: 1656954
Средняя длина: 24.61


In [5]:
import pandas as pd
import numpy as np

# Настройки
INPUT_PATH = OUTPUT_PATH
OUTPUT_PATH = 'data/prep/sessions_clipped.parquet'

def clip_played_ratios(row):
    played_ratios = row['played_ratio_pct']
    if played_ratios is None:
        return row  # Если вся колонка None — не трогаем
    
    # Конвертируем в numpy array, clip только not None
    ratios_array = np.array(played_ratios)
    # Маска для not None
    mask = ~pd.isnull(ratios_array)
    if np.any(mask):
        ratios_array[mask] = np.clip(ratios_array[mask], a_min=0, a_max=100)
    
    # Обновляем row
    updated_row = row.copy()
    updated_row['played_ratio_pct'] = ratios_array.tolist()
    return updated_row

# Основной цикл
print("Загружаем сессии...")
sessions_df = pd.read_parquet(INPUT_PATH)

updated_sessions = []
for _, row in tqdm(sessions_df.iterrows(), total=len(sessions_df), desc="Нормализация значений"):
    updated = clip_played_ratios(row.to_dict())
    if updated is not None:  # На случай, если skip (но здесь всегда сохраняем)
        updated_sessions.append(updated)

updated_df = pd.DataFrame(updated_sessions)
updated_df.to_parquet(OUTPUT_PATH, index=False, compression='snappy')

print(f"Готово! Новый файл: {OUTPUT_PATH}")
print(f"Сессий: {len(updated_df)}")
print(f"Пример первой строки: {updated_df.iloc[0]}")

Загружаем сессии...


Нормализация значений: 100%|██████████| 1656954/1656954 [01:59<00:00, 13913.64it/s]


Готово! Новый файл: data/prep/sessions_clipped.parquet
Сессий: 1656954
Пример первой строки: uid                                                            468600
session_idx                                                       121
session_length                                                     17
item_ids            [6307161, 132318, 3168387, 1939442, 4875792, 2...
played_ratio_pct    [1.0, 1.0, 1.0, 59.0, 1.0, 1.0, 1.0, 1.0, 1.0,...
event_type          [listen, listen, listen, listen, listen, liste...
Name: 0, dtype: object


In [7]:
print(updated_df['played_ratio_pct'].explode().unique())

[1.0 59.0 2.0 95.0 100.0 9.0 88.0 13.0 57.0 6.0 69.0 0.0 4.0 28.0 16.0 3.0
 98.0 99.0 89.0 74.0 35.0 92.0 64.0 5.0 93.0 90.0 50.0 32.0 91.0 24.0 25.0
 11.0 84.0 21.0 7.0 10.0 12.0 8.0 23.0 14.0 15.0 29.0 19.0 27.0 22.0 18.0
 47.0 81.0 52.0 67.0 46.0 30.0 51.0 26.0 83.0 82.0 73.0 20.0 37.0 31.0
 76.0 40.0 96.0 43.0 55.0 61.0 63.0 85.0 39.0 38.0 79.0 53.0 33.0 48.0
 49.0 36.0 44.0 97.0 nan 86.0 80.0 34.0 62.0 70.0 41.0 17.0 77.0 94.0 56.0
 66.0 68.0 78.0 60.0 54.0 45.0 72.0 75.0 71.0 87.0 58.0 42.0 65.0]


In [26]:
from config import DROPOUT_BATCH_RATE, BATCH_SIZE
def clean_small_sessions(row, dropout=DROPOUT_BATCH_RATE, batch=BATCH_SIZE):
    if row['session_length'] <= batch:
        return None
    elif row['session_length'] < 2*batch:
        if np.random.uniform(0, 1) < dropout:
            return None
    return row

In [27]:
from tqdm import tqdm

sessions_df = pd.read_parquet('data/prep/sessions_clipped.parquet')

cleaned_sessions=[]
for _, row in tqdm(sessions_df.iterrows(), total=len(sessions_df), desc="Очистка и фильтрация сессий"):
    cleaned = clean_small_sessions(row)
    if cleaned is not None:
        cleaned_sessions.append(cleaned.to_dict())

cleaned_df = pd.DataFrame(cleaned_sessions)
cleaned_df.to_parquet('data/prep/final_dataset.parquet', index=False, compression='snappy')
print(cleaned_df)




Очистка и фильтрация сессий: 100%|██████████| 1656954/1656954 [01:30<00:00, 18218.01it/s]


            uid  session_idx  session_length  \
0        468600          121              17   
1        671200           86              10   
2        580200          158              16   
3        408000         1963              13   
4        504900           22              11   
...         ...          ...             ...   
1434534  143700          387              13   
1434535  478800          147             131   
1434536  825700          413              26   
1434537  400300          543              28   
1434538  248500         1148              18   

                                                  item_ids  \
0        [6307161, 132318, 3168387, 1939442, 4875792, 2...   
1        [1615981, 2476958, 7710425, 1509660, 5676576, ...   
2        [871242, 6149293, 9004079, 4612877, 2118649, 7...   
3        [2698357, 6511451, 6511451, 6511451, 2698357, ...   
4        [678062, 872911, 7571390, 8614811, 1268565, 79...   
...                                                

In [None]:
pd.DataFrame({
    "uid": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
    "session_idx": [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000],

    "item_ids": [
        # 1. Длинная сессия с повторениями и tied actions
        [101, 102, 101, 103, 104, 105, 104, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118],
        # 2. 15 треков + like/dislike на повторяющихся
        [201, 202, 203, 202, 204, 205, 206, 205, 207, 208, 209, 210, 211, 212, 213],
        # 3. 12 треков + unlike/undislike
        [301, 302, 302, 304, 303, 305, 306, 307, 308, 309, 310, 311],
        # 4. 25 треков — длинная
        list(range(400, 425)),
        # 5. 8 треков — граничная
        [501, 502, 503, 504, 504, 504, 506, 507],
        # 6. 18 треков + много действий
        [601, 602, 603, 604, 605, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617],
        # 7. 10 треков
        [701, 702, 703, 704, 705, 706, 707, 708, 709, 710],
        # 8. 20 треков
        list(range(800, 820)),
        # 9. 16 треков
        [901, 902, 903, 904, 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916],
        # 10. 14 треков
        list(range(1000, 1014)),
    ],

    "played_ratio_pct": [
        # None везде, где не listen
        [85.0, 90.0, None, 70.0, 60.0, 95.0, None, 88.0, 45.0, 33.0, 77.0, 82.0, 91.0, 29.0, 66.0, 55.0, 99.0, 44.0, 78.0, 22.0],
        [78.0, 34.0, 91.0, None, 87.0, 66.0, 41.0, None, 93.0, 55.0, 77.0, 32.0, 88.0, 44.0, 81.0],
        [31.0, 18.0, None, 88.0, 67.0, 33.0, 19.0, 77.0, 22.0, 66.0, 28.0, 91.0],
        [np.random.uniform(20, 100) for _ in range(25)],
        [68.0, 33.0, 91.0, 27.0, None, None, 44.0, 77.0],
        [82.0, 39.0, 88.0, 45.0, 92.0, None, 33.0, 79.0, 61.0, 25.0, 88.0, 41.0, 95.0, 38.0, 77.0, 55.0, 83.0, 29.0],
        [75.0, 41.0, 88.0, 33.0, 77.0, 55.0, 92.0, 38.0, 81.0, 44.0],
        [np.random.uniform(30, 95) for _ in range(20)],
        [88.0, 44.0, 91.0, 33.0, 77.0, 55.0, 82.0, 39.0, 95.0, 28.0, 88.0, 41.0, 77.0, 55.0, 92.0, 36.0],
        [np.random.uniform(40, 98) for _ in range(14)],
    ],

    "event_type": [
        # listen + like/dislike/unlike/undislike — все tied!
        ['listen', 'listen', 'like', 'listen', 'listen', 'listen', 'dislike', 'listen', 'listen', 'listen', 'like', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen'],
        ['listen', 'listen', 'listen', 'like', 'listen', 'listen', 'listen', 'dislike', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen'],
        ['listen', 'listen', 'like', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen'],
        ['listen'] * 25,
        ['listen', 'listen', 'listen', 'listen', 'like', 'unlike', 'listen', 'listen'],
        ['listen', 'listen', 'listen', 'listen', 'like', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen'],
        ['listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen'],
        ['listen'] * 20,
        ['listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen'],
        ['listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen', 'listen']
    ],

    "session_length": [
        20, 15, 12, 25, 8, 18, 10, 20, 16, 14
    ]
})

In [71]:
import pandas as pd
import numpy as np
from tqdm import tqdm

# ================================
# НАСТРОЙКИ
# ================================
OUTPUT_PATH = 'final_dataset_v2.parquet'
BATCH_SIZE = 5
DROPOUT_BATCH_RATE = 0.6

# Правильный маппинг с накоплением
ACTION_TO_SCORE = {
    'like':       1.0,
    'dislike':   -1.0,
    'unlike':    -0.5,
    'undislike':   0.5,
    'listen':     0.0,   # listen не влияет на score
}

print("Загружаем датасет...")
df = pd.read_parquet('data/prep/final_dataset.parquet')
print(f"Исходных сессий: {len(df)}")

cleaned_sessions = []

for _, row in tqdm(df.iterrows(), total=len(df), desc="RL-ready с накоплением действий"):
    item_ids = np.array(row['item_ids'])
    ratios = np.array(row['played_ratio_pct']) if row['played_ratio_pct'] is not None else np.full(len(item_ids), None)
    events = np.array(row['event_type'])

    # Заменяем None в ratios на 0.5 (или можно среднее по сессии)
    ratios_clean = np.where(pd.isnull(ratios), 0.5, ratios).astype(float)
    ratios_clean = np.clip(ratios_clean / 100.0, 0.0, 1.0)

    # DataFrame для группировки
    temp_df = pd.DataFrame({
        'item_id': item_ids,
        'ratio': ratios_clean,
        'event': events
    })

    # Только tied действия (уже очищено ранее) — оставляем все
    # Группируем по item_id
    grouped = temp_df.groupby('item_id').agg({
        'ratio': 'mean',  # среднее прослушивание (или max, или last — как хочешь)
        'event': list
    }).reset_index()

    # Считаем накопительный score
    scores = []
    for event_list in grouped['event']:
        score = 0.0
        for e in event_list:
            if e in ACTION_TO_SCORE:
                score += ACTION_TO_SCORE[e]
        scores.append(score)

    grouped['action_score'] = scores

    # Финальные списки
    final_items = grouped['item_id'].tolist()
    final_ratios = grouped['ratio'].tolist()
    final_actions = grouped['action_score'].tolist()

    new_length = len(final_items)

    # Дропаут
    if new_length < BATCH_SIZE:
        continue
    elif new_length < 2 * BATCH_SIZE:
        if np.random.uniform() < DROPOUT_BATCH_RATE:
            continue

    cleaned_sessions.append({
        'uid': row['uid'],
        'session_idx': row['session_idx'],
        'session_length': new_length,
        'item_ids': final_items,
        'played_ratio_pct': final_ratios,
        'actions': final_actions,  # ← накопительные баллы!
    })

# Сохраняем
final_df = pd.DataFrame(cleaned_sessions)
final_df.to_parquet(OUTPUT_PATH, index=False, compression='snappy')

print("\n" + "="*60)
print("ГОТОВО! RL-READY v2 с НАКОПЛЕНИЕМ ДЕЙСТВИЙ")
print("="*60)
print(f"Сессий: {len(final_df):,} (из {len(df):,})")
print(f"Средняя длина: {final_df['session_length'].mean():.2f}")
print(f"Макс. action_score: {final_df['actions'].explode().astype(float).max():.2f}")
print(f"Мин. action_score: {final_df['actions'].explode().astype(float).min():.2f}")
print(f"Пример (с накоплением):")
example = final_df.iloc[0]
for i, item in enumerate(example['item_ids'][:10]):
    print(f"  {item}: ratio={example['played_ratio_pct'][i]:.2f}, actions → {example['actions'][i]:+.2f}")
print("="*60)
print(f"Сохранено → {OUTPUT_PATH}")

Загружаем датасет...
Исходных сессий: 1434539


RL-ready с накоплением действий: 100%|██████████| 1434539/1434539 [49:38<00:00, 481.64it/s] 



ГОТОВО! RL-READY v2 с НАКОПЛЕНИЕМ ДЕЙСТВИЙ
Сессий: 1,314,906 (из 1,434,539)
Средняя длина: 26.22
Макс. action_score: 4.50
Мин. action_score: -9.00
Пример (с накоплением):
  132318: ratio=0.01, actions → +0.00
  954649: ratio=0.01, actions → +0.00
  1939442: ratio=0.59, actions → +0.00
  2874245: ratio=0.01, actions → +0.00
  3168387: ratio=0.01, actions → +0.00
  4242136: ratio=0.01, actions → +0.00
  4319747: ratio=0.01, actions → +0.00
  4875792: ratio=0.01, actions → +0.00
  5150427: ratio=0.02, actions → +0.00
  5223476: ratio=0.02, actions → +0.00
Сохранено → final_dataset_v2.parquet


In [70]:
print(final_df['session_length'].iloc[0])
print(final_df['item_ids'].iloc[0])
print(final_df['played_ratio_pct'].iloc[0])
print(final_df['actions'].iloc[0])

18
[101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118]
[0.4275, 0.9, 0.7, 0.3025, 0.95, 0.88, 0.45, 0.33, 0.77, 0.82, 0.91, 0.29, 0.66, 0.55, 0.99, 0.44, 0.78, 0.22]
[1.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]


In [57]:
load_dotenv()

environment = MusicRecEnv(sessions_path='data/prep/final_dataset.parquet', qdrant_url=os.getenv('QDRANT_URL'))

In [59]:
environment.reset()

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

In [3]:
import pandas as pd
import os

print(os.getcwd())
# os.chdir('../../')

df = pd.read_parquet('data/prep/final_dataset_v2.parquet')
item_ids = df['item_ids'].explode().unique()
print(item_ids[:5])

c:\Users\User\Desktop\music_rec_system
[np.int64(132318) np.int64(954649) np.int64(1939442) np.int64(2874245)
 np.int64(3168387)]


In [7]:
embed = pd.read_parquet('data/raw/embeddings.parquet', columns=['item_id', 'normalized_embed'], filters=[('item_id', 'in', item_ids)])
embed.head()

Unnamed: 0,item_id,normalized_embed
0,26,"[0.04859397761569854, -0.011305588410678977, 0..."
1,43,"[-0.002297713108941854, -0.09647996650361734, ..."
2,50,"[-0.052600025461501755, 0.04867209823695481, 0..."
3,71,"[-0.1421526758688841, -0.10845793726932372, 0...."
4,81,"[-0.13562516940387961, -0.017089275237262717, ..."


In [9]:
print(len(embed))

762873


In [3]:
from dotenv import load_dotenv
import os
import pandas as pd
from qdrant_client import QdrantClient
print(os.getcwd())
os.chdir('../../')

load_dotenv()

print(os.getenv('QDRANT_URL'))

client = QdrantClient(
    url=os.getenv('QDRANT_URL'),
    api_key=os.getenv('QDRANT_API_KEY')
)

def get_by_point_id(point_id: int):
    points = client.retrieve(
        collection_name="yambda_50m",
        ids=[point_id],
        with_payload=False,
        with_vectors=True
    )
    return points[0] if points else None

# Использование
point = get_by_point_id(26)

c:\Users\User\Desktop\music_rec_system\code\environment
https://60266f02-4dd1-4ed8-82c0-801dc928f25d.eu-central-1-0.aws.cloud.qdrant.io:6333


In [5]:
print(point.vector)


[0.04859398, -0.011305588, 0.023160774, -0.045000304, -0.06427269, 0.04539275, -0.2367534, -0.02932359, 0.05037033, 0.01460231, 0.10112755, -0.1344197, -0.11328399, 0.060744457, 0.1893022, 0.07762245, -0.17071949, 0.043256227, 0.11297903, 0.122982815, -0.046959046, -0.11841896, -0.07622981, 0.10609183, -0.024477905, 0.036465302, -0.14199466, 0.062253356, -0.18444568, -0.06949321, 0.06848837, -0.12575838, 0.10475047, 0.122214906, -0.03863612, -0.029027352, -0.047815546, -0.12221988, 0.0005227094, -0.10080943, 0.07042479, 0.035289723, -0.030674824, 0.03330636, 0.024252098, 0.047780305, 0.13634527, -0.056022093, 0.0889867, -0.23868129, -0.11623513, -0.003988632, 0.08229693, 0.008950317, 0.07493132, 0.03597163, -0.13230133, -0.15976518, -0.12172516, 0.07115884, 0.029266864, -0.027016606, -0.08218977, 0.058771487, 0.047750324, 0.019146333, -0.100106135, 0.13820551, 0.07286545, -0.08536321, 0.039928734, -0.04043635, -0.039924886, -0.048481718, 0.15336001, 0.12683474, -0.045358777, -0.0973595

In [2]:
import os
print(os.getcwd())
os.chdir('../../')

c:\Users\User\Desktop\music_rec_system\code\environment


In [3]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import pandas as pd
from dotenv import load_dotenv
from qdrant_client import QdrantClient
from typing import Dict, Any
from config import DATA_READY_PATH, EMBED_DIM, COLLECTION_NAME, BATCH_SIZE




class MusicRecEnv(gym.Env):

    def __init__(
        self,
        qdrant_url: str,
        qdrant_api__key: str,
        sessions_path: str = DATA_READY_PATH,
        collection_name: str = COLLECTION_NAME,
        batch_size: int = BATCH_SIZE,
        embedding_dim: int = EMBED_DIM
    ):
        super().__init__()
        
        # === ДАННЫЕ ===
        self.sessions_df = pd.read_parquet(sessions_path)
        self.client = QdrantClient(
            url=qdrant_url,
            api_key=qdrant_api__key
        )
        self.collection_name = collection_name
        self.batch_size = batch_size
        self.embedding_dim = embedding_dim
        
        # === СПЕЙСЫ ===
        # Action: коррекция в 128D пространстве
        self.action_space = spaces.Box(
            low=-1.0, high=1.0, shape=(embedding_dim,), dtype=np.float32
        )
        
        # Observation: batch_size × (128 emb + time + action_score)
        state_size = batch_size * (embedding_dim + 2)
        self.observation_space = spaces.Box(
            low=-np.inf, high=np.inf, shape=(state_size,), dtype=np.float32
        )
        
        # === Состояние ===
        self.current_session = None #dataframe всей сессии
        self.session_pos = 0 #позици

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.current_session = self.sessions_df.sample(1, random_state=seed).iloc[0]
        self.session_pos = 0
        
        batch_ids = self.current_session['item_ids'][:self.batch_size]
        times = self.current_session['played_ratio_pct'][:self.batch_size]
        feedbacks = self.current_session['actions'][:self.batch_size]
        
        state = self._build_state(batch_ids, times, feedbacks)
        self.current_state = state
        return state, {}

    def step(self, action: np.ndarray):
        start_idx = self.session_pos + self.batch_size
        if start_idx >= self.current_session['session_length']:
            return self.current_state, 0.0, True, False, {}

        # === Извлекаем pos/neg из текущего батча ===
        pos_embs, neg_embs = [], []
        pos_w, neg_w = [], []
        
        for i in range(self.batch_size):
            emb = self.current_state[i*(self.embedding_dim+2):i*(self.embedding_dim+2)+self.embedding_dim]
            print(f"EMBEDDING LENGTH: {len(emb)}")
            time = self.current_state[i*(self.embedding_dim+2)+self.embedding_dim]
            print(f"TIME: {time}")
            fb = self.current_state[i*(self.embedding_dim+2)+self.embedding_dim+1]
            print(f"ACTION RAWARD: {fb}")
            
            if fb > 0 or time > 0.7:
                pos_embs.append(emb)
                pos_w.append(time + fb)
            if fb < 0 or time < 0.4:
                neg_embs.append(emb)
                neg_w.append((1 - time) + abs(fb))

        target_pos = np.average(pos_embs, axis=0, weights=pos_w) if pos_embs else np.zeros(self.embedding_dim)
        target_neg = np.average(neg_embs, axis=0, weights=neg_w) if neg_embs else np.zeros(self.embedding_dim)
        
        # === Коррекция от action ===
        final_target = target_pos - 0.5 * target_neg + 0.15 * action
        final_target = final_target / (np.linalg.norm(final_target) + 1e-8)

        # === Поиск в Qdrant ===
        results = self.client.search(
            collection_name=self.collection_name,
            query_vector=final_target.tolist(),
            limit=self.batch_size * 10,
            with_payload=True
        )
        
        seen = set(self.current_session['item_ids'][:start_idx + self.batch_size])
        new_batch = []
        for r in results:
            item_id = r.payload.get('item_id')
            if item_id and item_id not in seen and item_id not in new_batch:
                new_batch.append(item_id)
            if len(new_batch) == self.batch_size:
                break
        
        # fallback
        if len(new_batch) < self.batch_size:
            candidates = self.sessions_df['item_ids'].explode().unique()
            for cand in np.random.choice(candidates, size=100, replace=False):
                if cand not in seen and cand not in new_batch:
                    new_batch.append(cand)
                if len(new_batch) == self.batch_size:
                    break

        # === Фидбек из реальной сессии ===
        real_times = np.array(self.current_session['played_ratio_pct'][start_idx:start_idx+self.batch_size])
        real_actions = np.array(self.current_session['actions'][start_idx:start_idx+self.batch_size])
        
        reward = float(np.mean(real_times) + np.mean(real_actions))

        # === Новый state ===
        self.current_state = self._build_state(new_batch, real_times, real_actions)
        self.session_pos = start_idx
        
        done = (self.session_pos + self.batch_size >= self.current_session['session_length'])
        return self.current_state, reward, done, False, {}

    def _get_emb(self, item_id) -> np.ndarray:
        try:
            point = client.retrieve(
                collection_name=self.collection_name,
                ids=[int(item_id)],
                with_payload=False,
                with_vectors=True
            )
            return np.array(point[0].vector, dtype=np.float32)
        except Exception as e:
            print(f"EXCEPTION WHILE GETTING EMBEDDING: {e}")
            return np.zeros(self.embedding_dim, dtype=np.float32)
        

    def _build_state(self, batch_ids, times, feedbacks):
        state = []
        for i in range(self.batch_size):
            emb = self._get_emb(batch_ids[i])
            state.extend(emb)
            state.append(float(times[i]))
            state.append(float(feedbacks[i]))
        return np.array(state, dtype=np.float32)

In [None]:
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.logger import configure
import os

# Настройки
SESSIONS_PATH = "yambda_sessions_rl_ready_v2.parquet"  # Твой датасет
QDRANT_URL = "http://localhost:6333"  # Твой Qdrant
QDRANT_API_KEY = "your_key_here"  # Из .env
LOG_DIR = "./ppo_logs/"  # Для TensorBoard
MODEL_DIR = "./ppo_models/"
os.makedirs(LOG_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)

# Функция для создания env (для параллелизма)
def make_env():
    return MusicRecEnv(
        qdrant_url=QDRANT_URL,
        qdrant_api__key=QDRANT_API_KEY,
        sessions_path=SESSIONS_PATH,
        batch_size=5,
        embedding_dim=128
    )

# 8 параллельных env (для ускорения)
n_envs = 8
vec_env = SubprocVecEnv([make_env for _ in range(n_envs)])

# Создаём модель PPO
model = PPO(
    "MlpPolicy",  # MLP-нейросеть (можно CnnPolicy, если state как изображение)
    vec_env,      # Вектор env для параллелизма
    verbose=1,    # Печатать прогресс
    tensorboard_log=LOG_DIR,  # Логи для TensorBoard
    learning_rate=3e-4,       # Скорость обучения (маленькая для стабильности)
    n_steps=2048,             # Шаги на батч
    batch_size=256,           # Мини-батч
    n_epochs=10,              # Эпохи на данные
    gamma=0.99,               # Дисконт для будущих reward (длинные сессии)
    gae_lambda=0.95,          # Advantage estimation
    clip_range=0.2,           # PPO clip для стабильности
    ent_coef=0.01,            # Энтропия для exploration
    device="cuda" if torch.cuda.is_available() else "cpu"  # GPU если есть
)

# Callback для оценки (опционально)
eval_env = make_env()
eval_callback = EvalCallback(
    eval_env, best_model_save_path=MODEL_DIR,
    log_path=LOG_DIR, eval_freq=10_000,
    deterministic=True, render=False
)

# Обучение
print("СТАРТ ОБУЧЕНИЯ PPO!")
model.learn(
    total_timesteps=2_000_000,  # 2М шагов — ~6-12 часов на GPU
    callback=eval_callback,     # Сохраняет лучшую модель
    tb_log_name="ppo_music_v1"  # Имя для TensorBoard
)

# Сохранение финальной модели
model.save(MODEL_DIR + "ppo_music_final")
print("Модель сохранена!")

# Закрываем env
vec_env.close()

In [28]:
load_dotenv()

print(os.getenv('QDRANT_URL'))

environment = MusicRecEnv(os.getenv('QDRANT_URL'), os.getenv('QDRANT_API_KEY'))

https://60266f02-4dd1-4ed8-82c0-801dc928f25d.eu-central-1-0.aws.cloud.qdrant.io:6333


In [29]:
environment.reset()

(array([-1.20209590e-01,  3.00348122e-02, -6.40748888e-02, -7.41998926e-02,
        -4.45951037e-02, -1.84151769e-01, -8.50320086e-02,  4.32399176e-02,
        -4.63093892e-02, -4.05351855e-02, -1.59596615e-02,  8.93082246e-02,
        -9.35604200e-02, -1.05724230e-01, -1.62426401e-02,  5.73646016e-02,
         7.73822740e-02,  1.06718205e-01, -3.42969852e-03, -3.48600410e-02,
        -8.16980302e-02,  3.97368567e-03, -9.83161703e-02, -6.61197901e-02,
        -2.00565830e-02,  1.22501433e-01,  8.55014473e-03,  5.91160879e-02,
         9.06834379e-02,  7.70311654e-02, -3.73283401e-02, -8.92357342e-03,
        -2.38256708e-01,  1.02747202e-01,  2.47338470e-02, -5.54777384e-02,
        -8.69660079e-02,  2.72770636e-02,  7.87386075e-02,  8.27625021e-02,
         3.30131575e-02, -6.69487342e-02, -4.56252508e-02, -5.79202399e-02,
        -2.34274231e-02, -1.84014328e-02,  1.43721271e-02, -1.92199320e-01,
         9.07070637e-02,  1.35056255e-02, -8.20029974e-02,  5.88627607e-02,
         4.8

In [30]:
environment.step(np.ndarray([]))

EMBEDDING LENGTH: 128
TIME: 0.019999999552965164
ACTION RAWARD: 0.0
EMBEDDING LENGTH: 128
TIME: 0.10000000149011612
ACTION RAWARD: 0.0
EMBEDDING LENGTH: 128
TIME: 0.38999998569488525
ACTION RAWARD: 0.0
EMBEDDING LENGTH: 128
TIME: 0.9900000095367432
ACTION RAWARD: 0.0
EMBEDDING LENGTH: 128
TIME: 0.1899999976158142
ACTION RAWARD: 0.0


  results = self.client.search(


(array([-5.31583279e-02, -1.35798296e-02,  7.74330199e-02,  3.28604155e-03,
        -9.74206105e-02,  9.96894762e-02, -9.40124020e-02,  2.27793083e-02,
         7.63544515e-02, -5.96551783e-02, -5.23113739e-03, -1.63230281e-02,
        -3.09013091e-02, -9.80885848e-02, -9.41529050e-02,  6.80304095e-02,
         3.38732265e-02,  8.22616220e-02, -4.17902321e-02,  1.75932243e-01,
        -3.40861604e-02,  9.23346914e-03, -8.54981169e-02, -5.58488779e-02,
         6.69327658e-03,  1.18084483e-01, -5.23124039e-02,  6.18495978e-02,
        -1.01200275e-01,  2.06770953e-02,  4.75484319e-02, -2.10516416e-02,
        -2.61180848e-02,  2.12997630e-01,  6.42915741e-02, -1.09069698e-01,
        -6.61834553e-02, -1.32842362e-01,  7.95526281e-02,  3.57404910e-02,
         2.56592184e-01,  3.42040136e-02,  6.18433058e-02,  1.48254007e-01,
         5.52924685e-02, -4.96745110e-02,  4.23478000e-02,  1.15801670e-01,
        -7.85778090e-02,  1.30489632e-01, -7.00717121e-02,  1.29585117e-02,
         1.9