In [39]:
import torch
import os
import sys
sys.path.append(os.path.abspath('../'))
from models.modules import TabFormerBertLM, TabFormerBertModel, TabFormerHierarchicalLM
from dataset.vocab import Vocabulary
import pickle
import torch
from transformers import BertConfig

def load_pretrained_model(model_dir, vocab_path):
    vocab = pickle.load(open(vocab_path, "rb"))
    # Step 1: Load the config
    config_path = f"{model_dir}/config.json"
    config = BertConfig.from_json_file(config_path)

    # Step 2: Initialize the model
    model = TabFormerHierarchicalLM(config, vocab)

    # Step 3: Load the model weights
    model_weights_path = f"{model_dir}/pytorch_model.bin"
    state_dict = torch.load(model_weights_path, map_location=torch.device("cpu"))
    model.load_state_dict(state_dict)

    model.eval()  # Set model to evaluation mode
    return model

In [90]:
def get_cls_embedding(model, input_tokens, vocab):
    # Convert input_tokens (list of token IDs) to a PyTorch tensor
    input_tensor = torch.tensor([input_tokens], dtype=torch.long)  # Add batch dimension

    with torch.no_grad():
        outputs = model(input_tensor)  # Pass tensor to model

    cls_embedding = outputs[0][:, 0, :]  # Extract CLS token embedding
    return cls_embedding.squeeze(0).numpy()  # Convert to NumPy array for easy use


In [40]:
# Temporary data_path
data_path = "/data/IDEA_DeFi_Research/Data/eCommerce/Cosmetics/"

exp_name = "debug"
model_path = data_path + f"preprocessed/output/{exp_name}/final-model"
vocab_path = data_path + "preprocessed/vocab/vocab_ob"

model = load_pretrained_model(model_path, vocab_path)

In [86]:
# Load the pickle file containing transactions, IDs, and columns
def load_trans_rids_columns_from_pkl(pkl_path):
    with open(pkl_path, "rb") as f:
        data = pickle.load(f)  # Expecting a dictionary with "trans", "RIDs", and "columns" keys
    
    transactions = data["trans"]  # List of tokenized sequences
    transaction_ids = data["RIDs"]  # Corresponding transaction ID lists
    columns = data["columns"]  # Column names (used for vocab field names)

    assert len(transactions) == len(transaction_ids), "Mismatch between transactions and RIDs"
    
    return transactions, transaction_ids, columns

# Convert token sequences to embeddings while associating with the correct field
def process_and_embed(pkl_path, model, vocab):
    transactions, transaction_ids, columns = load_trans_rids_columns_from_pkl(pkl_path)

    embeddings_dict = {}  # Store {last_transaction_id: CLS_embedding}

    for i in range(len(transactions)):
        tokens = transactions[i]  # Tokenized sequence of transactions
        last_transaction_id = transaction_ids[i][-1]  # Get the last transaction ID

        # Ensure tokens and columns match in length
        if len(tokens) != len(columns):
            print(f"Skipping sequence {i}: Token and column count mismatch.")
            continue

        # Convert tokens to numerical IDs using the corresponding field names
        input_ids = []
        for token, field_name in zip(tokens, columns):  # Map each token to its field
            try:
                token_id = vocab.get_id(token, field_name)  # Use the correct field
                input_ids.append(token_id)
            except KeyError:
                print(f"Token '{token}' not found in vocab under field '{field_name}'")

        # Compute CLS embedding if valid input exists
        if input_ids:
            cls_embedding = get_cls_embedding(model, input_ids, vocab)
            embeddings_dict[last_transaction_id] = cls_embedding

    return embeddings_dict

In [91]:
data_to_encode_path = f"{data_path}preprocessed/preprocessed/transactions_user_time_test.user.pkl"
cls_embeddings_with_last_id = process_and_embed(data_to_encode_path, model, vocab)

# Print example output
for trans_id, embedding in cls_embeddings_with_last_id.items():
    print(f"Last Transaction ID: {trans_id}, CLS Embedding Shape: {embedding.shape}")
    break  # Show one example

# Convert to DataFrame for saving
df = pd.DataFrame.from_dict(cls_embeddings_with_last_id, orient="index")
df.reset_index(inplace=True)
df.rename(columns={"index": "transaction_id"}, inplace=True)

# Save to CSV
df.to_csv("cls_embeddings_with_ids.csv", index=False)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (44x64 and 2880x2880)

In [48]:
with open(data_to_encode_path, "rb") as f:
    data = pickle.load(f)  # Expecting a list of dicts or tuples (id, tokens)

transaction_ids = []
token_sequences = []


In [61]:
data.keys()

dict_keys(['trans', 'labels', 'RIDs', 'columns'])

In [83]:
print(data['labels'][0:5])

[[13.98140797394406], [14.901737125956725, 14.901737125956725, 14.901737125956725, 14.901737125956725, 14.901737125956725, 14.901737125956725, 14.901737125956725, 14.901737125956725, 14.901737125956725, 14.901737125956725, 14.901737125956725, 14.901737125956725, 14.901737125956725, 14.901737125956725, 14.901737125956725, 14.901737125956725, 14.901737125956725, 14.901737125956725, 14.901737125956725, 14.901737125956725, 14.901737125956725, 14.901737125956725, 14.901737125956725, 14.901737125956725], [15.354779622007266, 15.354779622007266, 15.354779622007266, 15.354779622007266, 15.354779622007266, 15.354779622007266, 15.354779622007266], [15.40303266391096], [16.13551475364073]]


In [85]:
print(data['RIDs'][0:5])

[[15.016693875286428], [14.81294763443448, 14.81295648615924, 14.812960543173569, 14.813022871597544, 14.813067494887552, 14.813082983497893, 14.81308999017138, 14.813143460536365, 14.813197296773918, 14.813226794853042, 14.813255554642437, 14.81329353104582, 14.813383120147126, 14.81348302280663, 14.813490395301178, 14.813526151128691, 14.81352762556836, 14.81353978961266, 14.813558219700884, 14.813560799886131, 14.81356522304535, 14.813566328832096, 14.81368722080863, 14.814296986473032], [13.46599847083523, 14.828855765655373, 14.82887028565504, 14.828872826633305, 14.828879723541766, 14.828881538509773, 14.828885168435898], [14.949156279710778], [14.945504581790887]]
