In [3]:
from collections import defaultdict

import numpy as np
import json
import pickle

In [4]:
embeddings_input_path = '../data/Beauty/content_embeddings.pkl'

In [5]:
with open(embeddings_input_path, 'rb') as f:
    data = pickle.load(f)

item_ids = np.array(data['item_id'], dtype=np.int64)
X = np.array(data['embedding'], dtype=np.float32)

In [6]:
X.shape

(12101, 4096)

In [7]:
X[0]

array([-0.04701373,  0.26094416, -0.32016486, ..., -0.32683063,
        0.16486016,  0.2398256 ], shape=(4096,), dtype=float32)

In [8]:
pairs_path = '../data/Beauty/positive_pairs.txt'

In [9]:
pairs = []
with open(pairs_path, 'r', encoding='utf-8') as f:
    for line in f:
        two_ints = line.strip().split()
        if len(two_ints) != 2:
            raise ValueError(f"not two ints, {two_ints_ints}")
        anchor, positive = map(int, two_ints)
        pairs.append((anchor, positive))

In [11]:
len(pairs)

131413

In [12]:
pairs[:10]

[(9839, 11863),
 (11863, 11752),
 (11752, 9449),
 (3309, 4572),
 (4572, 9079),
 (9079, 4136),
 (4386, 6362),
 (6362, 4208),
 (4208, 5665),
 (5665, 453)]

чек что в наших парах нет дубликатов

In [13]:
cnt = 0
for (a, b) in pairs:
    if a == b:
        cnt += 1
cnt

0

In [14]:
num_items, D = X.shape
num_items, D

(12101, 4096)

In [15]:
max_id = int(item_ids.max())
max_id

12100

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import List, Tuple
import numpy as np
import pickle

In [17]:
base_emb = nn.Embedding(max_id + 1, D)

In [18]:
with torch.no_grad():
    base_emb.weight.zero_()
    base_emb.weight[item_ids] = F.normalize(torch.tensor(X), dim=1)

for p in base_emb.parameters():
    p.requires_grad_(False)

In [19]:
class TowerMLP(nn.Module):
    def __init__(self, d_in, d_hidden, d_out, num_layers=2, p_drop=0.0):
        super().__init__()
        layers = []
        last = d_in
        for _ in range(num_layers - 1):
            layers += [nn.Linear(last, d_hidden), nn.ReLU(), nn.Dropout(p_drop)]
            last = d_hidden
        layers += [nn.Linear(last, d_out)]
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return F.normalize(self.net(x), dim=1)

In [20]:
tower = TowerMLP(D, D, D, num_layers=2, p_drop=0.0)

In [21]:
@dataclass
class PairDataset(torch.utils.data.Dataset):
    pairs: List[Tuple[int, int]]
    def __len__(self): return len(self.pairs)
    def __getitem__(self, idx):
        a, p = self.pairs[idx]
        return torch.tensor(a), torch.tensor(p)

In [22]:
B = 32

In [23]:
ds = PairDataset(pairs)
loader = torch.utils.data.DataLoader(ds, batch_size=B, shuffle=True, drop_last=True)

In [24]:
def nt_xent_loss(z1, z2, tau=0.07):
    """
    с 0.07 вопросы, конечно, смотрел статью SimCLR, а в либе lightly.loss.ntx_ent_loss.NTXentLoss стоит 0.5 дефолт
    """
    logits12 = (z1 @ z2.T) / tau
    logits21 = (z2 @ z1.T) / tau
    labels = torch.arange(z1.size(0), device=z1.device)
    return 0.5 * (F.cross_entropy(logits12, labels) + F.cross_entropy(logits21, labels))

In [25]:
opt = torch.optim.AdamW(tower.parameters(), lr=3e-4, weight_decay=1e-4)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tower = tower.to(device)
base_emb = base_emb.to(device)



In [26]:
device

device(type='cuda')

In [27]:
base_emb

Embedding(12101, 4096)

In [29]:
base_emb(torch.tensor(9839).to(device))

tensor([ 0.0093,  0.0063, -0.0002,  ..., -0.0030, -0.0096,  0.0041],
       device='cuda:0')

In [30]:
from tqdm.auto import tqdm 

In [33]:
for epoch in range(3):
    running = 0.0
    print(f"len(loader) is {len(loader)}")
    for a_ids, p_ids in tqdm(loader, desc=f"epoch {epoch+1}", leave=False):
        a_ids, p_ids = a_ids.to(device), p_ids.to(device)
        with torch.no_grad():
            a_base = base_emb(a_ids)
            p_base = base_emb(p_ids)
        zA = tower(a_base)
        zP = tower(p_base)
        loss = nt_xent_loss(zA, zP, tau=0.07)
        opt.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(tower.parameters(), 1.0)
        opt.step()
        running += loss.item()
        # break # временно прогоняю только один батч
    print(f"epoch {epoch+1}: loss={running/len(loader):.4f}")

len(loader) is 4106


                                                             

epoch 1: loss=2.7126
len(loader) is 4106


                                                             

epoch 2: loss=2.6199
len(loader) is 4106


                                                             

epoch 3: loss=2.5577




In [34]:
base_emb(torch.tensor(9839).to(device))

tensor([ 0.0093,  0.0063, -0.0002,  ..., -0.0030, -0.0096,  0.0041],
       device='cuda:0')

In [38]:
with torch.no_grad():
    item = 9839
    idx = torch.tensor([item], dtype=torch.long, device=device)
    base = base_emb(idx)
tower(base)

tensor([[-0.0082, -0.0239, -0.0034,  ...,  0.0116, -0.0077,  0.0035]],
       device='cuda:0', grad_fn=<DivBackward0>)

In [40]:
new_df = {
    'item_id': [],
    'embedding': []
}

In [41]:
tower.eval()
with torch.no_grad():
    for batch_ids in tqdm(torch.split(torch.tensor(item_ids), 512)):
        batch_ids = batch_ids.to(device)
        base_vecs = base_emb(batch_ids)
        tuned_vecs = tower(base_vecs)  # (bs, D)
        new_df['item_id'] += batch_ids.cpu().tolist()
        # вот тут надо хорошо подумать, хорошо бы, чтобы тюненые эмбеды далеко не убегали от исходного
        # какой-то лосс или мягкая сумма (1 - alpha) * ... + alpha * ... 
        # или сделать, как в той статье, где обучали вектор смещения дельта
        # или почитать, как сделали в PLUM
        new_df['embedding'] += tuned_vecs.cpu().tolist()

100%|██████████| 24/24 [00:01<00:00, 17.54it/s]


In [42]:
tuned_embeddings_output_path = '../data/Beauty/tuned_content_embeddings.pkl'

In [43]:
with open(tuned_embeddings_output_path, 'wb') as f:
    pickle.dump(new_df, f, protocol=pickle.HIGHEST_PROTOCOL)