In [1]:
from sentence_transformers import SentenceTransformer
from datasets import load_dataset

import numpy as np
import pandas as pd
import torch
import os

In [2]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="2,3"
print(torch.cuda.device_count())

2


In [3]:
def get_pd_utterances_speaker(data):
    '''
        parsing data
    '''
    utterances = []

    for obj in data:
        utterances += obj['utterance']

    speakers = []

    for obj in data:
        speakers += obj['speaker']   
        
    intents = []

    for obj in data:
        for x in obj['dialogue_acts']:
            intents.append(x['dialog_act']['act_type'])    
    
    df = pd.DataFrame()
    
    df['utterance'] = utterances
    df['speaker'] = speakers
    df['intent'] = intents
    
    return df

In [4]:
dataset = load_dataset("multi_woz_v22")
        
# train data
train_dataset = dataset['train']
train_dataset = train_dataset['turns']

# validation data
validation_dataset = dataset['validation']
validation_dataset = validation_dataset['turns']

# test data
test_dataset = dataset['test']
test_dataset = test_dataset['turns']

# get uttrances from data
train_df = get_pd_utterances_speaker(train_dataset)
test_df = get_pd_utterances_speaker(test_dataset)
validation_df = get_pd_utterances_speaker(validation_dataset)

No config specified, defaulting to: multi_woz_v22/v2.2_active_only
Found cached dataset multi_woz_v22 (/home/ledneva/.cache/huggingface/datasets/multi_woz_v22/v2.2_active_only/2.2.0/6719c8b21478299411a0c6fdb7137c3ebab2e6425129af831687fb7851c69eb5)


  0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
df = pd.concat([train_df, test_df, validation_df], ignore_index=True)

In [6]:
model = SentenceTransformer('sentence-transformers/all-distilroberta-v1')
model = model.to('cuda')

Downloading:   0%|          | 0.00/737 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/10.3k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/653 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/15.7k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/329M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/333 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/13.1k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/349 [00:00<?, ?B/s]

In [7]:
def get_embedding(sentence):
    return model.encode(sentence)

embeddings = df['utterance'].apply(get_embedding)
# 15 минут

In [8]:
embeddings = pd.DataFrame(np.row_stack(embeddings))

In [9]:
embeddings

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,758,759,760,761,762,763,764,765,766,767
0,-0.007226,-0.033957,-0.009085,-0.001372,-0.022856,-0.048311,0.041502,0.054538,0.071536,0.008875,...,-0.003501,0.022292,-0.002492,0.009908,-0.043196,-0.002803,0.040020,0.020825,-0.011816,-0.034958
1,0.018751,-0.011421,-0.017195,0.030763,0.011367,-0.008825,0.035039,0.003447,0.065989,-0.047363,...,-0.061483,-0.015553,0.002013,0.003989,0.041973,0.013563,0.017065,0.086267,-0.011211,0.018178
2,-0.009338,-0.012030,0.003013,0.006073,-0.019766,-0.040937,0.017565,-0.019893,0.066497,0.015676,...,-0.034179,0.010144,-0.027549,0.000214,0.045478,-0.022040,0.017160,0.082377,0.029150,-0.003910
3,0.003873,0.046686,-0.032505,0.020259,0.003932,-0.047845,0.034129,0.049410,-0.006170,0.003214,...,-0.013998,0.009503,-0.008501,0.007559,0.019120,0.027074,0.051466,0.011480,-0.001995,0.020767
4,-0.029017,0.014379,-0.002468,-0.030916,-0.026700,-0.065000,0.045521,-0.039768,0.094795,0.024008,...,-0.026093,0.033713,-0.020920,0.019808,0.023427,-0.001191,-0.018421,0.056101,0.019062,-0.012382
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
143039,0.012976,-0.058718,0.016624,0.040777,-0.059632,-0.004792,-0.020179,0.058529,0.028063,-0.012712,...,0.047848,-0.026998,0.021815,0.008680,0.017850,0.034079,-0.094322,-0.028001,-0.063287,-0.005191
143040,-0.022427,-0.006912,0.032676,-0.037562,-0.042634,-0.024707,0.030946,-0.017636,0.010411,0.058166,...,-0.039170,0.032023,-0.015683,-0.032754,0.008142,0.030408,-0.075982,-0.012069,-0.016874,-0.000491
143041,-0.021554,-0.042587,0.029049,-0.054671,-0.031003,-0.011591,0.010116,-0.014084,0.048402,-0.001506,...,-0.051333,0.020171,-0.032218,-0.022407,-0.031142,-0.009772,-0.061097,-0.002960,-0.017953,0.019558
143042,0.018267,-0.000391,0.008706,-0.016725,0.019559,-0.021098,-0.011251,0.045212,0.051478,-0.004916,...,0.005725,0.043488,-0.002871,0.041669,0.043827,-0.044472,0.082341,0.030984,0.005650,0.002012


In [10]:
embeddings.to_csv("distilroberta_embeddings.csv")