Docker pure python 3.11 build (DataSphere)
```
FROM ubuntu:22.04
ENV DEBIAN_FRONTEND=noninteractive

RUN apt-get update && apt-get install -y --no-install-recommends \
      ca-certificates \
      python3.11 python3.11-venv python3.11-distutils \
      python3-pip \
    && ln -sf /usr/bin/python3.11 /usr/bin/python3 \
    && ln -sf /usr/bin/pip3 /usr/bin/pip \
    && python3 -m pip install -U pip \
    && rm -rf /var/lib/apt/lists/*

RUN useradd -ms /bin/bash --uid 1000 jupyter
USER jupyter
WORKDIR /home/jupyter
```

In [1]:
!python3 --version

Python 3.11.0rc1


In [2]:
%pip install murmurhash==1.0.13
%pip install numpy==2.1.2
%pip install tensorboard==2.20.0
%pip install transformers==4.50.3

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m
Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m
Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --up

In [3]:
%pip install torch==2.5.1+cu121 --index-url https://download.pytorch.org/whl/cu121

Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://download.pytorch.org/whl/cu121

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m


In [4]:
import json
from functools import partial

import torch
from torch.utils.data import DataLoader

from modeling import utils
from modeling.dataloader import BatchProcessor
from modeling.dataset import Dataset
from modeling.loss import IdentityLoss
from modeling.metric import NDCGSemanticMetric, RecallSemanticMetric
from modeling.models import TigerModel, CorrectItemsLogitsProcessor
from modeling.trainer import Trainer
from modeling.utils import parse_args, create_logger, fix_random_seed



In [5]:
def parse_config(config_path):
    with open(config_path) as f:
        params = json.load(f)
    
    return params

In [6]:
config = parse_config('configs/tiger_kmeans_train_config.json')
print('Training config: \n{}'.format(json.dumps(config, indent=2)))

Training config: 
{
  "experiment_name": "tiger_beauty_kmeans",
  "dataset": {
    "inter_json_path": "../data/Beauty/inter.json",
    "index_json_path": "../data/Beauty/index_rqkmeans.json",
    "num_codebooks": 4,
    "max_sequence_length": 20,
    "sampler_type": "tiger"
  },
  "dataloader": {
    "train_batch_size": 256,
    "validation_batch_size": 256
  },
  "model": {
    "embedding_dim": 128,
    "codebook_size": 256,
    "num_positions": 20,
    "user_ids_count": 2000,
    "num_heads": 6,
    "num_encoder_layers": 4,
    "num_decoder_layers": 4,
    "dim_feedforward": 1024,
    "d_kv": 64,
    "dropout": 0.1,
    "activation": "relu",
    "num_beams": 100,
    "top_k": 20,
    "layer_norm_eps": 1e-06,
    "initializer_range": 0.02
  },
  "optimizer": {
    "lr": 0.0003
  },
  "train_epochs_num": 20
}


In [7]:
SEED_VALUE = 42
fix_random_seed(SEED_VALUE)

In [8]:
print('Current DEVICE: {}'.format(utils.DEVICE))

Current DEVICE: cuda


In [9]:
dataset = Dataset.create(
    inter_json_path=config['dataset']['inter_json_path'],
    max_sequence_length=config['dataset']['max_sequence_length'],
    sampler_type=config['dataset']['sampler_type'],
    is_extended=True
)
train_sampler, validation_sampler, test_sampler = dataset.get_samplers()
num_codebooks = config['dataset']['num_codebooks']
user_ids_count = config['model']['user_ids_count']

[2025-10-25 07:20:42] [INFO]: Train dataset size: 131413
[2025-10-25 07:20:42] [INFO]: Validation dataset size: 22363
[2025-10-25 07:20:42] [INFO]: Test dataset size: 22363
[2025-10-25 07:20:42] [INFO]: Max item id: 12100


In [10]:
batch_processor = BatchProcessor.create(
    config['dataset']['index_json_path'], num_codebooks, user_ids_count
)

train_dataloader = DataLoader(
    dataset=train_sampler,
    batch_size=config['dataloader']['train_batch_size'],
    drop_last=True,
    shuffle=True,
    collate_fn=batch_processor
)

validation_dataloader = DataLoader(
    dataset=validation_sampler,
    batch_size=config['dataloader']['validation_batch_size'],
    drop_last=False,
    shuffle=False,
    collate_fn=batch_processor
)

eval_dataloader = DataLoader(
    dataset=test_sampler,
    batch_size=config['dataloader']['validation_batch_size'],
    drop_last=False,
    shuffle=False,
    collate_fn=batch_processor
)

model = TigerModel(
    embedding_dim=config['model']['embedding_dim'],
    codebook_size=config['model']['codebook_size'],
    sem_id_len=num_codebooks,
    user_ids_count=user_ids_count,
    num_positions=config['model']['num_positions'],
    num_heads=config['model']['num_heads'],
    num_encoder_layers=config['model']['num_encoder_layers'],
    num_decoder_layers=config['model']['num_decoder_layers'],
    dim_feedforward=config['model']['dim_feedforward'],
    num_beams=config['model']['num_beams'],
    num_return_sequences=config['model']['top_k'],
    activation=config['model']['activation'],
    d_kv=config['model']['d_kv'],
    dropout=config['model']['dropout'],
    layer_norm_eps=config['model']['layer_norm_eps'],
    initializer_range=config['model']['initializer_range'],
    logits_processor=partial(
        CorrectItemsLogitsProcessor,
        config['dataset']['num_codebooks'],
        config['model']['codebook_size'],
        config['dataset']['index_json_path'],
        config['model']['num_beams']
    )
).to(utils.DEVICE)

In [11]:
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'Overall parameters: {total_params:,}')
print(f'Trainable parameters: {trainable_params:,}')

Overall parameters: 5,236,352
Trainable parameters: 5,236,352


In [12]:
loss_function = IdentityLoss(
    predictions_prefix='loss',
    output_prefix='loss'
)  # Passes through the loss computed inside the model without modification

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config['optimizer']['lr'],
)

codebook_size = config['model']['codebook_size']
ranking_metrics = {
    'ndcg@5': NDCGSemanticMetric(5, codebook_size, num_codebooks),
    'ndcg@10': NDCGSemanticMetric(10, codebook_size, num_codebooks),
    'ndcg@20': NDCGSemanticMetric(20, codebook_size, num_codebooks),
    'recall@5': RecallSemanticMetric(5, codebook_size, num_codebooks),
    'recall@10': RecallSemanticMetric(10, codebook_size, num_codebooks),
    'recall@20': RecallSemanticMetric(20, codebook_size, num_codebooks)
}

print('Everything is ready for training process!')

Everything is ready for training process!


In [13]:
config.get('train_epochs_num')

20

In [None]:
trainer = Trainer(
    experiment_name=config['experiment_name'],
    train_dataloader=train_dataloader,
    validation_dataloader=validation_dataloader,
    eval_dataloader=eval_dataloader,
    model=model,
    optimizer=optimizer,
    loss_function=loss_function,
    ranking_metrics=ranking_metrics,
    epoch_cnt=config.get('train_epochs_num'),
    step_cnt=config.get('train_steps_num'),
    best_metric='ndcg@20',
    epochs_threshold=config.get('early_stopping_threshold', 40),
    valid_step=256,
    eval_step=256
)

best_checkpoint = trainer.train()
trainer.save()

print('Training finished!')

[2025-10-25 07:20:55] [DEBUG]: Start training...
[2025-10-25 07:20:55] [DEBUG]: Start epoch 0
[2025-10-25 07:20:57] [DEBUG]: Running validation on step 0...
[2025-10-25 07:22:25] [DEBUG]: Running validation on step 0 is done!
[2025-10-25 07:22:25] [DEBUG]: Running eval on step 0...
[2025-10-25 07:23:52] [DEBUG]: Running eval on step 0 is done!


0 0.000757793589679488 0 0.0009387740354357149


[2025-10-25 07:24:09] [DEBUG]: Running validation on step 256...
[2025-10-25 07:27:03] [DEBUG]: Running eval on step 256 is done!


256 0.0027599836157562643 0.000757793589679488 0.0027145292450510335


[2025-10-25 07:27:19] [DEBUG]: Running validation on step 512...
[2025-10-25 07:28:46] [DEBUG]: Running validation on step 512 is done!
[2025-10-25 07:28:46] [DEBUG]: Running eval on step 512...
[2025-10-25 07:30:13] [DEBUG]: Running eval on step 512 is done!
[2025-10-25 07:30:13] [DEBUG]: Start epoch 1


512 0.009556818176604807 0.0027599836157562643 0.007540107550984796


[2025-10-25 07:30:30] [DEBUG]: Running validation on step 768...
[2025-10-25 07:31:57] [DEBUG]: Running validation on step 768 is done!
[2025-10-25 07:31:57] [DEBUG]: Running eval on step 768...
[2025-10-25 07:33:23] [DEBUG]: Running eval on step 768 is done!


768 0.018546506602585125 0.009556818176604807 0.015173930251171374


[2025-10-25 07:33:40] [DEBUG]: Running validation on step 1024...
[2025-10-25 07:35:06] [DEBUG]: Running validation on step 1024 is done!
[2025-10-25 07:35:06] [DEBUG]: Running eval on step 1024...
[2025-10-25 07:36:33] [DEBUG]: Running eval on step 1024 is done!


1024 0.022572664778784723 0.018546506602585125 0.018555083167375436


[2025-10-25 07:36:33] [DEBUG]: Start epoch 2
[2025-10-25 07:36:49] [DEBUG]: Running validation on step 1280...
[2025-10-25 07:38:15] [DEBUG]: Running validation on step 1280 is done!
[2025-10-25 07:38:15] [DEBUG]: Running eval on step 1280...
[2025-10-25 07:39:42] [DEBUG]: Running eval on step 1280 is done!


1280 0.02380148114794211 0.022572664778784723 0.02088523586265027


[2025-10-25 07:39:58] [DEBUG]: Running validation on step 1536...
[2025-10-25 07:58:37] [DEBUG]: Running eval on step 2816 is done!
[2025-10-25 07:58:53] [DEBUG]: Running validation on step 3072...
[2025-10-25 08:00:20] [DEBUG]: Running validation on step 3072 is done!
[2025-10-25 08:00:20] [DEBUG]: Running eval on step 3072...
[2025-10-25 08:01:47] [DEBUG]: Running eval on step 3072 is done!


3072 0.03383879984290838 0.03103755620515491 0.028773254263925033


[2025-10-25 08:01:47] [DEBUG]: Start epoch 6
[2025-10-25 08:02:03] [DEBUG]: Running validation on step 3328...
[2025-10-25 08:03:30] [DEBUG]: Running validation on step 3328 is done!
[2025-10-25 08:03:30] [DEBUG]: Running eval on step 3328...
[2025-10-25 08:11:29] [DEBUG]: Running validation on step 4096...
[2025-10-25 08:12:55] [DEBUG]: Running validation on step 4096 is done!
[2025-10-25 08:12:55] [DEBUG]: Running eval on step 4096...
[2025-10-25 08:14:21] [DEBUG]: Running eval on step 4096 is done!


4096 0.0352996068688074 0.03383879984290838 0.02918270947332113


[2025-10-25 08:14:22] [DEBUG]: Start epoch 8
[2025-10-25 08:14:38] [DEBUG]: Running validation on step 4352...
[2025-10-25 08:16:03] [DEBUG]: Running validation on step 4352 is done!
[2025-10-25 08:16:03] [DEBUG]: Running eval on step 4352...
[2025-10-25 08:17:29] [DEBUG]: Running eval on step 4352 is done!
[2025-10-25 08:17:45] [DEBUG]: Running validation on step 4608...
[2025-10-25 08:19:11] [DEBUG]: Running validation on step 4608 is done!
[2025-10-25 08:19:11] [DEBUG]: Running eval on step 4608...
[2025-10-25 08:20:37] [DEBUG]: Running eval on step 4608 is done!
[2025-10-25 08:20:37] [DEBUG]: Start epoch 9
[2025-10-25 08:20:53] [DEBUG]: Running validation on step 4864...
[2025-10-25 08:22:19] [DEBUG]: Running validation on step 4864 is done!
[2025-10-25 08:22:19] [DEBUG]: Running eval on step 4864...
[2025-10-25 08:23:44] [DEBUG]: Running eval on step 4864 is done!


4864 0.03675197100058836 0.0352996068688074 0.030686815992635397


[2025-10-25 08:24:00] [DEBUG]: Running validation on step 5120...
[2025-10-25 08:25:26] [DEBUG]: Running validation on step 5120 is done!
[2025-10-25 08:25:26] [DEBUG]: Running eval on step 5120...
[2025-10-25 08:26:52] [DEBUG]: Running eval on step 5120 is done!
[2025-10-25 08:26:53] [DEBUG]: Start epoch 10
[2025-10-25 08:27:08] [DEBUG]: Running validation on step 5376...


In [16]:
print('Final evaluation is being performed...')

trainer.load(best_checkpoint)
trainer.eval()

Final evaluation is being performed...


[2025-10-25 06:51:05] [DEBUG]: Running eval on step 0...
[2025-10-25 06:52:31] [DEBUG]: Running eval on step 0 is done!


ndcg@5 0.0064544741835421676
ndcg@10 0.006818608953440727
ndcg@20 0.007540107550984796
recall@5 0.008988060635871752
recall@10 0.01010597862540804
recall@20 0.012967848678620936


Eval from file checkpoint

In [15]:
from pathlib import Path

checkpoint_path = Path(f"../checkpoints/{config['experiment_name']}_final_state.pth")
state_dict = torch.load(checkpoint_path, map_location=utils.DEVICE)

trainer.load(state_dict)
trainer.eval()

  state_dict = torch.load(checkpoint_path, map_location=utils.DEVICE)
[2025-10-25 09:51:13] [DEBUG]: Running eval on step 0...
[2025-10-25 09:52:38] [DEBUG]: Running eval on step 0 is done!


ndcg@5 0.022584518243720333
ndcg@10 0.028540282975111846
ndcg@20 0.03560548419336809
recall@5 0.035147341591020884
recall@10 0.05374949693690471
recall@20 0.08187631355363771
