## Импорт библиотек

In [4]:
import pandas as pd
import numpy as np
from tqdm.notebook import trange
import torch
import transformers as ppb
import warnings
warnings.filterwarnings('ignore')

Определим, возможно ли провести расчеты на GPU

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

Загрузим очищенные и предобработанные данные

In [6]:
df = pd.read_feather('df_train_BERT.feather')
df.info()
df.head()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 440535 entries, 0 to 440534
Data columns (total 6 columns):
 #   Column            Non-Null Count   Dtype 
---  ------            --------------   ----- 
 0   post_index        440535 non-null  int64 
 1   post_text         440535 non-null  object
 2   comment_text      440535 non-null  object
 3   comment_score     440535 non-null  int64 
 4   post_text_fix     440535 non-null  object
 5   comment_text_fix  440535 non-null  object
dtypes: int64(2), object(4)
memory usage: 20.2+ MB


Unnamed: 0,post_index,post_text,comment_text,comment_score,post_text_fix,comment_text_fix
0,0,How many summer Y Combinator fundees decided n...,Going back to school is not identical with giv...,0,how many summer y combinator fundees decided n...,going back to school is not identical with giv...
1,0,How many summer Y Combinator fundees decided n...,There will invariably be those who do not see ...,1,how many summer y combinator fundees decided n...,there will invariably be those who do not see ...
2,0,How many summer Y Combinator fundees decided n...,For me school is a way to be connected to what...,2,how many summer y combinator fundees decided n...,for me school is a way to be connected to what...
3,0,How many summer Y Combinator fundees decided n...,I guess it really depends on how hungry you ar...,3,how many summer y combinator fundees decided n...,i guess it really depends on how hungry you ar...
4,0,How many summer Y Combinator fundees decided n...,I know pollground decided to go back to school...,4,how many summer y combinator fundees decided n...,i know pollground decided to go back to school...


Добавим разделиители в виде: 
* "p:" - post ;
* "c:" - comment.

In [7]:
df['post_plus_text'] = 'p: ' + df['post_text_fix'] + \
    ' c: ' + df['comment_text_fix']
df['post_plus_text']

0         p: how many summer y combinator fundees decide...
1         p: how many summer y combinator fundees decide...
2         p: how many summer y combinator fundees decide...
3         p: how many summer y combinator fundees decide...
4         p: how many summer y combinator fundees decide...
                                ...                        
440530    p: pay your rent with a credit or debit card. ...
440531    p: pay your rent with a credit or debit card. ...
440532    p: pay your rent with a credit or debit card. ...
440533    p: pay your rent with a credit or debit card. ...
440534    p: pay your rent with a credit or debit card. ...
Name: post_plus_text, Length: 440535, dtype: object

Модель BERT работает с текстовыми данными, длинна которых не превышает 512 символов. Отберем из данных только те посты + комментарии, длинна которых не превышает 512.

In [10]:
# Определим объекты длина которых больше 512
df['is_valid'] = df['post_plus_text'].str.len() > 512

In [11]:
# Определим индексы этих объектов
valid_idx = df.groupby('post_index')['is_valid'].sum()[lambda x: x < 1].index

In [12]:
valid_idx

Int64Index([    1,     2,     6,     8,     9,    15,    16,    20,    23,
               24,
            ...
            88046, 88050, 88055, 88059, 88061, 88066, 88079, 88094, 88104,
            88106],
           dtype='int64', name='post_index', length=18562)

Определим датафрейм dt с выбранными объектами.

In [15]:
dt = df.loc[df['post_index'].isin(valid_idx)].copy()

Импортируем предобученную модель DistilBERT и токенезатор.

In [11]:
model_class, tokenizer_class, pretrained_weights = (
    ppb.DistilBertModel,
    ppb.DistilBertTokenizer,
    'distilbert-base-uncased')

In [12]:
tokenizer = tokenizer_class.from_pretrained(pretrained_weights)
model = model_class.from_pretrained(pretrained_weights)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Применим токенезатор к датафрейму.

In [13]:
%%time
tokenized = dt['post_plus_text']\
    .apply(lambda x: tokenizer.encode(x, add_special_tokens=True))

CPU times: total: 1min 56s
Wall time: 1min 55s


Определим padding и attention mask

In [14]:
max_len = 0
for i in tokenized.values:
    if len(i) > max_len:
        max_len = len(i)

padded = np.array([i + [0]*(max_len-len(i)) for i in tokenized.values])

In [15]:
np.array(padded).shape

(92810, 428)

In [16]:
attention_mask = np.where(padded != 0, 1, 0)
attention_mask.shape

(92810, 428)

In [17]:
padded.shape[0]

92810

Получим embendings с помощью предобученной модели.

In [18]:
%%time
batch_size = 15 # для примера возьмем такой батч, где будет всего две строки датасета
embeddings = [] 
for i in trange(padded.shape[0] // batch_size):
    batch = torch.LongTensor(padded[batch_size*i:batch_size*(i+1)]).to(device) # закидываем тензор на GPU
    attention_mask_batch = torch.LongTensor(
        attention_mask[batch_size*i:batch_size*(i+1)]).to(device)

    with torch.inference_mode():
        model.to(device)
        batch_embeddings = model(batch, attention_mask=attention_mask_batch)

    embeddings.append(batch_embeddings[0][:,0,:].cpu().numpy()) # перевод обратно на проц, чтобы в нумпай кинуть
    del batch
    del attention_mask_batch
    del batch_embeddings

features = np.concatenate(embeddings) 

  0%|          | 0/6187 [00:00<?, ?it/s]

CPU times: total: 49min 6s
Wall time: 49min 7s


Результатом являются embendings размером n x 768

In [19]:
pd.DataFrame(features)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,758,759,760,761,762,763,764,765,766,767
0,0.144441,-0.275229,0.121034,-0.022717,-0.011386,-0.371438,-0.053915,0.416983,0.018654,-0.390223,...,0.014957,-0.223352,0.093404,-0.217284,0.143401,0.041880,-0.052693,-0.105857,0.474585,0.331695
1,0.084572,-0.088871,0.032786,0.102412,-0.013723,-0.360138,0.129133,0.175513,0.043969,-0.081129,...,-0.316360,-0.084931,0.211140,-0.289426,0.011242,0.042296,0.154023,-0.127552,0.650891,0.313405
2,-0.154102,-0.205968,-0.072974,-0.098393,-0.020828,-0.174417,0.107194,0.315201,0.073962,0.037438,...,0.007686,-0.075235,0.043659,-0.063159,0.172889,0.033672,-0.046778,-0.272708,0.584986,0.433110
3,-0.005600,0.012485,-0.133596,-0.120348,-0.041818,-0.272322,0.045477,0.303554,0.250218,-0.067512,...,-0.286925,-0.268241,0.150624,-0.180598,0.171506,0.122144,0.172196,-0.105090,0.736107,0.368727
4,-0.133143,-0.068215,-0.133230,0.009282,-0.072491,-0.192820,0.087804,0.277646,0.196886,-0.114941,...,-0.184534,-0.263188,0.163898,-0.065655,0.090101,0.059097,0.109243,-0.154215,0.651898,0.307381
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
92800,-0.107862,-0.205408,0.015032,-0.274493,-0.107417,-0.285679,0.317828,0.270155,-0.032156,-0.366551,...,0.223418,-0.003935,-0.180761,-0.329440,0.290537,0.162301,-0.183496,-0.018543,0.320735,0.310992
92801,-0.194246,-0.165839,-0.062190,-0.038309,-0.041377,-0.222664,0.204285,0.449760,-0.224287,-0.314920,...,0.246309,-0.109552,-0.022632,-0.191103,0.288921,0.248423,-0.043909,-0.064538,0.287570,0.453056
92802,-0.070702,-0.280414,-0.057889,-0.059047,-0.088623,-0.550471,0.412163,0.576670,0.139625,-0.455293,...,0.152376,-0.239245,-0.081165,-0.318592,0.315074,0.132671,-0.335718,-0.131204,0.365121,0.310466
92803,-0.022111,-0.155452,-0.057162,-0.118525,-0.297256,-0.521813,0.352698,0.551069,0.029096,-0.353160,...,0.196285,-0.064299,-0.083500,-0.434804,0.435073,0.050470,-0.216222,-0.065211,0.382141,0.380226


In [21]:
pd.DataFrame(features).to_pickle('Train_embendings.pickle')

Сохраним датафрейм в формате .pickle, предварительно добавив индекс поста

In [3]:
data = pd.read_pickle('Train_embendings.pickle')
data

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,758,759,760,761,762,763,764,765,766,767
0,0.144441,-0.275229,0.121034,-0.022717,-0.011386,-0.371438,-0.053915,0.416983,0.018654,-0.390223,...,0.014957,-0.223352,0.093404,-0.217284,0.143401,0.041880,-0.052693,-0.105857,0.474585,0.331695
1,0.084572,-0.088871,0.032786,0.102412,-0.013723,-0.360138,0.129133,0.175513,0.043969,-0.081129,...,-0.316360,-0.084931,0.211140,-0.289426,0.011242,0.042296,0.154023,-0.127552,0.650891,0.313405
2,-0.154102,-0.205968,-0.072974,-0.098393,-0.020828,-0.174417,0.107194,0.315201,0.073962,0.037438,...,0.007686,-0.075235,0.043659,-0.063159,0.172889,0.033672,-0.046778,-0.272708,0.584986,0.433110
3,-0.005600,0.012485,-0.133596,-0.120348,-0.041818,-0.272322,0.045477,0.303554,0.250218,-0.067512,...,-0.286925,-0.268241,0.150624,-0.180598,0.171506,0.122144,0.172196,-0.105090,0.736107,0.368727
4,-0.133143,-0.068215,-0.133230,0.009282,-0.072491,-0.192820,0.087804,0.277646,0.196886,-0.114941,...,-0.184534,-0.263188,0.163898,-0.065655,0.090101,0.059097,0.109243,-0.154215,0.651898,0.307381
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
92800,-0.107862,-0.205408,0.015032,-0.274493,-0.107417,-0.285679,0.317828,0.270155,-0.032156,-0.366551,...,0.223418,-0.003935,-0.180761,-0.329440,0.290537,0.162301,-0.183496,-0.018543,0.320735,0.310992
92801,-0.194246,-0.165839,-0.062190,-0.038309,-0.041377,-0.222664,0.204285,0.449760,-0.224287,-0.314920,...,0.246309,-0.109552,-0.022632,-0.191103,0.288921,0.248423,-0.043909,-0.064538,0.287570,0.453056
92802,-0.070702,-0.280414,-0.057889,-0.059047,-0.088623,-0.550471,0.412163,0.576670,0.139625,-0.455293,...,0.152376,-0.239245,-0.081165,-0.318592,0.315074,0.132671,-0.335718,-0.131204,0.365121,0.310466
92803,-0.022111,-0.155452,-0.057162,-0.118525,-0.297256,-0.521813,0.352698,0.551069,0.029096,-0.353160,...,0.196285,-0.064299,-0.083500,-0.434804,0.435073,0.050470,-0.216222,-0.065211,0.382141,0.380226


In [26]:
data['post_index'] = dt['post_index'][:-5].values
data

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,759,760,761,762,763,764,765,766,767,post_index
0,0.144441,-0.275229,0.121034,-0.022717,-0.011386,-0.371438,-0.053915,0.416983,0.018654,-0.390223,...,-0.223352,0.093404,-0.217284,0.143401,0.041880,-0.052693,-0.105857,0.474585,0.331695,1
1,0.084572,-0.088871,0.032786,0.102412,-0.013723,-0.360138,0.129133,0.175513,0.043969,-0.081129,...,-0.084931,0.211140,-0.289426,0.011242,0.042296,0.154023,-0.127552,0.650891,0.313405,1
2,-0.154102,-0.205968,-0.072974,-0.098393,-0.020828,-0.174417,0.107194,0.315201,0.073962,0.037438,...,-0.075235,0.043659,-0.063159,0.172889,0.033672,-0.046778,-0.272708,0.584986,0.433110,1
3,-0.005600,0.012485,-0.133596,-0.120348,-0.041818,-0.272322,0.045477,0.303554,0.250218,-0.067512,...,-0.268241,0.150624,-0.180598,0.171506,0.122144,0.172196,-0.105090,0.736107,0.368727,1
4,-0.133143,-0.068215,-0.133230,0.009282,-0.072491,-0.192820,0.087804,0.277646,0.196886,-0.114941,...,-0.263188,0.163898,-0.065655,0.090101,0.059097,0.109243,-0.154215,0.651898,0.307381,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
92800,-0.107862,-0.205408,0.015032,-0.274493,-0.107417,-0.285679,0.317828,0.270155,-0.032156,-0.366551,...,-0.003935,-0.180761,-0.329440,0.290537,0.162301,-0.183496,-0.018543,0.320735,0.310992,88104
92801,-0.194246,-0.165839,-0.062190,-0.038309,-0.041377,-0.222664,0.204285,0.449760,-0.224287,-0.314920,...,-0.109552,-0.022632,-0.191103,0.288921,0.248423,-0.043909,-0.064538,0.287570,0.453056,88104
92802,-0.070702,-0.280414,-0.057889,-0.059047,-0.088623,-0.550471,0.412163,0.576670,0.139625,-0.455293,...,-0.239245,-0.081165,-0.318592,0.315074,0.132671,-0.335718,-0.131204,0.365121,0.310466,88104
92803,-0.022111,-0.155452,-0.057162,-0.118525,-0.297256,-0.521813,0.352698,0.551069,0.029096,-0.353160,...,-0.064299,-0.083500,-0.434804,0.435073,0.050470,-0.216222,-0.065211,0.382141,0.380226,88104


In [31]:
data.index = dt[:-5].index
data

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,759,760,761,762,763,764,765,766,767,post_index
5,0.144441,-0.275229,0.121034,-0.022717,-0.011386,-0.371438,-0.053915,0.416983,0.018654,-0.390223,...,-0.223352,0.093404,-0.217284,0.143401,0.041880,-0.052693,-0.105857,0.474585,0.331695,1
6,0.084572,-0.088871,0.032786,0.102412,-0.013723,-0.360138,0.129133,0.175513,0.043969,-0.081129,...,-0.084931,0.211140,-0.289426,0.011242,0.042296,0.154023,-0.127552,0.650891,0.313405,1
7,-0.154102,-0.205968,-0.072974,-0.098393,-0.020828,-0.174417,0.107194,0.315201,0.073962,0.037438,...,-0.075235,0.043659,-0.063159,0.172889,0.033672,-0.046778,-0.272708,0.584986,0.433110,1
8,-0.005600,0.012485,-0.133596,-0.120348,-0.041818,-0.272322,0.045477,0.303554,0.250218,-0.067512,...,-0.268241,0.150624,-0.180598,0.171506,0.122144,0.172196,-0.105090,0.736107,0.368727,1
9,-0.133143,-0.068215,-0.133230,0.009282,-0.072491,-0.192820,0.087804,0.277646,0.196886,-0.114941,...,-0.263188,0.163898,-0.065655,0.090101,0.059097,0.109243,-0.154215,0.651898,0.307381,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
440520,-0.107862,-0.205408,0.015032,-0.274493,-0.107417,-0.285679,0.317828,0.270155,-0.032156,-0.366551,...,-0.003935,-0.180761,-0.329440,0.290537,0.162301,-0.183496,-0.018543,0.320735,0.310992,88104
440521,-0.194246,-0.165839,-0.062190,-0.038309,-0.041377,-0.222664,0.204285,0.449760,-0.224287,-0.314920,...,-0.109552,-0.022632,-0.191103,0.288921,0.248423,-0.043909,-0.064538,0.287570,0.453056,88104
440522,-0.070702,-0.280414,-0.057889,-0.059047,-0.088623,-0.550471,0.412163,0.576670,0.139625,-0.455293,...,-0.239245,-0.081165,-0.318592,0.315074,0.132671,-0.335718,-0.131204,0.365121,0.310466,88104
440523,-0.022111,-0.155452,-0.057162,-0.118525,-0.297256,-0.521813,0.352698,0.551069,0.029096,-0.353160,...,-0.064299,-0.083500,-0.434804,0.435073,0.050470,-0.216222,-0.065211,0.382141,0.380226,88104


In [32]:
data.to_pickle('Train_embendings_with_index.pickle')