From ae86ffcd50edcef889dc7256dd786f0f75ffb6d2 Mon Sep 17 00:00:00 2001 From: stbaione Date: Fri, 12 Dec 2025 21:14:17 +0000 Subject: [PATCH] - Add file and instructions for validating the ROCm accuracy against a tabular classification transformer model --- README.md | 136 +++++++ examples/tab_transform_pytorch/client.py | 337 ++++++++++++++++++ examples/tab_transform_pytorch/config.pbtxt | 33 ++ .../generate_reference.py | 167 +++++++++ examples/tab_transform_pytorch/model.py | 148 ++++++++ .../tab_transform_pytorch/requirements.txt | 2 + 6 files changed, 823 insertions(+) create mode 100644 examples/tab_transform_pytorch/client.py create mode 100644 examples/tab_transform_pytorch/config.pbtxt create mode 100644 examples/tab_transform_pytorch/generate_reference.py create mode 100644 examples/tab_transform_pytorch/model.py create mode 100644 examples/tab_transform_pytorch/requirements.txt diff --git a/README.md b/README.md index 8d867a2d..1d8c1ab4 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,13 @@ any C++ code. - [Model Instance Kind](#model-instance-kind) - [Auto-complete config](#auto-complete-config) - [Custom Metrics](#custom-metrics-1) +- [Validate Tabular Accuracy for ROCm](#validate-tabular-accuracy-for-rocm) + - [Setup](#setup) + - [Collect CPU Reference Outputs](#collect-cpu-reference-outputs) + - [Copy artifacts to triton-server repository](#copy-artifacts-to-triton-server-repository) + - [Add dependency to triton-server container](#add-dependency-to-triton-server-container) + - [Start the Triton Server](#start-the-triton-server) + - [Test the model](#test-the-model) - [Running with Inferentia](#running-with-inferentia) - [Logging](#logging) - [Reporting problems, asking questions](#reporting-problems-asking-questions) @@ -1737,6 +1744,135 @@ The example shows how to use custom metrics API in Python Backend. You can find the complete example instructions in [examples/custom_metrics](examples/custom_metrics/README.md). +# Validate Tabular Accuracy for ROCm + +This section shows how to validate the accuracy of the `python_backend` ROCm +enabled implementation for the [FTTransformer](https://github.com/lucidrains/tab-transformer-pytorch/blob/main/README.md#ft-transformer) model, which is a classification transformer for tabular data. + +You can find the reference paper for the model from the link above. + +For this experiment, we generate random weights and inputs, then collect the outputs from +running the `FTTransformer` model on CPU. + +We then re-use those same weights and inputs, and run the `FTTransformer` model +on ROCm, via the [Triton Inference Server](https://github.com/ROCm/triton-inference-server-server/tree/rocm_python_backend). + +We then compare the outputs from CPU and ROCm, and calculate the accuracy. + +## Setup + +You will need to have two terminals open for this test. + +1. Clone the [Triton Inference Server](https://github.com/ROCm/triton-inference-server-server/tree/rocm_python_backend) repository. + +``` +git clone https://github.com/ROCm/triton-inference-server-server.git +cd triton-inference-server-server +git checkout rocm_python_backend +``` + +2. In a separate terminal, build this repository [from source](#building-from-source). + +3. The rest of the necessary files for this test are included in the [examples/tab_transform_pytorch](examples/tab_transform_pytorch) directory of this repository. + +## Collect CPU Reference Outputs + +> [!NOTE] +> These should all be done from the `python_backend` terminal. + +After building from source, install the additional dependencies for this example. +``` +pip install -r examples/tab_transform_pytorch/requirements.txt +``` + +Then, generate the reference outputs. +``` +cd examples/tab_transform_pytorch +python generate_reference.py --output-dir . --num-samples 10000 --seed 42 +``` + +This will generate the following files: +- `ft_transformer.pt` - The model weights. +- `reference_inputs.npz` - The input data. +- `reference_outputs.npz` - The expected outputs from the model. + +## Copy artifacts to triton-server repository + +Next, we need to copy over our artifacts to the `triton-server` repo to run the model on ROCm. + +> [!NOTE] +> These should all be done from the `triton-server` terminal. + +1. Create the model repository. + +```bash +mkdir -p models/tab_transform_pytorch/1/ +``` + +2. Copy over the model and config files. + +This will define the model that we'll run on ROCm. + +```bash +cp ../python_backend/examples/tab_transform_pytorch/model.py models/tab_transform_pytorch/1/model.py +cp ../python_backend/examples/tab_transform_pytorch/config.pbtxt models/tab_transform_pytorch/config.pbtxt +``` + +3. Copy over the weights. + +This will copy over the weights that we generated earlier. + +```bash +cp ../python_backend/examples/tab_transform_pytorch/ft_transformer.pt models/tab_transform_pytorch/1/ft_transformer.pt +``` + +## Add dependency to triton-server container + +We will need to add the `tab-transformer-pytorch` dependency to the triton-server container. + +Search for the `pip3 install --upgrade "numpy<2"` line in the +`docker_prepare_container_linux` function of the `build.py` file. + +Add the following line after the `numpy` installation: + +```python +pip3 install --upgrade "tab-transformer-pytorch>=0.5.1" && \ +``` + +## Start the Triton Server + +See the [Triton Inference Server](https://github.com/ROCm/triton-inference-server-server/tree/rocm_python_backend) repository for instructions on how to start the server. Use the new model repository we created earlier. + +Build the container and run the server according to the instructions. + +## Test the model + +> [!NOTE] +> These should all be done from the `python_backend` terminal. + +Run the client script in the [examples/tab_transform_pytorch](examples/tab_transform_pytorch) directory +to verify that the GPU outputs match the CPU outputs. + +```bash +cd examples/tab_transform_pytorch +python client.py --verify --reference-dir . --tolerance 1e-5 +``` + +This should print out the results of the verification: + +```bash +Results: + Max absolute difference: 4.17e-07 + Mean absolute difference: 7.86e-08 + Tolerance: 1.00e-05 + Samples exceeding tolerance: 0/10000 + +============================================================ +PASS: All 10000 samples within tolerance (1e-05) + GPU implementation matches CPU reference! +============================================================ +``` + # Running with Inferentia Please see the diff --git a/examples/tab_transform_pytorch/client.py b/examples/tab_transform_pytorch/client.py new file mode 100644 index 00000000..67821f88 --- /dev/null +++ b/examples/tab_transform_pytorch/client.py @@ -0,0 +1,337 @@ +""" +Client for testing FTTransformer model on Triton Inference Server. + +Usage: + # Random inference test + python client.py [--url URL] [--batch-size N] [--verbose] + + # Verification mode (compare GPU outputs against CPU reference) + python client.py --verify --reference-dir ./1 [--url URL] [--tolerance T] + +Example: + python client.py --url localhost:8000 --batch-size 2 --verbose + python client.py --verify --reference-dir ./1 --tolerance 1e-5 +""" + +import argparse +import os +import sys + +import numpy as np +import tritonclient.http as httpclient +from tritonclient.utils import np_to_triton_dtype + +MODEL_NAME = "tab_transform_pytorch" + +# Model configuration (should match model.py and config.pbtxt) +NUM_CATEGORIES = 5 # Number of categorical features +NUM_CONTINUOUS = 10 # Number of continuous features +CATEGORIES = (10, 5, 6, 5, 8) # Max unique values per categorical feature + + +def generate_categorical_input(batch_size: int) -> np.ndarray: + """Generate random categorical input data. + + Each categorical feature value must be in range [0, max_category_value). + """ + categorical_data = np.zeros((batch_size, NUM_CATEGORIES), dtype=np.int64) + for i, max_val in enumerate(CATEGORIES): + categorical_data[:, i] = np.random.randint(0, max_val, size=batch_size) + return categorical_data + + +def generate_continuous_input(batch_size: int) -> np.ndarray: + """Generate random continuous input data (normalized).""" + return np.random.randn(batch_size, NUM_CONTINUOUS).astype(np.float32) + + +def run_inference( + client: httpclient.InferenceServerClient, + categorical_data: np.ndarray, + continuous_data: np.ndarray, + verbose: bool = False, +) -> np.ndarray: + """Run inference and return output array.""" + if verbose: + print(f"\nCategorical input (INPUT0):\n{categorical_data}") + print(f"\nContinuous input (INPUT1):\n{continuous_data}") + + # Prepare inputs + inputs = [ + httpclient.InferInput( + "INPUT0", + categorical_data.shape, + np_to_triton_dtype(categorical_data.dtype), + ), + httpclient.InferInput( + "INPUT1", + continuous_data.shape, + np_to_triton_dtype(continuous_data.dtype), + ), + ] + inputs[0].set_data_from_numpy(categorical_data) + inputs[1].set_data_from_numpy(continuous_data) + + # Prepare outputs + outputs = [httpclient.InferRequestedOutput("OUTPUT0")] + + # Run inference + response = client.infer(MODEL_NAME, inputs, request_id=str(1), outputs=outputs) + + output = response.as_numpy("OUTPUT0") + if output is None: + raise RuntimeError("No output data received from server") + return output + + +def run_random_inference(url: str, batch_size: int, verbose: bool = False) -> bool: + """Run inference with random inputs (original behavior). + + Args: + url: Triton server URL (e.g., "localhost:8000") + batch_size: Number of samples in the batch + verbose: Whether to print detailed output + + Returns: + True if inference succeeded, False otherwise + """ + try: + with httpclient.InferenceServerClient(url, verbose=verbose) as client: + # Check if model is ready + if not client.is_model_ready(MODEL_NAME): + print(f"ERROR: Model '{MODEL_NAME}' is not ready on server") + return False + + # Generate input data + categorical_data = generate_categorical_input(batch_size) + continuous_data = generate_continuous_input(batch_size) + + # Run inference + output_data = run_inference( + client, categorical_data, continuous_data, verbose + ) + + if output_data is None: + print("ERROR: No output data received from server") + return False + + print("\n" + "=" * 60) + print("FTTransformer Inference Results") + print("=" * 60) + print(f"Batch size: {batch_size}") + print("Input shapes:") + print(f" - Categorical (INPUT0): {categorical_data.shape}") + print(f" - Continuous (INPUT1): {continuous_data.shape}") + print(f"Output shape: {output_data.shape}") + print("\nPredictions (OUTPUT0):") + for i, pred in enumerate(output_data): + print(f" Sample {i}: {pred}") + print("=" * 60) + + # Basic validation: output shape should be (batch_size, 1) + expected_shape = (batch_size, 1) + if output_data.shape != expected_shape: + print( + f"ERROR: Unexpected output shape. " + f"Expected {expected_shape}, got {output_data.shape}" + ) + return False + + print("\nPASS: tab_transform_pytorch") + return True + + except Exception as e: + print(f"ERROR: Inference failed with exception: {e}") + return False + + +def run_verification( + url: str, + reference_dir: str, + tolerance: float = 1e-5, + verbose: bool = False, +) -> bool: + """Run verification against CPU reference outputs. + + Args: + url: Triton server URL + reference_dir: Directory containing reference_inputs.npz and reference_outputs.npz + tolerance: Maximum allowed absolute difference (default: 1e-5) + verbose: Whether to print detailed output + + Returns: + True if verification passed, False otherwise + """ + # Load reference data + inputs_path = os.path.join(reference_dir, "reference_inputs.npz") + outputs_path = os.path.join(reference_dir, "reference_outputs.npz") + + if not os.path.exists(inputs_path): + print(f"ERROR: Reference inputs not found: {inputs_path}") + print(" Run generate_reference.py first") + return False + + if not os.path.exists(outputs_path): + print(f"ERROR: Reference outputs not found: {outputs_path}") + print(" Run generate_reference.py first") + return False + + print("=" * 60) + print("FTTransformer Verification Mode") + print("=" * 60) + print(f"Reference directory: {reference_dir}") + print(f"Tolerance: {tolerance}") + print() + + # Load reference data + print("[1/4] Loading reference data...") + inputs_data = np.load(inputs_path) + outputs_data = np.load(outputs_path) + + x_categ = inputs_data["categorical"] + x_numer = inputs_data["continuous"] + reference_outputs = outputs_data["outputs"] + + num_samples = x_categ.shape[0] + print(f" Loaded {num_samples} samples") + print(f" Categorical shape: {x_categ.shape}") + print(f" Continuous shape: {x_numer.shape}") + print(f" Reference output shape: {reference_outputs.shape}") + + try: + with httpclient.InferenceServerClient(url, verbose=verbose) as client: + # Check if model is ready + print("[2/4] Checking model status...") + if not client.is_model_ready(MODEL_NAME): + print(f"ERROR: Model '{MODEL_NAME}' is not ready on server") + return False + print(f" Model '{MODEL_NAME}' is ready") + + # Run inference in batches (max_batch_size is 4) + print("[3/4] Running GPU inference...") + batch_size = 4 + gpu_outputs = [] + + for start_idx in range(0, num_samples, batch_size): + end_idx = min(start_idx + batch_size, num_samples) + batch_categ = x_categ[start_idx:end_idx] + batch_numer = x_numer[start_idx:end_idx] + + output = run_inference(client, batch_categ, batch_numer, verbose=False) + gpu_outputs.append(output) + + if verbose: + print(f" Processed samples {start_idx}-{end_idx}") + + gpu_outputs = np.vstack(gpu_outputs) + print(f" GPU output shape: {gpu_outputs.shape}") + + # Compare outputs + print("[4/4] Comparing outputs...") + abs_diff = np.abs(gpu_outputs - reference_outputs) + max_diff = np.max(abs_diff) + mean_diff = np.mean(abs_diff) + num_mismatches = np.sum(abs_diff > tolerance) + + print() + print("Results:") + print(f" Max absolute difference: {max_diff:.2e}") + print(f" Mean absolute difference: {mean_diff:.2e}") + print(f" Tolerance: {tolerance:.2e}") + print(f" Samples exceeding tolerance: {num_mismatches}/{num_samples}") + + if verbose or num_mismatches > 0: + print() + print("Sample-by-sample comparison (first 10):") + for i in range(min(10, num_samples)): + diff = abs_diff[i][0] + status = "✓" if diff <= tolerance else "✗" + print( + f" [{status}] Sample {i:3d}: " + f"CPU={reference_outputs[i][0]:+.8f}, " + f"GPU={gpu_outputs[i][0]:+.8f}, " + f"diff={diff:.2e}" + ) + + print() + print("=" * 60) + + if max_diff <= tolerance: + print( + f"PASS: All {num_samples} samples within tolerance ({tolerance:.0e})" + ) + print(" GPU implementation matches CPU reference!") + print("=" * 60) + return True + else: + print( + f"FAIL: {num_mismatches} samples exceed tolerance ({tolerance:.0e})" + ) + print(" GPU implementation may have numerical issues") + print("=" * 60) + return False + + except Exception as e: + print(f"ERROR: Verification failed with exception: {e}") + import traceback + + traceback.print_exc() + return False + + +def main(): + parser = argparse.ArgumentParser( + description="Client for FTTransformer model on Triton Inference Server" + ) + parser.add_argument( + "--url", + type=str, + default="localhost:8000", + help="Triton server URL (default: localhost:8000)", + ) + parser.add_argument( + "--batch-size", + type=int, + default=1, + help="Batch size for random inference (default: 1, max: 4 per config)", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Enable verbose output", + ) + parser.add_argument( + "--verify", + action="store_true", + help="Run verification mode (compare against CPU reference)", + ) + parser.add_argument( + "--reference-dir", + type=str, + default="./1", + help="Directory containing reference data (default: ./1)", + ) + parser.add_argument( + "--tolerance", + type=float, + default=1e-5, + help="Max absolute difference allowed in verification (default: 1e-5)", + ) + + args = parser.parse_args() + + if args.verify: + success = run_verification( + args.url, + args.reference_dir, + args.tolerance, + args.verbose, + ) + else: + success = run_random_inference(args.url, args.batch_size, args.verbose) + + sys.exit(0 if success else 1) + + +if __name__ == "__main__": + main() diff --git a/examples/tab_transform_pytorch/config.pbtxt b/examples/tab_transform_pytorch/config.pbtxt new file mode 100644 index 00000000..2e491211 --- /dev/null +++ b/examples/tab_transform_pytorch/config.pbtxt @@ -0,0 +1,33 @@ +name: "tab_transform_pytorch" +backend: "python" +max_batch_size: 4 + +input [ + { + name: "INPUT0" + data_type: TYPE_INT64 + dims: [ 5 ] + } +] +input [ + { + name: "INPUT1" + data_type: TYPE_FP32 + dims: [ 10 ] + } +] +output [ + { + name: "OUTPUT0" + data_type: TYPE_FP32 + dims: [ 1 ] + } +] + +instance_group [ + { + kind: KIND_GPU + count: 1 + gpus: [ 0 ] + } +] diff --git a/examples/tab_transform_pytorch/generate_reference.py b/examples/tab_transform_pytorch/generate_reference.py new file mode 100644 index 00000000..a3a5bbef --- /dev/null +++ b/examples/tab_transform_pytorch/generate_reference.py @@ -0,0 +1,167 @@ +""" +Generate deterministic FT Transformer weights and CPU reference outputs. + +This script creates: +1. ft_transformer.pt - Model weights with fixed random seed +2. reference_inputs.npz - Test inputs (categorical + continuous) +3. reference_outputs.npz - Expected outputs from CPU inference + +Usage: + python generate_reference.py [--output-dir DIR] [--num-samples N] [--seed S] + +Example: + python generate_reference.py --output-dir ./1 --num-samples 100 --seed 42 +""" + +import argparse +import numpy as np +import os +import torch + +from tab_transformer_pytorch import FTTransformer + +# Model configuration (must match model.py and config.pbtxt) +CATEGORIES = (10, 5, 6, 5, 8) # Unique values per categorical feature +NUM_CONTINUOUS = 10 # Number of continuous features +MODEL_CONFIG = { + "categories": CATEGORIES, + "num_continuous": NUM_CONTINUOUS, + "dim": 32, + "dim_out": 1, + "depth": 6, + "heads": 8, + "attn_dropout": 0.0, # Disabled for deterministic inference + "ff_dropout": 0.0, # Disabled for deterministic inference +} + + +def generate_categorical_inputs(num_samples: int, seed: int) -> np.ndarray: + """Generate deterministic categorical inputs.""" + rng = np.random.default_rng(seed) + categorical_data = np.zeros((num_samples, len(CATEGORIES)), dtype=np.int64) + for i, max_val in enumerate(CATEGORIES): + categorical_data[:, i] = rng.integers(0, max_val, size=num_samples) + return categorical_data + + +def generate_continuous_inputs(num_samples: int, seed: int) -> np.ndarray: + """Generate deterministic continuous inputs (normalized).""" + rng = np.random.default_rng(seed + 1000) # Different seed for continuous + return rng.standard_normal((num_samples, NUM_CONTINUOUS)).astype(np.float32) + + +def create_model(seed: int) -> FTTransformer: + """Create FTTransformer with deterministic initialization.""" + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + + model = FTTransformer(**MODEL_CONFIG) + model.eval() + return model + + +def run_cpu_inference( + model: FTTransformer, + x_categ: np.ndarray, + x_numer: np.ndarray, +) -> np.ndarray: + """Run inference on CPU and return outputs.""" + model.cpu() + model.eval() + + x_categ_tensor = torch.from_numpy(x_categ).long() + x_numer_tensor = torch.from_numpy(x_numer).float() + + with torch.no_grad(): + output = model(x_categ_tensor, x_numer_tensor) + + return output.numpy() + + +def main(): + parser = argparse.ArgumentParser( + description="Generate FT Transformer reference weights and outputs" + ) + parser.add_argument( + "--output-dir", + type=str, + default="./1", + help="Output directory for weights and reference data (default: ./1)", + ) + parser.add_argument( + "--num-samples", + type=int, + default=100, + help="Number of test samples to generate (default: 100)", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for reproducibility (default: 42)", + ) + args = parser.parse_args() + + # Create output directory if needed + os.makedirs(args.output_dir, exist_ok=True) + + print("=" * 60) + print("FT Transformer Reference Generation") + print("=" * 60) + print(f"Output directory: {args.output_dir}") + print(f"Number of samples: {args.num_samples}") + print(f"Random seed: {args.seed}") + print(f"Model config: {MODEL_CONFIG}") + print() + + # Step 1: Create model with deterministic weights + print("[1/4] Creating model with deterministic weights...") + model = create_model(args.seed) + + weights_path = os.path.join(args.output_dir, "ft_transformer.pt") + torch.save(model.state_dict(), weights_path) + print(f" Saved weights to: {weights_path}") + + # Step 2: Generate deterministic inputs + print("[2/4] Generating deterministic test inputs...") + x_categ = generate_categorical_inputs(args.num_samples, args.seed) + x_numer = generate_continuous_inputs(args.num_samples, args.seed) + + inputs_path = os.path.join(args.output_dir, "reference_inputs.npz") + np.savez(inputs_path, categorical=x_categ, continuous=x_numer) + print(f" Saved inputs to: {inputs_path}") + print(f" Categorical shape: {x_categ.shape}") + print(f" Continuous shape: {x_numer.shape}") + + # Step 3: Run CPU inference + print("[3/4] Running CPU inference...") + outputs = run_cpu_inference(model, x_categ, x_numer) + + outputs_path = os.path.join(args.output_dir, "reference_outputs.npz") + np.savez(outputs_path, outputs=outputs) + print(f" Saved outputs to: {outputs_path}") + print(f" Output shape: {outputs.shape}") + + # Step 4: Print sample results for verification + print("[4/4] Sample results (first 5):") + print() + for i in range(min(5, args.num_samples)): + print(f" Sample {i}:") + print(f" Categorical: {x_categ[i]}") + print(f" Continuous: {x_numer[i][:5]}... (truncated)") + print(f" Output: {outputs[i][0]:.8f}") + print() + + print("=" * 60) + print("Reference generation complete!") + print() + print("Next steps:") + print(f" 1. Copy {args.output_dir}/ to your model repository") + print(" 2. Start Triton server with the model") + print(" 3. Run: python client.py --verify --reference-dir " + args.output_dir) + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/tab_transform_pytorch/model.py b/examples/tab_transform_pytorch/model.py new file mode 100644 index 00000000..c8a7cf62 --- /dev/null +++ b/examples/tab_transform_pytorch/model.py @@ -0,0 +1,148 @@ +import json +import os +import torch + +import triton_python_backend_utils as pb_utils + +from tab_transformer_pytorch import FTTransformer + + +class TritonPythonModel: + """Triton Python model for FTTransformer inference. + + This model uses the FTTransformer from tab-transformer-pytorch for + tabular data prediction with both categorical and continuous features. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + self.model_config = json.loads(args["model_config"]) + self.model_repository = args["model_repository"] + self.model_version = args["model_version"] + + # Get output configuration + output0_config = pb_utils.get_output_config_by_name( + self.model_config, "OUTPUT0" + ) + self.output0_dtype = pb_utils.triton_string_to_numpy( + output0_config["data_type"] + ) + + # Determine the device + device_id = args["model_instance_device_id"] + instance_kind = args["model_instance_kind"] + + if instance_kind == "GPU": + self.device = torch.device(f"cuda:{device_id}") + pb_utils.Logger.log_info(f"FTTransformer initialized on GPU {device_id}") + else: + self.device = torch.device("cpu") + pb_utils.Logger.log_info("FTTransformer initialized on CPU") + + # FTTransformer model configuration + # These should match your trained model's configuration + self.categories = (10, 5, 6, 5, 8) # Unique values per categorical feature + self.num_continuous = 10 # Number of continuous features + + # Initialize FTTransformer model + # Note: Dropout is disabled (0.0) to ensure deterministic inference + # for verification against CPU reference outputs + self.model = FTTransformer( + categories=self.categories, + num_continuous=self.num_continuous, + dim=32, # Embedding dimension (paper recommends 32) + dim_out=1, # Output dimension (1 for binary/regression) + depth=6, # Number of transformer layers (paper recommends 6) + heads=8, # Number of attention heads (paper recommends 8) + attn_dropout=0.0, # Disabled for deterministic verification + ff_dropout=0.0, # Disabled for deterministic verification + ) + + # Load pre-trained weights if available + weights_path = os.path.join( + self.model_repository, self.model_version, "ft_transformer.pt" + ) + if os.path.exists(weights_path): + self.model.load_state_dict( + torch.load(weights_path, map_location=self.device) + ) + pb_utils.Logger.log_info(f"Loaded model weights from {weights_path}") + else: + pb_utils.Logger.log_warn( + f"No weights found at {weights_path}. Using randomly initialized model." + ) + + self.model.to(self.device) + self.model.eval() + + pb_utils.Logger.log_info( + f"FTTransformer initialized on {self.device} with " + f"categories={self.categories}, num_continuous={self.num_continuous}" + ) + + def execute(self, requests): + """`execute` is called when inference is requested for this model. + + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + + Returns + ------- + list + A list of pb_utils.InferenceResponse + """ + responses = [] + + for request in requests: + try: + # Get categorical input (INPUT0) - shape: [batch_size, num_categories] + input0_tensor = pb_utils.get_input_tensor_by_name(request, "INPUT0") + x_categ = torch.from_numpy(input0_tensor.as_numpy()).to(self.device) + + # Get continuous input (INPUT1) - shape: [batch_size, num_continuous] + input1_tensor = pb_utils.get_input_tensor_by_name(request, "INPUT1") + x_numer = torch.from_numpy(input1_tensor.as_numpy()).to(self.device) + + # Run inference + with torch.no_grad(): + output = self.model(x_categ, x_numer) + + # Convert output to numpy and create response tensor + output_np = output.cpu().numpy().astype(self.output0_dtype) + output_tensor = pb_utils.Tensor("OUTPUT0", output_np) + + inference_response = pb_utils.InferenceResponse( + output_tensors=[output_tensor] + ) + + except Exception as e: + inference_response = pb_utils.InferenceResponse( + output_tensors=[], + error=pb_utils.TritonError(f"Inference failed: {str(e)}"), + ) + + responses.append(inference_response) + + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + + This function allows the model to perform any necessary clean ups + before exit. + """ + pb_utils.Logger.log_info("Cleaning up FTTransformer model...") diff --git a/examples/tab_transform_pytorch/requirements.txt b/examples/tab_transform_pytorch/requirements.txt new file mode 100644 index 00000000..1b3ee791 --- /dev/null +++ b/examples/tab_transform_pytorch/requirements.txt @@ -0,0 +1,2 @@ +tab-transformer-pytorch>=0.5.1 +tritonclient[all]==2.39.0