In [2]:
# -*- coding: utf-8 -*-
"""baseline_qwen.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1dvfywmCg9ng6d4tl9QvsRQuUhaiV489o
"""

import os
import time
import json
import pickle
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
from collections import defaultdict
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import recall_score
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.amp import autocast, GradScaler
from tqdm import tqdm  # ‚ùå –ù–µ –∏—Å–ø–æ–ª—å–∑—É–µ–º auto –¥–ª—è –∏–∑–±–µ–∂–∞–Ω–∏—è –≤–∏–¥–∂–µ—Ç–æ–≤

#from google.colab import drive
#drive.mount('/content/drive')

#DATA_DIR = "/content/drive/MyDrive/AvitoTechML25/Data"
#CACHE_DIR = "/content/drive/MyDrive/AvitoTechML25/cache"
DATA_DIR = "Data"
CACHE_DIR = "cache2"

os.makedirs(CACHE_DIR, exist_ok=True)

PREDICTIONS_FILE = os.path.join(CACHE_DIR, "predictions.csv")
ITEM_EMBEDDINGS_FILE = os.path.join(CACHE_DIR, "item_embeddings.pkl")
BEST_MODEL_PATH = os.path.join(CACHE_DIR, "best_model.pth")

# ==============================
# üîß –û—Å–Ω–æ–≤–Ω—ã–µ –ø–∞—Ä–∞–º–µ—Ç—Ä—ã
# ==============================
IS_DEBUG = True
DEBUG_SAMPLE_PERCENT = 0.1
PATIENCE = 3

# –í–∫–ª—é—á–∞–µ–º —Å–∏–Ω—Ö—Ä–æ–Ω–Ω—É—é —Ä–∞–±–æ—Ç—É —Å GPU
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
torch.autograd.set_detect_anomaly(True)

# ==============================
# üïí –õ–æ–≥–∏—Ä–æ–≤–∞–Ω–∏–µ —Å –≤—Ä–µ–º–µ–Ω–µ–º
# ==============================
import builtins
def tprint(*args, **kwargs):
    current_time = time.strftime("%Y-%m-%d %H:%M:%S")
    builtins.print(f"[{current_time}]", *args, **kwargs)
print = tprint

# ==============================
# üß† 1. –ó–∞–≥—Ä—É–∑–∫–∞ –¥–∞–Ω–Ω—ã—Ö
# ==============================
import pyarrow.parquet as pq
import os

def load_data(fraction=0.1):
    print("–ó–∞–≥—Ä—É–∑–∫–∞ 10% –¥–∞–Ω–Ω—ã—Ö...")

    def load_fraction(path, fraction):
        table = pq.read_table(path)
        num_rows = int(len(table) * fraction)
        return table.slice(0, num_rows).to_pandas()

    data = {
        'clickstream': load_fraction(os.path.join(DATA_DIR, "clickstream.pq"), fraction),
        'cat_features': load_fraction(os.path.join(DATA_DIR, "cat_features.pq"), fraction),
        'events': load_fraction(os.path.join(DATA_DIR, "events.pq"), fraction),
        'test_users': load_fraction(os.path.join(DATA_DIR, "test_users.pq"), fraction),
    }

    return data

t_start = time.time()
data = load_data()

# ==============================
# üîç 2. –ü–æ–¥–≥–æ—Ç–æ–≤–∫–∞ –ø–∞—Ä (user, node)
# ==============================
def prepare_pairs(clickstream, events):
    print("–§–æ—Ä–º–∏—Ä–æ–≤–∞–Ω–∏–µ –ø–∞—Ä (user, node)...")
    contact_events = events[events['is_contact'] == 1]['event'].unique()
    clickstream['is_contact'] = clickstream['event'].isin(contact_events).astype(int)
    grouped = clickstream.groupby(['cookie', 'node'], as_index=False)['is_contact'].sum()
    grouped['target'] = (grouped['is_contact'] > 0).astype(int)
    return grouped

grouped = prepare_pairs(data['clickstream'], data['events'])

# ==============================
# üß™ –û—Ç–ª–∞–¥–æ—á–Ω—ã–π —Ä–µ–∂–∏–º: –≤—ã–±–æ—Ä–∫–∞ –∏–∑ N%
# ==============================
def sample_data(grouped, fraction=DEBUG_SAMPLE_PERCENT):
    print(f"–û—Ç–ª–∞–¥–∫–∞: –æ—Å—Ç–∞–≤–ª—è–µ—Ç—Å—è {int(fraction * 100)}% –¥–∞–Ω–Ω—ã—Ö...")
    users_sampled = grouped.sample(frac=fraction, random_state=42)['cookie'].unique()
    return grouped[grouped['cookie'].isin(users_sampled)]

# ==============================
# üî¢ –ö–æ–¥–∏—Ä–æ–≤–∞–Ω–∏–µ –ø–æ–ª—å–∑–æ–≤–∞—Ç–µ–ª–µ–π –∏ —Ç–æ–≤–∞—Ä–æ–≤
# ==============================
def encode_user_item(grouped):
    print("–ö–æ–¥–∏—Ä–æ–≤–∞–Ω–∏–µ –ø–æ–ª—å–∑–æ–≤–∞—Ç–µ–ª–µ–π –∏ —Ç–æ–≤–∞—Ä–æ–≤...")
    le_user = LabelEncoder()
    le_item = LabelEncoder()

    # –û–±—Ä–∞–±–æ—Ç–∫–∞ –≤–æ–∑–º–æ–∂–Ω—ã—Ö NaN
    grouped['cookie'] = grouped['cookie'].fillna('unknown')
    grouped['node'] = grouped['node'].fillna('unknown')

    grouped['user_id'] = le_user.fit_transform(grouped['cookie'])
    grouped['item_id'] = le_item.fit_transform(grouped['node'])

    num_users = grouped['user_id'].nunique()
    num_items = grouped['item_id'].nunique()

    return grouped, num_users, num_items, le_user, le_item

if IS_DEBUG:
  grouped = sample_data(grouped)

grouped, num_users, num_items, le_user, le_item = encode_user_item(grouped)

print(f"num_users: {num_users}, num_items: {num_items}")
print(f"–ú–∞–∫—Å–∏–º–∞–ª—å–Ω—ã–π user_id: {grouped['user_id'].max()}, –ú–∞–∫—Å–∏–º–∞–ª—å–Ω—ã–π item_id: {grouped['item_id'].max()}")

# ==============================
# üß™ –†–∞–∑–¥–µ–ª–µ–Ω–∏–µ –Ω–∞ train/val
# ==============================
def create_train_val_split(grouped, val_size=0.2, random_state=42):
    print("–†–∞–∑–¥–µ–ª–µ–Ω–∏–µ –Ω–∞ train/val...")
    users = grouped['user_id'].unique()
    train_users, val_users = train_test_split(users, test_size=val_size, random_state=random_state)
    train_mask = grouped['user_id'].isin(train_users)
    val_mask = grouped['user_id'].isin(val_users)
    return grouped[train_mask], grouped[val_mask]

# ==============================
# üßÆ Two-Tower –º–æ–¥–µ–ª—å
# ==============================
class TwoTower(nn.Module):
    def __init__(self, num_users, num_items, embed_dim=256):
        super().__init__()
        self.user_emb = nn.Embedding(num_users + 2, embed_dim)
        self.item_emb = nn.Embedding(num_items + 2, embed_dim)

        self.user_tower = nn.Sequential(
            nn.Linear(embed_dim, 512),
            nn.ReLU(),
            nn.LayerNorm(512),
            nn.Linear(512, embed_dim)
        )

        self.item_tower = nn.Sequential(
            nn.Linear(embed_dim, 512),
            nn.ReLU(),
            nn.LayerNorm(512),
            nn.Linear(512, embed_dim)
        )

    def forward(self, users, items):
        u = self.user_tower(self.user_emb(users))
        i = self.item_tower(self.item_emb(items))
        return torch.sum(u * i, dim=-1)

    def get_user_vector(self, users):
        return self.user_tower(self.user_emb(users))

    def get_item_vector(self, items):
        return self.item_tower(self.item_emb(items))

# ==============================
# üèãÔ∏è‚Äç‚ôÇÔ∏è 3. –û–±—É—á–µ–Ω–∏–µ –º–æ–¥–µ–ª–∏ —Å –≤–∞–ª–∏–¥–∞—Ü–∏–µ–π
# ==============================
def train_model_with_validation(grouped, num_users, num_items):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"–ò—Å–ø–æ–ª—å–∑—É–µ—Ç—Å—è —É—Å—Ç—Ä–æ–π—Å—Ç–≤–æ: {device}")

    model = TwoTower(num_users, num_items).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
    criterion = nn.BCEWithLogitsLoss()

    # –ï—Å–ª–∏ –º–æ–¥–µ–ª—å —É–∂–µ –µ—Å—Ç—å ‚Äî –∑–∞–≥—Ä—É–∂–∞–µ–º
    if os.path.exists(BEST_MODEL_PATH):
        print("–ó–∞–≥—Ä—É–∑–∫–∞ –æ–±—É—á–µ–Ω–Ω–æ–π –º–æ–¥–µ–ª–∏ –∏–∑ –∫—ç—à–∞...")
        model.load_state_dict(torch.load(BEST_MODEL_PATH))
        return model, device

    # –ü–æ–¥–≥–æ—Ç–æ–≤–∫–∞ –¥–∞–Ω–Ω—ã—Ö
    train_data, val_data = create_train_val_split(grouped)
    X_train = torch.tensor(train_data[['user_id', 'item_id']].values, dtype=torch.long)
    y_train = torch.tensor(train_data['target'].values, dtype=torch.float)
    X_val = torch.tensor(val_data[['user_id', 'item_id']].values, dtype=torch.long)
    y_val = torch.tensor(val_data['target'].values, dtype=torch.float)

    train_dataset = TensorDataset(X_train, y_train)
    val_dataset = TensorDataset(X_val, y_val)

    train_loader = DataLoader(train_dataset, batch_size=2048, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=2048, shuffle=False)

    scaler = GradScaler()

    best_loss = float('inf')
    patience_counter = 0

    for epoch in range(5):
        model.train()
        total_loss = 0
        with tqdm(train_loader, desc=f"Epoch {epoch+1}/10") as pbar:
            for x_batch, y_batch in pbar:
                users = x_batch[:, 0].to(device)
                items = x_batch[:, 1].to(device)
                y_batch = y_batch.to(device)

                if users.max().item() >= num_users or items.max().item() >= num_items:
                    raise ValueError("‚ö†Ô∏è –ù–∞–π–¥–µ–Ω—ã user/item_id –≤–Ω–µ –¥–∏–∞–ø–∞–∑–æ–Ω–∞!")

                with autocast(device_type=device.type):
                    logits = model(users, items)
                    loss = criterion(logits, y_batch)

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)

                total_loss += loss.item()
                pbar.set_postfix(loss=loss.item())

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1} | Loss: {avg_loss:.4f}")

        # --- –í–∞–ª–∏–¥–∞—Ü–∏—è ---
        model.eval()
        all_preds = []
        all_true = []

        with torch.no_grad():
            for x_batch, y_batch in val_loader:
                users = x_batch[:, 0].to(device)
                items = x_batch[:, 1].to(device)
                scores = model(users, items)
                preds = (torch.sigmoid(scores) > 0.5).float()
                all_preds.extend(preds.cpu())
                all_true.extend(y_batch.cpu())

        val_recall = recall_score(all_true, all_preds, average='binary')
        print(f"Epoch {epoch+1} | Val Recall@40: {val_recall:.4f}")

        # --- Early Stopping ---
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), BEST_MODEL_PATH)
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= PATIENCE:
            print("Early stopping triggered")
            break

    print("–ú–æ–¥–µ–ª—å –æ–±—É—á–µ–Ω–∞ –∏ —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∞.")
    return model, device

model, device = train_model_with_validation(grouped, num_users, num_items)

# ==============================
# üëÄ build_seen_nodes ‚Äî —Å–±–æ—Ä –∏–Ω—Ñ–æ—Ä–º–∞—Ü–∏–∏ –æ –ø—Ä–æ—Å–º–æ—Ç—Ä–µ–Ω–Ω—ã—Ö –Ω–æ–¥–∞—Ö
# ==============================
def build_seen_nodes(clickstream):
    seen = defaultdict(set)
    for _, row in clickstream.iterrows():
        cookie = row['cookie']
        node = row['node']
        if pd.notna(cookie) and pd.notna(node):
            seen[cookie].add(str(node))  # –ü—Ä–∏–≤–µ–¥–µ–Ω–∏–µ –∫ —Å—Ç—Ä–æ–∫–µ –¥–ª—è —Å—Ç–∞–±–∏–ª—å–Ω–æ—Å—Ç–∏
    return seen
seen_dict = build_seen_nodes(data['clickstream'])

# ==============================
# üîÑ –†–µ–∫–æ–º–µ–Ω–¥–∞—Ü–∏—è —Å –∫—ç—à–∏—Ä–æ–≤–∞–Ω–∏–µ–º
# ==============================
def recommend_for_users_resumable(model, device, test_users, all_nodes, le_user, seen_dict, top_k=40):
    model.eval()
    try:
        all_nodes = [str(node) for node in all_nodes]
        item_ids = torch.arange(len(all_nodes), device=device)

        # --- Item —ç–º–±–µ–¥–¥–∏–Ω–≥–∏ ---
        if os.path.exists(ITEM_EMBEDDINGS_FILE):
            print("–ó–∞–≥—Ä—É–∑–∫–∞ item —ç–º–±–µ–¥–¥–∏–Ω–≥–æ–≤ –∏–∑ –∫—ç—à–∞...")
            with open(ITEM_EMBEDDINGS_FILE, 'rb') as f:
                item_embeddings = pickle.load(f)
        else:
            print("–í—ã—á–∏—Å–ª–µ–Ω–∏–µ item —ç–º–±–µ–¥–¥–∏–Ω–≥–æ–≤...")
            with torch.no_grad():
                item_embeddings = model.get_item_vector(item_ids).cpu().numpy()
            with open(ITEM_EMBEDDINGS_FILE, 'wb') as f:
                pickle.dump(item_embeddings, f)

        # --- –ü–æ–¥–≥–æ—Ç–æ–≤–∫–∞ —Ç–µ—Å—Ç–æ–≤—ã—Ö –ø–æ–ª—å–∑–æ–≤–∞—Ç–µ–ª–µ–π ---
        valid_cookies = []
        encoded_ids = []
        for _, row in test_users.iterrows():
            cookie = row['cookie']
            if cookie in le_user.classes_:
                try:
                    encoded_id = int(le_user.transform([cookie])[0])
                    encoded_ids.append(encoded_id)
                    valid_cookies.append(str(cookie))
                except Exception as e:
                    print(f"–û—à–∏–±–∫–∞ –∫–æ–¥–∏—Ä–æ–≤–∞–Ω–∏—è cookie {cookie}: {e}")

        processed_cookies = set()
        predictions = []

        # --- –í–æ–∑–æ–±–Ω–æ–≤–ª–µ–Ω–∏–µ –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–π ---
        if os.path.exists(PREDICTIONS_FILE):
            try:
                df_prev = pd.read_csv(PREDICTIONS_FILE)
                predictions = df_prev.values.tolist()
                print(f"–ó–∞–≥—Ä—É–∂–µ–Ω–æ {len(df_prev)} –∑–∞–ø–∏—Å–µ–π –∏–∑ –ø—Ä–µ–¥—ã–¥—É—â–µ–≥–æ –∑–∞–ø—É—Å–∫–∞.")
                processed_cookies.update(df_prev['cookie'].astype(str).unique())
            except Exception as e:
                print(f"–û—à–∏–±–∫–∞ –∑–∞–≥—Ä—É–∑–∫–∏ –ø—Ä–µ–¥—ã–¥—É—â–∏—Ö –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–π: {e}")

        remaining = [(c, e) for c, e in zip(valid_cookies, encoded_ids) if c not in processed_cookies]
        print(f"–û—Å—Ç–∞–ª–æ—Å—å –æ–±—Ä–∞–±–æ—Ç–∞—Ç—å: {len(remaining)} –ø–æ–ª—å–∑–æ–≤–∞—Ç–µ–ª–µ–π")

        BATCH_SIZE = 8192
        start = time.time()

        for i in range(0, len(remaining), BATCH_SIZE):
            batch = remaining[i:i+BATCH_SIZE]
            cookies_batch, encoded_batch = zip(*batch)
            user_tensor = torch.LongTensor(encoded_batch).to(device)

            with torch.no_grad():
                user_vectors = model.get_user_vector(user_tensor).cpu().numpy()

            scores = user_vectors @ item_embeddings.T

            for idx, cookie in enumerate(cookies_batch):
                ranked = [
                    (all_nodes[j], float(scores[idx, j]))
                    for j in np.argsort(-scores[idx])
                    if all_nodes[j] not in seen_dict.get(cookie, set())
                ]
                for node, score in ranked[:top_k]:
                    # ‚úÖ –¢–µ–ø–µ—Ä—å —Å–æ—Ö—Ä–∞–Ω—è–µ–º –≤ —Ñ–æ—Ä–º–∞—Ç–µ: node, cookie, score
                    predictions.append([str(node), str(cookie), score])
                processed_cookies.add(cookie)

            if (i // BATCH_SIZE) % 10 == 0:
                pd.DataFrame(predictions, columns=['node', 'cookie', 'score']).to_csv(PREDICTIONS_FILE, index=False)
                elapsed = time.time() - start
                print(f"[{i + len(batch):>6}/{len(valid_cookies)}] —Å–æ—Ö—Ä–∞–Ω–µ–Ω–æ... [–í—Ä–µ–º—è: {elapsed:.2f} —Å–µ–∫.]")

        print("–ò–Ω—Ñ–µ—Ä–µ–Ω—Å –∑–∞–≤–µ—Ä—à—ë–Ω.")
        return pd.DataFrame(predictions, columns=['node', 'cookie', 'score'])

    except Exception as e:
        print(f"–û—à–∏–±–∫–∞ —Ä–µ–∫–æ–º–µ–Ω–¥–∞—Ü–∏–∏: {e}")
        raise
        
submission = recommend_for_users_resumable(
            model=model,
            device=device,
            test_users=data['test_users'],
            all_nodes=data['cat_features']['node'].unique(),
            le_user=le_user,
            seen_dict=seen_dict,
            top_k=40
        )

submission['node'] = submission['node'].astype(int)
submission['cookie'] = submission['cookie'].astype(int)

# ==============================
# üíæ –°–æ—Ö—Ä–∞–Ω–µ–Ω–∏–µ —Ä–µ–∑—É–ª—å—Ç–∞—Ç–∞
# ==============================
def save_submission(df, path="submission.csv"):
    try:
        # –£–±–µ–¥–∏—Ç–µ—Å—å, —á—Ç–æ –∫–æ–ª–æ–Ω–∫–∏ –≤ –Ω—É–∂–Ω–æ–º –ø–æ—Ä—è–¥–∫–µ
        df[['node', 'cookie', 'score']].to_csv(path, index=False)
        print(f"–†–µ–∑—É–ª—å—Ç–∞—Ç —Å–æ—Ö—Ä–∞–Ω—ë–Ω –≤ {path}")
    except Exception as e:
        print(f"–û—à–∏–±–∫–∞ —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∏—è —Ä–µ–∑—É–ª—å—Ç–∞—Ç–æ–≤: {e}")
        raise

save_submission(submission)

[2025-05-01 23:42:04] –ó–∞–≥—Ä—É–∑–∫–∞ 10% –¥–∞–Ω–Ω—ã—Ö...
[2025-05-01 23:42:10] –§–æ—Ä–º–∏—Ä–æ–≤–∞–Ω–∏–µ –ø–∞—Ä (user, node)...
[2025-05-01 23:42:11] –û—Ç–ª–∞–¥–∫–∞: –æ—Å—Ç–∞–≤–ª—è–µ—Ç—Å—è 10% –¥–∞–Ω–Ω—ã—Ö...
[2025-05-01 23:42:11] –ö–æ–¥–∏—Ä–æ–≤–∞–Ω–∏–µ –ø–æ–ª—å–∑–æ–≤–∞—Ç–µ–ª–µ–π –∏ —Ç–æ–≤–∞—Ä–æ–≤...
[2025-05-01 23:42:12] num_users: 81549, num_items: 201806
[2025-05-01 23:42:12] –ú–∞–∫—Å–∏–º–∞–ª—å–Ω—ã–π user_id: 81548, –ú–∞–∫—Å–∏–º–∞–ª—å–Ω—ã–π item_id: 201805
[2025-05-01 23:42:12] –ò—Å–ø–æ–ª—å–∑—É–µ—Ç—Å—è —É—Å—Ç—Ä–æ–π—Å—Ç–≤–æ: cuda
[2025-05-01 23:42:12] –ó–∞–≥—Ä—É–∑–∫–∞ –æ–±—É—á–µ–Ω–Ω–æ–π –º–æ–¥–µ–ª–∏ –∏–∑ –∫—ç—à–∞...
[2025-05-01 23:44:19] –ó–∞–≥—Ä—É–∑–∫–∞ item —ç–º–±–µ–¥–¥–∏–Ω–≥–æ–≤ –∏–∑ –∫—ç—à–∞...
[2025-05-01 23:44:21] –ó–∞–≥—Ä—É–∂–µ–Ω–æ 217200 –∑–∞–ø–∏—Å–µ–π –∏–∑ –ø—Ä–µ–¥—ã–¥—É—â–µ–≥–æ –∑–∞–ø—É—Å–∫–∞.
[2025-05-01 23:44:21] –û—Å—Ç–∞–ª–æ—Å—å –æ–±—Ä–∞–±–æ—Ç–∞—Ç—å: 0 –ø–æ–ª—å–∑–æ–≤–∞—Ç–µ–ª–µ–π
[2025-05-01 23:44:21] –ò–Ω—Ñ–µ—Ä–µ–Ω—Å –∑–∞–≤–µ—Ä—à—ë–Ω.
[2025-05-01 23:44:21] –†–µ–∑—É–ª—å—