In [1]:
import time
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
import platform

In [2]:


if platform.system() == 'Darwin':  # macOS
    torch.backends.quantized.engine = 'qnnpack'
elif platform.system() == 'Windows':  # Windows
    torch.backends.quantized.engine = 'fbgemm'
else:
    print("Unsupported platform for quantization engine")



In [3]:
# Load the model
model = T5ForConditionalGeneration.from_pretrained('t5-large')
tokenizer = T5Tokenizer.from_pretrained('t5-large')

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:


# Apply dynamic quantization
start_time = time.time()
print("Applying quantization...")
quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8  # Quantize the linear layers
)
print(f"Applied quantization in {time.time() - start_time:.2f} seconds")

# Save the quantized model's state dictionary
torch.save(quantized_model.state_dict(), 'quantized_t5_large.pth')
print("Quantized model saved successfully.")

# Define the tokenizer
quantized_tokenizer = tokenizer

Applying quantization...
Applied quantization in 4.06 seconds
Quantized model saved successfully.


In [5]:

# Generate translation
input_text = "translate English to French: My name is Alana"
input_ids = quantized_tokenizer(input_text, return_tensors="pt").input_ids

start_time = time.time()
print("Generating translation...")
outputs = quantized_model.generate(input_ids, max_new_tokens=50)
translation = quantized_tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Generated translation in {time.time() - start_time:.2f} seconds")
print("Translation:", translation)

Generating translation...




Generated translation in 3.67 seconds
Translation: Mon nom est Alana
