In [None]:
%%bash
DATASET="wiki10-31k"
wget -ncq https://archive.org/download/xr-transformer-demos/${DATASET}-bert.tar.gz
mkdir -p ./work_dir/xr-transformer-encoder
tar -zxf ./${DATASET}-bert.tar.gz -C ./work_dir/xr-transformer-encoder
find ./work_dir/xr-transformer-encoder/*

In [2]:
import logging
import numpy as np
from pecos.utils import smat_util, logging_util
DATASET="wiki10-31k"
# set logging level to WARNING(1)
# you can change this to INFO(2) or DEBUG(3) if you'd like to see more logging
LOGGER = logging.getLogger(__name__)
logging_util.setup_logging_config(level=1)

# load training data
X_feat_trn = smat_util.load_matrix(f"xmc-base/{DATASET}/tfidf-attnxml/X.trn.npz", dtype=np.float32)
Y_trn = smat_util.load_matrix(f"xmc-base/{DATASET}/Y.trn.npz", dtype=np.float32)

with open(f"xmc-base/{DATASET}/X.trn.txt", 'r') as fin:
    X_txt_trn = [xx.strip() for xx in fin.readlines()]

# load test data
X_feat_tst = smat_util.load_matrix(f"xmc-base/{DATASET}/tfidf-attnxml/X.tst.npz", dtype=np.float32)
Y_tst = smat_util.load_matrix(f"xmc-base/{DATASET}/Y.tst.npz", dtype=np.float32)

with open(f"xmc-base/{DATASET}/X.tst.txt", 'r') as fin:
    X_txt_tst = [xx.strip() for xx in fin.readlines()]

In [None]:
import json
from model import XTransformer
import requests
# get XR-Transformer training params
#param_url = "https://raw.githubusercontent.com/amzn/pecos/mainline/examples/xr-transformer-neurips21/params/wiki10-31k/bert/params.json"
file_path = f'./params/{DATASET}/bert/params.json'
with open(file_path, "r", encoding="utf-8") as f:
    params = json.load(f)
# param_url = "https://raw.githubusercontent.com/amzn/pecos/mainline/examples/xr-transformer-neurips21/params/wiki10-31k/bert/params.json"
# params = json.loads(requests.get(param_url).text)    
cur_train_params = XTransformer.TrainParams.from_dict(params["train_params"])
cur_pred_params = XTransformer.PredParams.from_dict(params["pred_params"])

# you can view the detailed parameter setting via
# print(json.dumps(cur_train_params.to_dict(), indent=True))
# print(json.dumps(cur_pred_params.to_dict(), indent=True))

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
print(cur_train_params)

XTransformer.TrainParams(preliminary_indexer_params=HierarchicalKMeans.TrainParams(nr_splits=16, min_codes=128, max_leaf_size=16, spherical=True, seed=10001, kmeans_max_iter=20, threads=-1, do_sample=False, max_sample_rate=1.0, min_sample_rate=0.1, warmup_ratio=0.4), refined_indexer_params=HierarchicalKMeans.TrainParams(nr_splits=16, min_codes=None, max_leaf_size=16, spherical=True, seed=10001, kmeans_max_iter=20, threads=-1, do_sample=False, max_sample_rate=1.0, min_sample_rate=0.1, warmup_ratio=0.4), matcher_params_chain=[TransformerMatcher.TrainParams(model_shortcut='bert-base-uncased', negative_sampling='tfn+man', loss_function='weighted-squared-hinge', bootstrap_method='weighted-linear', lr_schedule='linear', threshold=0.001, hidden_dropout_prob=0.1, batch_size=32, batch_gen_workers=16, max_active_matching_labels=1000, max_num_labels_in_gpu=65536, max_steps=1000, max_no_improve_cnt=-1, num_train_epochs=10, gradient_accumulation_steps=1, weight_decay=0.0, max_grad_norm=1.0, learnin

## Baseline XR-linear

In [None]:
from pecos.xmc import Indexer, LabelEmbeddingFactory
cluster_chain = Indexer.gen(
    LabelEmbeddingFactory.create(Y_trn, X_feat_trn, method="pifa"),
    train_params=cur_train_params.refined_indexer_params,
)

print(cluster_chain.chain)

[<8x1 sparse matrix of type '<class 'numpy.float32'>'
	with 8 stored elements in Compressed Sparse Column format>, <128x8 sparse matrix of type '<class 'numpy.float32'>'
	with 128 stored elements in Compressed Sparse Column format>, <2048x128 sparse matrix of type '<class 'numpy.float32'>'
	with 2048 stored elements in Compressed Sparse Column format>, <30938x2048 sparse matrix of type '<class 'numpy.float32'>'
	with 30938 stored elements in Compressed Sparse Column format>]


In [None]:
# construct label hierarchy
from pecos.xmc import Indexer, LabelEmbeddingFactory
cluster_chain = Indexer.gen(
    LabelEmbeddingFactory.create(Y_trn, X_feat_trn, method="pifa"),
    train_params=cur_train_params.refined_indexer_params,
)

# train XR-Linear model
from pecos.xmc.xlinear import XLinearModel
xlm = XLinearModel.train(
    X_feat_trn,
    Y_trn,
    C=cluster_chain,
    train_params=cur_train_params.ranker_params,
    pred_params=cur_pred_params.ranker_params,
)

# predict on test set with XR-Linear model
P_xlm = xlm.predict(X_feat_tst)

# compute metrics using ground truth
metrics = smat_util.Metrics.generate(Y_tst, P_xlm)
print("Evaluation metrics of XR-Linear model")
print(metrics)

## Without Fine-tune XR-transformer

In [None]:
# define the problem
from module import MLProblemWithText
prob = MLProblemWithText(X_txt_trn, Y_trn, X_feat=X_feat_trn)

# disable fine-tuning, directly use pre-trained bert model from huggingface
cur_train_params.do_fine_tune = False

# train XR-Transformer (without fine-tuning)
# this will be slow on CPU only machine
xrt_pretrained = XTransformer.train(
    prob,
    train_params=cur_train_params,
    pred_params=cur_pred_params,
)

# predict and compute metrics
P_xrt_pretrained = xrt_pretrained.predict(X_txt_tst, X_feat=X_feat_tst)
metrics = smat_util.Metrics.generate(Y_tst, P_xrt_pretrained)
print("Evaluation metrics of XR-Transformer (not fine-tuned)")
print(metrics)

train model xr-transformer

In [6]:
print(Y_trn.shape[0])

14146


In [None]:
# define the problem
from module import MLProblemWithText
prob = MLProblemWithText(X_txt_trn, Y_trn, X_feat=X_feat_trn)
from pecos.xmc import Indexer, LabelEmbeddingFactory
cluster_chain = Indexer.gen(
    LabelEmbeddingFactory.create(Y_trn, X_feat_trn, method="pifa"),
    train_params=cur_train_params.preliminary_indexer_params,# 这里是使用preliminary_indexer_params
    # preliminary_indexer_params是xrtransformer的训练参数
)
print(cluster_chain.chain)

DO_FINE_TUNE_NOW = True

if DO_FINE_TUNE_NOW:
    cur_train_params.do_fine_tune = True
else:
    # skip fine-tuning and use existing fine-tuned encoder
    cur_train_params.do_fine_tune = False
    cur_train_params.matcher_params_chain[0].init_model_dir = "./work_dir/xr-transformer-encoder/wiki10-31k/bert/text_encoder"

# this will be slow on CPU only machine
xrt_fine_tuned = XTransformer.train(
    prob,
    clustering=cluster_chain,
    train_params=cur_train_params,
    pred_params=cur_pred_params,
)

P_xrt_fine_tuned = xrt_fine_tuned.predict(X_txt_tst, X_feat=X_feat_tst)
metrics = smat_util.Metrics.generate(Y_tst, P_xrt_fine_tuned, topk=10)
print("Evaluation metrics of XR-Transformer")
print(metrics)

[<8x1 sparse matrix of type '<class 'numpy.float32'>'
	with 8 stored elements in Compressed Sparse Column format>, <128x8 sparse matrix of type '<class 'numpy.float32'>'
	with 128 stored elements in Compressed Sparse Column format>, <2048x128 sparse matrix of type '<class 'numpy.float32'>'
	with 2048 stored elements in Compressed Sparse Column format>, <30938x2048 sparse matrix of type '<class 'numpy.float32'>'
	with 30938 stored elements in Compressed Sparse Column format>]


ValueError: len(params.matcher_params_chain)=3 != 4

In [None]:
"""
load model 
"""
model_folder = "./work_dir/my_xrt"
xrt_fine_tuned.save(model_folder)
del xrt_fine_tuned
xrt_fine_tuned = XTransformer.load(model_folder, is_predict_only=True)