<a href="https://colab.research.google.com/github/Misetsu/SBERT-Model/blob/main/SERT_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ライブラリをインストール

In [None]:
!pip install -U sentence-transformers | tail -n 1
!apt-get install mecab mecab-ipadic-utf8 python-mecab libmecab-dev | tail -n 1
!pip install mecab-python3 fugashi ipadic | tail -n 1

Successfully installed huggingface-hub-0.17.3 safetensors-0.4.0 sentence-transformers-2.2.2 sentencepiece-0.1.99 tokenizers-0.14.1 transformers-4.34.0
E: Unable to locate package python-mecab
Reading state information...
Successfully installed fugashi-1.3.0 ipadic-1.0.0 mecab-python3-1.0.8


In [None]:
!wget https://nlp.ist.i.kyoto-u.ac.jp/DLcounter/lime.cgi?down=https://nlp.ist.i.kyoto-u.ac.jp/nl-resource/JSNLI/jsnli_1.1.zip&name=JSNLI.zip
!mv *zip* jsnli_1.1.zip
!unzip jsnli_1.1.zip

--2023-10-10 00:57:56--  https://nlp.ist.i.kyoto-u.ac.jp/DLcounter/lime.cgi?down=https://nlp.ist.i.kyoto-u.ac.jp/nl-resource/JSNLI/jsnli_1.1.zip
Resolving nlp.ist.i.kyoto-u.ac.jp (nlp.ist.i.kyoto-u.ac.jp)... 219.94.192.32
Connecting to nlp.ist.i.kyoto-u.ac.jp (nlp.ist.i.kyoto-u.ac.jp)|219.94.192.32|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://nlp.ist.i.kyoto-u.ac.jp/nl-resource/JSNLI/jsnli_1.1.zip [following]
--2023-10-10 00:57:56--  https://nlp.ist.i.kyoto-u.ac.jp/nl-resource/JSNLI/jsnli_1.1.zip
Reusing existing connection to nlp.ist.i.kyoto-u.ac.jp:443.
HTTP request sent, awaiting response... 200 OK
Length: 44931163 (43M) [application/zip]
Saving to: ‘lime.cgi?down=https:%2F%2Fnlp.ist.i.kyoto-u.ac.jp%2Fnl-resource%2FJSNLI%2Fjsnli_1.1.zip’


2023-10-10 00:58:01 (9.63 MB/s) - ‘lime.cgi?down=https:%2F%2Fnlp.ist.i.kyoto-u.ac.jp%2Fnl-resource%2FJSNLI%2Fjsnli_1.1.zip’ saved [44931163/44931163]

Archive:  jsnli_1.1.zip
   creating: jsnli_1.1/
  infl

In [None]:
from google.colab import auth
auth.authenticate_user()

# モデルのロード関数

In [None]:
import os
import json
import tensorflow as tf
from sentence_transformers import models, SentenceTransformer
from transformers import BertJapaneseTokenizer, AdamW

def load_model(model_name, max_seq_length=75):
  word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
  pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode='mean')
  model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
  return model

# データの読み込み

JSNLIデータセット。二つの文章を比べて、「含意」、「矛盾」、または「どちらでもない」のラベルを付ける。

In [None]:
import math
from sentence_transformers import models, losses, datasets
from sentence_transformers import LoggingHandler, SentenceTransformer, util, InputExample
from sentence_transformers.evaluation import ParaphraseMiningEvaluator
import logging
from datetime import datetime
import sys
import os
import gzip
import csv
import random

logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])

In [None]:
# 学習データ
def add_to_samples(data, sent1, sent2, label):
  if sent1 not in data:
    data[sent1] = {'contradiction': set(), 'entailment': set(), 'neutral': set()}
  data[sent1][label].add(sent2)

def load_data(filename):
  data = {}
  with open(filename, "r") as f:
    lines = f.readlines()
    lines = [line.strip().split("\t") for line in lines]
    rows = [[line[0], line[1].replace(" ", ""), line[2].replace(" ", "")] for line in lines]
    for row in rows:
      label = row[0]
      sent1 = row[1]
      sent2 = row[2]
      add_to_samples(data, sent1, sent2, label)
      add_to_samples(data, sent2, sent1, label)

  # 「含意」の文章（似たような文章）を１つのテストデータとしてまとめる
  samples = []
  for sent1, others in data.items():
    if len(others['entailment']) > 0 and len(others['contradiction']) > 0:
      samples.append(InputExample(texts=[sent1, random.choice(list(others['entailment'])), random.choice(list(others['contradiction']))]))
      samples.append(InputExample(texts=[random.choice(list(others['entailment'])), sent1, random.choice(list(others['contradiction']))]))

  # 重複したデータを消す
  dedup = {}
  for sample in samples:
    key = "".join(sample.texts)
    dedup[key] = sample
  return list(dedup.values())

train_samples = load_data("jsnli_1.1/train_w_filtering.tsv")

# 学習データの件数
len(train_samples)

294576

In [None]:
for sample in train_samples[:3]:
  print(sample.texts)

['ガレージで、壁にナイフを投げる男。', 'ガレージに男がいます。', '男が台所のテーブルで本を読んでいます。'] 
['ガレージに男がいます。', 'ガレージで、壁にナイフを投げる男。', '男が台所のテーブルで本を読んでいます。'] 
['ラップトップコンピューターを使用して机に座っている若い白人男。', '人は椅子に座っています。', '黒人はデスクトップコンピューターを使用します。'] 


In [None]:
# 評価データ
def load_data_for_paraphrase_mining(filename):
  sentences_map = {} # id -> sent
  sentences_reverse_map = {} # sent -> id
  duplicates_list = [] # (id1, id2)

  def register(sent):
    if sent not in sentences_reverse_map:
      id = str(len(sentences_reverse_map))
      sentences_reverse_map[sent] = id
      sentences_map[id] = sent
      return id
    else:
      return sentences_reverse_map[sent]

  with open(filename, "r") as f:
    lines = f.readlines()
    lines = [line.strip().split("\t") for line in lines]
    rows = [[line[0], line[1].replace(" ", ""), line[2].replace(" ", "")] for line in lines]
    for row in rows:
      label = row[0]
      sent1 = row[1]
      sent2 = row[2]
      ids = [register(sent) for sent in [sent1, sent2]]
      if label == "entailment":
        duplicates_list.append(tuple(ids))
  return sentences_map, duplicates_list

sentences_map, duplicates_list = load_data_for_paraphrase_mining("jsnli_1.1/dev.tsv")

# 評価データの件数
len(sentences_map)

5809

In [None]:
# 「含意」ラベル（類似文章）の件数
len(duplicates_list)

1432

# 学習

In [None]:
model_name = "cl-tohoku/bert-base-japanese-whole-word-masking"

model_save_path = "./strf_{}".format(model_name.replace("/","_"))
model_save_path

'./strf_cl-tohoku_bert-base-japanese-whole-word-masking'

In [None]:
train_batch_size = 48
max_seq_length = 75
num_epochs = 1

model = load_model(model_name, max_seq_length=max_seq_length)

Downloading (…)lve/main/config.json:   0%|          | 0.00/479 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/445M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/110 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/258k [00:00<?, ?B/s]

In [None]:
model.tokenizer.__class__.__name__

'BertJapaneseTokenizer'

In [None]:
train_dataloader = datasets.NoDuplicatesDataLoader(train_samples, batch_size=train_batch_size)

train_loss = losses.MultipleNegativesRankingLoss(model)

dev_evaluator = ParaphraseMiningEvaluator(sentences_map, duplicates_list, name="paramin-jsnli-dev")

warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1)
logging.info("Warmup-steps: {}".format(warmup_steps))

In [None]:
model.fit(train_objectives=[(train_dataloader, train_loss)],
          evaluator=dev_evaluator,
          epochs=num_epochs,
          evaluation_steps=int(len(train_dataloader)*0.1),
          warmup_steps=warmup_steps,
          output_path=model_save_path,
          use_amp=False
          )

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/6137 [00:00<?, ?it/s]

In [None]:
!zip -r /content/strf_cl-tohoku_bert_model.zip /content/strf_cl-tohoku_bert-base-japanese-whole-word-masking

  adding: content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/ (stored 0%)
  adding: content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/vocab.txt (deflated 49%)
  adding: content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/config.json (deflated 48%)
  adding: content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/pytorch_model.bin (deflated 7%)
  adding: content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/modules.json (deflated 53%)
  adding: content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/special_tokens_map.json (deflated 42%)
  adding: content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/README.md (deflated 58%)
  adding: content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/eval/ (stored 0%)
  adding: content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/eval/paraphrase_mining_evaluation_paramin-jsnli-dev_results.csv (deflated 52%)
  adding: content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/toke

In [None]:
from google.colab import files
files.download("/content/strf_cl-tohoku_bert_model.zip")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
!unzip "/content/drive/MyDrive/strf_cl-tohoku_bert_model.zip" -d "./"

Archive:  /content/drive/MyDrive/strf_cl-tohoku_bert_model.zip
   creating: ./content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/
  inflating: ./content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/vocab.txt  
  inflating: ./content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/config.json  
  inflating: ./content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/pytorch_model.bin  
  inflating: ./content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/modules.json  
  inflating: ./content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/special_tokens_map.json  
  inflating: ./content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/README.md  
   creating: ./content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/eval/
  inflating: ./content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/eval/paraphrase_mining_evaluation_paramin-jsnli-dev_results.csv  
  inflating: ./content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/token

# 蒸留モデル

In [None]:
import torch
from torch.utils.data import DataLoader
from sentence_transformers.datasets import ParallelSentencesDataset
from sentence_transformers import models, losses, evaluation, SentenceTransformer
from sentence_transformers.evaluation import ParaphraseMiningEvaluator

In [None]:
teacher_model_name = "/content/content/strf_cl-tohoku_bert-base-japanese-whole-word-masking"
teacher_model = SentenceTransformer(teacher_model_name)

In [None]:
output_path = "./strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking"

In [None]:
student_model = SentenceTransformer(teacher_model_name)
auto_model = student_model._first_module().auto_model
layers_to_keep = [1, 4, 7, 10]

new_layers = torch.nn.ModuleList([layer_module for i, layer_module in enumerate(auto_model.encoder.layer) if i in layers_to_keep])
auto_model.encoder.layer = new_layers
auto_model.config.num_hidden_layers = len(layers_to_keep)

In [None]:
inference_batch_size = 48
train_batch_size = 48

## データの読み込み

In [None]:
def load_sentences(filename):
  data = []
  with open(filename, "r") as f:
    lines = f.readlines()
    lines = [line.strip().split("\t") for line in lines]
    rows = [[line[0], line[1].replace(" ", ""), line[2].replace(" ", "")] for line in lines]
    for row in rows:
      label = row[0]
      sent1 = row[1]
      sent2 = row[2]
      data.append(sent1)
      data.append(sent2)
    return list(set(data))

train_sentences = load_sentences("jsnli_1.1/train_w_filtering.tsv")
len(train_sentences)

584921

In [None]:
train_data = ParallelSentencesDataset(student_model=student_model, teacher_model=teacher_model, batch_size=inference_batch_size, use_embedding_cache=False)
train_data.add_dataset([[sent] for sent in train_sentences], max_sentence_length=75)

In [None]:
train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size)
train_loss = losses.MSELoss(model=student_model)

In [None]:
def load_data_for_paraphrase_mining(filename):
  sentences_map = {} # id -> sent
  sentences_reverse_map = {} # sent -> id
  duplicates_list = [] # (id1, id2)

  def register(sent):
    if sent not in sentences_reverse_map:
      id = str(len(sentences_reverse_map))
      sentences_reverse_map[sent] = id
      sentences_map[id] = sent
      return id
    else:
      return sentences_reverse_map[sent]

  with open(filename, "r") as f:
    lines = f.readlines()
    lines = [line.strip().split("\t") for line in lines]
    rows = [[line[0], line[1].replace(" ", ""), line[2].replace(" ", "")] for line in lines]
    for row in rows:
      label = row[0]
      sent1 = row[1]
      sent2 = row[2]
      ids = [register(sent) for sent in [sent1, sent2]]
      if label == "entailment":
        duplicates_list.append(tuple(ids))
  return sentences_map, duplicates_list

sentences_map, duplicates_list = load_data_for_paraphrase_mining("jsnli_1.1/dev.tsv")
len(sentences_map)

5809

In [None]:
dev_sentences = load_sentences("jsnli_1.1/dev.tsv")
dev_evaluator_mse = evaluation.MSEEvaluator(dev_sentences, dev_sentences, teacher_model=teacher_model)
dev_evaluator = ParaphraseMiningEvaluator(sentences_map, duplicates_list, name="paramin-jsnli-dev")

In [None]:
student_model.fit(train_objectives=[(train_dataloader, train_loss)],
                  evaluator=evaluation.SequentialEvaluator([dev_evaluator, dev_evaluator_mse]),
                  epochs=1,
                  warmup_steps=1000,
                  evaluation_steps=5000,
                  output_path=output_path,
                  save_best_model=True,
                  optimizer_class= AdamW,
                  optimizer_params={'lr': 1e-4, 'eps': 1e-6, 'correct_bias': False},
                  use_amp=False)



Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/12173 [00:00<?, ?it/s]

  labels = torch.tensor(labels)


In [None]:
!zip -r /content/strf_distilled_cl-tohoku_bert_model.zip /content/strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking

  adding: content/strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking/ (stored 0%)
  adding: content/strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking/config.json (deflated 48%)
  adding: content/strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking/1_Pooling/ (stored 0%)
  adding: content/strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking/1_Pooling/config.json (deflated 47%)
  adding: content/strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking/sentence_bert_config.json (deflated 4%)
  adding: content/strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking/special_tokens_map.json (deflated 42%)
  adding: content/strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking/config_sentence_transformers.json (deflated 28%)
  adding: content/strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking/eval/ (stored 0%)
  adding: content/strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking/eval/paraphrase_mining_ev

# 評価

In [None]:
!unzip "/content/drive/MyDrive/strf_cl-tohoku_bert_model.zip" -d "./"

Archive:  /content/drive/MyDrive/strf_cl-tohoku_bert_model.zip
   creating: ./content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/
  inflating: ./content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/vocab.txt  
  inflating: ./content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/config.json  
  inflating: ./content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/pytorch_model.bin  
  inflating: ./content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/modules.json  
  inflating: ./content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/special_tokens_map.json  
  inflating: ./content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/README.md  
   creating: ./content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/eval/
  inflating: ./content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/eval/paraphrase_mining_evaluation_paramin-jsnli-dev_results.csv  
  inflating: ./content/strf_cl-tohoku_bert-base-japanese-whole-word-masking/token

In [None]:
!unzip "/content/drive/MyDrive/strf_distilled_cl-tohoku_bert_model.zip" -d "./"

Archive:  /content/drive/MyDrive/strf_distilled_cl-tohoku_bert_model.zip
   creating: ./content/strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking/
  inflating: ./content/strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking/config.json  
   creating: ./content/strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking/1_Pooling/
  inflating: ./content/strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking/1_Pooling/config.json  
  inflating: ./content/strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking/sentence_bert_config.json  
  inflating: ./content/strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking/special_tokens_map.json  
  inflating: ./content/strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking/config_sentence_transformers.json  
   creating: ./content/strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking/eval/
  inflating: ./content/strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking/eval/pa

In [None]:
model = SentenceTransformer("/content/content/strf_cl-tohoku_bert-base-japanese-whole-word-masking")
dis_model = SentenceTransformer("/content/content/strf_distilled_cl-tohoku_bert-base-japanese-whole-word-masking")

In [None]:
dev_evaluator = ParaphraseMiningEvaluator(sentences_map, duplicates_list, name="paramin-jsnli-dev")

In [None]:
dev_evaluator(model)

0.13061097484805062

In [None]:
from sklearn.decomposition import PCA
import numpy as np
import torch
from sentence_transformers import models, losses, evaluation, SentenceTransformer

In [None]:
new_dimension = 128

In [None]:
def load_sentences(filename):
  data = []
  with open(filename, "r") as f:
    lines = f.readlines()
    lines = [line.strip().split("\t") for line in lines]
    rows = [[line[0], line[1].replace(" ", ""), line[2].replace(" ", "")] for line in lines]
    for row in rows:
      label = row[0]
      sent1 = row[1]
      sent2 = row[2]
      data.append(sent1)
      data.append(sent2)
    return list(set(data))

train_sentences = load_sentences("jsnli_1.1/train_w_filtering.tsv")
len(train_sentences)

584921

In [None]:
train_embeddings = dis_model.encode(train_sentences, convert_to_numpy=True)

In [None]:
pca = PCA(n_components=new_dimension)
pca.fit(train_embeddings)
pca_comp = np.asarray(pca.components_)

In [None]:
dense = models.Dense(in_features=dis_model.get_sentence_embedding_dimension(), out_features=new_dimension, bias=False, activation_function=torch.nn.Identity())
dense.linear.weight = torch.nn.Parameter(torch.tensor(pca_comp))
dis_model.add_module('dense', dense)

In [None]:
dev_sentences = load_sentences("jsnli_1.1/dev.tsv")
dev_evaluator_mse = evaluation.MSEEvaluator(dev_sentences, dev_sentences, teacher_model=model)
dev_evaluator = ParaphraseMiningEvaluator(sentences_map, duplicates_list, name="paramin-jsnli-dev")

In [None]:
dev_evaluator(dis_model)

0.11169841647482813