To run this POC:
1. Download and uncompress TenRec preprocessed dataset: https://drive.google.com/file/d/1OW6zIk2jUOyYiugr4fNDOCiYF86Nz4-O/view?usp=sharing  

2. Run a Merlin PyT Docker container (adjusting the paths)
```bash
docker run --runtime=nvidia --rm -it --ipc=host --cap-add SYS_NICE -v /home/gmoreira/projects/nvidia/nvidia_merlin/:/merlin_dev/ -v /mnt/nvme0n1/datasets:/data -p 8888:8888 nvcr.io/nvidia/merlin/merlin-pytorch:23.06 /bin/bash
```

3. Pull latest code from main and pip install latest code from models 
```
cd /models
git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && git fetch && git checkout main && pip install . --no-deps
```

4. Start Jupyter notebook
```bash
jupyter notebook --no-browser --ip 0.0.0.0 --no-browser --allow-root
```

In [1]:
import os

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [3]:
from torch import nn
import torch

In [4]:
from pytorch_lightning.loggers import WandbLogger
import pytorch_lightning as pl

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
import merlin.models.torch as mm
from merlin.dataloader.torch import Loader
from merlin.io.dataset import Dataset
from merlin.schema import ColumnSchema

  warn(f"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}")


In [6]:
config = dict(
    epochs=1,
    batch_size=16384,
    embedding_dim = 64,
    bottom_mlp = [64],
    top_mlp = [256,128,64],
    dropout = 0.18033334331720113,
    l2_reg = 1e-6, #3.5665386015190466e-05,
    LR=0.001,
    #LR_decay_factor=0.98,
    #LR_decay_steps=100,
    positive_class_weight=9.0
)

In [7]:
wandb_logger = WandbLogger()
wandb_logger.experiment.config.update(config)

[34m[1mwandb[0m: Currently logged in as: [33mgspmoreira[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [8]:
train_dataset = Dataset("/data/tenrec/tenrec_preproc_v01/ranking/train/*.parquet")
valid_dataset = Dataset("/data/tenrec/tenrec_preproc_v01/ranking/valid/*.parquet")
schema = train_dataset.schema

In [9]:
def bce_with_class_weight(pos_weight, neg_weight=1.0):
    bce_loss = nn.BCELoss(reduction='none')
    return lambda pred, target: ((target*pos_weight + (1-target)*neg_weight) * bce_loss(pred, target)).mean()

In [10]:
output_block = mm.BinaryOutput(ColumnSchema("click"), 
                               loss=bce_with_class_weight(pos_weight=config["positive_class_weight"])
                              )
model = mm.DLRMModel(
    schema,
    dim=config["embedding_dim"],
    bottom_block=mm.MLPBlock(config["bottom_mlp"], dropout=config["dropout"]),
    top_block=mm.MLPBlock(config["top_mlp"],dropout=config["dropout"]),
    output_block=output_block
)



In [11]:
model

DLRMModel(
  (0): DLRMBlock(
    (0): DLRMInputBlock(
      (categorical): EmbeddingTables(
          (pre): Block(
            (0): SelectKeys(user_id, item_id, video_category, gender, age)
          )
          (branches): (
            (user_id): Block(
              (0): EmbeddingTable(
                features: user_id
                (table): Embedding(2633852, 64)
              )
            )
            (item_id): Block(
              (0): EmbeddingTable(
                features: item_id
                (table): Embedding(179281, 64)
              )
            )
            (video_category): Block(
              (0): EmbeddingTable(
                features: video_category
                (table): Embedding(6, 64)
              )
            )
            (gender): Block(
              (0): EmbeddingTable(
                features: gender
                (table): Embedding(6, 64)
              )
            )
            (age): Block(
              (0): EmbeddingTable(
     

In [12]:
model.optimizer = torch.optim.Adam(model.parameters(), lr=config["LR"], weight_decay=config["l2_reg"])
#model.scheduler = torch.optim.lr_scheduler.StepLR(model.optimizer, 
#                                                  step_size=config["LR_decay_steps"], 
#                                                  gamma=config["LR_decay_factor"])

In [13]:
%%time
trainer = pl.Trainer(devices=1, max_epochs=config["epochs"], #max_steps=1000, 
                    val_check_interval=1000, limit_val_batches=100,  #check_val_every_n_epoch=None,
                    logger=wandb_logger
                    )

with Loader(train_dataset, batch_size=config["batch_size"]) as train_loader, \
      Loader(valid_dataset, batch_size=config["batch_size"]) as valid_loader:
    model.initialize(train_loader)
    trainer.fit(model, train_loader, #val_dataloaders=valid_loader
               )

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type       | Params
--------------------------------------
0 | values | ModuleList | 180 M 
--------------------------------------
180 M     Trainable params
0         Non-trainable params
180 M     Total params
720.344   Total estimated model params size (MB)


Epoch 0: 100%|██████████████████| 9229/9229 [06:10<00:00, 24.93it/s, v_num=rm6f]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████████████| 9229/9229 [06:13<00:00, 24.70it/s, v_num=rm6f]
CPU times: user 7min 9s, sys: 36.3 s, total: 7min 45s
Wall time: 6min 16s


In [14]:
trainer.logged_metrics

{'train_loss': tensor(1.4949),
 'train_binary_accuracy': tensor(0.4096),
 'train_binary_auroc': tensor(0.7442),
 'train_binary_precision': tensor(0.3254),
 'train_binary_recall': tensor(0.9713)}

In [15]:
%%time
metrics = trainer.validate(model, valid_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation DataLoader 0: 100%|███████████████| 100/100 [00:00<00:00, 108.10it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
   val_binary_accuracy      0.41141846776008606
    val_binary_auroc        0.7365943789482117
  val_binary_precision      0.34194207191467285
    val_binary_recall       0.9782023429870605
        val_loss            1.5347142219543457
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
CPU times: user 1.95 s, sys: 712 ms, total: 2.66 s
Wall time: 2.56 s
