# This notebook is designed to trace PyTorch models

### Defining the appropriate classes

In [74]:
import torch
from torch import nn, optim
from transformers import DistilBertTokenizer, DistilBertModel

In [75]:
class BertRegressor(nn.Module):
    def __init__(self, bert_model):
        super(BertRegressor, self).__init__()
        self.bert = bert_model
        self.regressor = nn.Linear(self.bert.config.dim, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        out = self.regressor(outputs[0][:, 0, :])
        return torch.squeeze(out, -1)

### Loading the model

In [76]:
model_name = "distil_v1_friendly"
version = "1"

In [77]:
model_path = f"model_checkpoints/torch/{model_name}.pt"  # Replace with your .pt file path
bert_model = DistilBertModel.from_pretrained("distilbert-base-uncased")
model = BertRegressor(bert_model)
model.load_state_dict(torch.load(model_path))
model.eval()  # Set the model to evaluation mode

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

### pre-process & inference function

In [78]:
def _remove_padding(input_tensor):
    last_non_padding = torch.nonzero(input_tensor[0]).squeeze(-1)[-1].item() + 1
    return input_tensor[:, :last_non_padding]

def preprocess(text, remove_padding=True):
    max_len = 512 # maybe need to make this dynamic
    device = "cpu"
    
    # pre-process text
    inputs = tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=max_len,
        return_token_type_ids=False,
        padding='max_length',
        return_attention_mask=True,
        return_tensors='pt',
    )

    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)
    
    if remove_padding:
        input_ids = _remove_padding(input_ids)
        attention_mask = _remove_padding(attention_mask)
    
    return input_ids, attention_mask

In [79]:
def trace_model(input_ids, attention_mask):
    
    traced_model = torch.jit.trace(model, (input_ids, attention_mask))
    
    return traced_model

### trace model

In [80]:
i, am = preprocess("I love machine learning!")

In [81]:
torchscript = trace_model(i, am)

In [82]:
import os

def ensure_path_exists(path):
    if not os.path.exists(path):
        os.makedirs(path)
        print(f"Path created: {path}")
    else:
        print(f"Path already exists: {path}")

ensure_path_exists(f"model_repository/{model_name}/{version}")

Path created: model_repository/distil_v1_friendly/1


In [83]:
# Save the traced model
torchscript.save(f"model_repository/{model_name}/{version}/model.pt")

In [84]:
# Generate config

config = f"""name: "{model_name}"
platform: "pytorch_libtorch"
max_batch_size: 1
input [
  {{
    name: "input_ids"
    data_type: TYPE_INT64
    dims: [ -1 ]
  }},
  {{
    name: "attention_mask"
    data_type: TYPE_FP32
    dims: [ -1 ]
  }}
]
output [
  {{
    name: "output"
    data_type: TYPE_FP32
    dims: [ 1 ]
  }}
]
optimization {{
  execution_accelerators {{
    cpu_execution_accelerator : [{{
      name : "openvino"
    }}]
  }}
}}
parameters: {{
key: "INFERENCE_MODE"
    value: {{
    string_value: "true"
    }}
}}
instance_group [
  {{
    count: 1
    kind: KIND_CPU
  }}
]"""

with open(f"model_repository/{model_name}/config.pbtxt", 'w') as file:
    file.write(config)
