In [11]:
import numpy as np
import torch
from transformers import BertJapaneseTokenizer, BertModel

MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'
tokenizer = BertJapaneseTokenizer.from_pretrained(MODEL_NAME)
model = BertModel.from_pretrained(MODEL_NAME)

Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-whole-word-masking were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [12]:
sentences1 = ["The cat is sleeping.", "The dog is running.",
              "The mouse is hiding.", "The fish is swimming.", "The bird is singing."]
sentences2 = ["The dog is barking.", "The cat is playing.",
              "The bird is flying.", "The fish is jumping.", "The mouse is eating."]

In [13]:
def sentence_to_vector(model, tokenizer, sentence):
    tokens = tokenizer(sentence, add_special_tokens=True)["input_ids"]
    input = torch.tensor(tokens).unsqueeze(0)
    with torch.no_grad():
        outputs = model(input, output_hidden_states=True)
        last_hidden_state = outputs[0][:, 0, :]
        averaged_hidden_state = last_hidden_state.mean(dim=0).unsqueeze(0)
    return averaged_hidden_state

In [14]:
def calc_similarity(sentence1, sentence2):
    print("{}\n{}".format(sentence1, sentence2))

    sentence_vector1 = sentence_to_vector(model, tokenizer, sentence1)
    sentence_vector2 = sentence_to_vector(model, tokenizer, sentence2)

    # Reshape the tensors to 1D
    sentence_vector1 = sentence_vector1.reshape(-1)
    sentence_vector2 = sentence_vector2.reshape(-1)

    # Calculate the cosine similarity
    similarity = float(torch.nn.functional.cosine_similarity(
        sentence_vector1, sentence_vector2, dim=0).detach().numpy().copy())
    print("Similarity:", similarity)

    return similarity

In [15]:
def calc_average_similarity(sentences):
    similarities = []
    for i in range(1, len(sentences)):
        similarity = calc_similarity(sentences[0], sentences[i])
        similarities.append(similarity)
    average_similarity = sum(similarities) / len(similarities)
    print("Similarities:", similarities)
    print("Average similarity:", average_similarity)

    return average_similarity

In [16]:
calc_average_similarity(sentences1)

The cat is sleeping.
The dog is running.
Similarity: 0.9558551907539368
The cat is sleeping.
The mouse is hiding.
Similarity: 0.9288668632507324
The cat is sleeping.
The fish is swimming.
Similarity: 0.9353626370429993
The cat is sleeping.
The bird is singing.
Similarity: 0.9280246496200562
Similarities: [0.9558551907539368, 0.9288668632507324, 0.9353626370429993, 0.9280246496200562]
Average similarity: 0.9370273351669312


0.9370273351669312

In [17]:
def calc_average_similarity(sentences):
    similarities = []
    for i in range(len(sentences)):
        for j in range(i + 1, len(sentences)):
            similarity = calc_similarity(sentences[i], sentences[j])
            similarities.append(similarity)

    similarities = np.array(similarities)
    quartiles = np.quantile(similarities, [0, 0.25, 0.5, 0.75, 1])
    average_similarity = np.mean(similarities)

    return list(quartiles) + [average_similarity]

In [18]:
calc_average_similarity(sentences1)

The cat is sleeping.
The dog is running.
Similarity: 0.9558551907539368
The cat is sleeping.
The mouse is hiding.
Similarity: 0.9288668632507324
The cat is sleeping.
The fish is swimming.
Similarity: 0.9353626370429993
The cat is sleeping.
The bird is singing.
Similarity: 0.9280246496200562
The dog is running.
The mouse is hiding.
Similarity: 0.8939509391784668
The dog is running.
The fish is swimming.
Similarity: 0.918599545955658
The dog is running.
The bird is singing.
Similarity: 0.9109365940093994
The mouse is hiding.
The fish is swimming.
Similarity: 0.9141663908958435
The mouse is hiding.
The bird is singing.
Similarity: 0.9057633876800537
The fish is swimming.
The bird is singing.
Similarity: 0.941848874092102


[0.8939509391784668,
 0.9117440432310104,
 0.9233120977878571,
 0.9337386935949326,
 0.9558551907539368,
 0.9233375072479248]