In [None]:
import torch
import torch_tensorrt as torchtrt
import clip

# Set device to CUDA if available
device = "cuda" if torch.cuda.is_available() else "cpu"

100%|███████████████████████████████████████| 338M/338M [00:10<00:00, 32.5MiB/s]


In [None]:
# Load the CLIP model (ViT-B/32) and its preprocessing pipeline
model, preprocess = clip.load("ViT-B/32", device=device)
model.eval()

In [None]:
# Extract the visual encoder from CLIP
visual_encoder = model.visual.to(device)

# Compile the visual encoder with Torch-TensorRT
trt_visual_encoder = torchtrt.compile(
    visual_encoder,
    inputs=[torchtrt.Input((1, 3, 224, 224), dtype=torch.half)],
    enabled_precisions={torchtrt.dtype.f16}  # use half-precision
)

# Create a traced TorchScript module for deployment
ts_trt_visual_encoder = torch.jit.trace(
    trt_visual_encoder,
    torch.rand(1, 3, 224, 224).to(device).type(torch.half)
)

In [None]:
# Defining a wrapper for the text encoder.
# This module will take tokenized text as input and return the text embeddings.
class CLIPTextEncoder(torch.nn.Module):
    def __init__(self, clip_model):
        super(CLIPTextEncoder, self).__init__()
        self.clip_model = clip_model
    def forward(self, tokens):
        # tokens: expected shape [batch, 77] with dtype torch.int64
        return self.clip_model.encode_text(tokens)

# Instantiate the text encoder wrapper.
text_encoder = CLIPTextEncoder(model).to(device)
text_encoder.eval()

# Compile the text encoder with Torch-TensorRT, An input shape of (1, 77) (batch size 1 and 77 tokens) is specified.
trt_text_encoder = torchtrt.compile(
    text_encoder,
    inputs=[torchtrt.Input((1, 77), dtype=torch.int64)],
    enabled_precisions={torchtrt.dtype.f16}  # Enable FP16 kernels if supported.
)

# Create a traced TorchScript module for deployment.
# Use a sample tokenized text input.
sample_text = "Hello, world!"
# Tokenize the sample text using CLIP's tokenize function.
# clip.tokenize returns a tensor of shape [batch, 77] and type torch.int64.
example_tokens = clip.tokenize([sample_text]).to(device).long()

# Trace the compiled module.
ts_trt_text_encoder = torch.jit.trace(
    trt_text_encoder,
    example_tokens  # Ensure the traced input has the same dtype and shape.
)

# Test inference: Pass in tokenized text and get embeddings.
with torch.no_grad():
    embeddings = ts_trt_text_encoder(example_tokens)
    print("Text embeddings shape:", embeddings.shape)


In [None]:
import os
# Create the directory if it doesn't exist
os.makedirs("/content/model_repository/clip_visual/1", exist_ok=True)
# Save the model
torch.jit.save(ts_trt_visual_encoder, "/content/model_repository/clip_visual/1/model.pt")

In [8]:
# Create the directory if it doesn't exist
os.makedirs("/content/model_repository/clip_text/1", exist_ok=True)
# Save the model
torch.jit.save(ts_trt_text_encoder, "/content/model_repository/clip_text/1/model.pt")

In [9]:
from PIL import Image

# Set device to CUDA if available
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the saved TorchScript model
model_path = "/content/model_repository/clip_visual/1/model.pt"
visual_encoder = torch.jit.load(model_path).to(device)
visual_encoder.eval()  # Set model to evaluation mode

# Load CLIP to get its preprocessing pipeline
_, preprocess = clip.load("ViT-B/32", device=device)

# Attempt to load a sample image; if not available, use a random tensor
try:
    # Replace "sample.jpg" with the path to your test image if needed
    image = Image.open("Adidas-1.jpg").convert("RGB")
    # Preprocess image to get a tensor of shape (3, 224, 224)
    image_tensor = preprocess(image)
    # Convert to half-precision to match the compiled model's input dtype
    image_tensor = image_tensor.half()
except Exception as e:
    print("Sample image not found or failed to load. Using a random tensor for testing.")
    image_tensor = torch.rand(3, 224, 224).half()

# Add a batch dimension and move the tensor to the selected device
input_tensor = image_tensor.unsqueeze(0).to(device)

# Perform inference with the loaded model
with torch.no_grad():
    features = visual_encoder(input_tensor)
    output = visual_encoder(input_tensor)

print("Extracted visual features shape:", features.shape)

# Get the predicted class index (largest logit)
_, predicted = torch.max(output, 1)
print("Predicted class index:", predicted.item())


Sample image not found or failed to load. Using a random tensor for testing.
Extracted visual features shape: torch.Size([1, 512])
Predicted class index: 321


In [None]:
# Set device to CUDA if available.
print("Using device:", device)

# Load the TorchScript text encoder module.
model_path = "/content/model_repository/clip_text/1/model.pt"  # Ensure this file is in the current directory.
print(f"Loading TorchScript module from {model_path} ...")
ts_model = torch.jit.load(model_path, map_location=device)
ts_model.eval()

# Function to perform inference on input text.
def infer_text(text: str):
    # Tokenize the input text using CLIP's tokenizer.
    # clip.tokenize returns a tensor of shape [batch, 77] with dtype torch.int64.
    tokens = clip.tokenize([text]).to(device).long()  # Ensure tokens are Long tensors.

    # Run inference using the traced model.
    with torch.no_grad():
        embeddings = ts_model(tokens)
    return embeddings

# Example inference: specify your text input.
input_text = "whats up chatgpt> how are you doing?"
embeddings = infer_text(input_text)

Using device: cuda
Loading TorchScript module from /content/model_repository/clip_text/1/model.pt ...


In [None]:
embeddings_np = embeddings.cpu().numpy().tolist()
print("Embedding as NumPy array:", embeddings_np)


Embedding as NumPy array: [[0.26123046875, 0.180908203125, -0.213134765625, -0.2049560546875, -0.0419921875, 0.105712890625, 0.28076171875, -0.27490234375, -0.036285400390625, 0.08984375, 0.079345703125, -0.1348876953125, 0.3369140625, -0.007724761962890625, -0.14892578125, 0.335693359375, -0.51953125, -0.2724609375, -0.306884765625, -0.301025390625, 0.194580078125, -0.6162109375, -0.2081298828125, 0.11334228515625, 0.015869140625, 0.139404296875, 0.14990234375, 0.03277587890625, -0.05047607421875, 0.4091796875, -0.1324462890625, -0.217529296875, -0.024078369140625, 0.021759033203125, -0.03076171875, -0.1419677734375, 0.106689453125, -0.042205810546875, 0.19140625, -0.157470703125, -0.09649658203125, 0.06756591796875, 0.271728515625, 0.00954437255859375, -0.031768798828125, 0.18798828125, -0.1923828125, -0.0562744140625, -0.07171630859375, 0.26708984375, -0.020355224609375, -0.31982421875, 0.155517578125, -0.0595703125, -0.4990234375, 0.1304931640625, -0.197509765625, 0.032806396484375