# sBERT

In [None]:
from transformers import AutoModel, AutoTokenizer
import torch.nn as nn
import torch


class sBERTRegressor(nn.Module):
    def __init__(self, model_name, is_freeze) -> None:
        super(sBERTRegressor, self).__init__()
        self.sbert = AutoModel.from_pretrained(model_name)
        self.regressor = nn.Linear(self.sbert.config.hidden_size, 1)

        if is_freeze:
            for param in self.sbert.parameters():
                param.requires_grad = False

    def forward(self, input_ids, attention_mask):
        outputs = self.sbert(input_ids=input_ids, attention_mask=attention_mask)
        cls_embedding = outputs.last_hidden_state[:, 0, :]  # [B, CLS Index, h]
        regression_output = self.regressor(cls_embedding)
        return regression_output.squeeze()
    
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
model_name = "snunlp/KR-SBERT-V40K-klueNLI-augSTS" 
is_freeze = False  

model = sBERTRegressor(model_name, is_freeze).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

trained_model = sBERTRegressor(model_name, is_freeze).to(device)
model_path = "checkpoints/sBERT_2024-07-03 17:15:29.pth"
trained_model.load_state_dict(torch.load(model_path))

In [None]:
example = "으아아아"
max_length = 128
encoded_text = tokenizer(example, return_tensors="pt",max_length=max_length,padding="max_length",truncation=True)
iids = encoded_text["input_ids"].to(device)
atm = encoded_text["attention_mask"].to(device)

In [None]:
print("학습되지 않은 모델:",round(float(model(iids, atm)), 3), "점")
print("학습된 모델:",round(float(trained_model(iids, atm)), 3), "점")