In [1]:
%pip install torch transformers datasets evaluate --upgrade




In [2]:
import torch
from transformers import MBartForConditionalGeneration
from pathlib import Path

# Load original model
model_dir = Path("2. translation_model_ta_si")
model = MBartForConditionalGeneration.from_pretrained(model_dir)

# Apply dynamic quantization
quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)

print(quantized_model)

# Create a directory to save
save_path = Path("quantized_ta_si_2")
save_path.mkdir(exist_ok=True)

# Save the model state_dict
torch.save(quantized_model.state_dict(), save_path / "pytorch_model.bin")

# Save the config
quantized_model.config.save_pretrained(save_path)

MBartForConditionalGeneration(
  (model): MBartModel(
    (shared): MBartScaledWordEmbedding(250054, 1024, padding_idx=1)
    (encoder): MBartEncoder(
      (embed_tokens): MBartScaledWordEmbedding(250054, 1024, padding_idx=1)
      (embed_positions): MBartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0-11): 12 x MBartEncoderLayer(
          (self_attn): MBartSdpaAttention(
            (k_proj): DynamicQuantizedLinear(in_features=1024, out_features=1024, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
            (v_proj): DynamicQuantizedLinear(in_features=1024, out_features=1024, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
            (q_proj): DynamicQuantizedLinear(in_features=1024, out_features=1024, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
            (out_proj): DynamicQuantizedLinear(in_features=1024, out_features=1024, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
          )
          (self_attn_layer_norm): LayerNorm

In [5]:
import os

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

print_size_of_model(model)
print_size_of_model(quantized_model)

Size (MB): 2444.671001
Size (MB): 1643.909962


In [5]:
torch.save(quantized_model, "quantized_mbart_model.pt")

In [10]:
import torch

quantized_model = torch.load("quantized_mbart_model.pt", weights_only=False)
quantized_model.eval()


MBartForConditionalGeneration(
  (model): MBartModel(
    (shared): MBartScaledWordEmbedding(250054, 1024, padding_idx=1)
    (encoder): MBartEncoder(
      (embed_tokens): MBartScaledWordEmbedding(250054, 1024, padding_idx=1)
      (embed_positions): MBartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0-11): 12 x MBartEncoderLayer(
          (self_attn): MBartSdpaAttention(
            (k_proj): DynamicQuantizedLinear(in_features=1024, out_features=1024, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
            (v_proj): DynamicQuantizedLinear(in_features=1024, out_features=1024, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
            (q_proj): DynamicQuantizedLinear(in_features=1024, out_features=1024, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
            (out_proj): DynamicQuantizedLinear(in_features=1024, out_features=1024, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
          )
          (self_attn_layer_norm): LayerNorm

In [1]:
import torch
from transformers import MBartTokenizer

# Load quantized model
quantized_model = torch.load("quantized_mbart_model.pt", weights_only=False)
quantized_model.eval()

# Load tokenizer (same as original model)
tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")

# 🔤 Input: Tamil → Sinhala
src_text = "நீங்கள் எப்படி இருக்கிறீர்கள்?"
inputs = tokenizer(src_text, return_tensors="pt", src_lang="ta_IN")

# Set target language
forced_bos_token_id = tokenizer.lang_code_to_id["si_LK"]

# 🔁 Translate using generate()
with torch.no_grad():
    output_tokens = quantized_model.generate(
        **inputs,
        forced_bos_token_id=forced_bos_token_id,
        max_length=100,
        num_beams=4,
        early_stopping=True
    )

# Decode translation
translated_text = tokenizer.batch_decode(output_tokens, skip_special_tokens=True)[0]
print("🔁 Translated Sinhala Output:", translated_text)


  device=storage.device,
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'MBart50Tokenizer'. 
The class this function is called from is 'MBartTokenizer'.
Keyword arguments {'src_lang': 'ta_IN'} not recognized.


🔁 Translated Sinhala Output: ඔබ ඉන්නේ කෙසේද?
