## Python Script to Load and Use the .pt Model

In [None]:
import torch
from transformers import AutoTokenizer, BertModel  # or use AutoModel if unsure

# Load tokenizer (assuming you're using MiniLMv2 with BERT-like tokenizer)
tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLMv2-L6-H384-distilled-from-BERT-Large")

# Define model architecture
class MiniLMv2Model(torch.nn.Module):
    def __init__(self):
        super(MiniLMv2Model, self).__init__()
        self.bert = BertModel.from_pretrained("microsoft/MiniLMv2-L6-H384-distilled-from-BERT-Large")
    
    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        return self.bert(input_ids=input_ids,
                         attention_mask=attention_mask,
                         token_type_ids=token_type_ids)

# Initialize and load weights
model = MiniLMv2Model()
model.load_state_dict(torch.load("minilm_model.pt", map_location=torch.device('cpu')))
model.eval()

# Sample input
text = "MiniLMv2 is a lightweight and fast transformer model."
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)

# Forward pass
with torch.no_grad():
    outputs = model(**inputs)
    last_hidden_state = outputs.last_hidden_state

# Example: get embedding of [CLS] token
cls_embedding = last_hidden_state[:, 0, :]  # shape: (1, hidden_size)
print("CLS embedding shape:", cls_embedding.shape)


## What You Might Need to Adjust
### If your .pt file saved the entire model (torch.save(model)), then you can load it directly with:

In [None]:
model = torch.load("minilm_model.pt")
model.eval()
