In [8]:
import pandas as pd
import numpy as np
import torch
import warnings
warnings.filterwarnings('ignore')

В этом блокноте переведем тексты из таблицы с новостями в эмбеддинги.

In [11]:
table_train = pd.read_csv('data/DJIA_Table(train).csv')
table_test = pd.read_csv('data/DJIA_Table(test).csv')
table_train['Date'] = pd.to_datetime(table_train['Date'], dayfirst=True)
table_test['Date'] = pd.to_datetime(table_test['Date'], dayfirst=True)
table_train.head(1)

Unnamed: 0,Date,Open,High,Low,Close,Volume,Adj Close
0,2015-12-31,17590.66016,17590.66016,17421.16016,17425.0293,93690000,17425.0293


In [13]:
news_train = pd.read_csv('data/Combined_News_DJIA(train).csv')
news_test = pd.read_csv('data/Combined_News_DJIA(test).csv')
news_train['Date'] = pd.to_datetime(news_train['Date'])
news_test['Date'] = pd.to_datetime(news_test['Date'], dayfirst=True)
news_train.head(1)

Unnamed: 0,Date,Label,Top1,Top2,Top3,Top4,Top5,Top6,Top7,Top8,...,Top16,Top17,Top18,Top19,Top20,Top21,Top22,Top23,Top24,Top25
0,2008-08-08,0,"b""Georgia 'downs two Russian warplanes' as cou...",b'BREAKING: Musharraf to be impeached.',b'Russia Today: Columns of troops roll into So...,b'Russian tanks are moving towards the capital...,"b""Afghan children raped with 'impunity,' U.N. ...",b'150 Russian tanks have entered South Ossetia...,"b""Breaking: Georgia invades South Ossetia, Rus...","b""The 'enemy combatent' trials are nothing but...",...,b'Georgia Invades South Ossetia - if Russia ge...,b'Al-Qaeda Faces Islamist Backlash',"b'Condoleezza Rice: ""The US would not act to p...",b'This is a busy day: The European Union has ...,"b""Georgia will withdraw 1,000 soldiers from Ir...",b'Why the Pentagon Thinks Attacking Iran is a ...,b'Caucasus in crisis: Georgia invades South Os...,b'Indian shoe manufactory - And again in a se...,b'Visitors Suffering from Mental Illnesses Ban...,"b""No Help for Mexico's Kidnapping Surge"""


Чтобы выбрать новости, для которых есть данные котировок, таблицы с котировками и новостями соединяем по дате.   
Помним, что в обучающей выборке были пропуски, поэтому заполним их нулями.   
Сортируем по дате, оставляем только тексты, считаем среднее количество слов для токенайзера.

In [102]:
train = table_train[['Date', 'Open']].merge(news_train, on='Date', how='left')
train = train.fillna(0)
test = table_test[['Date', 'Open']].merge(news_test, on='Date', how='left')
train, test = train.sort_values('Date'), test.sort_values('Date')
train_text = train.drop(['Date', 'Open', 'Label'], axis=1).astype('str')
test_text = test.drop(['Date', 'Open', 'Label'], axis=1).astype('str')
train_text.map(lambda x: len(x.split())).mean().mean()

17.712420826623728

Будем использовать модель Bert.

In [104]:
from transformers import BertTokenizer
from transformers import BertModel  # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel

tokenizer, model = BertTokenizer.from_pretrained('bert-base-cased'), BertModel.from_pretrained('bert-base-cased')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

train_text = train_text.to_numpy().flatten()
test_text = test_text.to_numpy().flatten()

tokens_train = tokenizer.batch_encode_plus(
    list(train_text),
    max_length = 18,
    padding = 'max_length',
    truncation = True, add_special_tokens=True, return_token_type_ids=False)
tokens_test = tokenizer.batch_encode_plus(
    list(test_text),
    max_length = 18,
    padding = 'max_length',
    truncation = True, add_special_tokens=True, return_token_type_ids=False)

Из тензоров создаем датасет и даталоадер.

In [106]:
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler

train_seq = torch.tensor(tokens_train['input_ids'])
train_mask = torch.tensor(tokens_train['attention_mask'])
train_id = torch.tensor(np.array(range(train_text.shape[0])))
train_data = TensorDataset(train_seq, train_mask, train_id)
train_sampler = SequentialSampler(train_data)
train_loader = DataLoader(train_data, sampler = train_sampler, batch_size = 32)

test_seq = torch.tensor(tokens_test['input_ids'])
test_mask = torch.tensor(tokens_test['attention_mask'])
test_id = torch.tensor(np.array(range(test_text.shape[0])))
test_data = TensorDataset(test_seq, test_mask, test_id)
test_sampler = SequentialSampler(test_data)
test_loader = DataLoader(test_data, sampler = test_sampler, batch_size = 32)

Получаем эмбеддинги от модели и сохраняем.

In [108]:
from tqdm import tqdm

@torch.inference_mode()
def get_embeddings(model, loader):
    model.eval()
    
    total_embeddings = []
    for batch in tqdm(loader):
        embeddings = model(batch[0].to(device), batch[1].to(device))['last_hidden_state'][:, 0, :]
        total_embeddings.append(embeddings.cpu())

    return torch.cat(total_embeddings, dim=0)

In [110]:
embeddings = get_embeddings(model, train_loader)
torch.save(embeddings, 'train_embeddings.pt')
embeddings = get_embeddings(model, test_loader)
torch.save(embeddings, 'test_embeddings.pt')

100%|██████████████████████████████████████████████████████████████████████████████| 1456/1456 [28:07<00:00,  1.16s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 99/99 [01:30<00:00,  1.10it/s]
