<a href="https://colab.research.google.com/github/ShinAsakawa/ShinAsakawa.github.io/blob/master/2023notebooks/2023_0603sentenceBERT_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# sentence BERT

覚書：`from transformers import BertModel` は TensorFlow-io を仮定しているらしい。
かつ，tensorflow-io は M1 Mac では動作しない。2023 年 9 月 18 日現在。
仕方がないので，Intel Mac または colab で実行せざるを得ない。


In [None]:
import IPython
isColab = 'google.colab' in str(IPython.get_ipython())
if isColab:
    !pip install 'fugashi'
    !pip install 'ipadic'
    !pip install transformers

In [None]:
import torch
from transformers import BertJapaneseTokenizer, BertModel
from scipy.stats import pearsonr

class SentenceBertJapanese:
    def __init__(self, model_name_or_path, device=None):
        self.tokenizer = BertJapaneseTokenizer.from_pretrained(model_name_or_path)
        self.model = BertModel.from_pretrained(model_name_or_path)
        self.model.eval()

        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(device)
        self.model.to(device)

    def _mean_pooling(self, model_output, attention_mask):
        #First element of model_output contains all token embeddings
        token_embeddings = model_output[0]

        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    @torch.no_grad()
    def encode(self, sentences, batch_size=8):
        all_embeddings = []
        iterator = range(0, len(sentences), batch_size)
        for batch_idx in iterator:
            batch = sentences[batch_idx:batch_idx + batch_size]

            encoded_input = self.tokenizer.batch_encode_plus(
              batch,
              padding="longest",
              truncation=True, return_tensors="pt").to(self.device)
            model_output = self.model(**encoded_input)
            sentence_embeddings = self._mean_pooling(
              model_output,
              encoded_input["attention_mask"]).to('cpu')

            all_embeddings.extend(sentence_embeddings)

        # return torch.stack(all_embeddings).numpy()
        return torch.stack(all_embeddings)


MODEL_NAME = "sonoisa/sentence-bert-base-ja-mean-tokens-v2"  # <- v2です。
model = SentenceBertJapanese(MODEL_NAME)

In [None]:
sentences = ["暴走したAI", "暴走した人工知能"]

sentence_embeddings = model.encode(sentences, batch_size=8)
##print("Sentence embeddings:", sentence_embeddings)

print(f'相関係数:{pearsonr(sentence_embeddings[0], sentence_embeddings[1])[0]:.3f}')

In [None]:
print(f'相関係数:{pearsonr(sentence_embeddings[0], sentence_embeddings[1])[0]:.3f}')