<a href="https://colab.research.google.com/github/Dash400air/Bert_task/blob/main/STS_B.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **STS-B** 2文が意味的にどれだけ類似しているかをスコア1~5で判別

# Setup

In [1]:
!pip install sentence-transformers

Collecting sentence-transformers
  Downloading sentence-transformers-2.0.0.tar.gz (85 kB)
[K     |████████████████████████████████| 85 kB 3.8 MB/s 
[?25hCollecting transformers<5.0.0,>=4.6.0
  Downloading transformers-4.10.2-py3-none-any.whl (2.8 MB)
[K     |████████████████████████████████| 2.8 MB 18.3 MB/s 
Collecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 64.2 MB/s 
[?25hCollecting huggingface-hub
  Downloading huggingface_hub-0.0.17-py3-none-any.whl (52 kB)
[K     |████████████████████████████████| 52 kB 2.1 MB/s 
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 56.9 MB/s 
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB)
[K     |██

In [2]:
import pandas as pd
import numpy as np

import torch

from sentence_transformers import SentenceTransformer

In [3]:
!wget "http://ixa2.si.ehu.es/stswiki/images/4/48/Stsbenchmark.tar.gz"

--2021-09-16 14:07:17--  http://ixa2.si.ehu.es/stswiki/images/4/48/Stsbenchmark.tar.gz
Resolving ixa2.si.ehu.es (ixa2.si.ehu.es)... 158.227.106.100
Connecting to ixa2.si.ehu.es (ixa2.si.ehu.es)|158.227.106.100|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: http://ixa2.si.ehu.eus/stswiki/images/4/48/Stsbenchmark.tar.gz [following]
--2021-09-16 14:07:17--  http://ixa2.si.ehu.eus/stswiki/images/4/48/Stsbenchmark.tar.gz
Resolving ixa2.si.ehu.eus (ixa2.si.ehu.eus)... 158.227.106.100
Connecting to ixa2.si.ehu.eus (ixa2.si.ehu.eus)|158.227.106.100|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 409630 (400K) [application/x-gzip]
Saving to: ‘Stsbenchmark.tar.gz’


2021-09-16 14:07:19 (373 KB/s) - ‘Stsbenchmark.tar.gz’ saved [409630/409630]



In [4]:
!tar -xzvf "/content/Stsbenchmark.tar.gz"

stsbenchmark/
stsbenchmark/readme.txt
stsbenchmark/sts-test.csv
stsbenchmark/correlation.pl
stsbenchmark/LICENSE.txt
stsbenchmark/sts-dev.csv
stsbenchmark/sts-train.csv


In [5]:
train = pd.read_csv("/content/stsbenchmark/sts-train.csv", sep='\t', 
                    header=None, error_bad_lines=False, 
                    names=['type', 'source', 'ver', 'num', 'score', 's1', 's2'])
valid = pd.read_csv("/content/stsbenchmark/sts-dev.csv", sep='\t', 
                    header=None, error_bad_lines=False, 
                    names=['type', 'source', 'ver', 'num', 'score', 's1', 's2'])

# Preprocessing

In [6]:
train.head()

Unnamed: 0,type,source,ver,num,score,s1,s2
0,main-captions,MSRvid,2012test,1,5.0,A plane is taking off.,An air plane is taking off.
1,main-captions,MSRvid,2012test,4,3.8,A man is playing a large flute.,A man is playing a flute.
2,main-captions,MSRvid,2012test,5,3.8,A man is spreading shreded cheese on a pizza.,A man is spreading shredded cheese on an uncoo...
3,main-captions,MSRvid,2012test,6,2.6,Three men are playing chess.,Two men are playing chess.
4,main-captions,MSRvid,2012test,9,4.25,A man is playing the cello.,A man seated is playing the cello.


In [7]:
valid.head()

Unnamed: 0,type,source,ver,num,score,s1,s2
0,main-captions,MSRvid,2012test,0,5.0,A man with a hard hat is dancing.,A man wearing a hard hat is dancing.
1,main-captions,MSRvid,2012test,2,4.75,A young child is riding a horse.,A child is riding a horse.
2,main-captions,MSRvid,2012test,3,5.0,A man is feeding a mouse to a snake.,The man is feeding a mouse to the snake.
3,main-captions,MSRvid,2012test,7,2.4,A woman is playing the guitar.,A man is playing guitar.
4,main-captions,MSRvid,2012test,8,2.75,A woman is playing the flute.,A man is playing a flute.


## Vectorizing

In [11]:
model = SentenceTransformer('all-mpnet-base-v2')

In [12]:
def get_vectors(df):
    s1 = df['s1'].values
    s2 = df['s2'].values

    s1_vectors = model.encode(s1)
    s2_vectors = model.encode(s2)

    return s1_vectors, s2_vectors

In [13]:
train_s1_vectors, train_s2_vectors = get_vectors(train)
valid_s1_vectors, valid_s2_vectors = get_vectors(valid)

## Get Cosine Similarity

In [14]:
def get_cossimlarity(s1_vectors, s2_vectors):
    s1_tensor = torch.from_numpy(s1_vectors).clone()
    s2_tensor = torch.from_numpy(s2_vectors).clone()

    sims = []
    for s1_vec, s2_vec in zip(s1_tensor, s2_tensor):
        sim = torch.cosine_similarity(s1_vec, s2_vec, dim=0)
        sim = sim.detach().numpy().copy()
        sims.append(sim)
    return sims

In [15]:
train_sims = get_cossimlarity(train_s1_vectors, train_s2_vectors)
valid_sims = get_cossimlarity(valid_s1_vectors, valid_s2_vectors)

In [16]:
train['cos_similarity'] = train_sims
valid['cos_similarity'] = valid_sims

In [17]:
train.head()

Unnamed: 0,type,source,ver,num,score,s1,s2,cos_similarity
0,main-captions,MSRvid,2012test,1,5.0,A plane is taking off.,An air plane is taking off.,0.95965767
1,main-captions,MSRvid,2012test,4,3.8,A man is playing a large flute.,A man is playing a flute.,0.8690923
2,main-captions,MSRvid,2012test,5,3.8,A man is spreading shreded cheese on a pizza.,A man is spreading shredded cheese on an uncoo...,0.87193775
3,main-captions,MSRvid,2012test,6,2.6,Three men are playing chess.,Two men are playing chess.,0.8031405
4,main-captions,MSRvid,2012test,9,4.25,A man is playing the cello.,A man seated is playing the cello.,0.90234494


In [18]:
valid.head()

Unnamed: 0,type,source,ver,num,score,s1,s2,cos_similarity
0,main-captions,MSRvid,2012test,0,5.0,A man with a hard hat is dancing.,A man wearing a hard hat is dancing.,0.9967053
1,main-captions,MSRvid,2012test,2,4.75,A young child is riding a horse.,A child is riding a horse.,0.9509652
2,main-captions,MSRvid,2012test,3,5.0,A man is feeding a mouse to a snake.,The man is feeding a mouse to the snake.,0.85433936
3,main-captions,MSRvid,2012test,7,2.4,A woman is playing the guitar.,A man is playing guitar.,0.59489024
4,main-captions,MSRvid,2012test,8,2.75,A woman is playing the flute.,A man is playing a flute.,0.73564374


# Linear Regression

In [19]:
from sklearn.linear_model import LinearRegression

In [20]:
X_train, X_valid = train['cos_similarity'].values.reshape(-1, 1), valid['cos_similarity'].values.reshape(-1, 1)
y_train, y_valid = train['score'].values.reshape(-1, 1), valid['score'].values.reshape(-1, 1)

In [21]:
lr = LinearRegression()

In [22]:
lr.fit(X_train, y_train)

LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None, normalize=False)

In [23]:
pred = lr.predict(X_valid)

In [24]:
valid['prediction'] = pred

In [25]:
valid[['s1', 's2', 'score', 'prediction']].sample(10)

Unnamed: 0,s1,s2,score,prediction
1175,Selenski descended down the wall and used the ...,Selenski used the mattress to scale a 10-foot ...,3.0,2.987624
553,Two men are playing a game of Scrabble together.,The two women are playing a game.,2.2,1.244518
231,A baby is crawling happily.,A cat is walking on hardwood floor.,0.1,0.555896
334,A person is standing underneath an overpass ne...,A person walks a dog along the water's edge.,0.2,0.399342
109,People are playing baseball.,The cricket player hit the ball.,0.5,1.238011
1326,Indian police round up all five suspects in Mu...,Mumbai police arrest fifth suspect in gang-rap...,3.8,3.599162
536,The man is in a deserted field.,The man is outside in the field.,4.0,2.424569
100,A woman is riding on a horse.,A man is turning over tables in anger.,0.0,-0.710159
1285,"Romney leads rivals, not Obama, in fundraising",Romney: I left all management of Bain in 1999,0.8,1.291568
1325,Lebanon businesses strike in protest at politi...,Tunisia president says confident can overcome ...,0.6,1.058821


# Score

In [26]:
from sklearn.metrics import mean_squared_error
from scipy.stats import pearsonr, spearmanr

In [27]:
RMSE = np.sqrt(mean_squared_error(y_valid, pred))
print(RMSE)

0.7361655425935434


In [28]:
y_pred = pred.squeeze()

In [29]:
print(spearmanr(y_valid.squeeze(), y_pred))
print(pearsonr(y_valid.squeeze(), y_pred))

SpearmanrResult(correlation=0.8743784663328783, pvalue=0.0)
(0.8721285941029627, 0.0)


# Test

In [66]:
def examine(s1, s2, lr):
    df = pd.DataFrame({'s1': s1, 's2': s2})
    vec1, vec2 = get_vectors(df)
    sims = np.array(get_cossimlarity(vec1, vec2))

    pred = lr.predict(sims.reshape(-1, 1))

    for t1, t2, y in zip(s1, s2, pred.tolist()):
        if y[0] < 0:
            y = np.array([0.00])
        print('文1：', t1)
        print('文2：', t2)
        print('score：', round(*y, ndigits=3))
        print('----------------------')

In [67]:
s1 = ["He won the game.",
      "I hate apples",
      "Reading books is thought to be good for mental health.",
      "In a few year, people are able to travel moon without much money."]

s2 = ["He beated the rival.",
      "I love apples",
      "He always turns up late.",
      "Space travel will be affordable soon for any people."]

examine(s1, s2, lr)

文1： He won the game.
文2： He beated the rival.
score： 2.632
----------------------
文1： I hate apples
文2： I love apples
score： 3.657
----------------------
文1： Reading books is thought to be good for mental health.
文2： He always turns up late.
score： 0.0
----------------------
文1： In a few year, people are able to travel moon without much money.
文2： Space travel will be affordable soon for any people.
score： 2.788
----------------------
