In [None]:
# baseline_with_faiss.py

import os
import time
import pickle
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
from collections import defaultdict
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import recall_score
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.amp import autocast, GradScaler
from tqdm import tqdm
import faiss

# ==============================
# 📁 Путь к данным и кэш
# ==============================
DATA_DIR = "DATA"
CACHE_DIR = "cache3"
os.makedirs(CACHE_DIR, exist_ok=True)

# Файлы кэша
ITEM_EMBEDDINGS_FILE = os.path.join(CACHE_DIR, "item_embeddings.pkl")
USER_EMBEDDINGS_FILE = os.path.join(CACHE_DIR, "user_embeddings.pkl")
PREDICTIONS_FILE = os.path.join(CACHE_DIR, "predictions.csv")
BEST_MODEL_PATH = os.path.join(CACHE_DIR, "best_model.pth")
SAMPLED_DATA_CACHE = os.path.join(CACHE_DIR, "sampled_grouped.pkl")
ENCODER_CACHE = os.path.join(CACHE_DIR, "label_encoders.pkl")

# ==============================
# 🔧 Гиперпараметры
# ==============================
IS_DEBUG = True
DEBUG_SAMPLE_PERCENT = 0.1
MAX_EPOCHS = 5
PATIENCE = 3
EMBED_DIM = 256
BATCH_SIZE_INFERENCE = 8192
BATCH_SIZE_ITEMS = 8192

# ==============================
# 🕒 Логирование с временем
# ==============================
import builtins
def tprint(*args, **kwargs):
    current_time = time.strftime("%Y-%m-%d %H:%M:%S")
    builtins.print(f"[{current_time}]", *args, **kwargs)
print = tprint

# ==============================
# 📦 Включаем отладку CUDA
# ==============================
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
torch.autograd.set_detect_anomaly(True)

# ==============================
# 🧠 Загрузка данных
# ==============================
def load_data():
    print("Загрузка данных...")
    data = {
        'clickstream': pq.read_table(os.path.join(DATA_DIR, "clickstream.pq")).to_pandas(),
        'cat_features': pq.read_table(os.path.join(DATA_DIR, "cat_features.pq")).to_pandas(),
        'events': pq.read_table(os.path.join(DATA_DIR, "events.pq")).to_pandas(),
        'test_users': pq.read_table(os.path.join(DATA_DIR, "test_users.pq")).to_pandas()
    }
    return data

# ==============================
# 🔍 Формирование пар (user, node)
# ==============================
def prepare_pairs(clickstream, events):
    print("Формирование пар (user, node)...")
    contact_events = events[events['is_contact'] == 1]['event'].unique()
    clickstream['is_contact'] = clickstream['event'].isin(contact_events).astype(int)
    grouped = clickstream.groupby(['cookie', 'node'], as_index=False)['is_contact'].sum()
    grouped['target'] = (grouped['is_contact'] > 0).astype(int)
    return grouped

# ==============================
# 🧪 Отладочный режим: выборка N%
# ==============================
def sample_data(grouped, fraction=DEBUG_SAMPLE_PERCENT):
    print(f"Отладка: оставляется {int(fraction * 100)}% данных...")
    users_sampled = grouped.sample(frac=fraction, random_state=42)['cookie'].unique()
    return grouped[grouped['cookie'].isin(users_sampled)]

# ==============================
# 🔢 Кодирование пользователей и товаров
# ==============================
def encode_user_item(grouped):
    print("Кодирование пользователей и товаров...")
    le_user = LabelEncoder()
    le_item = LabelEncoder()

    grouped['cookie'] = grouped['cookie'].fillna('unknown')
    grouped['node'] = grouped['node'].fillna('unknown')

    grouped['user_id'] = le_user.fit_transform(grouped['cookie'])
    grouped['item_id'] = le_item.fit_transform(grouped['node'])

    num_users = le_user.classes_.shape[0]
    num_items = le_item.classes_.shape[0]

    return grouped, num_users, num_items, le_user, le_item

# ==============================
# 🧮 Two-Tower модель
# ==============================
class TwoTower(nn.Module):
    def __init__(self, num_users, num_items, embed_dim=256):
        super().__init__()
        self.user_emb = nn.Embedding(num_users + 2, embed_dim)
        self.item_emb = nn.Embedding(num_items + 2, embed_dim)

        self.user_tower = nn.Sequential(
            nn.Linear(embed_dim, 512),
            nn.ReLU(),
            nn.LayerNorm(512),
            nn.Linear(512, embed_dim)
        )

        self.item_tower = nn.Sequential(
            nn.Linear(embed_dim, 512),
            nn.ReLU(),
            nn.LayerNorm(512),
            nn.Linear(512, embed_dim)
        )

    def forward(self, users, items):
        u = self.user_tower(self.user_emb(users))
        i = self.item_tower(self.item_emb(items))
        return torch.sum(u * i, dim=-1)

    def get_user_vector(self, users):
        return self.user_tower(self.user_emb(users))

    def get_item_vector(self, items):
        return self.item_tower(self.item_emb(items))

# ==============================
# 🏋️‍♂️ Обучение модели с валидацией
# ==============================
def train_model_with_validation(grouped, num_users, num_items):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Используется устройство: {device}")

    # Если модель уже есть — загружаем
    if os.path.exists(BEST_MODEL_PATH):
        print("Загрузка модели из кэша...")
        try:
            model = TwoTower(num_users, num_items, EMBED_DIM).to(device)
            model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=device))
            return model, device
        except Exception as e:
            print(f"Ошибка загрузки модели: {e}. Обучение с нуля.")

    model = TwoTower(num_users, num_items, EMBED_DIM).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
    criterion = nn.BCEWithLogitsLoss()
    scaler = GradScaler()

    # Подготовка данных
    train_data, val_data = grouped.copy(), grouped.copy()
    X_train = torch.tensor(train_data[['user_id', 'item_id']].values, dtype=torch.long)
    y_train = torch.tensor(train_data['target'].values, dtype=torch.float)
    X_val = torch.tensor(val_data[['user_id', 'item_id']].values, dtype=torch.long)
    y_val = torch.tensor(val_data['target'].values, dtype=torch.float)

    train_dataset = TensorDataset(X_train, y_train)
    val_dataset = TensorDataset(X_val, y_val)

    train_loader = DataLoader(train_dataset, batch_size=2048, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=2048, shuffle=False)

    best_metric = 0.0
    patience_counter = 0

    for epoch in range(MAX_EPOCHS):
        model.train()
        total_loss = 0
        with tqdm(train_loader, desc=f"Epoch {epoch+1}/{MAX_EPOCHS}") as pbar:
            for x_batch, y_batch in pbar:
                users = x_batch[:, 0].to(device)
                items = x_batch[:, 1].to(device)
                y_batch = y_batch.to(device)

                if users.max().item() >= num_users or items.max().item() >= num_items:
                    raise ValueError("⚠️ Индексы выходят за диапазон!")

                with autocast(device_type=device.type):
                    logits = model(users, items)
                    loss = criterion(logits, y_batch)

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)

                total_loss += loss.item()
                pbar.set_postfix(loss=loss.item())

        # --- Валидация ---
        model.eval()
        all_preds = []
        all_true = []

        with torch.no_grad():
            for x_batch, y_batch in val_loader:
                users = x_batch[:, 0].to(device)
                items = x_batch[:, 1].to(device)
                scores = model(users, items)
                probs = torch.sigmoid(scores) > 0.5
                all_preds.extend(probs.cpu())
                all_true.extend(y_batch.cpu())

        val_recall = recall_score(all_true, all_preds, average='binary')
        print(f"Epoch {epoch+1} | Loss: {total_loss / len(train_loader):.4f} | Val Recall@40: {val_recall:.4f}")

        # Early Stopping
        if val_recall > best_metric:
            best_metric = val_recall
            torch.save(model.state_dict(), BEST_MODEL_PATH)
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= PATIENCE:
            print("Early stopping triggered")
            break

    return model, device

# ==============================
# 👀 build_seen_nodes — просмотренные товары
# ==============================
def build_seen_nodes(clickstream):
    seen = defaultdict(set)
    for _, row in clickstream.iterrows():
        cookie = row['cookie']
        node = row['node']
        if pd.notna(cookie) and pd.notna(node):
            seen[cookie].add(str(node))  # Приведение к строке
    return seen

# ==============================
# 🧱 Ускоренная генерация item эмбеддингов
# ==============================
def compute_item_embeddings(model, device, all_nodes, le_item, num_items):
    print("Генерация item эмбеддингов по батчам...")
    valid_node_indices = []

    for node in all_nodes:
        if str(node) in le_item.classes_.astype(str):
            idx = le_item.transform([str(node)])[0]
            valid_node_indices.append(idx)
        else:
            valid_node_indices.append(num_items + 1)  # вне диапазона → исключено

    item_ids = torch.tensor(valid_node_indices, dtype=torch.long).to(device)
    embeddings_list = []

    with torch.no_grad():
        for i in range(0, len(item_ids), BATCH_SIZE_ITEMS):
            batch = item_ids[i:i+BATCH_SIZE_ITEMS]
            emb = model.get_item_vector(batch).cpu().numpy()
            embeddings_list.append(emb)
            print(f"Обработано {i + BATCH_SIZE_ITEMS} / {len(item_ids)} item'ов")

    item_embeddings = np.vstack(embeddings_list)
    with open(ITEM_EMBEDDINGS_FILE, 'wb') as f:
        pickle.dump(item_embeddings, f)
    print("Item эмбеддинги сохранены")
    return item_embeddings, [all_nodes[i] for i in valid_node_indices]

# ==============================
# 🧰 Ускоренный инференс через FAISS
# ==============================
def recommend_for_users_resumable(model, device, test_users, all_nodes, le_user, seen_dict, top_k=40):
    model.eval()
    
    # 1️⃣ Получаем item эмбеддинги
    if os.path.exists(ITEM_EMBEDDINGS_FILE):
        print("Загрузка item эмбеддингов из кэша...")
        with open(ITEM_EMBEDDINGS_FILE, 'rb') as f:
            item_embeddings = pickle.load(f)
    else:
        item_embeddings = compute_item_embeddings(model, device, all_nodes, le_item, num_items)

    # 2️⃣ Создаём FAISS index
    d = item_embeddings.shape[1]  # размерность эмбеддинга
    index = faiss.IndexFlatIP(d)  # inner product
    index.add(item_embeddings.astype(np.float32))

    # 3️⃣ Подготовка тестовых пользователей
    valid_cookies = []
    encoded_ids = []
    for _, row in test_users.iterrows():
        cookie = row['cookie']
        if cookie in le_user.classes_:
            try:
                encoded_id = int(le_user.transform([cookie])[0])
                encoded_ids.append(encoded_id)
                valid_cookies.append(str(cookie))
            except Exception as e:
                print(f"Ошибка кодирования cookie {cookie}: {e}")

    processed_cookies = set()
    predictions = []

    if os.path.exists(PREDICTIONS_FILE):
        df_prev = pd.read_csv(PREDICTIONS_FILE)
        predictions = df_prev.values.tolist()
        print(f"Загружено {len(df_prev)} записей из предыдущего запуска.")
        processed_cookies.update(df_prev['cookie'].astype(str).unique())

    remaining = [(c, e) for c, e in zip(valid_cookies, encoded_ids) if c not in processed_cookies]
    print(f"Осталось обработать: {len(remaining)} пользователей")

    start = time.time()
    for i in range(0, len(remaining), BATCH_SIZE_INFERENCE):
        batch = remaining[i:i+BATCH_SIZE_INFERENCE]
        cookies_batch, encoded_batch = zip(*batch)
        user_tensor = torch.LongTensor(encoded_batch).to(device)

        with torch.no_grad():
            user_vectors = model.get_user_vector(user_tensor).cpu().numpy()

        # 4️⃣ Поиск через FAISS
        D, I = index.search(user_vectors.astype(np.float32), top_k=top_k)

        for idx, cookie in enumerate(cookies_batch):
            ranked = [
                (all_nodes[I[idx][j]], float(D[idx][j]))
                for j in range(top_k)
                if str(all_nodes[I[idx][j]]) not in seen_dict.get(cookie, set())
            ]
            for node, score in ranked[:top_k]:
                predictions.append([str(node), str(cookie), score])
            processed_cookies.add(cookie)

        if (i // BATCH_SIZE_INFERENCE) % 10 == 0:
            pd.DataFrame(predictions, columns=['node', 'cookie', 'score']).to_csv(PREDICTIONS_FILE, index=False)
            elapsed = time.time() - start
            print(f"[{i + len(batch):>6}/{len(valid_cookies)}] сохранено... [Время: {elapsed:.2f} сек.]")

    print("Инференс завершён.")
    return pd.DataFrame(predictions, columns=['node', 'cookie', 'score'])

# ==============================
# 💾 Сохранение результата
# ==============================
def save_submission(df, path="submission.csv"):
    df[['node', 'cookie']].to_csv(path, index=False)
    print(f"Результат сохранён в {path}")

# ==============================
# 🚀 MAIN пайплайн
# ==============================
if __name__ == "__main__":
    try:
        t_start = time.time()

        # Шаг 1: Загрузка или восстановление из кэша
        if os.path.exists(SAMPLED_DATA_CACHE):
            print("Загрузка подготовленных данных из кэша...")
            with open(SAMPLED_DATA_CACHE, 'rb') as f:
                grouped = pickle.load(f)
            with open(ENCODER_CACHE, 'rb') as f:
                le_user, le_item = pickle.load(f)
            num_users = le_user.classes_.shape[0]
            num_items = le_item.classes_.shape[0]
        else:
            print("Начало полной подготовки данных...")
            data = load_data()
            grouped = prepare_pairs(data['clickstream'], data['events'])
            if IS_DEBUG:
                grouped = sample_data(grouped)
            grouped, num_users, num_items, le_user, le_item = encode_user_item(grouped)
            with open(SAMPLED_DATA_CACHE, 'wb') as f:
                pickle.dump(grouped, f)
            with open(ENCODER_CACHE, 'wb') as f:
                pickle.dump((le_user, le_item), f)
            print("Данные закодированы и сохранены в кэш")

        print(f"num_users: {num_users}, num_items: {num_items}")
        print(f"Максимальный user_id: {grouped['user_id'].max()}, Максимальный item_id: {grouped['item_id'].max()}")

        # Шаг 2: Обучение модели
        model, device = train_model_with_validation(grouped, num_users, num_items)

        # Шаг 3: Инференс с FAISS
        data = load_data()
        seen_dict = build_seen_nodes(data['clickstream'])

        submission = recommend_for_users_resumable(
            model=model,
            device=device,
            test_users=data['test_users'],
            all_nodes=data['cat_features']['node'].unique().astype(str),
            le_user=le_user,
            seen_dict=seen_dict
        )
        save_submission(submission)

        t_end = time.time()
        print(f"Общее время выполнения: {(t_end - t_start)/60:.2f} мин")

    except Exception as e:
        print(f"⚠️ Критическая ошибка в main: {e}")
        raise

[2025-05-04 04:01:30] Загрузка подготовленных данных из кэша...
[2025-05-04 04:01:31] num_users: 116846, num_items: 408278
[2025-05-04 04:01:31] Максимальный user_id: 116845, Максимальный item_id: 408277
[2025-05-04 04:01:31] Используется устройство: cuda
[2025-05-04 04:01:31] Загрузка модели из кэша...
[2025-05-04 04:01:32] Загрузка данных...
