## todo
- negatives. : done
- logQ correction : divay
- postional encoding : arshi
- time feature : done
- pl model : done
- pl dataset : done
- trainer : done
- make random for only max_positives : divay
- encoder to decoder
- positive mask loss removal : done

# Data loader

In [0]:
import os
import mlflow
import gc
import pickle 
from time import time
from google.cloud import storage
import math
import datetime as datetime
import sys
import psutil
from pyspark.sql.window import Window
import pyspark.sql.types as T
import pyspark.sql.functions as F

import pyarrow.parquet as pq
import pyarrow.dataset as pads

import pandas as pd
import numpy as np
import random

import torch
import torch.nn as nn
from torch.utils.data import DistributedSampler, DataLoader, Dataset, IterableDataset

import lightning.pytorch as pl
from lightning import Trainer

from transformers import get_cosine_schedule_with_warmup

storage_client = storage.Client()


bucket = storage_client.get_bucket('gcs-dsci-fryou-fy-dev-prd')

2025-02-20 11:22:20.067590: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-02-20 11:22:20.120933: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-02-20 11:22:20.120959: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-02-20 11:22:20.120984: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-20 11:22:20.131150: I tensorflow/core/platform/cpu_feature_g

# Functions

In [0]:
def get_equally_distributed_files(files, world_size):
    total_files = len(files)
    print("total files inside fn: ", total_files)
    end = (total_files // world_size) * world_size
    print("value of end inside fn: ", end)
    files = files[:end]
    return files

In [0]:
def get_index(index_path, entity, special_tokens, shift_index=True):
    bucket.blob(blob_name=f"{index_path}{entity}_index.pkl").download_to_filename(
        f"{entity}_index.pkl"
    )
    with open(f"{entity}_index.pkl", "rb") as file:
        entity_index = pickle.load(file)
    if shift_index:
        special_tokens_len = len(special_tokens)
        entity_index = special_tokens | {k:v+special_tokens_len for k,v in entity_index.items()}
    return entity_index, len(entity_index)

# Dataset

In [0]:
class SeqDataSet(IterableDataset):
    def __init__(self, files, behaviour_context_length, all_dense_perc, future_window, max_positives, negatives, batch_size):
        super().__init__()
        self.files = files  
        self.behaviour_context_length = behaviour_context_length
        self.all_dense_perc = all_dense_perc
        self.future_window = future_window
        self.max_positives = max_positives
        self.negatives = negatives
        self.batch_size = batch_size
        self.global_negs = torch.randint(0, len(catalog_id_index), (behaviour_context_length, negatives))
        self.rand_inbatch_negs = np.random.randint(0, self.max_positives * self.batch_size, (self.negatives,))

    def __len__(self):
        return len(self.files) * 100000 // self.batch_size

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            total_workers = 1
            worker_id = 0
        else:
            total_workers = worker_info.num_workers
            worker_id = worker_info.id
        
        for file in self.files[worker_id::total_workers]:
            parr_data = pads.dataset(source=file.astype(str), format="parquet")
            for batch in parr_data.to_batches(batch_size=100000):
                i = 0
                df = batch.to_pandas()
                df_size = df.shape[0]
                while i < df_size:
                    data = df[i:i + self.batch_size]
                    batch_series = data.apply(lambda x: self.create_seq(x.interaction_list), axis=1)
                    __batch_size = len(data)
                    batch_random = torch.randint(1, 10000, (1,))
                    global_negs = (self.global_negs * batch_random) % (catalog_id_vocab_size - 2) + 2
                    global_negs = global_negs.unsqueeze(0).tile((__batch_size, 1, 1))

                    batch_series_dict = {
                        "_positives": np.concatenate(batch_series.map(lambda x: x['_positives']).to_list(), axis=0),
                        "_len_pos": np.array(batch_series.map(lambda x: x['_len_pos']).to_list())
                    }

                    inbatch_negs = self.generate_inbatch_negatives(batch_series_dict, __batch_size)
                    negatives = torch.cat((inbatch_negs, global_negs), dim=-1)

                    selection_indices = torch.tensor(np.array(batch_series.map(lambda x: x['selection_indices']).to_list()), dtype=torch.long)
                    batch_indices = torch.arange(selection_indices.shape[0]).view(-1, 1)
                    negatives = negatives[batch_indices, selection_indices, :]

                    yield {
                        "behaviour_seq": torch.tensor(np.array(batch_series.map(lambda x: x['behaviour_seq']).to_list()), dtype=torch.long),
                        "event_type_seq": torch.tensor(np.array(batch_series.map(lambda x: x['event_type_seq']).to_list()), dtype=torch.long),
                        "event_time_seq": torch.tensor(np.array(batch_series.map(lambda x: x['event_time_seq']).to_list()), dtype=torch.long),
                        "positives": torch.tensor(np.array(batch_series.map(lambda x: x['positives']).to_list()), dtype=torch.long),
                        "mask": torch.tensor(np.array(batch_series.map(lambda x: x['mask']).to_list()), dtype=torch.bool),
                        "selection_indices": selection_indices,
                        "negatives": negatives
                    }

                    i = i + self.batch_size
    
    def create_seq(self, interaction_seq):
        event_times = [datetime.datetime.fromtimestamp(ele['event_time'] / 1000000000) for ele in interaction_seq]
        end_time = event_times[-1]
        start_time = end_time - datetime.timedelta(days=self.future_window)
        all_event = [event for event in interaction_seq if datetime.datetime.fromtimestamp(event['event_time'] / 1000000000) < start_time]
        all_event = all_event[-1 * int(self.behaviour_context_length):]

        catalog_id_seq = np.array([catalog_id_index.get(event['catalog_id'], catalog_id_index['<MASK>']) for event in all_event])
        event_type_seq = np.array([event_type_index[event['event_type']] for event in all_event])
        event_time_seq = np.array([event['event_time'] / 1000000000 for event in all_event])

        mask = np.zeros(len(catalog_id_seq), dtype=bool)
        num_positives = min(int(len(catalog_id_seq) * self.all_dense_perc), self.max_positives - 1)
        mask[:num_positives] = True
        np.random.shuffle(mask)
        mask[-1] = True
        num_masking = np.sum(mask)

        all_positives = np.array([catalog_id_index.get(event['catalog_id'], catalog_id_index['<MASK>']) for event in interaction_seq if datetime.datetime.fromtimestamp(event['event_time'] / 1000000000) >= start_time])

        _positives = all_positives[np.random.randint(0, len(all_positives), num_masking)]

        seq_len = len(catalog_id_seq)

        if seq_len < self.behaviour_context_length:
            pad_num = self.behaviour_context_length - seq_len
            catalog_id_seq = np.pad(catalog_id_seq, (pad_num, 0))
            event_type_seq = np.pad(event_type_seq, (pad_num, 0))
            mask = np.pad(mask, (pad_num, 0))
            event_time_seq = np.pad(event_time_seq, (pad_num, 0))

        positive_pad_num = self.max_positives - num_masking
        selection_indices = np.pad(np.where(mask == 1)[0], (positive_pad_num, 0))
        positives = np.pad(_positives, (positive_pad_num, 0))
            
        return {
            "behaviour_seq": catalog_id_seq, 
            "event_type_seq": event_type_seq,
            "event_time_seq": event_time_seq,
            "selection_indices": selection_indices,
            "positives": positives,
            "mask": mask,
            "_positives": _positives,
            "_len_pos": num_masking,
        }

    def generate_inbatch_negatives(self, batch, _batch_size):
        _positives = batch['_positives']
        _lens = batch['_len_pos']
        batch_size = _batch_size  
    
        inbatch_negs_num = self.negatives
        inbatch_negs = np.zeros((batch_size, inbatch_negs_num))
        
        start = 0
        indices = np.arange(len(_positives))
        for i in range(batch_size):
            exclude_range = np.delete(indices, np.arange(start, start + _lens[i]))
            exclude_range_len = len(exclude_range)
            inbatch_negs[i] = _positives[exclude_range][self.rand_inbatch_negs % exclude_range_len]
            start += _lens[i]

        inbatch_negs = torch.tensor(inbatch_negs, dtype=torch.long).unsqueeze(1).tile((1, self.behaviour_context_length, 1))

        return inbatch_negs

In [0]:
class PinnerFormerDataset(pl.LightningDataModule):

    def __init__(self, train_files, val_files, behaviour_context_length, all_dense_perc, future_window, max_positives, negatives, batch_size, num_workers):
        super().__init__()
        self.train_files = train_files  
        self.val_files = val_files  
        self.behaviour_context_length = behaviour_context_length
        self.all_dense_perc = all_dense_perc
        self.future_window = future_window
        self.max_positives = max_positives
        self.negatives = negatives
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        if self.trainer:
            start = self.trainer.global_rank
            step = self.trainer.world_size
        else:
            start = 0
            step = 1
        if stage == "fit" or stage is None:
            self.train_set = SeqDataSet(
                self.train_files[start::step],
                self.behaviour_context_length,
                self.all_dense_perc,
                self.future_window,
                self.max_positives,
                self.negatives,
                self.batch_size,

            )

            self.val_set = SeqDataSet(
                self.val_files[start::step],
                self.behaviour_context_length,
                self.all_dense_perc,
                self.future_window,
                self.max_positives,
                self.negatives,
                self.batch_size,
            )
    
    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, persistent_workers=True, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, persistent_workers=True, pin_memory=True)

In [0]:
# mask = np.zeros(10, dtype=bool)
# num_positives = min(4,8)
# mask[:num_positives] = True
# np.random.shuffle(mask)
# mask[-1]=True
# num_masking = np.sum(mask)

In [0]:
# behaviour_context_length = 1500
# all_dense_perc = 0.5
# future_window = 28
# max_positives = 640
# negatives = 1000
# batch_size = 512
# num_workers = 1
# ds = SeqDataSet(train_files, behaviour_context_length, all_dense_perc, future_window, max_positives, negatives, batch_size)

In [0]:
# batch_data = next(iter(ds))

In [0]:
# dataset = PinnerFormerDataset(train_files, val_files, behaviour_context_length, all_dense_perc, future_window, max_positives, negatives, batch_size, num_workers)

In [0]:
# dataset.setup("fit")

In [0]:
# next(iter(dataset.train_dataloader()))

# Model

In [0]:
class CatalogEmbedding(nn.Module):
    def __init__(
        self, 
        catalog_id_vocab_size,
        embedding_dim,
        use_pretrained=False,
        pretrained_weights=None,
    ):
        super().__init__()

        self.catalog_id_vocab_size = catalog_id_vocab_size
        self.embedding_dim = embedding_dim

        self.catalog_embedding = nn.Embedding(num_embeddings=catalog_id_vocab_size, embedding_dim=embedding_dim, padding_idx=0)
        if use_pretrained and pretrained_weights is not None:
            self.catalog_embedding.weight = nn.Parameter(pretrained_weights, requires_grad=False)

    def forward(self, X):
        x = self.catalog_embedding(X)
        return x

In [0]:
class UserTower(nn.Module):
    def __init__(self,
                 num_layers,
                 nhead,
                 behaviour_context_length,
                 embedding_dim,
                 event_type_vocab_size
                ):
        super(UserTower, self).__init__()

        self.nhead = nhead
        dropout_perc = 0.3 
        self.behaviour_context_length = behaviour_context_length
        self.embedding_dim = embedding_dim

        # Event type embedding
        self.event_type_embedding = nn.Embedding(event_type_vocab_size, embedding_dim)
        num_features = 2
        time_features = 0
        self.input_dim = num_features * embedding_dim + time_features

        # Transformer layers
        self.dropout = nn.Dropout(dropout_perc)
        self.norm = nn.LayerNorm(embedding_dim)
        encoder_layer = nn.TransformerEncoderLayer(embedding_dim, nhead=nhead, batch_first=True, dropout=dropout_perc, norm_first=True)
        self.tf_encoder = nn.TransformerEncoder(encoder_layer, num_layers, norm=self.norm)

        # MLP head
        self.mlp_extract = nn.Linear(self.input_dim, embedding_dim)

        self.mlp_out = nn.Sequential(
            nn.LayerNorm(embedding_dim),
            nn.Linear(embedding_dim, 4 * embedding_dim),
            nn.GELU(),
            nn.Linear(4 * embedding_dim, embedding_dim),
            nn.LayerNorm(embedding_dim)
        )

        self.future_mask = torch.triu(torch.ones(self.behaviour_context_length, self.behaviour_context_length, dtype=torch.bool), diagonal=1)
        self.register_buffer('future_mask_const', self.future_mask)
        self.register_buffer('seq_diag_const', ~torch.diag(torch.ones(self.behaviour_context_length, dtype=torch.bool)))

        div_term = torch.exp(torch.arange(0, embedding_dim, 2) * -(torch.log(torch.tensor(10000.0)) / embedding_dim))
        self.register_buffer('div_term', div_term)

    def merge_attn_masks(self, X):
        batch_size = X["behaviour_seq"].shape[0]
        seq_len = X["behaviour_seq"].shape[1]

        padding = (X["behaviour_seq"] == 0)
        masking = (X["behaviour_seq"] == 1)
        padding_mask = padding.logical_or(masking)

        padding_mask_broadcast = padding_mask.bool().unsqueeze(1)
        future_masks = torch.tile(self.future_mask_const[:seq_len, :seq_len], (batch_size, 1, 1))

        merged_masks = torch.logical_or(padding_mask_broadcast, future_masks)
        
        diag_masks = torch.tile(self.seq_diag_const[:seq_len, :seq_len], (batch_size * self.nhead, 1, 1))
        multi_head_mask = torch.tile(merged_masks.unsqueeze(1), (1, self.nhead, 1, 1)).reshape((-1, seq_len, seq_len))
        
        return torch.logical_and(diag_masks, multi_head_mask)

    def positional_encoding(self, timestamps, embedding_dim=16):
        pe = torch.stack((torch.sin(timestamps.unsqueeze(-1) * self.div_term), torch.cos(timestamps.unsqueeze(-1) * self.div_term)), dim=-1).flatten(start_dim=-2)
        return pe

    def forward(self, X, catalog_embedding):
        pe = self.positional_encoding(X['event_time_seq'], embedding_dim=self.embedding_dim)

        x = torch.cat((
            catalog_embedding(X['behaviour_seq']), 
            self.event_type_embedding(X['event_type_seq'])
        ), dim=-1)

        x = self.mlp_extract(x) + pe

        attn_mask = self.merge_attn_masks(X)

        x = self.tf_encoder(x, mask=attn_mask)

        batch_indices = torch.arange(X['selection_indices'].shape[0]).view(-1, 1)
        x = x[batch_indices, X['selection_indices'], :]

        x = self.mlp_out(x)

        x = F.normalize(x, p=2, dim=-1)

        return x

In [0]:
class CatalogTower(nn.Module):
    def __init__(self, embedding_dim):
        super(CatalogTower, self).__init__()

        self.mlp = nn.Sequential(
            nn.LayerNorm(embedding_dim),
            nn.Linear(embedding_dim, 4 * embedding_dim),
            nn.GELU(),
            nn.Linear(4 * embedding_dim, embedding_dim),
            nn.LayerNorm(embedding_dim)
        )
    
    def forward(self, x, catalog_embedding):
        x = catalog_embedding(x)
        x = self.mlp(x)
        x = F.normalize(x, p=2, dim=-1)
        return x

In [0]:
# class PinnerFormer(pl.LightningModule):
#     def __init__(self, user_tower, catalog_tower, catalog_embedding, batch_size, max_positives):
#         super().__init__()
#         self.user_tower = user_tower
#         self.catalog_tower = catalog_tower
#         self.catalog_embedding = catalog_embedding
#         self.criterion = nn.CrossEntropyLoss(reduction='none')
#         self.lr = 0.001
#         self.temperature = 0.1
#         self.max_positives = max_positives
#         self.register_buffer('pos_labels', torch.ones((batch_size*max_positives), dtype=torch.long))
    
#     def training_step(self, batch, batch_idx):
#         batch_size = batch['behaviour_seq'].shape[0]
#         context_length = batch['behaviour_seq'].shape[1]
        
#         catalog_embeddings = self.catalog_embedding(batch['behaviour_seq'])
#         user_embedding = self.user_tower(batch, self.catalog_embedding)

#         candidates = torch.cat([batch['positives'].unsqueeze(-1), batch['negatives']], dim=-1)
#         candidates_embedding = self.catalog_tower(candidates, self.catalog_embedding)
    
#         loss = self.embedding_loss(batch, user_embedding, candidates_embedding)
#         self.log('train_loss', loss.item(), on_step=True, on_epoch=True)
        
#         return loss
        
    
#     def validation_step(self, batch, batch_idx):
#         batch_size = batch['behaviour_seq'].shape[0]
#         context_length = batch['behaviour_seq'].shape[1]
        
#         catalog_embeddings = self.catalog_embedding(batch['behaviour_seq'])
#         user_embedding = self.user_tower(batch, self.catalog_embedding)

#         candidates = torch.cat([batch['positives'].unsqueeze(-1), batch['negatives']], dim=-1)
#         candidates_embedding = self.catalog_tower(candidates, self.catalog_embedding)
    
#         loss = self.embedding_loss(batch, user_embedding, candidates_embedding)
#         self.log('val_loss', loss.item(), on_step=True, on_epoch=True)
    
#     def configure_optimizers(self):
#        optimizer = torch.optim.AdamW([
#            {"params": self.user_tower.parameters()}, 
#            {"params": self.catalog_tower.parameters()},
#            {"params": self.catalog_embedding.parameters()}
#         ], lr=self.lr)        
#        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
       
#        return {
#             'optimizer': optimizer,
#             'lr_scheduler': {
#                 'scheduler': scheduler,
#                 'interval': 'step',
#                 'frequency': 500, 
#                 "name": 'lr'
#             }
#         }
    
#     def embedding_loss(self, batch, user_embedding, candidates_embedding):
#         scores = torch.matmul(user_embedding.unsqueeze(2), candidates_embedding.transpose(-2, -1)).squeeze(-2)
#         scores = scores.reshape((-1, scores.shape[-1])) / self.temperature
#         non_padding = (batch["positives"] != 0).logical_or(batch["positives"] != 1).reshape(-1)
#         loss = (self.criterion(scores, self.pos_labels[:batch["positives"].shape[0]*self.max_positives]) * non_padding).sum() / non_padding.sum()
#         return loss

In [0]:
class PinnerFormer(pl.LightningModule):
    def __init__(self, user_tower, catalog_tower, catalog_embedding, batch_size, max_positives):
        super().__init__()
        self.user_tower = user_tower
        self.catalog_tower = catalog_tower
        self.catalog_embedding = catalog_embedding
        self.criterion = nn.CrossEntropyLoss(reduction='none')
        self.lr = 0.001
        self.temperature = 0.1
        self.max_positives = max_positives
        self.register_buffer('pos_labels', torch.ones((batch_size * max_positives), dtype=torch.long))
    
    def training_step(self, batch, batch_idx):
        batch_size = batch['behaviour_seq'].shape[0]
        context_length = batch['behaviour_seq'].shape[1]
        
        catalog_embeddings = self.catalog_embedding(batch['behaviour_seq'])
        user_embedding = self.user_tower(batch, self.catalog_embedding)

        candidates = torch.cat([batch['positives'].unsqueeze(-1), batch['negatives']], dim=-1)
        candidates_embedding = self.catalog_tower(candidates, self.catalog_embedding)
    
        loss = self.embedding_loss(batch, user_embedding, candidates_embedding)

        self.log('train_loss', loss.item(), rank_zero_only=True)

        return loss
    
    def validation_step(self, batch, batch_idx):
        batch_size = batch['behaviour_seq'].shape[0]
        context_length = batch['behaviour_seq'].shape[1]
        
        catalog_embeddings = self.catalog_embedding(batch['behaviour_seq'])
        user_embedding = self.user_tower(batch, self.catalog_embedding)

        candidates = torch.cat([batch['positives'].unsqueeze(-1), batch['negatives']], dim=-1)
        candidates_embedding = self.catalog_tower(candidates, self.catalog_embedding)
    
        loss = self.embedding_loss(batch, user_embedding, candidates_embedding)

        self.log('val_loss', loss.item(), rank_zero_only=True)
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW([
            {"params": self.user_tower.parameters()}, 
            {"params": self.catalog_tower.parameters()},
            {"params": self.catalog_embedding.parameters()}
        ], lr=self.lr)        
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9, verbose=True, last_epoch=-1)
       
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'step',
                'frequency': 500, 
                "name": 'lr'
            }
        }
    
    def cosine_sim(self, a, b, eps=1e-8):
        a_n, b_n = a.norm(dim=-1).unsqueeze(-1), b.norm(dim=-1).unsqueeze(-1)
        a_norm = a / torch.clamp(a_n, min=eps)
        b_norm = b / torch.clamp(b_n, min=eps)
        sim_mt = torch.matmul(a_norm.unsqueeze(2), b_norm.transpose(-2, -1)).squeeze(2)
        return sim_mt

    def embedding_loss(self, batch, user_embedding, candidates_embedding):
        scores = self.cosine_sim(user_embedding.unsqueeze(2), candidates_embedding.transpose(-2, -1)).squeeze(-2)
        scores = scores.reshape((-1, scores.shape[-1])) / self.temperature
        non_padding = (batch["positives"] != 0).logical_or(batch["positives"] != 1).reshape(-1)
        loss = (self.criterion(scores, self.pos_labels[:batch["positives"].shape[0] * self.max_positives]) * non_padding).sum() / non_padding.sum()
        return loss

In [0]:
# catalog_embedding = CatalogEmbedding(catalog_id_vocab_size,embedding_dim)

# user_tower = UserTower(
#     num_layers,
#     nhead,
#     behaviour_context_length,
#     embedding_dim,
#     event_type_vocab_size
# )

# catalog_tower = CatalogTower(embedding_dim)

In [0]:
# x = user_tower(batch_data,catalog_embedding)

# Config

In [0]:
db_host = dbutils.notebook.entry_point.getDbutils().notebook().getContext().extraContext().apply('api_url')
db_token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()

username = 'arshi.khan@meesho.com'
experiment_path = f'/Users/{username}/long-term-embedding'
experiment_id = mlflow.set_experiment(experiment_path).experiment_id
mlflow.start_run(experiment_id = experiment_id, log_system_metrics=True)
mlflow.set_tag(key='model',value='transformer')
run_id = mlflow.active_run().info.run_id

2025/02/20 11:22:28 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.


In [0]:
# mlflow.end_run()

In [0]:
config = {
    # data params 
    "checkpoint_version": "v9.1",
    "index_path": "long_term_user_emb/indexes/v9/",
    "filtered_journeys_path": "long_term_user_emb/user_seq_final/v9/",
    "checkpoint_path": "gs://gcs-dsci-fryou-fy-dev-prd/long_term_user_emb/checkpoint_dir/v9.1/",

    # dataset params
    "behaviour_context_length": 1000,
    "all_dense_perc": 0.15,
    "future_window": 14,
    "max_positives": 64,
    "negatives": 10,
    "batch_size": 2,
    "num_workers": 1,

    # model params
    "embedding_dim": 16,
    "num_layers": 2,
    "nhead": 2,

    # training params 
    "max_epochs": 100,
    "limit_val_batches": -1,
    "limit_train_batches": 5000,
    "val_check_interval": 500,
    "overfit_batches": 1,
    "log_every_n_steps": 1,
    "accelerator": "gpu",
    "num_gpu": 1,
    "num_nodes": 1,
    "num_sanity_val_steps": 0
}

# config = {
#     # data params 
#     "checkpoint_version": "v10.1",
#     "index_path": "long_term_user_emb/indexes/v10/",
#     "filtered_journeys_path": "long_term_user_emb/user_seq_final/v10/",
#     "checkpoint_path": "gs://gcs-dsci-fryou-fy-dev-prd/long_term_user_emb/checkpoint_dir/v10.1/",

#     # dataset params
#     "behaviour_context_length": 1000,
#     "all_dense_perc": 0.15,
#     "future_window": 14,
#     "max_positives": 64,
#     "negatives": 500,
#     "batch_size": 64,
#     "num_workers": 8,

#     # model params
#     "embedding_dim": 32,
#     "num_layers": 2,
#     "nhead": 2,

#     # training params 
#     "max_epochs": 1,
#     "limit_val_batches": 0,
#     "limit_train_batches": 5000,
#     "val_check_interval": 500,
#     "overfit_batches": 0,
#     "log_every_n_steps": 50,
#     "accelerator": "gpu",
#     "num_gpu": 1,
#     "num_nodes": 1,
#     "num_sanity_val_steps": 2
# }


In [0]:
config['num_gpu'] = torch.cuda.device_count()

In [0]:
mlflow.log_params(config)

In [0]:
config["filtered_journeys_path"]

'long_term_user_emb/user_seq_final/v9/'

In [0]:
files = [ele.name for ele in bucket.list_blobs(prefix=config["filtered_journeys_path"]) if 'parquet' in ele.name]  
files = [f"gs://gcs-dsci-fryou-fy-dev-prd/{ele}" for ele in files]
files = np.array(files).astype(np.string_)

WORLD_SIZE = config["num_nodes"] if config["num_gpu"] == 0 else config["num_gpu"] * config["num_nodes"]

files = get_equally_distributed_files(files,WORLD_SIZE)
train_files = files[:-1*WORLD_SIZE]
val_files = files[-1*WORLD_SIZE:]
print(f"num train files:{len(train_files)}, num val files:{len(val_files)}")

total files inside fn:  154
value of end inside fn:  152
num train files:148, num val files:4


In [0]:
pl.seed_everything(42, workers=True)
SPECIAl_TOKENS = {
    "<PAD>":0,
    "<MASK>":1
}

event_type_index, event_type_vocab_size = get_index(config["index_path"], "event_type", SPECIAl_TOKENS)
catalog_id_index, catalog_id_vocab_size = get_index(config["index_path"], "catalog_id", SPECIAl_TOKENS)

print(f"catalog vocab size:{catalog_id_vocab_size/1000000} M")

INFO: [rank: 0] Seed set to 42
INFO:lightning.fabric.utilities.seed:[rank: 0] Seed set to 42


catalog vocab size:8.40019 M


# Init Model and Dataset

In [0]:
dataset = PinnerFormerDataset(
    train_files,
    val_files,
    config["behaviour_context_length"],
    config["all_dense_perc"],
    config["future_window"],
    config["max_positives"],
    config["negatives"],
    config["batch_size"],
    config["num_workers"],
)


In [0]:
# dataset.setup("fit")
# batch_data = next(iter(dataset.train_dataloader()))

In [0]:
catalog_embedding = CatalogEmbedding(catalog_id_vocab_size,config["embedding_dim"])

user_tower = UserTower(
    config["num_layers"],
    config["nhead"],
    config["behaviour_context_length"],
    config["embedding_dim"],
    event_type_vocab_size
)
catalog_tower = CatalogTower(config["embedding_dim"])

model = PinnerFormer(user_tower, catalog_tower, catalog_embedding, config["batch_size"], config["max_positives"])

# Trainer Test

In [0]:
mlflow.set_experiment(experiment_path)
mlf_logger = pl.loggers.MLFlowLogger(experiment_name=experiment_path, run_id=run_id)

In [0]:
trainer = Trainer(
    precision=16,
    max_epochs=config["max_epochs"],
    limit_val_batches=config["limit_val_batches"],
    limit_train_batches=config["limit_train_batches"],
    val_check_interval = config["val_check_interval"],
    overfit_batches=config["overfit_batches"],
    log_every_n_steps = config["log_every_n_steps"],
    check_val_every_n_epoch=None,
    accelerator=config["accelerator"],
    strategy="auto",
    logger=mlf_logger,
    num_sanity_val_steps = config["num_sanity_val_steps"],
)

/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/lightning/fabric/connector.py:563: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
INFO: Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
INFO:lightning.pytorch.utilities.rank_zero:Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU

In [0]:
trainer.fit(model, dataset)

INFO: You are using a CUDA device ('NVIDIA L4') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:lightning.pytorch.utilities.rank_zero:You are using a CUDA device ('NVIDIA L4') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
INFO: 
  | Name              | Type             | Params
---------------------------------------------------

Adjusting learning rate of group 0 to 1.0000e-03.
Adjusting learning rate of group 1 to 1.0000e-03.
Adjusting learning rate of group 2 to 1.0000e-03.


[0;31m---------------------------------------------------------------------------[0m
[0;31mValueError[0m                                Traceback (most recent call last)
File [0;32m<command-720585262727022>, line 1[0m
[0;32m----> 1[0m [43mtrainer[49m[38;5;241;43m.[39;49m[43mfit[49m[43m([49m[43mmodel[49m[43m,[49m[43m [49m[43mdataset[49m[43m)[49m

File [0;32m/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:544[0m, in [0;36mTrainer.fit[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)[0m
[1;32m    542[0m [38;5;28mself[39m[38;5;241m.[39mstate[38;5;241m.[39mstatus [38;5;241m=[39m TrainerStatus[38;5;241m.[39mRUNNING
[1;32m    543[0m [38;5;28mself[39m[38;5;241m.[39mtraining [38;5;241m=[39m [38;5;28;01mTrue[39;00m
[0;32m--> 544[0m [43mcall[49m[38;5;241;43m.[39;49m[43m_call_and_handle_interrupt[49m[43m([49m
[1;32m    545[0m [43m    

# Training

In [0]:
callback_checkpoint = pl.callbacks.ModelCheckpoint(
  dirpath=config["checkpoint_path"], 
  filename='{epoch}-{step}-{val_loss_5:.2f}',
  save_top_k = -1
)

[0;31m---------------------------------------------------------------------------[0m
[0;31mValueError[0m                                Traceback (most recent call last)
File [0;32m<command-720585262727022>, line 1[0m
[0;32m----> 1[0m [43mtrainer[49m[38;5;241;43m.[39;49m[43mfit[49m[43m([49m[43mmodel[49m[43m,[49m[43m [49m[43mdataset[49m[43m)[49m

File [0;32m/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:544[0m, in [0;36mTrainer.fit[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)[0m
[1;32m    542[0m [38;5;28mself[39m[38;5;241m.[39mstate[38;5;241m.[39mstatus [38;5;241m=[39m TrainerStatus[38;5;241m.[39mRUNNING
[1;32m    543[0m [38;5;28mself[39m[38;5;241m.[39mtraining [38;5;241m=[39m [38;5;28;01mTrue[39;00m
[0;32m--> 544[0m [43mcall[49m[38;5;241;43m.[39;49m[43m_call_and_handle_interrupt[49m[43m([49m
[1;32m    545[0m [43m    

In [0]:
def main_training_loop(num_tasks, num_proc_per_task):
 
  from torch import optim, nn, utils, Tensor
  from torchvision import datasets, transforms
  from lightning.pytorch.callbacks import LearningRateMonitor
  import pytorch_lightning as pl
  import mlflow
  
  print("running distributed training")
  ############################
  ##### Setting up MLflow ####
  # We need to do this so that different processes that will be able to find mlflow
  os.environ['DATABRICKS_HOST'] = db_host
  os.environ['DATABRICKS_TOKEN'] = db_token
  
  # NCCL P2P can cause issues with incorrect peer settings, so let's turn this off to scale for now
  os.environ["NCCL_SOCKET_IFNAME"] = "eth0"
  os.environ['NCCL_P2P_DISABLE'] = '1'
  os.environ['NCCL_DEBUG']='INFO'
  
  mlflow.set_experiment(experiment_path)
  mlf_logger = pl.loggers.MLFlowLogger(experiment_name=experiment_path, run_id=run_id)
  callback_lr_monitor = LearningRateMonitor(logging_interval='step')
  
  trainer = Trainer(
    precision=16,
    max_epochs=config["max_epochs"],
    limit_val_batches=config["limit_val_batches"],
    limit_train_batches=config["limit_train_batches"],
    val_check_interval = config["val_check_interval"],
    overfit_batches=config["overfit_batches"],
    log_every_n_steps = config["log_every_n_steps"],
    check_val_every_n_epoch=None,
    accelerator=config["accelerator"],
    devices=num_proc_per_task, 
    strategy="ddp",
    logger=mlf_logger,
    num_sanity_val_steps = config["num_sanity_val_steps"],
    callbacks=[callback_checkpoint, callback_lr_monitor]
  )
  trainer.fit(model, dataset)

[0;31m---------------------------------------------------------------------------[0m
[0;31mValueError[0m                                Traceback (most recent call last)
File [0;32m<command-720585262727022>, line 1[0m
[0;32m----> 1[0m [43mtrainer[49m[38;5;241;43m.[39;49m[43mfit[49m[43m([49m[43mmodel[49m[43m,[49m[43m [49m[43mdataset[49m[43m)[49m

File [0;32m/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:544[0m, in [0;36mTrainer.fit[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)[0m
[1;32m    542[0m [38;5;28mself[39m[38;5;241m.[39mstate[38;5;241m.[39mstatus [38;5;241m=[39m TrainerStatus[38;5;241m.[39mRUNNING
[1;32m    543[0m [38;5;28mself[39m[38;5;241m.[39mtraining [38;5;241m=[39m [38;5;28;01mTrue[39;00m
[0;32m--> 544[0m [43mcall[49m[38;5;241;43m.[39;49m[43m_call_and_handle_interrupt[49m[43m([49m
[1;32m    545[0m [43m    

In [0]:
NUM_TASKS = 1
NUM_PROC_PER_TASK = config["num_gpu"]
NUM_PROC = NUM_TASKS * NUM_PROC_PER_TASK

[0;31m---------------------------------------------------------------------------[0m
[0;31mValueError[0m                                Traceback (most recent call last)
File [0;32m<command-720585262727022>, line 1[0m
[0;32m----> 1[0m [43mtrainer[49m[38;5;241;43m.[39;49m[43mfit[49m[43m([49m[43mmodel[49m[43m,[49m[43m [49m[43mdataset[49m[43m)[49m

File [0;32m/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:544[0m, in [0;36mTrainer.fit[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)[0m
[1;32m    542[0m [38;5;28mself[39m[38;5;241m.[39mstate[38;5;241m.[39mstatus [38;5;241m=[39m TrainerStatus[38;5;241m.[39mRUNNING
[1;32m    543[0m [38;5;28mself[39m[38;5;241m.[39mtraining [38;5;241m=[39m [38;5;28;01mTrue[39;00m
[0;32m--> 544[0m [43mcall[49m[38;5;241;43m.[39;49m[43m_call_and_handle_interrupt[49m[43m([49m
[1;32m    545[0m [43m    

In [0]:
from pyspark.ml.torch.distributor import TorchDistributor

TorchDistributor(
    num_processes=NUM_PROC, local_mode=False, use_gpu=True
).run(main_training_loop, NUM_TASKS, NUM_PROC_PER_TASK)


[0;31m---------------------------------------------------------------------------[0m
[0;31mValueError[0m                                Traceback (most recent call last)
File [0;32m<command-720585262727022>, line 1[0m
[0;32m----> 1[0m [43mtrainer[49m[38;5;241;43m.[39;49m[43mfit[49m[43m([49m[43mmodel[49m[43m,[49m[43m [49m[43mdataset[49m[43m)[49m

File [0;32m/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:544[0m, in [0;36mTrainer.fit[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)[0m
[1;32m    542[0m [38;5;28mself[39m[38;5;241m.[39mstate[38;5;241m.[39mstatus [38;5;241m=[39m TrainerStatus[38;5;241m.[39mRUNNING
[1;32m    543[0m [38;5;28mself[39m[38;5;241m.[39mtraining [38;5;241m=[39m [38;5;28;01mTrue[39;00m
[0;32m--> 544[0m [43mcall[49m[38;5;241;43m.[39;49m[43m_call_and_handle_interrupt[49m[43m([49m
[1;32m    545[0m [43m    

In [0]:
# from pyspark.ml.torch.distributor import TorchDistributor

# TorchDistributor(num_processes=NUM_PROC, local_mode=True, use_gpu=True).run(main_training_loop, NUM_TASKS, NUM_PROC_PER_TASK) 

[0;31m---------------------------------------------------------------------------[0m
[0;31mValueError[0m                                Traceback (most recent call last)
File [0;32m<command-720585262727022>, line 1[0m
[0;32m----> 1[0m [43mtrainer[49m[38;5;241;43m.[39;49m[43mfit[49m[43m([49m[43mmodel[49m[43m,[49m[43m [49m[43mdataset[49m[43m)[49m

File [0;32m/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:544[0m, in [0;36mTrainer.fit[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)[0m
[1;32m    542[0m [38;5;28mself[39m[38;5;241m.[39mstate[38;5;241m.[39mstatus [38;5;241m=[39m TrainerStatus[38;5;241m.[39mRUNNING
[1;32m    543[0m [38;5;28mself[39m[38;5;241m.[39mtraining [38;5;241m=[39m [38;5;28;01mTrue[39;00m
[0;32m--> 544[0m [43mcall[49m[38;5;241;43m.[39;49m[43m_call_and_handle_interrupt[49m[43m([49m
[1;32m    545[0m [43m    