In [1]:
import torch

def format_pytorch_version(version):
    return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
    return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install -q torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-geometric
!pip install -q sentence-transformers==2.2.2

!pip install faiss-cpu
!pip install optuna

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m53.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.9/4.9 MB[0m [31m37.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m67.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m887.8/887.8 kB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.0/86.0 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.9/7.9 MB[0m [31m67.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m83.4 MB/s[0m eta

In [2]:
import numpy as np
import polars as pl
from tqdm import tqdm

from collections import defaultdict
from typing import List, Any

import faiss
import scipy.sparse as sp
from sklearn.model_selection import train_test_split

import torch
from torch_geometric.nn import Node2Vec, SAGEConv, LightGCN, to_hetero
from torch_geometric.data import Data, HeteroData
from torch_geometric.utils import degree
import torch_geometric.transforms as T

# будем использовать cuda, если доступны вычисления на gpu
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'device: {device}')

device: cpu


In [3]:
data = pl.read_parquet('train.parquet')
data

uid,friend_uid
i64,i64
93464,114312
93464,103690
93464,108045
93464,116128
93464,94113
93464,101668
93464,118820
93464,93617
93464,97587
93464,101941


Данные состоят из двух колонок:

- `uid` – идентификатор пользователя
- `friend_uid` – идентификатор друга этого пользователя

Нашей задачей будет порекомендовать возможных друзей, для оценки вашего решения будет использоваться метрика Recall@10, равная проценту верно угаданных друзей

In [13]:
TOP_K = 20
RANDOM_STATE = 42

SUBMISSION_PATH = 'submission.parquet'


def user_intersection(y_rel: List[Any], y_rec: List[Any], k: int = TOP_K) -> int:
    """
    :param y_rel: relevant items
    :param y_rec: recommended items
    :param k: number of top recommended items
    :return: number of items in intersection of y_rel and y_rec (truncated to top-K)
    """
    return len(set(y_rec[:k]).intersection(set(y_rel)))


def user_recall(y_rel: List[Any], y_rec: List[Any], k: int = TOP_K) -> float:
    """
    :param y_rel: relevant items
    :param y_rec: recommended items
    :param k: number of top recommended items
    :return: percentage of found relevant items through recommendations
    """
    return user_intersection(y_rel, y_rec, k) / min(k, len(set(y_rel)))

## Валидация

Так как у нас нет временной последовательности и рекомендации друзей не так сильно зависят от временной составляющей, в качестве валидации будем использовать случайно выбранные ребра в графе (при этом внутри каждого пользователя будет равная пропорция друзей в валидации, которую можно достичь с помощью stratify параметра)

In [5]:
friends_count = data.groupby('uid').count()
filtered_uid = set(friends_count.filter(pl.col('count') > 1)['uid'].to_list())
data = data.filter(pl.col('uid').is_in(filtered_uid))

train_df, test_df = train_test_split(
    data,
    stratify=data['uid'],
    test_size=0.1,
    random_state=RANDOM_STATE
)
train_df

uid,friend_uid
i64,i64
17595,49363
67176,70404
92795,92911
9989,92398
45017,118933
46520,100857
48723,109402
35872,71525
54936,68900
15052,111432


## TF-IDF

Простым бейзлайном было бы использовать количество общих друзей между двумя пользователями, который легко интерпретируется и не требует слишком больших вычислений. Также можно добавить дополнительные веса, в частности подойдет tf-idf, который старается уменьшить эффект от слишком популярных пользователей.

Впрочем, признаки могут пригодиться для следующего этапа: ранжирования

In [6]:
# соберем строчки для разреженной матрицы
rows = []
cols = []
values = []
for row_ind, col_ind in train_df.rows():
    rows.append(row_ind)
    cols.append(col_ind)
    values.append(1)

    rows.append(col_ind)
    cols.append(row_ind)
    values.append(1)

sparse_data = sp.csr_matrix((values, (rows, cols)))
sparse_data

<120061x120061 sparse matrix of type '<class 'numpy.int64'>'
	with 5163024 stored elements in Compressed Sparse Row format>

In [7]:
# количество друзей у каждого пользователя
friends_count = np.asarray(sparse_data.sum(axis=1)).squeeze(-1)

# IDF – обратное число встречаемости этого пользователя
idf = np.log(train_df['uid'].n_unique() / (friends_count + 1e-9))
idf = np.where(friends_count == 0, 0, idf)

sparse_data = sparse_data.multiply(sp.csr_matrix(idf))
sparse_data

<120061x120061 sparse matrix of type '<class 'numpy.float64'>'
	with 5163024 stored elements in Compressed Sparse Row format>

In [8]:
grouped_df = (
    test_df
    .groupby('uid')
    .agg(pl.col('friend_uid').alias('y_rel'))
    .join(
        train_df
        .groupby('uid')
        .agg(pl.col('friend_uid').alias('user_history')),
        'uid',
        how='left'
    )
)

recall_list = []
for user_id, y_rel, user_history in tqdm(grouped_df.rows()):
    similarities = sparse_data[user_id] @ sparse_data
    y_rec = np.argsort(-similarities.A[0])[1:][:TOP_K + len(user_history)]
    y_rec = [uid for uid in y_rec if uid not in user_history]
    recall_list.append(user_recall(y_rel, y_rec))

print(f'Recall@{TOP_K} = {np.mean(recall_list)}')

100%|██████████| 67970/67970 [04:06<00:00, 275.19it/s]


Recall@20 = 0.1896923452242527


### Строим рекомендации

In [29]:
# соберем строчки для разреженной матрицы
rows = []
cols = []
values = []
for row_ind, col_ind in data.rows():
    rows.append(row_ind)
    cols.append(col_ind)
    values.append(1)

    rows.append(col_ind)
    cols.append(row_ind)
    values.append(1)

sparse_data = sp.csr_matrix((values, (rows, cols)))
sparse_data

<120061x120061 sparse matrix of type '<class 'numpy.int64'>'
	with 5736694 stored elements in Compressed Sparse Row format>

In [30]:
# количество друзей у каждого пользователя
friends_count = np.asarray(sparse_data.sum(axis=1)).squeeze(-1)

# IDF – обратное число встречаемости этого пользователя
idf = np.log(train_df['uid'].n_unique() / (friends_count + 1e-9))
idf = np.where(friends_count == 0, 0, idf)

sparse_data = sparse_data.multiply(sp.csr_matrix(idf))
sparse_data

<120061x120061 sparse matrix of type '<class 'numpy.float64'>'
	with 5736694 stored elements in Compressed Sparse Row format>

In [31]:
sample_submission = pl.read_parquet('sample_submission.parquet')

In [32]:
grouped_df = (
    sample_submission.select('uid')
    .join(
        data
        .groupby('uid')
        .agg(pl.col('friend_uid').alias('user_history')),
        'uid',
        how='left'
    )
)

submission = []
for user_id, user_history in tqdm(grouped_df.rows()):
    user_history = [] if user_history is None else user_history

    similarities = sparse_data[user_id] @ sparse_data
    y_rec = np.argsort(-similarities.A[0])[1:][:TOP_K + len(user_history)]
    y_rec = [uid for uid in y_rec if uid not in user_history]

    submission.append((user_id, y_rec))

submission = pl.DataFrame(submission, schema=['user_id', 'y_recs'])
submission.write_parquet('tfidf_submission.parquet')
submission

  0%|          | 0/85483 [00:00<?, ?it/s]

user_id,y_recs
i64,list[i64]
0,"[68034, 40381, … 102729]"
1,"[90756, 101349, … 98545]"
3,"[70188, 2963, … 104228]"
4,"[38464, 41652, … 93734]"
5,"[18562, 60942, … 8860]"
6,"[66133, 103202, … 81952]"
7,"[102537, 98968, … 50881]"
8,"[29574, 55892, … 94312]"
9,"[10855, 29133, … 99734]"
10,"[37541, 18127, … 24697]"


## Graph ML

Так как в блоке были задания на тему графовых нейросетей, то я также приведу примеры их использования, однако результаты получаются значительно хуже, чем у "простых" подходов вроде tf-idf из-за слишком разреженных данных и отсутствия дополнительных признаков в данных.

In [9]:
edge_index = torch.from_numpy(train_df.to_numpy()).long()

num_nodes = edge_index.max().item() + 1
graph_data = Data(
    x=torch.arange(num_nodes),
    edge_index=edge_index.T.contiguous(),
    num_nodes=num_nodes
).to(device)

graph_data = T.ToUndirected()(graph_data)
graph_data.validate(raise_on_error=True)
graph_data

Data(x=[120061], edge_index=[2, 5163024], num_nodes=120061)

In [21]:
def get_recommendations(user_embs: np.array, item_embs: np.array, k: int = TOP_K):
    # строим индекс объектов
    index = faiss.IndexFlatIP(item_embs.shape[1])
    index.add(item_embs)

    # строим рекомендации с помощью dot-product расстояния
    return index.search(user_embs, k)


grouped_df = (
    test_df
    .groupby('uid')
    .agg(pl.col('friend_uid').alias('y_rel'))
    .join(
        train_df
        .groupby('uid')
        .agg(pl.col('friend_uid').alias('user_history')),
        'uid',
        how='left'
    )
)
grouped_df

uid,y_rel,user_history
i64,list[i64],list[i64]
9768,"[19105, 13261, … 39250]","[79126, 83720, … 47057]"
67472,"[107070, 118515, 100052]","[117655, 86808, … 95816]"
8512,[69784],"[86019, 65450, … 81639]"
61184,[67454],"[114142, 76185, … 111515]"
33840,[101458],"[61967, 63022, … 115693]"
20216,[69762],"[93920, 112494, … 70380]"
46616,[106196],"[92942, 82576, … 52513]"
60040,"[77944, 64329, 102902]","[112025, 111356, … 93687]"
8656,"[97808, 20407, … 115218]","[76099, 58768, … 70783]"
96128,[97373],"[106297, 113675, … 101360]"


In [22]:
median_seq_len = int(grouped_df['user_history'].apply(len).median())
print(f"среднее число uid в user_history: {median_seq_len}")

среднее число uid в user_history: 22


### Node2Vec

In [13]:
model = Node2Vec(
    graph_data.edge_index,
    embedding_dim=128,  # размер эмбеддинга вершины
    walk_length=10,  # длина случайного блуждания
    context_size=10,  # размер окна из случайного блуждания (как в w2v)
    walks_per_node=5,  # количество случайных блужданий из одной вершины
    num_negative_samples=5,  # количество негативных примеров на один позитивный
    p=9.0,  # параметр вероятности вернуться в предыдущую вершину
    q=7.5,  # параметр вероятности исследовать граф вглубь
    sparse=True,
).to(device)


n_epochs = 200

# класс Node2Vec предоставляет сразу генератор случайного блуждания
loader = model.loader(batch_size=1024, shuffle=True, num_workers=4)
optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, n_epochs, eta_min=1e-3
)


def train(model, loader, optimizer):
    model.train()
    total_loss = 0
    for pos_rw, neg_rw in tqdm(loader):
        # pos_rw – последовательность из случайного блуждания
        # neg_rw – случайные негативные примеры
        optimizer.zero_grad()
        loss = model.loss(pos_rw.to(device), neg_rw.to(device))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    scheduler.step()
    return total_loss / len(loader)


@torch.no_grad()
def test(model):
    model.eval()
    user_embs = model().cpu().detach().numpy()
    # найдем ближайшие по метрике dot-product пользователей в полученном векторном пространстве
    _, recs = get_recommendations(user_embs, user_embs, TOP_K + median_seq_len)

    recall_list = []
    for user_id, y_rel, user_history in tqdm(grouped_df.rows()):
        if user_id >= len(recs):
            continue

        y_rec = [uid for uid in recs[user_id] if uid not in user_history]
        recall_list.append(user_recall(y_rel, y_rec))
    return np.mean(recall_list)


for epoch in range(1, n_epochs + 1):
    loss = train(model, loader, optimizer)
    if epoch % 50 == 0:
        mean_recall = test(model)
        print(f'Epoch: {epoch:03d}, Loss: {loss:.3f}, Recall@{TOP_K}: {mean_recall:.4f}')
    else:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.3f}')

100%|██████████| 118/118 [00:03<00:00, 34.61it/s]


Epoch: 001, Loss: 8.853


100%|██████████| 118/118 [00:03<00:00, 34.60it/s]


Epoch: 002, Loss: 7.551


100%|██████████| 118/118 [00:03<00:00, 34.05it/s]


Epoch: 003, Loss: 6.361


100%|██████████| 118/118 [00:03<00:00, 33.97it/s]


Epoch: 004, Loss: 5.334


100%|██████████| 118/118 [00:03<00:00, 34.53it/s]


Epoch: 005, Loss: 4.482


100%|██████████| 118/118 [00:03<00:00, 34.57it/s]


Epoch: 006, Loss: 3.800


100%|██████████| 118/118 [00:03<00:00, 34.36it/s]


Epoch: 007, Loss: 3.246


100%|██████████| 118/118 [00:03<00:00, 33.70it/s]


Epoch: 008, Loss: 2.802


100%|██████████| 118/118 [00:03<00:00, 34.18it/s]


Epoch: 009, Loss: 2.439


100%|██████████| 118/118 [00:03<00:00, 34.61it/s]


Epoch: 010, Loss: 2.144


100%|██████████| 118/118 [00:03<00:00, 34.84it/s]


Epoch: 011, Loss: 1.909


100%|██████████| 118/118 [00:03<00:00, 34.24it/s]


Epoch: 012, Loss: 1.734


100%|██████████| 118/118 [00:03<00:00, 34.14it/s]


Epoch: 013, Loss: 1.604


100%|██████████| 118/118 [00:03<00:00, 34.85it/s]


Epoch: 014, Loss: 1.506


100%|██████████| 118/118 [00:03<00:00, 34.41it/s]


Epoch: 015, Loss: 1.433


100%|██████████| 118/118 [00:03<00:00, 34.09it/s]


Epoch: 016, Loss: 1.376


100%|██████████| 118/118 [00:03<00:00, 33.99it/s]


Epoch: 017, Loss: 1.331


100%|██████████| 118/118 [00:03<00:00, 34.70it/s]


Epoch: 018, Loss: 1.298


100%|██████████| 118/118 [00:03<00:00, 34.59it/s]


Epoch: 019, Loss: 1.272


100%|██████████| 118/118 [00:03<00:00, 34.28it/s]


Epoch: 020, Loss: 1.252


100%|██████████| 118/118 [00:03<00:00, 33.73it/s]


Epoch: 021, Loss: 1.236


100%|██████████| 118/118 [00:03<00:00, 34.24it/s]


Epoch: 022, Loss: 1.223


100%|██████████| 118/118 [00:03<00:00, 34.44it/s]


Epoch: 023, Loss: 1.213


100%|██████████| 118/118 [00:03<00:00, 34.61it/s]


Epoch: 024, Loss: 1.205


100%|██████████| 118/118 [00:03<00:00, 34.09it/s]


Epoch: 025, Loss: 1.199


100%|██████████| 118/118 [00:03<00:00, 33.89it/s]


Epoch: 026, Loss: 1.193


100%|██████████| 118/118 [00:03<00:00, 34.41it/s]


Epoch: 027, Loss: 1.188


100%|██████████| 118/118 [00:03<00:00, 34.52it/s]


Epoch: 028, Loss: 1.185


100%|██████████| 118/118 [00:03<00:00, 34.29it/s]


Epoch: 029, Loss: 1.181


100%|██████████| 118/118 [00:03<00:00, 33.54it/s]


Epoch: 030, Loss: 1.179


100%|██████████| 118/118 [00:03<00:00, 34.60it/s]


Epoch: 031, Loss: 1.176


100%|██████████| 118/118 [00:03<00:00, 34.42it/s]


Epoch: 032, Loss: 1.175


100%|██████████| 118/118 [00:03<00:00, 34.18it/s]


Epoch: 033, Loss: 1.173


100%|██████████| 118/118 [00:03<00:00, 33.54it/s]


Epoch: 034, Loss: 1.172


100%|██████████| 118/118 [00:03<00:00, 34.52it/s]


Epoch: 035, Loss: 1.171


100%|██████████| 118/118 [00:03<00:00, 34.81it/s]


Epoch: 036, Loss: 1.170


100%|██████████| 118/118 [00:03<00:00, 34.83it/s]


Epoch: 037, Loss: 1.170


100%|██████████| 118/118 [00:03<00:00, 33.49it/s]


Epoch: 038, Loss: 1.169


100%|██████████| 118/118 [00:03<00:00, 33.69it/s]


Epoch: 039, Loss: 1.169


100%|██████████| 118/118 [00:03<00:00, 34.65it/s]


Epoch: 040, Loss: 1.168


100%|██████████| 118/118 [00:03<00:00, 34.33it/s]


Epoch: 041, Loss: 1.168


100%|██████████| 118/118 [00:03<00:00, 34.27it/s]


Epoch: 042, Loss: 1.168


100%|██████████| 118/118 [00:03<00:00, 33.66it/s]


Epoch: 043, Loss: 1.168


100%|██████████| 118/118 [00:03<00:00, 34.35it/s]


Epoch: 044, Loss: 1.167


100%|██████████| 118/118 [00:03<00:00, 34.35it/s]


Epoch: 045, Loss: 1.167


100%|██████████| 118/118 [00:03<00:00, 34.43it/s]


Epoch: 046, Loss: 1.167


100%|██████████| 118/118 [00:03<00:00, 33.67it/s]


Epoch: 047, Loss: 1.167


100%|██████████| 118/118 [00:03<00:00, 34.45it/s]


Epoch: 048, Loss: 1.166


100%|██████████| 118/118 [00:03<00:00, 34.37it/s]


Epoch: 049, Loss: 1.166


100%|██████████| 118/118 [00:03<00:00, 34.55it/s]
100%|██████████| 67970/67970 [00:04<00:00, 15587.38it/s]


Epoch: 050, Loss: 1.166, Recall@20: 0.0341


100%|██████████| 118/118 [00:03<00:00, 34.45it/s]


Epoch: 051, Loss: 1.165


100%|██████████| 118/118 [00:03<00:00, 33.95it/s]


Epoch: 052, Loss: 1.165


100%|██████████| 118/118 [00:03<00:00, 33.99it/s]


Epoch: 053, Loss: 1.165


100%|██████████| 118/118 [00:03<00:00, 34.52it/s]


Epoch: 054, Loss: 1.165


100%|██████████| 118/118 [00:03<00:00, 34.40it/s]


Epoch: 055, Loss: 1.165


100%|██████████| 118/118 [00:03<00:00, 34.01it/s]


Epoch: 056, Loss: 1.164


100%|██████████| 118/118 [00:03<00:00, 33.48it/s]


Epoch: 057, Loss: 1.164


100%|██████████| 118/118 [00:03<00:00, 34.21it/s]


Epoch: 058, Loss: 1.163


100%|██████████| 118/118 [00:03<00:00, 34.41it/s]


Epoch: 059, Loss: 1.163


100%|██████████| 118/118 [00:03<00:00, 34.05it/s]


Epoch: 060, Loss: 1.162


100%|██████████| 118/118 [00:03<00:00, 33.22it/s]


Epoch: 061, Loss: 1.162


100%|██████████| 118/118 [00:03<00:00, 34.01it/s]


Epoch: 062, Loss: 1.161


100%|██████████| 118/118 [00:03<00:00, 34.19it/s]


Epoch: 063, Loss: 1.161


100%|██████████| 118/118 [00:03<00:00, 34.49it/s]


Epoch: 064, Loss: 1.161


100%|██████████| 118/118 [00:03<00:00, 33.90it/s]


Epoch: 065, Loss: 1.160


100%|██████████| 118/118 [00:03<00:00, 33.80it/s]


Epoch: 066, Loss: 1.160


100%|██████████| 118/118 [00:03<00:00, 34.48it/s]


Epoch: 067, Loss: 1.159


100%|██████████| 118/118 [00:03<00:00, 34.75it/s]


Epoch: 068, Loss: 1.158


100%|██████████| 118/118 [00:03<00:00, 34.51it/s]


Epoch: 069, Loss: 1.158


100%|██████████| 118/118 [00:03<00:00, 33.61it/s]


Epoch: 070, Loss: 1.158


100%|██████████| 118/118 [00:03<00:00, 34.49it/s]


Epoch: 071, Loss: 1.157


100%|██████████| 118/118 [00:03<00:00, 34.45it/s]


Epoch: 072, Loss: 1.157


100%|██████████| 118/118 [00:03<00:00, 34.37it/s]


Epoch: 073, Loss: 1.156


100%|██████████| 118/118 [00:03<00:00, 33.64it/s]


Epoch: 074, Loss: 1.156


100%|██████████| 118/118 [00:03<00:00, 34.35it/s]


Epoch: 075, Loss: 1.155


100%|██████████| 118/118 [00:03<00:00, 34.28it/s]


Epoch: 076, Loss: 1.155


100%|██████████| 118/118 [00:03<00:00, 34.28it/s]


Epoch: 077, Loss: 1.155


100%|██████████| 118/118 [00:03<00:00, 34.03it/s]


Epoch: 078, Loss: 1.154


100%|██████████| 118/118 [00:03<00:00, 33.74it/s]


Epoch: 079, Loss: 1.154


100%|██████████| 118/118 [00:03<00:00, 34.58it/s]


Epoch: 080, Loss: 1.153


100%|██████████| 118/118 [00:03<00:00, 34.44it/s]


Epoch: 081, Loss: 1.153


100%|██████████| 118/118 [00:03<00:00, 34.07it/s]


Epoch: 082, Loss: 1.152


100%|██████████| 118/118 [00:03<00:00, 33.27it/s]


Epoch: 083, Loss: 1.152


100%|██████████| 118/118 [00:03<00:00, 34.34it/s]


Epoch: 084, Loss: 1.152


100%|██████████| 118/118 [00:03<00:00, 34.50it/s]


Epoch: 085, Loss: 1.151


100%|██████████| 118/118 [00:03<00:00, 34.32it/s]


Epoch: 086, Loss: 1.150


100%|██████████| 118/118 [00:03<00:00, 33.64it/s]


Epoch: 087, Loss: 1.151


100%|██████████| 118/118 [00:03<00:00, 34.04it/s]


Epoch: 088, Loss: 1.149


100%|██████████| 118/118 [00:03<00:00, 34.34it/s]


Epoch: 089, Loss: 1.149


100%|██████████| 118/118 [00:03<00:00, 34.42it/s]


Epoch: 090, Loss: 1.149


100%|██████████| 118/118 [00:03<00:00, 33.68it/s]


Epoch: 091, Loss: 1.148


100%|██████████| 118/118 [00:03<00:00, 33.78it/s]


Epoch: 092, Loss: 1.148


100%|██████████| 118/118 [00:03<00:00, 34.43it/s]


Epoch: 093, Loss: 1.148


100%|██████████| 118/118 [00:03<00:00, 34.31it/s]


Epoch: 094, Loss: 1.147


100%|██████████| 118/118 [00:03<00:00, 34.16it/s]


Epoch: 095, Loss: 1.147


100%|██████████| 118/118 [00:03<00:00, 33.60it/s]


Epoch: 096, Loss: 1.146


100%|██████████| 118/118 [00:03<00:00, 34.13it/s]


Epoch: 097, Loss: 1.146


100%|██████████| 118/118 [00:03<00:00, 34.39it/s]


Epoch: 098, Loss: 1.145


100%|██████████| 118/118 [00:03<00:00, 34.33it/s]


Epoch: 099, Loss: 1.145


100%|██████████| 118/118 [00:03<00:00, 33.37it/s]
100%|██████████| 67970/67970 [00:04<00:00, 15543.87it/s]


Epoch: 100, Loss: 1.144, Recall@20: 0.0421


100%|██████████| 118/118 [00:03<00:00, 33.69it/s]


Epoch: 101, Loss: 1.144


100%|██████████| 118/118 [00:03<00:00, 34.17it/s]


Epoch: 102, Loss: 1.144


100%|██████████| 118/118 [00:03<00:00, 34.38it/s]


Epoch: 103, Loss: 1.143


100%|██████████| 118/118 [00:03<00:00, 34.04it/s]


Epoch: 104, Loss: 1.143


100%|██████████| 118/118 [00:03<00:00, 33.78it/s]


Epoch: 105, Loss: 1.142


100%|██████████| 118/118 [00:03<00:00, 34.27it/s]


Epoch: 106, Loss: 1.142


100%|██████████| 118/118 [00:03<00:00, 34.58it/s]


Epoch: 107, Loss: 1.142


100%|██████████| 118/118 [00:03<00:00, 33.96it/s]


Epoch: 108, Loss: 1.141


100%|██████████| 118/118 [00:03<00:00, 33.71it/s]


Epoch: 109, Loss: 1.140


100%|██████████| 118/118 [00:03<00:00, 34.18it/s]


Epoch: 110, Loss: 1.140


100%|██████████| 118/118 [00:03<00:00, 34.49it/s]


Epoch: 111, Loss: 1.140


100%|██████████| 118/118 [00:03<00:00, 34.18it/s]


Epoch: 112, Loss: 1.139


100%|██████████| 118/118 [00:03<00:00, 34.02it/s]


Epoch: 113, Loss: 1.139


100%|██████████| 118/118 [00:03<00:00, 34.40it/s]


Epoch: 114, Loss: 1.138


100%|██████████| 118/118 [00:03<00:00, 34.27it/s]


Epoch: 115, Loss: 1.138


100%|██████████| 118/118 [00:03<00:00, 33.93it/s]


Epoch: 116, Loss: 1.138


100%|██████████| 118/118 [00:03<00:00, 33.61it/s]


Epoch: 117, Loss: 1.138


100%|██████████| 118/118 [00:03<00:00, 33.99it/s]


Epoch: 118, Loss: 1.137


100%|██████████| 118/118 [00:03<00:00, 34.46it/s]


Epoch: 119, Loss: 1.136


100%|██████████| 118/118 [00:03<00:00, 34.11it/s]


Epoch: 120, Loss: 1.136


100%|██████████| 118/118 [00:03<00:00, 34.07it/s]


Epoch: 121, Loss: 1.136


100%|██████████| 118/118 [00:03<00:00, 34.14it/s]


Epoch: 122, Loss: 1.135


100%|██████████| 118/118 [00:03<00:00, 34.13it/s]


Epoch: 123, Loss: 1.135


100%|██████████| 118/118 [00:03<00:00, 34.16it/s]


Epoch: 124, Loss: 1.135


100%|██████████| 118/118 [00:03<00:00, 33.55it/s]


Epoch: 125, Loss: 1.134


100%|██████████| 118/118 [00:03<00:00, 34.43it/s]


Epoch: 126, Loss: 1.134


100%|██████████| 118/118 [00:03<00:00, 33.93it/s]


Epoch: 127, Loss: 1.134


100%|██████████| 118/118 [00:03<00:00, 34.18it/s]


Epoch: 128, Loss: 1.133


100%|██████████| 118/118 [00:03<00:00, 33.64it/s]


Epoch: 129, Loss: 1.133


100%|██████████| 118/118 [00:03<00:00, 34.34it/s]


Epoch: 130, Loss: 1.132


100%|██████████| 118/118 [00:03<00:00, 34.21it/s]


Epoch: 131, Loss: 1.132


100%|██████████| 118/118 [00:03<00:00, 34.06it/s]


Epoch: 132, Loss: 1.132


100%|██████████| 118/118 [00:03<00:00, 33.12it/s]


Epoch: 133, Loss: 1.131


100%|██████████| 118/118 [00:03<00:00, 33.72it/s]


Epoch: 134, Loss: 1.131


100%|██████████| 118/118 [00:03<00:00, 34.15it/s]


Epoch: 135, Loss: 1.131


100%|██████████| 118/118 [00:03<00:00, 34.43it/s]


Epoch: 136, Loss: 1.130


100%|██████████| 118/118 [00:03<00:00, 33.67it/s]


Epoch: 137, Loss: 1.130


100%|██████████| 118/118 [00:03<00:00, 34.28it/s]


Epoch: 138, Loss: 1.129


100%|██████████| 118/118 [00:03<00:00, 34.12it/s]


Epoch: 139, Loss: 1.129


100%|██████████| 118/118 [00:03<00:00, 34.28it/s]


Epoch: 140, Loss: 1.129


100%|██████████| 118/118 [00:03<00:00, 33.75it/s]


Epoch: 141, Loss: 1.129


100%|██████████| 118/118 [00:03<00:00, 34.44it/s]


Epoch: 142, Loss: 1.128


100%|██████████| 118/118 [00:03<00:00, 34.18it/s]


Epoch: 143, Loss: 1.128


100%|██████████| 118/118 [00:03<00:00, 34.08it/s]


Epoch: 144, Loss: 1.128


100%|██████████| 118/118 [00:03<00:00, 33.70it/s]


Epoch: 145, Loss: 1.127


100%|██████████| 118/118 [00:03<00:00, 34.58it/s]


Epoch: 146, Loss: 1.127


100%|██████████| 118/118 [00:03<00:00, 34.25it/s]


Epoch: 147, Loss: 1.127


100%|██████████| 118/118 [00:03<00:00, 34.15it/s]


Epoch: 148, Loss: 1.126


100%|██████████| 118/118 [00:03<00:00, 33.31it/s]


Epoch: 149, Loss: 1.126


100%|██████████| 118/118 [00:03<00:00, 34.30it/s]
100%|██████████| 67970/67970 [00:04<00:00, 15816.27it/s]


Epoch: 150, Loss: 1.126, Recall@20: 0.0559


100%|██████████| 118/118 [00:03<00:00, 33.97it/s]


Epoch: 151, Loss: 1.125


100%|██████████| 118/118 [00:03<00:00, 34.45it/s]


Epoch: 152, Loss: 1.125


100%|██████████| 118/118 [00:03<00:00, 33.47it/s]


Epoch: 153, Loss: 1.125


100%|██████████| 118/118 [00:03<00:00, 34.14it/s]


Epoch: 154, Loss: 1.125


100%|██████████| 118/118 [00:03<00:00, 34.38it/s]


Epoch: 155, Loss: 1.125


100%|██████████| 118/118 [00:03<00:00, 34.45it/s]


Epoch: 156, Loss: 1.124


100%|██████████| 118/118 [00:03<00:00, 33.89it/s]


Epoch: 157, Loss: 1.124


100%|██████████| 118/118 [00:03<00:00, 34.00it/s]


Epoch: 158, Loss: 1.124


100%|██████████| 118/118 [00:03<00:00, 34.41it/s]


Epoch: 159, Loss: 1.124


100%|██████████| 118/118 [00:03<00:00, 34.13it/s]


Epoch: 160, Loss: 1.123


100%|██████████| 118/118 [00:03<00:00, 33.88it/s]


Epoch: 161, Loss: 1.123


100%|██████████| 118/118 [00:03<00:00, 33.75it/s]


Epoch: 162, Loss: 1.123


100%|██████████| 118/118 [00:03<00:00, 34.53it/s]


Epoch: 163, Loss: 1.123


100%|██████████| 118/118 [00:03<00:00, 34.49it/s]


Epoch: 164, Loss: 1.122


100%|██████████| 118/118 [00:03<00:00, 33.84it/s]


Epoch: 165, Loss: 1.122


100%|██████████| 118/118 [00:03<00:00, 33.99it/s]


Epoch: 166, Loss: 1.122


100%|██████████| 118/118 [00:03<00:00, 34.15it/s]


Epoch: 167, Loss: 1.122


100%|██████████| 118/118 [00:03<00:00, 34.64it/s]


Epoch: 168, Loss: 1.122


100%|██████████| 118/118 [00:03<00:00, 33.87it/s]


Epoch: 169, Loss: 1.121


100%|██████████| 118/118 [00:03<00:00, 33.63it/s]


Epoch: 170, Loss: 1.121


100%|██████████| 118/118 [00:03<00:00, 34.21it/s]


Epoch: 171, Loss: 1.121


100%|██████████| 118/118 [00:03<00:00, 34.52it/s]


Epoch: 172, Loss: 1.121


100%|██████████| 118/118 [00:03<00:00, 33.97it/s]


Epoch: 173, Loss: 1.121


100%|██████████| 118/118 [00:03<00:00, 33.92it/s]


Epoch: 174, Loss: 1.121


100%|██████████| 118/118 [00:03<00:00, 34.34it/s]


Epoch: 175, Loss: 1.120


100%|██████████| 118/118 [00:03<00:00, 34.28it/s]


Epoch: 176, Loss: 1.120


100%|██████████| 118/118 [00:03<00:00, 33.99it/s]


Epoch: 177, Loss: 1.120


100%|██████████| 118/118 [00:03<00:00, 33.43it/s]


Epoch: 178, Loss: 1.120


100%|██████████| 118/118 [00:03<00:00, 34.42it/s]


Epoch: 179, Loss: 1.120


100%|██████████| 118/118 [00:03<00:00, 34.54it/s]


Epoch: 180, Loss: 1.120


100%|██████████| 118/118 [00:03<00:00, 34.24it/s]


Epoch: 181, Loss: 1.120


100%|██████████| 118/118 [00:03<00:00, 33.41it/s]


Epoch: 182, Loss: 1.119


100%|██████████| 118/118 [00:03<00:00, 34.13it/s]


Epoch: 183, Loss: 1.119


100%|██████████| 118/118 [00:03<00:00, 34.37it/s]


Epoch: 184, Loss: 1.119


100%|██████████| 118/118 [00:03<00:00, 34.14it/s]


Epoch: 185, Loss: 1.119


100%|██████████| 118/118 [00:03<00:00, 33.67it/s]


Epoch: 186, Loss: 1.119


100%|██████████| 118/118 [00:03<00:00, 34.28it/s]


Epoch: 187, Loss: 1.118


100%|██████████| 118/118 [00:03<00:00, 34.56it/s]


Epoch: 188, Loss: 1.119


100%|██████████| 118/118 [00:03<00:00, 34.14it/s]


Epoch: 189, Loss: 1.118


100%|██████████| 118/118 [00:03<00:00, 33.48it/s]


Epoch: 190, Loss: 1.118


100%|██████████| 118/118 [00:03<00:00, 34.36it/s]


Epoch: 191, Loss: 1.118


100%|██████████| 118/118 [00:03<00:00, 34.51it/s]


Epoch: 192, Loss: 1.118


100%|██████████| 118/118 [00:03<00:00, 34.42it/s]


Epoch: 193, Loss: 1.118


100%|██████████| 118/118 [00:03<00:00, 33.46it/s]


Epoch: 194, Loss: 1.118


100%|██████████| 118/118 [00:03<00:00, 34.08it/s]


Epoch: 195, Loss: 1.118


100%|██████████| 118/118 [00:03<00:00, 34.54it/s]


Epoch: 196, Loss: 1.117


100%|██████████| 118/118 [00:03<00:00, 34.14it/s]


Epoch: 197, Loss: 1.118


100%|██████████| 118/118 [00:03<00:00, 33.53it/s]


Epoch: 198, Loss: 1.117


100%|██████████| 118/118 [00:03<00:00, 34.23it/s]


Epoch: 199, Loss: 1.117


100%|██████████| 118/118 [00:03<00:00, 34.50it/s]
100%|██████████| 67970/67970 [00:04<00:00, 16361.95it/s]

Epoch: 200, Loss: 1.117, Recall@20: 0.0634





In [14]:
model.eval()
user_embs = model().cpu().detach().numpy()
_, recs = get_recommendations(user_embs, user_embs, TOP_K + median_seq_len)

In [19]:
grouped_df = (
    sample_submission.select('uid')
    .join(
        data
        .groupby('uid')
        .agg(pl.col('friend_uid').alias('user_history')),
        'uid',
        how='left'
    )
)

submission = []
for user_id, user_history in tqdm(grouped_df.rows()):
    user_history = [] if user_history is None else user_history

    y_rec = [uid for uid in recs[user_id] if uid not in user_history]
    submission.append((user_id, y_rec))

submission = pl.DataFrame(submission, schema=['user_id', 'y_recs'])
submission.write_parquet('node2vec_submission.parquet')
submission

100%|██████████| 85483/85483 [00:04<00:00, 18375.58it/s]


user_id,y_recs
i64,list[i64]
0,"[0, 74993, … 77084]"
1,"[1, 12009, … 96208]"
3,"[3, 70188, … 6058]"
4,"[4, 57466, … 65780]"
5,"[5, 51435, … 25570]"
6,"[39233, 78442, … 22478]"
7,"[7, 74993, … 15211]"
8,"[8, 43188, … 78442]"
9,"[9, 78442, … 86618]"
10,"[43188, 51435, … 30700]"


### LightGCN

Тут я привожу свою реализацию, которая использует разреженные матрицы, чтобы эффективно строить обучающую выборку. Детали реализации можно прочитать [тут](https://telegra.ph/LightGCN-Simplifying-and-Powering-Graph-Convolution-Network-for-Recommendation-07-28). Тем не менее, в задаче рекомендации друзей объекты излишне и можно уменьшить число параметров, попробуйте это сделать в качестве практики

In [25]:
from collections import defaultdict

import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
from scipy.sparse import csr_matrix
from sklearn.metrics import roc_auc_score
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm


class LightGCNDataset(Dataset):
    def __init__(
            self,
            user_items: csr_matrix,
            n_negatives: int = 10,
            ns_exponent: float = 0.0,
            verify_negative_samples: bool = True):
        self.n_negatives = n_negatives
        self.verify_negative_samples = verify_negative_samples

        self.n_users = user_items.shape[0]
        self.m_items = user_items.shape[1]
        n_nodes = self.n_users + self.m_items

        user_items_coo = user_items.tocoo()
        self.unique_users = user_items_coo.row

        self.positive_items = defaultdict(list)
        for user_id, item_id in zip(user_items_coo.row, user_items_coo.col):
            self.positive_items[user_id].append(item_id)

        # разреженная матрица формируется из прямых и обратных ребер
        tmp_adj = sp.csr_matrix((user_items_coo.data, (user_items_coo.row, user_items_coo.col + self.n_users)),
                                shape=(n_nodes, n_nodes))
        adj_mat = tmp_adj + tmp_adj.T

        # нормируем матрицу
        rowsum = np.array(adj_mat.sum(1))
        d_inv = np.power(rowsum, -0.5).flatten()
        d_inv[np.isinf(d_inv)] = 0.
        d_mat_inv = sp.diags(d_inv)

        # нормируем на число взаимодействий по пользователям
        norm_adj_tmp = d_mat_inv.dot(adj_mat)
        # нормируем на число взаимодействий по объектам
        normalized_adj_matrix = norm_adj_tmp.dot(d_mat_inv)

        # преобразуем scipy.sparse -> torch.sparse
        adj_mat_coo = normalized_adj_matrix.tocoo()

        values = adj_mat_coo.data
        indices = np.vstack((adj_mat_coo.row, adj_mat_coo.col))

        i = torch.LongTensor(indices)
        v = torch.FloatTensor(values)
        shape = adj_mat_coo.shape

        self.adj_matrix = torch.sparse_coo_tensor(i, v, torch.Size(shape))

        # ns_exponent – аналогично использованию в w2v
        self.neg_probs = rowsum.flatten()[:self.m_items] ** ns_exponent
        self.neg_probs[np.isnan(self.neg_probs)] = 0.
        self.neg_probs[np.isinf(self.neg_probs)] = 0.
        self.neg_probs /= self.neg_probs.sum()

        self.random_items_buffer = None
        self.random_items_buffer_pointer = 0

    def get_user_positives(self, user):
        return self.positive_items[user]

    def get_random_item(self):
        if self.random_items_buffer is None or self.random_items_buffer_pointer == len(self.random_items_buffer):
            self.random_items_buffer = np.random.choice(np.arange(self.m_items), 10_000, p=self.neg_probs)
            self.random_items_buffer_pointer = 0

        ret = self.random_items_buffer[self.random_items_buffer_pointer]
        self.random_items_buffer_pointer += 1
        return ret

    def get_user_negatives(self, user, k=10):
        neg = []
        positives = set(self.get_user_positives(user)) if self.verify_negative_samples else []
        while len(neg) < k:
            candidate = self.get_random_item()
            if not self.verify_negative_samples or \
                    self.verify_negative_samples and candidate not in positives:
                neg.append(candidate)
        return neg

    def get_sparse_graph(self):
        """
        Returns a graph in torch.sparse_coo_tensor.
        A = |0,   R|
            |R^T, 0|
        """
        return self.adj_matrix

    def __len__(self):
        return len(self.unique_users)

    def __getitem__(self, idx):
        """
        returns user, pos_items, neg_items

        :param idx: index of user from unique_users
        :return:
        """
        user = self.unique_users[idx]
        pos = np.random.choice(self.get_user_positives(user), self.n_negatives)
        neg = self.get_user_negatives(user, self.n_negatives)
        return user, pos, neg


def collate_function(batch):
    users = []
    pos_items = []
    neg_items = []
    for user, pos, neg in batch:
        users.extend([user for _ in pos])
        pos_items.extend(pos)
        neg_items.extend(neg)
    return list(map(torch.tensor, [users, pos_items, neg_items]))


class LightGCN(nn.Module):
    def __init__(
            self,
            learning_rate: float = 0.01,
            regularization: float = 0.01,
            batch_size: int = 128,
            factors: int = 100,
            n_negatives: int = 10,
            iterations: int = 100,
            n_layers: int = 2,
            ns_exponent: float = 0.0,
            verify_negative_samples: bool = True,
            calculate_training_roc_auc: bool = True
    ):
        """
        :param learning_rate: float, optional
            The learning rate to apply for SGD updates during training
        :param regularization: float, optional
            The regularization factor to use
        :param batch_size: int, optional
            Size of the batch used in training
        :param factors: int, optional
            The number of latent factors to compute
        :param n_negatives: int, optional
            The number of negative candidates in sampling
        :param iterations: int, optional
            The number of training epochs to use when fitting the data
        :param verify_negative_samples: bool, optional
            When sampling negative items, check if the randomly picked negative item has actually
            been liked by the user. This check increases the time needed to train but usually leads
            to better predictions.
        """
        super(LightGCN, self).__init__()
        self.learning_rate = learning_rate
        self.regularization = regularization
        self.batch_size = batch_size
        self.factors = factors
        self.n_negatives = n_negatives
        self.iterations = iterations
        self.n_layers = n_layers
        self.ns_exponent = ns_exponent
        self.verify_negative_samples = verify_negative_samples
        self.calculate_training_roc_auc = calculate_training_roc_auc

        self.user_factors = None
        self.item_factors = None

        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.to(self.device)

    def __init_weight(self, dataset: LightGCNDataset):
        """
        Инициализация весов с помощью xavier_normal_
        :return:
        """
        self.num_users = dataset.n_users
        self.num_items = dataset.m_items
        self.Graph = dataset.get_sparse_graph().to(self.device)

        if self.user_factors is not None and self.item_factors is not None:
            return

        self.user_factors = torch.nn.Embedding(
            num_embeddings=dataset.n_users, embedding_dim=self.factors
        ).to(self.device)
        self.item_factors = torch.nn.Embedding(
            num_embeddings=dataset.m_items, embedding_dim=self.factors
        ).to(self.device)

        nn.init.xavier_normal_(self.user_factors.weight)
        nn.init.xavier_normal_(self.item_factors.weight)

    def computer(self) -> tuple:
        """
        Propagate high-hop embeddings for lightGCN
        :return: user embeddings, item embeddings
        """
        users_emb = self.user_factors.weight
        items_emb = self.item_factors.weight
        all_emb = torch.cat([users_emb, items_emb])

        # на нулевом слое использует эмбеддинги из весов
        layer_embeddings = [all_emb]
        for _ in range(self.n_layers):
            # в цикле делаем шаг lightGCN в матричном виде, подробности в статьей
            all_emb = torch.sparse.mm(self.Graph, all_emb)
            layer_embeddings.append(all_emb)
        layer_embeddings = torch.stack(layer_embeddings, dim=1)

        final_embeddings = layer_embeddings.mean(dim=1)  # output – среднее всех слоев
        users, items = torch.split(final_embeddings, [self.num_users, self.num_items])
        return users, items

    def get_embedding(self, users: torch.tensor, pos_items: torch.tensor,
                      neg_items: torch.tensor) -> tuple:
        all_users, all_items = self.computer()
        users_emb = all_users[users]  # эмбеддинги пользователей
        pos_emb = all_items[pos_items] # эмбеддинги позитивных объектов
        neg_emb = all_items[neg_items] # эмбеддинги негативных объектов
        users_emb_ego = self.user_factors(users) # эмбеддинги пользователей с нулевого слоя
        pos_emb_ego = self.item_factors(pos_items) # эмбеддинги позитивных объектов с нулевого слоя
        neg_emb_ego = self.item_factors(neg_items) # эмбеддинги позитивных объектов с нулевого слоя
        return users_emb, pos_emb, neg_emb, users_emb_ego, pos_emb_ego, neg_emb_ego

    def bpr_loss(self, users: torch.tensor, pos: torch.tensor, neg: torch.tensor) -> tuple:
        """
        Calculate BPR loss as - sum ln(sigma(pos_scores - neg_scores)) + L2 norm
        :param users: users for which calculate loss
        :param pos: positive items
        :param neg: negative items
        :return: loss, reg_loss
        """
        (users_emb, pos_emb, neg_emb,
         userEmb0, posEmb0, negEmb0) = self.get_embedding(users.long(), pos.long(), neg.long())
        reg_loss = (1 / 2) * (userEmb0.norm(2).pow(2) +
                              posEmb0.norm(2).pow(2) +
                              negEmb0.norm(2).pow(2)) / float(len(users))

        # скоры – dot-product между веторами пользователей и объектов
        pos_scores = torch.mul(users_emb, pos_emb)
        pos_scores = torch.sum(pos_scores, dim=1)
        neg_scores = torch.mul(users_emb, neg_emb)
        neg_scores = torch.sum(neg_scores, dim=1)

        roc_auc = None
        if self.calculate_training_roc_auc:
            scores = torch.cat([pos_scores, neg_scores])
            labels = torch.cat([torch.ones_like(pos_scores), torch.zeros_like(neg_scores)])
            roc_auc = roc_auc_score(labels.detach().cpu().numpy(), scores.detach().cpu().numpy())

        # bpr loss
        loss = - (pos_scores - neg_scores).sigmoid().log().mean()
        return loss, reg_loss, roc_auc

    def forward(self, users: torch.tensor, items: torch.tensor):
        # compute embedding
        all_users, all_items = self.computer()

        users_emb = all_users[users]
        items_emb = all_items[items]
        inner_prod = torch.mul(users_emb, items_emb)
        return torch.sum(inner_prod, dim=1).sigmoid()

    def fit(self, user_items: csr_matrix, callback_fn=None):
        """
        Fitting model with BPR loss.
        :param user_items: dataset for training
        :param callback_fn: callback function
        :return:
        """
        dataset = LightGCNDataset(
            user_items, self.n_negatives, self.ns_exponent, self.verify_negative_samples)
        self.__init_weight(dataset)

        dataloader = DataLoader(
            dataset, batch_size=self.batch_size,
            shuffle=True, collate_fn=collate_function,
            pin_memory=True, num_workers=4
        )

        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        pbar = tqdm(range(self.iterations))
        print(f'len of dataloader = {len(dataloader)}')

        for _ in pbar:
            for users, pos, neg in dataloader:
                optimizer.zero_grad()
                users, pos, neg = users.to(self.device), pos.to(self.device), neg.to(self.device)
                loss, reg_loss, roc_auc = self.bpr_loss(users, pos, neg)
                total_loss = loss + self.regularization * reg_loss

                total_loss.backward()
                optimizer.step()

                pbar.set_postfix({
                    'bpr_loss': loss.item(),
                    'reg_loss': reg_loss.item(),
                    'train_auc': roc_auc,
                })

            if callback_fn is not None:
                callback_fn()

In [26]:
# соберем строчки для разреженной матрицы
rows = []
cols = []
values = []
for row_ind, col_ind in train_df.rows():
    rows.append(row_ind)
    cols.append(col_ind)
    values.append(1)

    rows.append(col_ind)
    cols.append(row_ind)
    values.append(1)

sparse_data = sp.csr_matrix((values, (rows, cols)))
sparse_data

<120061x120061 sparse matrix of type '<class 'numpy.int64'>'
	with 5163024 stored elements in Compressed Sparse Row format>

In [None]:
grouped_df = (
    test_df
    .groupby('uid')
    .agg(pl.col('friend_uid').alias('y_rel'))
    .join(
        train_df
        .groupby('uid')
        .agg(pl.col('friend_uid').alias('user_history')),
        'uid',
        how='left'
    )
)

In [27]:
model = LightGCN(
    batch_size=8192,  # размер батча для SGD
    factors=256,  # размерность эмбеддингов
    n_negatives=5,  # количество негативных примеров на один позитивный
    iterations=5,  # количество эпох
    n_layers=1,  # количество слоев в алгоритме LightGCN
    regularization=1e-4,  # регуляризация для эмбеддингов
    verify_negative_samples=False,  # нужно ли проверять, что случайные негативы на самом деле позитивные (увеличивает время сэмплирования)
    ns_exponent=0.75  # параметр для вероятности в негативном сэмплировании
)

def callback_fn():
    user_factors, item_factors = model.computer()

    user_factors = user_factors.cpu().detach().numpy()
    item_factors = item_factors.cpu().detach().numpy()

    _, recs = get_recommendations(user_factors, item_factors, TOP_K + median_seq_len)

    # посчитаем метрики качества рекомендаций
    recall_list = []
    for user_id, user_history, y_rel in grouped_df.select('uid', 'user_history', 'y_rel').rows():
        y_rec = [
            item_id
            for item_id in recs[user_id]
            # фильтруем уже просмотренное
            if item_id not in user_history
        ]
        recall_list.append(user_recall(y_rel, y_rec))

    mean_recall = np.mean(recall_list)
    print(f'Recall@{TOP_K} = {mean_recall}')

model.fit(sparse_data, callback_fn=callback_fn)

  d_inv = np.power(rowsum, -0.5).flatten()


  0%|          | 0/5 [00:00<?, ?it/s]

len of dataloader = 631
Recall@20 = 0.06295178054093842
Recall@20 = 0.06835021834733394
Recall@20 = 0.06874360334508252
Recall@20 = 0.06748730814956592
Recall@20 = 0.06421586035335795


In [28]:
user_factors, item_factors = model.computer()

user_factors = user_factors.cpu().detach().numpy()
item_factors = item_factors.cpu().detach().numpy()

_, recs = get_recommendations(user_factors, item_factors, TOP_K + median_seq_len)

grouped_df = (
    sample_submission.select('uid')
    .join(
        data
        .groupby('uid')
        .agg(pl.col('friend_uid').alias('user_history')),
        'uid',
        how='left'
    )
)

submission = []
for user_id, user_history in tqdm(grouped_df.rows()):
    user_history = [] if user_history is None else user_history

    y_rec = [
        item_id
        for item_id in recs[user_id]
        # фильтруем уже просмотренное
        if item_id not in user_history
    ]

    submission.append((user_id, y_rec))

submission = pl.DataFrame(submission, schema=['user_id', 'y_recs'])
submission.write_parquet('lightgcn_submission.parquet')
submission

  0%|          | 0/85483 [00:00<?, ?it/s]

user_id,y_recs
i64,list[i64]
0,"[0, 64886, … 108866]"
1,"[1, 90756, … 98970]"
3,"[3, 2963, … 29580]"
4,"[4, 100651, … 74993]"
5,"[5, 48984, … 75160]"
6,"[6, 48163, … 95375]"
7,"[102537, 60568, … 90335]"
8,"[8, 42199, … 69340]"
9,"[9, 42669, … 104388]"
10,"[17525, 110077, … 44960]"
