## 順位の実装  
・アプリケーションから使用するスニペットのみ取得

##### モジュール

In [34]:
device = "cuda:1"
# device = "cuda:0"

In [35]:
import pickle
def read_bin(filename):
    with open(filename,'rb') as bf:
        bin_data = pickle.load(bf)
    return bin_data

def save_bin(filename,data):
    with open(filename,'wb') as bf:
        pickle.dump(data,bf)

In [18]:
from typing import List,Dict
from collections import defaultdict
import numpy
import pickle
import csv
import os

import torch
from transformers import Wav2Vec2ForPreTraining,Wav2Vec2Processor

In [20]:
num = str(1)
# num = str(2)
# num = str(3)
# num = str(4)
# num = str(5)

##### データ

In [36]:
# # =============知識グラフ
nodes=dict()
with open('new_data/all_BirdDBnode.tsv', mode='r', newline='', encoding='utf-8') as f:
    for row in csv.DictReader(f, delimiter = '\t'):
        nodes[row["id"]] = row

In [37]:
import pandas as pd
test = pd.read_parquet('BirdModel_remove_test/data_cross-valid/test_'+num+'.parquet')#10293
test_index = list(test.index)

In [38]:
# testのid一覧,英語正式名称一覧を作成
WavDesc2Kana=read_bin('new_data/WavDesc2Kana.bin')
name2jid=read_bin('new_data/name2jid.bin')
tuples_bid_jid = read_bin('new_data/tuples_bid_jid.bin')

test_bids = []
for test_name in list(test["description"]):
    test_jid = name2jid[WavDesc2Kana[test_name]]
    for bid,jid in tuples_bid_jid:
        if jid == test_jid:
            test_bids.append(bid)
            break

test_queries = []
for test_bid in test_bids:
    test_queries.append(nodes[test_bid]["en_name"])

In [39]:
if len(test_index) == len(test_bids) == len(test_queries):
    print("OK")

OK


In [40]:
del test,WavDesc2Kana,name2jid,tuples_bid_jid

##### 関数

In [41]:
# ============= 2つのListのCos類似度算出関数
import torch

def cos_sim(v1, v2):
    # リストをテンソルに変換し、GPUに転送
    min_l = min(len(v1), len(v2))
    
    # v1 = torch.tensor(v1[:min_l]).float().cuda()
    # v2 = torch.tensor(v2[:min_l]).float().cuda()
    v1 = torch.tensor(v1[:min_l], device=device, dtype=torch.float32)
    v2 = torch.tensor(v2[:min_l], device=device, dtype=torch.float32)

    # Cos類似度の計算
    dot_product = torch.dot(v1, v2)
    norm_v1 = torch.norm(v1)
    norm_v2 = torch.norm(v2)
    cos_sim = dot_product / (norm_v1 * norm_v2)
    
    return cos_sim.item()  # 結果をPythonのfloat型に変換して返す

def to_katakana(text):
    # kakasiオブジェクトを作成
    kakasi_instance = kakasi()
    kakasi_instance.setMode("J", "K")  # J（漢字）をH（ひらがな）に変換
    kakasi_instance.setMode("H", "K")  # H（ひらがな）をK（カタカナ）に変換
    
    # カタカナに変換
    conv = kakasi_instance.getConverter()
    katakana_text = conv.do(text)
    return katakana_text

In [42]:
#=================================================マルチモーダル検索用関数
def concat_vecs(query,lang,Wikidata_id,inp_Svec):
    inp_Lvec = [0]*768
    inp_Gvec = [0]*64
    inp_Svec = [0]*256

    if query != None:
        if lang=="en":
            en_tokens = en_tokenizer(query, return_tensors="pt", padding=True, truncation=True)
            with torch.no_grad():
                en_model.eval()
                output = en_model(**en_tokens)
        else:
            query = to_katakana(query)
            ja_tokens = ja_tokenizer(query, return_tensors="pt", padding=True, truncation=True)
            with torch.no_grad():
                ja_model.eval()
                output = ja_model(**ja_tokens)
        inp_Lvec = output.last_hidden_state[0][0].tolist()# queryの分散表現

    if Wikidata_id != None and Wikidata_id in bid2Gvec:
        inp_Gvec = bid2Gvec[Wikidata_id]

    return inp_Lvec+inp_Gvec+inp_Svec

In [43]:
print(1)
en_concat_vecs_et=read_bin('new_data/en_concat_vecs.bin')
print(2)

1
2


検索内リスト化関数

In [26]:
import numpy as np
import scipy.io.wavfile as wavfile

def Vecs_tolist(vec,processor_mode):
    out=[]
    v=np.mean(vec,axis=1)
    return v.tolist()[0]

In [44]:
import csv

def multi_Search_topN_SLG(concat_vecs,input_v,filename,i,num):
    max_sim=dict()
    c2 = 0
        
    for bid,v in concat_vecs:
        print('\r%d : %d / %d' %(i,c2, len(concat_vecs)), end='')
        c2 =  c2 + 1
        cos_sim_value = cos_sim(input_v,v)
        
        if bid not in max_sim:
            max_sim[bid] = cos_sim_value
        else:
            if max_sim[bid] < cos_sim_value:
                max_sim[bid] = cos_sim_value
    save_bin('csv_cv_'+flg+"_"+num+'/'+filename+'.bin',max_sim)

## input_vecs_list_ALGの作成

In [45]:
from transformers import BertModel, BertTokenizer
en_model = BertModel.from_pretrained('models/en_model')
en_tokenizer = BertTokenizer.from_pretrained('models/en_tokenizer')
bid2Gvec=read_bin('new_data/bid2Gvec.bin')

In [67]:
# num = str(1)
# num = str(2)
# num = str(3)
# num = str(4)
num = str(5)

In [68]:
from contextlib import redirect_stdout
from transformers import Wav2Vec2ForPreTraining,Wav2Vec2FeatureExtractor

with redirect_stdout(open(os.devnull, 'w')):#一時的に出力を無効化
    path = "./models/cv"+str(num)+"/"
    processor = Wav2Vec2FeatureExtractor.from_pretrained("patrickvonplaten/wav2vec2-base-v2")# Wav2Vec2Processor => Wav2Vec2FeatureExtractor
    model = Wav2Vec2ForPreTraining.from_pretrained(path)# Wav2Vec2ForCTC => Wav2Vec2ForPreTraining
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

Some weights of the model checkpoint at ./models/cv5/ were not used when initializing Wav2Vec2ForPreTraining: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForPreTraining from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForPreTraining from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForPreTraining were not initialized from the model checkpoint at ./models/cv5/ and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this mo

In [69]:
for slg_i in range(5):
    SLG = [(0,1,0),(0,0,1),(0,1,1),(1,0,1),(1,1,0),(1,1,1)]
    if_S,if_L,if_G = SLG[slg_i]
    
    input_vecs_list = []
    
    # n = 10 #結果をいくつ表示するか
    processor_mode = False #未学習の埋め込みを使用するモード
    lang = "en"
    
    c = 0
    
    for i in range(len(test_index)):
        idx = test_index[i]
        query = None
        Wikidata_id = None
        print('\r%d / %d' %(c, len(test_index)), end='')
        c = c + 1
        
        if if_S == 1:
            sample_rate, waveform = wavfile.read("wav/"+str(idx)+".wav")
            SVec = processor(waveform, sampling_rate=sample_rate, return_tensors="pt").input_values
            del sample_rate,waveform
            SVec = SVec.to(torch.float32)
            SVec = SVec.to(device)
            with torch.no_grad():
                model_output = model(SVec)
                imp_SVec = model_output.projected_states.detach().cpu().numpy()
                # imp_SVec = normalize(imp_SVec)#正規化処理を行う場合以下を実行
                imp_SVec = Vecs_tolist(imp_SVec,processor_mode)
        else:
            imp_SVec = [0]*256
    
        if if_L == 1:
            query = test_queries[i]
    
        if if_G == 1:
            Wikidata_id = test_bids[i]
        
        input_vecs = concat_vecs(query,lang,Wikidata_id,imp_SVec)
        input_vecs_list.append(input_vecs)
    
    flg = ""
    if if_S == 1:
        flg += "S"
    if if_L == 1:
        flg += "L"
    if if_G == 1:
        flg += "G"
    
    print(flg)
    
    save_bin("new_data_cv_SLG/input_vecs_list_"+flg+"_"+num+".bin",input_vecs_list)

1837 / 1838SLG


### 検索実装

In [28]:
# num = str(1)
num = str(2)
# num = str(3)
# num = str(4)
# num = str(5)

In [29]:
from transformers import BertModel, BertTokenizer
en_model = BertModel.from_pretrained('models/en_model')
en_tokenizer = BertTokenizer.from_pretrained('models/en_tokenizer')
bid2Gvec=read_bin('new_data/bid2Gvec.bin')

In [30]:
from contextlib import redirect_stdout
from transformers import Wav2Vec2ForPreTraining,Wav2Vec2FeatureExtractor

with redirect_stdout(open(os.devnull, 'w')):#一時的に出力を無効化
    path = "./models/cv"+str(num)+"/"
    processor = Wav2Vec2FeatureExtractor.from_pretrained("patrickvonplaten/wav2vec2-base-v2")# Wav2Vec2Processor => Wav2Vec2FeatureExtractor
    model = Wav2Vec2ForPreTraining.from_pretrained(path)# Wav2Vec2ForCTC => Wav2Vec2ForPreTraining
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

Some weights of the model checkpoint at ./models/cv2/ were not used when initializing Wav2Vec2ForPreTraining: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForPreTraining from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForPreTraining from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForPreTraining were not initialized from the model checkpoint at ./models/cv2/ and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this mo

In [31]:
def multi_Search_topN_filechk(concat_vecs,input_v,file_path,i):
    max_sim=dict()
    c2 = 0
        
    for bid,v in concat_vecs:
        print('\r%d : %d / %d' %(i,c2, len(concat_vecs)), end='')
        c2 =  c2 + 1
        cos_sim_value = cos_sim(input_v,v)
        
        if bid not in max_sim:
            max_sim[bid] = cos_sim_value
        else:
            if max_sim[bid] < cos_sim_value:
                max_sim[bid] = cos_sim_value
    save_bin(file_path,max_sim)

In [32]:
import pandas as pd
test = pd.read_parquet('BirdModel_remove_test/data_cross-valid/test_'+num+'.parquet')#10293
test_index = list(test.index)

In [None]:
print(len(test_index))
for slg_i in range(5):
    SLG = [(0,1,0),(0,0,1),(0,1,1),(1,0,1),(1,1,0),(1,1,1)
    if_S,if_L,if_G = SLG[slg_i]

    flg = ""
    if if_S == 1:
        flg += "S"
    if if_L == 1:
        flg += "L"
    if if_G == 1:
        flg += "G"
    print(flg)
    
    input_vecs_list = read_bin("new_data_cv_SLG/input_vecs_list_"+flg+"_"+num+".bin")
    try:
        os.makedirs('csv_cv_'+flg+'_'+num)
    except:
        pass

    for i in range(len(test_index)):
        idx = test_index[i]
        file_path = 'csv_cv_'+flg+"_"+num+'/'+str(idx)+'.bin'
        if not os.path.exists(file_path):
            input_vecs = input_vecs_list[i]
            multi_Search_topN_filechk(en_concat_vecs_et,input_vecs,file_path,i)

1838
L
32 : 21456 / 88152

## 結果のソート

In [35]:
# prm = "loss_1_10"
# prm = "loss_1_100"
# prm = "step_3_2"
# prm = "step_3_2_norm"
prm = "cv"+num

In [36]:
import pickle

def save_bin(filename,data):
    with open(filename,'wb') as bf:
        pickle.dump(data,bf)

def read_bin(filename):
    with open(filename,'rb') as bf:
        bin_data = pickle.load(bf)
    return bin_data

In [37]:
import pandas as pd

# test_index = list(test.index)
# del test

In [38]:
# not p
data_sorted = []

for idx in test_index:
    data = read_bin('csv_cv'+num+'/'+str(idx)+'.bin')
    data_sorted.append(dict(sorted(data.items(), key=lambda item: item[1],reverse=True)))

In [39]:
save_bin('csv_cv'+num+'/data_sorted_'+prm+'.bin',data_sorted)

In [4]:
# all_l = []
# for i in range(20):
#     all_l.append(list(data_sorted[i].keys()))
#     all_l.append(list(data_sorted[i].values()))
    
# import pandas as pd

# today = "0729"
# df = pd.DataFrame(all_l)
# df.to_csv('output/output_'+today+'_'+prm+'.csv', index=False)

In [None]:
#p
data_p_sorted = []

for idx in test_index:
    data_p = read_bin('csv_p/'+str(idx)+'.bin')
    # data_p = read_bin('csv_cv'+prm+'_p/'+str(idx)+'.bin')
    print(idx)
    data_p_sorted.append(dict(sorted(data_p.items(), key=lambda item: item[1])))

In [None]:
save_bin('csv_cv'+num+'_p/data_p_sorted_'+prm+'.bin',data_p_sorted)

In [5]:
# all_l = []
# for i in range(20):
#     all_l.append(list(data_p_sorted[i].keys()))
#     all_l.append(list(data_p_sorted[i].values()))
    
# import pandas as pd

# today = "0714"
# df = pd.DataFrame(all_l)
# # df.to_csv('output_p_'+today+'_'+prm+'.csv', index=False)
# df.to_csv('output/output_p_'+today+'_'+prm+'.csv', index=False)

## 検索

In [133]:
# prmは一つ上のセクションにて定義

In [134]:
# data_sorted = read_bin('csv_cv'+prm+'/data_sorted_'+prm+'.bin')
# data_sorted = read_bin('csv_cv'+prm+'_p/data_p_sorted_'+prm+'.bin')

In [135]:
# data_sorted[3]

In [136]:
data_sorted_id = []
data_sorted_sim = []

for d in data_sorted:
    data_sorted_id.append(list(d.keys())[:400])
    data_sorted_sim.append(list(d.values())[:400])

In [137]:
data_sorted_sim[0][:10]

[0.05552816789009378,
 0.05483066575735448,
 0.05449225043376088,
 0.054418008033872,
 0.05427616965235937,
 0.0539751159009753,
 0.05388954977586201,
 0.053651935039797,
 0.05348710437880466,
 0.0534708726473036]

In [138]:
ontology = read_bin("new_data/ontology.pickle")

In [139]:
ontology_d = dict()

for d in ontology:
    ontology_d[d["id"]]=d

del ontology

In [140]:
data_sorted_name_en = []
data_sorted_name_ja = []

for id_l in data_sorted_id:
    en_l = []
    ja_l = []
    for a_id in id_l:
        d = ontology_d[a_id]
        en = ""
        ja = ""
        
        if d["en_name"] != None:
            en = d["en_name"]
        if d["ja_name"] != None:
            ja = d["ja_name"]
            
        if d["en_aliases"] != {}:
            en = en + ' | '+' | '.join(list(d["en_aliases"].values()))
            
        if d["ja_aliases"] != {}:
            ja = ja + ' | '+' | '.join(list(d["ja_aliases"].values()))
            
        en_l.append(en)
        ja_l.append(ja)
        
    data_sorted_name_en.append(en_l)
    data_sorted_name_ja.append(ja_l)

In [141]:
all_l = []
for i in range(20):
    all_l.append(data_sorted_name_en[i])
    all_l.append(data_sorted_name_ja[i])
    all_l.append(data_sorted_id[i])
    all_l.append(data_sorted_sim[i])
    all_l.append([])

In [142]:
df = pd.DataFrame(all_l)
df.to_csv('output_'+prm+'.csv', index=False)
# df.to_csv('output_p_'+prm+'.csv', index=False)

In [143]:
# for i in range(20):
#     for j in range(15):
#         print(data_sorted_id[j][i],end=" ")
#     print()
#     print()

## hit@k and MRR

検索結果の読み込み

In [144]:
# prmは２つ上のセクションにて定義済み

In [146]:
data_sorted = read_bin('csv/data_sorted_'+prm+'.bin')
# data_sorted_p = read_bin('csv_p/data_p_sorted_'+prm+'.bin')

In [147]:
def get_sorted_id(data_sorted):
    data_sorted_id = []
    for d in data_sorted:
        data_sorted_id.append(list(d.keys())[:400])
    return data_sorted_id

In [148]:
id_sorted = get_sorted_id(data_sorted)
# id_sorted_p = get_sorted_id(data_sorted_p)

In [151]:
# del data_sorted
# del data_sorted_p

正解文書

In [152]:
#20件
test_ans = ["Q195518","Q2669182","Q1589896","Q25700","Q1074291","Q27075595","Q235152","Q1272534","Q1074163","Q1270171","Q1071547","Q177798","Q493173","Q26650","Q495144","Q195518","Q1034960","Q890912","Q184820","Q862812"]	

検索

In [153]:
def ranks(id_sorted):
    ranks = []
    for i in range(20):
        if test_ans[i] in id_sorted[i]:
            index = id_sorted[i].index(test_ans[i])
            ranks.append(index+1)
        else:
            ranks.append(-1)
    return ranks

In [154]:
# id_sorted
# id_sorted_p

In [155]:
result_ranks = ranks(id_sorted)
# result_ranks_p = ranks(id_sorted_p)

In [156]:
print(" ".join(map(str, result_ranks)))
# print(" ".join(map(str, result_ranks_p)))

134 59 191 227 211 132 400 150 87 126 73 38 86 96 204 141 58 20 285 195


In [91]:
def calculate_mrr(ranks):
    reciprocal_ranks = [1 / (rank + 1) for rank in ranks]
    mrr = sum(reciprocal_ranks) / len(ranks)
    print(mrr)

In [157]:
calculate_mrr(result_ranks)
# calculate_mrr(result_ranks_p)

0.011015160092636093


In [None]:
# prm = "step_3_2"

    # 112 73 144 201 233 46 400 175 46 112 66 55 103 129 192 98 26 28 244 182
    # 320 258 400 345 400 224 400 400 326 400 400 400 195 400 296 400 244 223 176 46

    # 0.012225021479883039
    # 0.004187707828634603

# prm = "step_3_2_norm"

    # 134 59 191 227 211 132 400 150 87 126 73 38 86 96 204 141 58 20 285 195

    # 0.011015160092636093

In [None]:
# 音声の何パーセントがなくなっているのかは

In [None]:
# ランダムシャッフルした検索順位での比較実験

In [None]:
# epoch = 80

In [None]:
# 0.12以上

In [None]:
# setp数はシャッフルのあと
# それよりは

In [None]:
# 20%でテスト（5ホールド前提）

In [None]:
# クロスバリデーション