In [None]:
import onnxruntime as ort
import tensorrt as trt

from transformers import StaticCache

# Covert to onnx

In [None]:
device='cpu' 

model_path = "/kaggle/input/m/mambagetout/tinyllama/pytorch/default/1/"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
model.eval()

## Export takes care of a static KV cache and handling dynamic inputs

In [51]:
# Create dummy inputs
batch_size = 1
sequence_length = 256  # Reduced for testing

# Get model configuration
num_layers = model.config.num_hidden_layers
num_attention_heads = model.config.num_attention_heads
hidden_size = model.config.hidden_size
head_dim = hidden_size // num_attention_heads

In [None]:
class ModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        
    def forward(self, input_ids, attention_mask, position_ids):
        # Initialize static cache
        past_key_values = StaticCache(
            config=self.model.config,
            batch_size=batch_size,
            max_cache_len= sequence_length, # 32,  # Match the sequence length
            device=input_ids.device,
            dtype=self.model.dtype
        )
        
        # Generate cache position
        cache_position = torch.arange(sequence_length, device=input_ids.device)
        
        # Forward pass with static cache
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            cache_position=cache_position,
            use_cache=True,
            return_dict=False
        )
        
        return outputs[0]  # Return only logits for simplicity


wrapped_model = ModelWrapper(model).to(device)

# Input tensors
dummy_input_ids = torch.ones(batch_size, sequence_length, dtype=torch.long, device=device)
dummy_attention_mask = torch.ones(batch_size, sequence_length, dtype=torch.long, device=device)
dummy_position_ids = torch.arange(sequence_length, device=device).unsqueeze(0).repeat(batch_size, 1)

# Define dynamic axes for variable sequence lengths
dynamic_axes = {
    'input_ids': {0: 'batch_size', 1: 'sequence_length'},
    'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
    'position_ids': {0: 'batch_size', 1: 'sequence_length'},
}

In [None]:
# Export the model
torch.onnx.export(
    wrapped_model,
    (
        dummy_input_ids,
        dummy_attention_mask,
        dummy_position_ids,
    ),
    '/kaggle/working/tiny_llama.onnx',
    input_names=['input_ids', 'attention_mask', 'position_ids'],
    output_names=['logits'],
    dynamic_axes=dynamic_axes,
    opset_version=18, # DO NOT, I REPEAT DO NOT USE OPSET_VERSION 15. FULL PAIN.
    do_constant_folding=False, 
    verbose=False 
)

## Run Inference using the ort InferenceSession

In [None]:
ort_session = ort.InferenceSession("/kaggle/working/tiny_llama.onnx", providers=["CPUExecutionProvider"])

In [59]:
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

ort_inputs = {
    ort_session.get_inputs()[0].name: to_numpy(dummy_input_ids),
    ort_session.get_inputs()[1].name: to_numpy(dummy_attention_mask),
    ort_session.get_inputs()[2].name: to_numpy(dummy_position_ids)
}
ort_outs = ort_session.run(None, ort_inputs)

In [60]:
type(ort_outs), len(ort_outs), ort_outs[0].shape

(list, 1, (1, 256, 32000))

## outputs are close which is good

In [21]:
# check if outputs are close
np.testing.assert_allclose(
    to_numpy(wrapped_model(dummy_input_ids, dummy_attention_mask, dummy_position_ids)), 
    ort_outs[0], 
    rtol=1e-03, 
    atol=1e-04
)
# throws assertion error when atol is 1e-05
# So I think this is fair enough

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


## this works, so now work with text inputs

In [69]:
def prepare_inputs(text_input, tokenizer, target_length=256, device='cuda'):
    """
    Prepare inputs by padding/truncating to target length
    """
    # Tokenize the input
    encoded = tokenizer.encode(text_input, add_special_tokens=True, return_tensors="pt")
    input_ids = encoded[0]  # Remove batch dimension
    
    # Get the actual sequence length
    seq_len = len(input_ids)
    
    if seq_len > target_length:
        # Truncate if longer than target_length
        input_ids = input_ids[:target_length]
        seq_len = target_length
    
    # Pad input_ids
    padding_length = target_length - seq_len
    padding_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
    input_ids = torch.cat([
        input_ids,
        torch.full((padding_length,), padding_id, dtype=torch.long)
    ])
    
    # Create attention mask (1 for real tokens, 0 for padding)
    attention_mask = torch.cat([
        torch.ones(seq_len, dtype=torch.long),  # Explicitly set dtype to long
        torch.zeros(padding_length, dtype=torch.long)  # Explicitly set dtype to long
    ])
    
    # Create position ids
    # position_ids = torch.arange(target_length)
    position_ids = torch.arange(target_length, dtype=torch.long)  # Explicitly set dtype to long

    
    # Add batch dimension and move to device
    input_ids = input_ids.unsqueeze(0).to(device)
    attention_mask = attention_mask.unsqueeze(0).to(device)
    position_ids = position_ids.unsqueeze(0).to(device)
    
    return input_ids, attention_mask, position_ids

In [73]:
def make_both_inferences(text_input: str):

    input_ids, attention_mask, position_ids = prepare_inputs(
        text_input, 
        tokenizer, 
        target_length=256,
        device=device
    )

    wrapped_model_outputs = to_numpy(wrapped_model(input_ids, attention_mask, position_ids))

    ort_inputs = { 
        ort_session.get_inputs()[0].name: to_numpy(input_ids),
        ort_session.get_inputs()[1].name: to_numpy(attention_mask),
        ort_session.get_inputs()[2].name: to_numpy(position_ids)
    }
    ort_outs = ort_session.run(None, ort_inputs)

    return wrapped_model_outputs, ort_outs

In [85]:
wrapped_model_outputs, ort_outs = make_both_inferences(text_input ='hi there my work is computer science and I like')

## The token ids converted back to text aren't beautiful though. 

In [87]:
 # Convert logits to token IDs (if necessary)
wrapped_token_ids = torch.argmax(torch.from_numpy(wrapped_model_outputs), dim=-1)
ort_token_ids = torch.argmax(torch.from_numpy(ort_outs[0]), dim=-1)  # Assuming first output is relevant

# Decode token IDs back to text
wrapped_text_output = tokenizer.decode(wrapped_token_ids[0], skip_special_tokens=True)
ort_text_output = tokenizer.decode(ort_token_ids[0], skip_special_tokens=True)
wrapped_text_output, ort_text_output

('amedag. my. my. computer computer am I I I my my my my my computer computer computer I I I computer computer computer computer my my my my I I my my my my computer computer computer computer computer computer andag my my my my computer computer computer computer my I\n\nag my my my my computer computer computer computer computer computer computer computer computer computer my computer I I my my my my my computer computer computer computer computer my my I my my my my my and computer my my and and andag my my my my computer computer computer my my my my I I I I I\n\n I my my my computer computer computer computer I I I\n\nag\n my my my my computer computer computer computer I I I my my my computer computer computer computer my my my and and and my my my my computer computer my my I and and I my my computer computer\n\n\n my my my my computer computer computer my I I I I my my my\n\n my my my my computer I I I my my my computer computer computer my my my and and and my my computer comp

In [3]:
!jupyter nbconvert --to html tensorrt-v2.ipynb

[NbConvertApp] Converting notebook tensorrt-v2.ipynb to html
[NbConvertApp] Writing 310840 bytes to tensorrt-v2.html
