自然言語処理関連の特徴量作成用notebook。

Output Files:

ALL_w2v.csv, ALL_svd_fastText.csv, ALL_pca_fastText.csv, ALL_svd_SciBERT.csv, 

w2v.csv, svd_fastText.csv, pca_fastText.csv, svd_SciBERT.csv, tfidf.csv

# Libraries

In [None]:
!pip install texthero



In [None]:
!pip install fasttext



In [None]:
!pip install transformers



In [None]:
import warnings
warnings.filterwarnings('ignore')

import random
import os
import time
import matplotlib.pyplot as plt
import nltk
import numpy as np
import pandas as pd
import seaborn as sns
import re
import itertools
import collections
import pickle
import texthero as hero
import torch
import transformers
from transformers import BertTokenizer
from contextlib import contextmanager
from pathlib import Path
from typing import Optional
from fasttext import load_model
from gensim.models import word2vec, KeyedVectors
from geopy.geocoders import Nominatim
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import KFold
from sklearn.decomposition import TruncatedSVD
from sklearn.decomposition import PCA
from sklearn.feature_extraction.text import TfidfVectorizer
from tqdm import tqdm

tqdm.pandas()

# Utils

In [None]:
@contextmanager
def timer(name: str):
    t0 = time.time()
    print(f"[{name}] start")
    yield
    msg = f"[{name}] done in {time.time() - t0:.0f} s"
    print(msg)
    
    
def set_seed(seed=42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    
SEED = 42
set_seed(SEED)

# Data Loading

In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd "drive/My Drive/probspace_citations"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/My Drive/probspace_citations


In [None]:
DATADIR = Path('input')
OUTDIR = Path('features')

In [None]:
with timer("Load ALL_df"):
    ALL_df = pd.read_pickle(DATADIR / 'ALL_df.pkl') # 全てのtrain/testを結合したデータ

[Load ALL_df] start
[Load ALL_df] done in 31 s


In [None]:
with timer("Load df"):
    df = pd.read_pickle(DATADIR / 'df.pkl') # citesの値が与えられたtrainとtestを結合したデータ

[Load df] start
[Load df] done in 4 s


# ALL_df Features

citesの無いtrainデータ含めて作成した特徴量

## Word2Vec

In [None]:
def make_id_feat(x):
    try: return re.search('(.*)/(.*)', x).group(1)
    except: return x[:4]
def make_doi_feat(x):
    try: return x.split('/')[0]
    except: return np.nan

ALL_df['id_feat'] = ALL_df['id'].apply(make_id_feat)
ALL_df['doi_feat'] = ALL_df['doi'].apply(make_doi_feat)
ALL_df['categories_space_list'] = ALL_df['categories'].apply(lambda x: list(x.split()))
ALL_df['categories_comma_list'] = ALL_df['categories'].apply(lambda x: list(re.split('[. ]', x)))

In [None]:
id_doi_df = pd.concat([ALL_df['id'],(ALL_df['id_feat'] + ' ' +  ALL_df['doi_feat']).apply(lambda x: list(x.split()))],axis=1).rename(columns={0: 'target'})
categories_space_df = pd.concat([ALL_df['id'],ALL_df['categories_space_list']],axis=1).rename(columns={'categories_space_list': 'target'})
categories_comma_df = pd.concat([ALL_df['id'],ALL_df['categories_comma_list']],axis=1).rename(columns={'categories_comma_list': 'target'})
id_doi_categories_space_df = pd.concat([ALL_df['id'],id_doi_df['target'] + categories_space_df['target']],axis=1).rename(columns={0: 'target'})

In [None]:
model_size = {
    'id_doi': 16,
    'categories_space': 16,
    'categories_comma': 32,
    'id_doi_categories_space': 64,
}

n_iter = 100
w2v_dfs = []
for _df, _df_name in zip(
        [id_doi_df, categories_space_df, categories_comma_df, id_doi_categories_space_df],
        ['id_doi', 'categories_space', 'categories_comma', 'id_doi_categories_space']
    ):

    with timer(f"Creating w2v for {_df_name}"):
        # Word2Vecの学習
        w2v_model = word2vec.Word2Vec(_df['target'].values.tolist(),
                                    size=model_size[_df_name],
                                    min_count=1,
                                    window=100,
                                    workers=1,
                                    iter=n_iter)

    with timer(f"Getting document vector for {_df_name}"):
        # 各文章ごとにそれぞれの単語をベクトル表現に直し、平均をとって文章ベクトルにする
        sentence_vectors = _df['target'].progress_apply(lambda x: np.mean([w2v_model.wv[e] for e in x], axis=0))
        sentence_vectors = np.vstack([x for x in sentence_vectors])
        sentence_vector_df = pd.DataFrame(sentence_vectors,
                                        columns=[f"ALL_{_df_name}_w2v_{i}" for i in range(model_size[_df_name])])
        sentence_vector_df.index = _df['id']
        w2v_dfs.append(sentence_vector_df)

[Creating w2v for id_doi] start


  1%|          | 6881/910608 [00:00<00:25, 35497.37it/s]

[Creating w2v for id_doi] done in 210 s
[Getting document vector for id_doi] start


100%|██████████| 910608/910608 [00:25<00:00, 36067.55it/s]


[Getting document vector for id_doi] done in 27 s
[Creating w2v for categories_space] start


  0%|          | 2049/910608 [00:00<00:44, 20488.55it/s]

[Creating w2v for categories_space] done in 163 s
[Getting document vector for categories_space] start


100%|██████████| 910608/910608 [00:24<00:00, 37791.06it/s]


[Getting document vector for categories_space] done in 26 s
[Creating w2v for categories_comma] start


  0%|          | 3712/910608 [00:00<00:24, 37104.91it/s]

[Creating w2v for categories_comma] done in 193 s
[Getting document vector for categories_comma] start


100%|██████████| 910608/910608 [00:22<00:00, 40328.89it/s]


[Getting document vector for categories_comma] done in 25 s
[Creating w2v for id_doi_categories_space] start


  0%|          | 3520/910608 [00:00<00:25, 35196.76it/s]

[Creating w2v for id_doi_categories_space] done in 362 s
[Getting document vector for id_doi_categories_space] start


100%|██████████| 910608/910608 [00:24<00:00, 37431.16it/s]


[Getting document vector for id_doi_categories_space] done in 27 s


In [None]:
ALL_w2v = df[['id']]
for i in range(4):
    ALL_w2v = pd.merge(ALL_w2v, w2v_dfs[i], on='id', how='left')

In [None]:
ALL_w2v.to_csv(OUTDIR / 'ALL_w2v.csv', index=False)

## fastText features

fastTextの学習済モデルを用いた特徴量作成

In [None]:
model_en = load_model("cc.en.300.bin")



In [None]:
def text_cleaning(raw_text):
    clean_text = hero.clean(raw_text, pipeline=[
        hero.preprocessing.fillna, # 欠損埋め
        hero.preprocessing.lowercase, # 小文字への統一
        hero.preprocessing.remove_digits, # 数字の削除
        hero.preprocessing.remove_punctuation, # 句読点の削除
        hero.preprocessing.remove_diacritics, # ダイアクリティカルマーク（発音区別符号。àやéなど）の削除
        hero.preprocessing.remove_stopwords, # ストップワードの除去
        hero.preprocessing.remove_whitespace, # スペースの削除
        hero.preprocessing.remove_brackets,
    ])
    return clean_text

In [None]:
ALL_df['abstract'] = text_cleaning(ALL_df['abstract'])
ALL_df['title'] = text_cleaning(ALL_df['title'])
ALL_df['comments'] = text_cleaning(ALL_df['comments'])

In [None]:
ALL_svd_fastText = pd.DataFrame(ALL_df["id"])
ALL_pca_fastText = pd.DataFrame(ALL_df["id"])
ALL_svd_fastText

Unnamed: 0,id
0,hep-ph/9902295
1,1403.7138
2,1405.5857
3,1807.01034
4,1905.05921
...,...
910603,1210.4112
910604,1701.03465
910605,1709.10428
910606,gr-qc/9803020


### TruncatedSVD

In [None]:
with timer("ALL SVD abstract fastText"):
    X = ALL_df["abstract"].progress_apply(lambda x: model_en.get_sentence_vector(x.replace("\n", "")))
    X = np.stack(X.values)
    svd = TruncatedSVD(n_components=64, random_state=SEED)
    X = svd.fit_transform(X)
    X_df = pd.DataFrame(X, columns=[f"ALL_SVD_abstract_fastText_{i}" for i in range(X.shape[1])])
    X_df["id"] = ALL_df["id"]
    ALL_svd_fastText = ALL_svd_fastText.merge(X_df, on="id", how="left")

  0%|          | 570/910608 [00:00<06:25, 2361.70it/s]

[ALL SVD abstract fastText] start


100%|██████████| 910608/910608 [04:17<00:00, 3540.78it/s]


[ALL SVD abstract fastText] done in 278 s


In [None]:
with timer("ALL SVD title fastText"):
    X = ALL_df["title"].progress_apply(lambda x: model_en.get_sentence_vector(x.replace("\n", "")))
    X = np.stack(X.values)
    svd = TruncatedSVD(n_components=32, random_state=SEED)
    X = svd.fit_transform(X)
    X_df = pd.DataFrame(X, columns=[f"ALL_SVD_title_fastText_{i}" for i in range(X.shape[1])])
    X_df["id"] = ALL_df["id"]
    ALL_svd_fastText = ALL_svd_fastText.merge(X_df, on="id", how="left")

  0%|          | 4275/910608 [00:00<00:43, 20980.43it/s]

[ALL SVD title fastText] start


100%|██████████| 910608/910608 [00:39<00:00, 23157.36it/s]


[ALL SVD title fastText] done in 54 s


In [None]:
with timer("ALL SVD comments fastText"):
    X = ALL_df["comments"].progress_apply(lambda x: model_en.get_sentence_vector(x.replace("\n", "")))
    X = np.stack(X.values)
    svd = TruncatedSVD(n_components=16, random_state=SEED)
    X = svd.fit_transform(X)
    X_df = pd.DataFrame(X, columns=[f"ALL_SVD_comments_fastText_{i}" for i in range(X.shape[1])])
    X_df["id"] = ALL_df["id"]
    ALL_svd_fastText = ALL_svd_fastText.merge(X_df, on="id", how="left")

  0%|          | 3531/910608 [00:00<00:25, 35305.15it/s]

[ALL SVD comments fastText] start


100%|██████████| 910608/910608 [00:25<00:00, 36104.23it/s]


[ALL SVD comments fastText] done in 36 s


In [None]:
ALL_svd_fastText

Unnamed: 0,id,ALL_SVD_abstract_fastText_0,ALL_SVD_abstract_fastText_1,ALL_SVD_abstract_fastText_2,ALL_SVD_abstract_fastText_3,ALL_SVD_abstract_fastText_4,ALL_SVD_abstract_fastText_5,ALL_SVD_abstract_fastText_6,ALL_SVD_abstract_fastText_7,ALL_SVD_abstract_fastText_8,ALL_SVD_abstract_fastText_9,ALL_SVD_abstract_fastText_10,ALL_SVD_abstract_fastText_11,ALL_SVD_abstract_fastText_12,ALL_SVD_abstract_fastText_13,ALL_SVD_abstract_fastText_14,ALL_SVD_abstract_fastText_15,ALL_SVD_abstract_fastText_16,ALL_SVD_abstract_fastText_17,ALL_SVD_abstract_fastText_18,ALL_SVD_abstract_fastText_19,ALL_SVD_abstract_fastText_20,ALL_SVD_abstract_fastText_21,ALL_SVD_abstract_fastText_22,ALL_SVD_abstract_fastText_23,ALL_SVD_abstract_fastText_24,ALL_SVD_abstract_fastText_25,ALL_SVD_abstract_fastText_26,ALL_SVD_abstract_fastText_27,ALL_SVD_abstract_fastText_28,ALL_SVD_abstract_fastText_29,ALL_SVD_abstract_fastText_30,ALL_SVD_abstract_fastText_31,ALL_SVD_abstract_fastText_32,ALL_SVD_abstract_fastText_33,ALL_SVD_abstract_fastText_34,ALL_SVD_abstract_fastText_35,ALL_SVD_abstract_fastText_36,ALL_SVD_abstract_fastText_37,ALL_SVD_abstract_fastText_38,...,ALL_SVD_title_fastText_8,ALL_SVD_title_fastText_9,ALL_SVD_title_fastText_10,ALL_SVD_title_fastText_11,ALL_SVD_title_fastText_12,ALL_SVD_title_fastText_13,ALL_SVD_title_fastText_14,ALL_SVD_title_fastText_15,ALL_SVD_title_fastText_16,ALL_SVD_title_fastText_17,ALL_SVD_title_fastText_18,ALL_SVD_title_fastText_19,ALL_SVD_title_fastText_20,ALL_SVD_title_fastText_21,ALL_SVD_title_fastText_22,ALL_SVD_title_fastText_23,ALL_SVD_title_fastText_24,ALL_SVD_title_fastText_25,ALL_SVD_title_fastText_26,ALL_SVD_title_fastText_27,ALL_SVD_title_fastText_28,ALL_SVD_title_fastText_29,ALL_SVD_title_fastText_30,ALL_SVD_title_fastText_31,ALL_SVD_comments_fastText_0,ALL_SVD_comments_fastText_1,ALL_SVD_comments_fastText_2,ALL_SVD_comments_fastText_3,ALL_SVD_comments_fastText_4,ALL_SVD_comments_fastText_5,ALL_SVD_comments_fastText_6,ALL_SVD_comments_fastText_7,ALL_SVD_comments_fastText_8,ALL_SVD_comments_fastText_9,ALL_SVD_comments_fastText_10,ALL_SVD_comments_fastText_11,ALL_SVD_comments_fastText_12,ALL_SVD_comments_fastText_13,ALL_SVD_comments_fastText_14,ALL_SVD_comments_fastText_15
0,hep-ph/9902295,0.330791,-0.040681,-0.036119,0.013659,-0.016790,0.087835,0.048660,0.039931,0.031968,-0.024824,0.003808,0.042545,-0.005603,-0.050364,0.022270,0.033589,0.019019,0.101123,0.000044,0.013913,-0.020021,-0.023491,-0.030688,-0.015847,0.032511,0.028007,-0.001736,0.004960,-0.010098,0.045267,0.027386,-0.009925,-0.021486,-0.044623,0.013781,-0.013956,-0.039226,0.010656,0.008050,...,0.064884,-0.009866,0.061122,0.042080,-0.121580,-0.005575,0.059404,-0.057897,0.059150,-0.018576,-0.070641,0.018830,0.096663,-0.139469,-0.035567,-0.047924,0.010380,0.011918,0.026518,0.082314,-0.014170,-0.053461,-0.009393,0.009566,2.489214e-01,1.867214e-01,5.922201e-03,-4.138458e-02,0.117438,0.101735,0.010537,0.030196,-0.008867,-0.028349,-0.014525,-0.031311,-0.022195,-0.021503,-0.012576,0.020626
1,1403.7138,0.363122,0.029175,-0.061232,-0.003548,0.012257,-0.057539,0.032653,-0.012248,0.045867,-0.011151,0.008271,0.055627,0.030200,0.031697,0.025476,0.025927,0.001187,0.001148,-0.011435,-0.046066,-0.003005,-0.004352,0.001753,0.019525,-0.009865,0.005342,-0.000365,-0.000570,0.009256,0.010934,-0.005433,-0.000334,0.005619,-0.003421,-0.006042,-0.001026,0.032405,-0.018932,-0.005383,...,0.030021,-0.007932,-0.031351,-0.072215,-0.029012,0.006704,-0.016811,0.007148,-0.031679,-0.050298,0.044642,0.034284,-0.008809,0.004279,0.057215,0.043713,-0.015236,-0.024644,-0.000256,0.021189,-0.042301,-0.012851,-0.020254,-0.001741,4.625807e-01,1.755179e-01,1.996905e-02,1.759009e-01,-0.030985,-0.021186,-0.013829,0.027517,-0.017860,0.048070,-0.044942,0.009950,0.056432,0.068888,0.023235,0.017437
2,1405.5857,0.349692,0.035704,-0.011180,0.033973,-0.006480,0.029008,-0.046002,-0.008810,-0.014794,-0.011983,0.010968,0.002813,0.026181,-0.017254,-0.008956,-0.002908,-0.015610,-0.005823,-0.009052,0.014030,0.007281,0.018405,0.002123,0.002863,0.004967,-0.010812,0.000266,-0.000521,-0.014107,0.002901,0.005273,0.002242,-0.000694,-0.007664,-0.016045,-0.020220,0.022278,-0.002251,-0.003424,...,-0.005732,0.008806,0.002111,-0.026526,0.019286,0.002025,-0.014652,0.011609,0.013377,0.028202,-0.001844,0.051041,0.002964,0.037172,-0.009589,0.012611,-0.033493,0.026556,-0.010884,-0.002346,-0.024532,0.024549,-0.034807,0.002468,2.663621e-01,2.050515e-01,4.030778e-02,-9.591454e-02,-0.011152,0.061577,-0.050832,-0.003069,-0.011843,-0.009176,-0.016062,0.036345,0.010105,-0.024213,-0.001937,-0.026670
3,1807.01034,0.362087,-0.005666,-0.035338,-0.015742,-0.021473,0.010435,-0.012916,-0.008198,-0.003925,0.015096,-0.000610,-0.016530,-0.029869,0.009424,0.008627,-0.016949,0.023073,0.027168,-0.002743,-0.001642,0.018823,0.004214,-0.000848,-0.022814,0.013625,0.010001,-0.037796,0.001498,-0.001814,-0.007115,-0.011093,-0.003159,-0.008315,0.026198,-0.001329,0.000190,-0.004834,0.009012,-0.000548,...,0.046117,-0.021214,-0.009864,0.038760,0.053239,0.031428,-0.010799,-0.026263,-0.036329,-0.047076,-0.033318,0.042232,-0.002615,-0.072132,0.080986,-0.054732,-0.023784,0.085740,0.021546,0.023753,0.001067,-0.014426,-0.025000,0.003958,6.615096e-01,-1.414772e-01,7.334760e-01,6.600955e-03,0.028152,-0.015025,0.018883,0.008865,-0.003711,-0.029585,-0.035047,-0.000242,0.017281,-0.002563,-0.000390,-0.001503
4,1905.05921,0.367315,-0.069502,-0.030228,0.019579,-0.043252,-0.005621,0.008671,-0.008865,0.001567,0.004722,0.031792,0.013475,-0.005978,0.042248,-0.013319,0.025463,-0.007599,0.009399,0.008288,-0.006536,-0.009018,0.021290,-0.011798,-0.002491,-0.018998,0.021806,-0.000102,0.008271,0.007988,-0.012246,0.000185,-0.013066,-0.007627,0.007977,-0.008081,-0.011934,0.026948,0.009257,-0.005914,...,0.000249,0.001728,0.022186,0.049517,-0.000759,0.096961,0.017068,-0.009351,0.017197,-0.026852,-0.031919,-0.006205,-0.004080,0.054149,-0.016639,-0.036510,0.046630,0.029566,0.028383,-0.001707,-0.002072,-0.035128,-0.012119,0.029055,1.676206e-12,2.677602e-11,2.318584e-09,6.815106e-07,0.000001,0.000010,-0.000003,-0.000002,0.000003,0.000008,-0.000002,0.000006,-0.000009,0.000006,0.000014,0.000005
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
910603,1210.4112,0.332291,0.127288,0.043539,0.032968,-0.051357,0.058937,-0.018546,0.001195,0.019642,0.016690,-0.012374,-0.009929,-0.013056,-0.011729,-0.005097,0.041661,0.012513,0.016216,0.007717,0.017680,-0.014941,0.031982,0.028239,-0.023249,-0.002966,0.013337,-0.031382,0.010637,-0.004468,0.024369,0.002987,-0.004356,0.013225,-0.008216,0.006025,0.004557,-0.001738,-0.000653,-0.007761,...,0.060669,0.021829,0.033957,0.039530,0.009117,-0.051582,-0.030634,-0.028661,0.081542,0.080760,-0.016722,0.026461,0.041409,-0.032970,-0.011108,0.028472,-0.073479,0.024853,-0.011068,0.028837,-0.007738,0.016145,-0.012336,0.006911,7.457914e-01,-1.806703e-01,1.120440e-01,1.806329e-02,0.001429,0.000785,0.003305,-0.005845,-0.001281,-0.018253,-0.014578,-0.006847,0.001964,-0.005046,-0.001860,-0.003397
910604,1701.03465,0.283563,0.094405,-0.038168,0.101497,0.065103,-0.011192,0.040075,-0.032620,-0.017636,-0.011220,0.006241,-0.009018,-0.007224,0.005731,0.042061,-0.003559,-0.014912,-0.025615,0.028341,0.001485,0.014747,-0.003887,0.008522,-0.023313,-0.025224,-0.004481,-0.009921,0.009283,0.013465,-0.014544,0.011850,-0.020512,0.005230,0.012494,0.004155,-0.012511,-0.010866,0.015184,0.003866,...,-0.019132,-0.020609,-0.080366,-0.021578,0.013457,0.006356,0.007575,-0.048747,-0.044910,-0.035894,-0.001025,-0.060074,0.008764,-0.013223,0.005352,-0.004465,-0.021526,-0.043921,0.038958,0.041951,0.034302,-0.023034,-0.003574,-0.005756,4.340814e-01,1.632236e-01,2.695745e-02,1.323287e-01,0.015347,-0.061250,-0.021638,-0.012232,-0.018616,0.011477,0.090792,0.035409,-0.018433,0.017594,0.030256,0.060864
910605,1709.10428,0.345093,0.017260,0.057072,0.006777,-0.005121,-0.030490,-0.021627,0.032796,0.005612,0.008500,-0.015465,-0.016583,0.007779,0.018799,0.017693,0.012854,0.017518,-0.023898,0.017344,-0.000950,0.007909,0.014940,-0.005505,-0.027073,0.006648,0.005475,-0.012111,0.022093,-0.012773,-0.010269,-0.007269,0.008070,-0.004874,-0.004906,-0.003040,0.004226,-0.015147,0.001271,0.010307,...,0.012624,0.000461,-0.062092,0.010613,0.047554,-0.031582,0.021917,-0.021160,0.004175,0.110188,0.039952,-0.031892,-0.062594,-0.012739,0.034873,-0.015253,-0.037143,0.000590,-0.044777,-0.020679,0.005291,0.011539,-0.008917,0.028518,8.300738e-01,-2.198623e-01,-5.093879e-01,2.951659e-02,-0.025332,0.016612,-0.012250,-0.020556,0.001120,-0.006954,0.005887,-0.013473,-0.013328,-0.007536,-0.003398,-0.005320
910606,gr-qc/9803020,0.360455,-0.058124,0.057613,-0.019414,0.096685,-0.011575,-0.008305,0.019542,0.008696,-0.010449,0.010774,-0.046026,0.031459,0.018637,0.016230,-0.004454,-0.011340,-0.007492,-0.030532,0.015442,-0.038508,-0.003543,0.043072,0.030364,0.008828,0.005538,-0.022054,-0.026604,-0.022139,0.010313,0.011223,-0.012056,-0.028589,-0.008258,-0.013974,0.015846,-0.021865,-0.022467,0.005979,...,-0.015363,-0.063604,-0.076294,-0.047457,-0.034851,0.054332,-0.005424,-0.103521,-0.033232,-0.041258,-0.011034,-0.010167,-0.015650,-0.001727,-0.059196,0.028081,-0.074098,-0.061237,-0.038331,0.006055,-0.041777,-0.024815,-0.003778,0.022226,3.375283e-01,1.764231e-01,7.257260e-02,4.557867e-02,0.012872,0.068631,-0.034886,0.043715,0.032794,-0.004187,-0.005500,-0.070158,0.013084,-0.048618,-0.026590,-0.031455


In [None]:
ALL_svd_fastText.to_csv(OUTDIR / 'ALL_svd_fastText.csv', index=False)

### PCA

In [None]:
with timer("ALL PCA abstract fastText"):
    X = ALL_df["abstract"].progress_apply(lambda x: model_en.get_sentence_vector(x.replace("\n", "")))
    X = np.stack(X.values)
    pca = PCA(n_components=64, random_state=SEED)
    X = pca.fit_transform(X)
    X_df = pd.DataFrame(X, columns=[f"PCA_abstract_fastText_{i}" for i in range(X.shape[1])])
    X_df["id"] = ALL_df["id"]
    ALL_pca_fastText = ALL_pca_fastText.merge(X_df, on="id", how="left")

  0%|          | 566/910608 [00:00<06:20, 2392.74it/s]

[ALL PCA abstract fastText] start


100%|██████████| 910608/910608 [04:28<00:00, 3389.19it/s]


[ALL PCA abstract fastText] done in 289 s


In [None]:
with timer("ALL PCA title fastText"):
    X = ALL_df["title"].progress_apply(lambda x: model_en.get_sentence_vector(x.replace("\n", "")))
    X = np.stack(X.values)
    pca = PCA(n_components=32, random_state=SEED)
    X = pca.fit_transform(X)
    X_df = pd.DataFrame(X, columns=[f"PCA_title_fastText_{i}" for i in range(X.shape[1])])
    X_df["id"] = ALL_df["id"]
    ALL_pca_fastText = ALL_pca_fastText.merge(X_df, on="id", how="left")

  1%|          | 4621/910608 [00:00<00:38, 23501.15it/s]

[ALL PCA title fastText] start


100%|██████████| 910608/910608 [00:41<00:00, 22125.75it/s]


[ALL PCA title fastText] done in 54 s


In [None]:
with timer("ALL PCA comments fastText"):
    X = ALL_df["comments"].progress_apply(lambda x: model_en.get_sentence_vector(x.replace("\n", "")))
    X = np.stack(X.values)
    pca = PCA(n_components=16, random_state=SEED)
    X = pca.fit_transform(X)
    X_df = pd.DataFrame(X, columns=[f"PCA_comments_fastText_{i}" for i in range(X.shape[1])])
    X_df["id"] = ALL_df["id"]
    ALL_pca_fastText = ALL_pca_fastText.merge(X_df, on="id", how="left")

  0%|          | 2384/910608 [00:00<00:38, 23839.97it/s]

[ALL PCA comments fastText] start


100%|██████████| 910608/910608 [00:25<00:00, 35604.55it/s]


[ALL PCA comments fastText] done in 39 s


In [None]:
ALL_pca_fastText

Unnamed: 0,id,PCA_abstract_fastText_0,PCA_abstract_fastText_1,PCA_abstract_fastText_2,PCA_abstract_fastText_3,PCA_abstract_fastText_4,PCA_abstract_fastText_5,PCA_abstract_fastText_6,PCA_abstract_fastText_7,PCA_abstract_fastText_8,PCA_abstract_fastText_9,PCA_abstract_fastText_10,PCA_abstract_fastText_11,PCA_abstract_fastText_12,PCA_abstract_fastText_13,PCA_abstract_fastText_14,PCA_abstract_fastText_15,PCA_abstract_fastText_16,PCA_abstract_fastText_17,PCA_abstract_fastText_18,PCA_abstract_fastText_19,PCA_abstract_fastText_20,PCA_abstract_fastText_21,PCA_abstract_fastText_22,PCA_abstract_fastText_23,PCA_abstract_fastText_24,PCA_abstract_fastText_25,PCA_abstract_fastText_26,PCA_abstract_fastText_27,PCA_abstract_fastText_28,PCA_abstract_fastText_29,PCA_abstract_fastText_30,PCA_abstract_fastText_31,PCA_abstract_fastText_32,PCA_abstract_fastText_33,PCA_abstract_fastText_34,PCA_abstract_fastText_35,PCA_abstract_fastText_36,PCA_abstract_fastText_37,PCA_abstract_fastText_38,...,PCA_title_fastText_8,PCA_title_fastText_9,PCA_title_fastText_10,PCA_title_fastText_11,PCA_title_fastText_12,PCA_title_fastText_13,PCA_title_fastText_14,PCA_title_fastText_15,PCA_title_fastText_16,PCA_title_fastText_17,PCA_title_fastText_18,PCA_title_fastText_19,PCA_title_fastText_20,PCA_title_fastText_21,PCA_title_fastText_22,PCA_title_fastText_23,PCA_title_fastText_24,PCA_title_fastText_25,PCA_title_fastText_26,PCA_title_fastText_27,PCA_title_fastText_28,PCA_title_fastText_29,PCA_title_fastText_30,PCA_title_fastText_31,PCA_comments_fastText_0,PCA_comments_fastText_1,PCA_comments_fastText_2,PCA_comments_fastText_3,PCA_comments_fastText_4,PCA_comments_fastText_5,PCA_comments_fastText_6,PCA_comments_fastText_7,PCA_comments_fastText_8,PCA_comments_fastText_9,PCA_comments_fastText_10,PCA_comments_fastText_11,PCA_comments_fastText_12,PCA_comments_fastText_13,PCA_comments_fastText_14,PCA_comments_fastText_15
0,hep-ph/9902295,-0.034009,-0.027467,0.034190,-0.020338,0.071076,0.079460,0.032222,0.030883,-0.028977,0.002730,0.042287,-0.007282,-0.045794,0.044792,0.033063,0.030043,0.080763,0.042274,-0.000531,-0.025177,-0.025730,-0.030648,-0.016472,-0.027330,0.032695,-0.001768,0.006116,-0.008904,0.050653,-0.010952,-0.007194,-0.018085,-0.049698,0.008482,-0.016584,-0.039945,0.008781,-0.003764,-0.001783,...,0.005420,0.069683,-0.007154,-0.125480,0.016297,0.053856,-0.059930,-0.064618,-0.023876,-0.064458,-0.016769,0.091364,-0.139449,-0.042287,-0.053730,0.005880,0.012420,-0.030252,0.076496,-0.014310,-0.056411,-0.014962,0.008311,0.018551,-0.201534,0.001345,0.004868,0.001520,0.107258,0.106533,0.006339,0.031008,-0.009703,-0.027321,-0.014041,-0.030959,-0.022415,-0.022086,-0.012576,0.020606
1,1403.7138,0.017416,-0.065326,-0.005350,0.011330,-0.059585,0.013561,-0.025806,0.046607,-0.009489,0.006072,0.056326,0.030148,0.031944,0.021936,0.008136,-0.012914,-0.021111,0.048298,0.010481,0.005039,0.001667,0.007354,0.019298,0.004105,0.001155,-0.006795,-0.001424,0.010394,0.009537,0.006879,-0.001400,0.004880,0.001531,-0.009286,-0.002389,0.033047,-0.020034,0.006441,0.011188,...,-0.025143,-0.053273,0.053872,-0.022409,-0.032081,-0.024660,-0.001255,0.034079,-0.050565,0.043055,0.031418,0.004060,0.006553,0.063791,0.044858,-0.018906,-0.020528,0.000528,0.022580,-0.039969,-0.003958,-0.020788,-0.012126,0.012905,0.003238,0.007290,0.168955,-0.143630,-0.047562,-0.023402,-0.012295,0.027360,-0.017133,0.047482,-0.047766,0.008870,0.056183,0.067945,0.022838,0.017862
2,1405.5857,0.035972,-0.014313,0.024727,-0.002210,0.038460,-0.041943,0.005480,-0.016978,-0.013649,0.009739,0.003327,0.026520,-0.017472,-0.007434,0.001166,-0.012303,0.004495,-0.015865,-0.015668,0.006241,0.016396,-0.000110,0.002951,-0.003692,-0.008267,0.005770,-0.000035,-0.014247,0.004500,-0.004276,0.002478,-0.000211,-0.006015,-0.018001,-0.018378,0.022288,-0.000473,0.000130,-0.001520,...,0.006040,-0.006189,0.022209,0.021526,-0.009381,-0.014020,0.013280,-0.011525,0.019687,-0.017334,0.051596,0.014490,0.035619,-0.006435,0.020131,-0.028405,0.030227,0.012058,0.000188,-0.021680,0.025151,-0.032486,-0.001568,0.005272,-0.192350,0.033331,0.047226,0.094547,0.009096,0.061350,-0.052661,-0.002674,-0.012370,-0.008573,-0.016611,0.035951,0.010027,-0.025089,-0.000798,-0.026301
3,1807.01034,-0.014406,-0.035689,-0.012672,-0.021794,0.012426,-0.010690,-0.002568,-0.004844,0.009108,0.002426,-0.017346,-0.031128,0.010881,0.009443,-0.012926,0.031022,0.024278,0.004976,-0.002914,0.019531,0.003790,0.001419,-0.023271,-0.019987,0.003556,-0.034740,0.002448,-0.003976,-0.010186,0.008688,-0.003819,-0.009053,0.024529,0.001852,0.001979,-0.004948,0.009650,-0.000263,-0.002344,...,-0.011828,0.009000,-0.045194,0.049043,0.027353,-0.027878,-0.032967,0.033605,-0.049699,-0.034501,0.018442,0.012730,-0.073521,0.080517,-0.052126,-0.022866,0.082802,-0.022792,0.027357,0.004263,-0.007704,-0.024275,-0.002455,-0.034481,0.316522,0.736910,-0.002334,-0.015386,0.023976,-0.013744,0.018854,0.008790,-0.003871,-0.030864,-0.033901,-0.000500,0.017192,-0.002702,-0.000151,-0.001436
4,1905.05921,-0.070930,-0.020321,0.031987,-0.043099,-0.004906,0.002867,-0.011990,0.001564,0.002014,0.032079,0.013438,-0.005995,0.041985,-0.008722,0.025766,-0.010451,0.004551,0.010158,0.011279,-0.007735,0.021066,-0.014210,-0.002348,0.019557,0.016494,-0.011119,0.007445,0.009527,-0.011069,-0.005643,-0.013567,-0.007857,0.008593,-0.008180,-0.010332,0.026841,0.010683,0.002603,-0.007200,...,0.000579,0.022222,-0.009911,-0.012911,0.102361,-0.026650,-0.023673,-0.009846,-0.034388,-0.030978,-0.007578,-0.013369,0.054246,-0.023215,-0.042902,0.037185,0.018399,-0.029081,-0.005414,-0.004432,-0.036396,-0.016759,0.032596,0.008735,-0.325181,0.014027,-0.226825,-0.039978,-0.082232,-0.017465,0.008874,-0.001602,0.002135,-0.010499,0.008478,0.002236,0.000302,0.008110,-0.001871,-0.003651
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
910603,1210.4112,0.133973,0.028699,0.007632,-0.045762,0.066094,-0.014883,0.003652,0.020723,0.015113,-0.010005,-0.010447,-0.014314,-0.009097,0.014792,0.044820,0.012771,0.013147,-0.007318,-0.000093,-0.016408,0.030627,0.022621,-0.022637,0.007391,0.008047,-0.028301,0.011479,-0.007612,0.022583,0.009239,-0.001851,0.014976,-0.014523,0.007938,0.004780,-0.001550,0.000137,0.004887,-0.010233,...,0.034789,0.047242,-0.030029,0.010564,-0.026469,0.003459,0.001984,-0.097511,0.080995,-0.024475,0.018810,0.050127,-0.035413,-0.008757,0.037015,-0.060393,0.037969,0.016724,0.034447,-0.003986,0.009694,-0.009895,0.009354,0.015988,0.415275,0.117345,-0.037165,-0.019585,-0.005797,0.000221,0.003479,-0.005881,-0.001407,-0.019087,-0.013589,-0.006861,0.001981,-0.005024,-0.001696,-0.003456
910604,1701.03465,0.115725,-0.040556,0.099291,0.067216,-0.015735,0.028967,-0.041041,-0.019318,-0.013279,0.005859,-0.009206,-0.007472,0.007447,0.036699,-0.019265,-0.018528,-0.016483,-0.025298,0.023389,0.013553,-0.005347,0.007081,-0.022998,0.020887,-0.012141,-0.014059,0.008842,0.015760,-0.010077,-0.016375,-0.019862,0.005561,0.011115,0.004606,-0.011922,-0.011338,0.015608,-0.002722,0.000716,...,-0.018322,-0.070201,-0.025039,0.017602,-0.018603,-0.007720,-0.056761,0.039168,-0.024419,0.018044,-0.063864,-0.002233,-0.011557,0.003535,-0.005851,-0.024209,-0.034964,-0.036542,0.039356,0.034770,-0.021470,-0.003300,-0.012244,0.001180,-0.018197,0.016498,0.130955,-0.118090,0.004304,-0.060502,-0.020987,-0.012106,-0.018917,0.015438,0.089844,0.036038,-0.018470,0.019264,0.027829,0.061051
910605,1709.10428,0.024717,0.054867,0.000512,-0.004952,-0.028139,-0.019456,0.034873,0.007476,0.010852,-0.014589,-0.016627,0.007251,0.020274,0.019966,0.005585,0.010695,-0.026283,-0.012987,0.015211,0.008154,0.013625,-0.006607,-0.027177,-0.005915,0.004683,-0.009789,0.022909,-0.014075,-0.012123,0.004192,0.007623,-0.004930,-0.004121,-0.003284,0.004457,-0.015007,0.001332,-0.008871,-0.001960,...,-0.001261,-0.053148,-0.033767,0.048233,-0.007773,0.034634,-0.016210,-0.013245,0.113437,0.032430,0.003925,-0.069595,-0.012003,0.036055,-0.011010,-0.042924,0.004117,0.047509,-0.023557,0.004089,0.012444,-0.008874,0.027583,-0.024284,0.514029,-0.502220,-0.071996,-0.023796,-0.035466,0.014185,-0.011925,-0.020574,0.001089,-0.007270,0.006664,-0.013306,-0.013294,-0.007383,-0.003192,-0.005461
910606,gr-qc/9803020,-0.055093,0.062854,-0.016323,0.093787,-0.017120,0.004282,0.025686,0.006886,-0.015613,0.010655,-0.045633,0.031599,0.019849,0.014361,-0.009362,-0.010480,-0.000911,-0.007384,-0.034510,-0.037932,0.001250,0.046033,0.030461,-0.010984,0.002130,-0.019740,-0.025845,-0.022176,0.013734,-0.009362,-0.011483,-0.027742,-0.009867,-0.013207,0.016623,-0.021790,-0.022401,-0.007051,-0.002828,...,-0.067639,-0.073333,-0.000211,-0.029757,-0.010334,-0.050858,-0.115470,0.025796,-0.036306,-0.003709,-0.017472,-0.010318,-0.002004,-0.059530,0.040175,-0.082095,-0.050590,0.044526,0.006910,-0.038668,-0.028425,-0.003322,0.024498,-0.012224,-0.113451,0.064682,0.079460,-0.044107,0.000705,0.067225,-0.036137,0.044024,0.032594,-0.004639,-0.004655,-0.070161,0.013259,-0.049731,-0.024896,-0.031172


In [None]:
ALL_pca_fastText.to_csv(OUTDIR / 'ALL_pca_fastText.csv', index=False)

## SciBERT

In [None]:
class BertSequenceVectorizer:
    def __init__(self):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model_name = 'scibert_scivocab_uncased' # modelを保存しているpathを指定
        self.tokenizer = BertTokenizer.from_pretrained(self.model_name)
        self.bert_model = transformers.BertModel.from_pretrained(self.model_name)
        self.bert_model = self.bert_model.to(self.device)
        self.max_len = 128
        print(self.device)


    def vectorize(self, sentence : str) -> np.array:
        inp = self.tokenizer.encode(sentence)
        len_inp = len(inp)

        if len_inp >= self.max_len:
            inputs = inp[:self.max_len]
            masks = [1] * self.max_len
        else:
            inputs = inp + [0] * (self.max_len - len_inp)
            masks = [1] * len_inp + [0] * (self.max_len - len_inp)

        inputs_tensor = torch.tensor([inputs], dtype=torch.long).to(self.device)
        masks_tensor = torch.tensor([masks], dtype=torch.long).to(self.device)

        bert_out = self.bert_model(inputs_tensor, masks_tensor)
        seq_out, pooled_out = bert_out['last_hidden_state'], bert_out['pooler_output']

        if torch.cuda.is_available():    
            return seq_out[0][0].cpu().detach().numpy()
        else:
            return seq_out[0][0].detach().numpy()

BSV = BertSequenceVectorizer()

cuda


In [None]:
ALL_svd_SciBERT = pd.DataFrame(ALL_df["id"])
ALL_svd_SciBERT

Unnamed: 0,id
0,hep-ph/9902295
1,1403.7138
2,1405.5857
3,1807.01034
4,1905.05921
...,...
910603,1210.4112
910604,1701.03465
910605,1709.10428
910606,gr-qc/9803020


In [None]:
with timer("ALL SVD abstract bert"):
    bert = ALL_df['abstract'].fillna('nan').progress_apply(lambda x: BSV.vectorize(x))
    svd = TruncatedSVD(n_components=32, random_state=SEED+1)
    X = np.stack(bert.values)
    X = svd.fit_transform(X)
    X_df = pd.DataFrame(X, columns=[f"ALL_scibert_abstract_{i}" for i in range(X.shape[1])])
    X_df["id"] = ALL_df["id"]
    ALL_svd_SciBERT = ALL_svd_SciBERT.merge(X_df, on="id", how="left")

  0%|          | 0/910608 [00:00<?, ?it/s]

[ALL SVD abstract bert] start


100%|██████████| 910608/910608 [3:53:05<00:00, 65.11it/s]


[ALL SVD abstract bert] done in 14011 s


In [None]:
with timer("ALL SVD title bert"):
    bert = ALL_df['title'].fillna('nan').progress_apply(lambda x: BSV.vectorize(x))
    svd = TruncatedSVD(n_components=16, random_state=SEED+1)
    X = np.stack(bert.values)
    X = svd.fit_transform(X)
    X_df = pd.DataFrame(X, columns=[f"ALL_scibert_title_{i}" for i in range(X.shape[1])])
    X_df["id"] = ALL_df["id"]
    ALL_svd_SciBERT = ALL_svd_SciBERT.merge(X_df, on="id", how="left")

  0%|          | 8/910608 [00:00<3:12:02, 79.03it/s]

[ALL SVD title bert] start


100%|██████████| 910608/910608 [3:04:41<00:00, 82.17it/s]


[ALL SVD title bert] done in 11109 s


In [None]:
ALL_svd_SciBERT.to_csv(OUTDIR / 'ALL_svd_SciBERT.csv', index=False)

# df Features

citesのあるtrain及びtestのみを使用した特徴量作成

## Word2Vec

In [None]:
def make_id_feat(x):
    try: return re.search('(.*)/(.*)', x).group(1)
    except: return x[:4]
def make_doi_feat(x):
    try: return x.split('/')[0]
    except: return np.nan

df['id_feat'] = df['id'].apply(make_id_feat)
df['doi_feat'] = df['doi'].apply(make_doi_feat)
df['categories_space_list'] = df['categories'].apply(lambda x: list(x.split()))
df['categories_comma_list'] = df['categories'].apply(lambda x: list(re.split('[. ]', x)))
df.authors_parsed = df.authors_parsed.apply(lambda x: [[y[0]+' '+y[1], y[-1]] for y in x])
df['authors_list'] = df.authors_parsed.apply(lambda x: [i for i in list(itertools.chain.from_iterable(x)) if i!=''])

In [None]:
id_df = df[['id', 'id_feat']].fillna(' ').rename(columns={'id_feat': 'target'})
doi_df = df[['id', 'doi_feat']].fillna(' ').rename(columns={'doi_feat': 'target'})
submitter_df = df[['id', 'submitter']].fillna(' ').rename(columns={'submitter': 'target'})
categories_df = df[['id', 'categories']].fillna(' ').rename(columns={'categories': 'target'})
authors_df = df[['id', 'authors_list']].fillna(' ').rename(columns={'authors_list': 'target'})
id_doi_df = pd.concat([df['id'],(df['id_feat'] + ' ' +  df['doi_feat']).apply(lambda x: list(x.split()))],axis=1).fillna(' ').rename(columns={0: 'target'})
categories_space_df = pd.concat([df['id'],df['categories_space_list']],axis=1).fillna(' ').rename(columns={'categories_space_list': 'target'})
categories_comma_df = pd.concat([df['id'],df['categories_comma_list']],axis=1).fillna(' ').rename(columns={'categories_comma_list': 'target'})
id_doi_categories_space_df = pd.concat([df['id'],id_doi_df['target'] + categories_space_df['target']],axis=1).fillna(' ').rename(columns={0: 'target'})

In [None]:
model_size = {
    'id': 10,
    'doi': 10,
    'submitter': 10,
    'categories': 10,
    'authors': 10,
    'id_doi': 10,
    'categories_space': 10,
    'categories_comma': 15,
    'id_doi_categories_space': 30,
}

n_iter = 100
w2v_dfs = []
for _df, _df_name in zip(
        [id_df, doi_df, submitter_df, categories_df, authors_df, id_doi_df, categories_space_df, categories_comma_df, id_doi_categories_space_df],
        ['id', 'doi', 'submitter', 'categories', 'authors', 'id_doi', 'categories_space', 'categories_comma', 'id_doi_categories_space']
    ):

    with timer(f"Creating w2v for {_df_name}"):
        # Word2Vecの学習
        w2v_model = word2vec.Word2Vec(_df['target'].values.tolist(),
                                    size=model_size[_df_name],
                                    min_count=1,
                                    window=100,
                                    workers=1,
                                    iter=n_iter)

    with timer(f"Getting document vector for {_df_name}"):
        # 各文章ごとにそれぞれの単語をベクトル表現に直し、平均をとって文章ベクトルにする
        sentence_vectors = _df['target'].progress_apply(lambda x: np.mean([w2v_model.wv[e] for e in x], axis=0))
        sentence_vectors = np.vstack([x for x in sentence_vectors])
        sentence_vector_df = pd.DataFrame(sentence_vectors,
                                        columns=[f"{_df_name}_w2v_{i}" for i in range(model_size[_df_name])])
        sentence_vector_df.index = _df['id']
        w2v_dfs.append(sentence_vector_df)

[Creating w2v for id] start


  5%|▍         | 3655/74201 [00:00<00:01, 36538.54it/s]

[Creating w2v for id] done in 17 s
[Getting document vector for id] start


100%|██████████| 74201/74201 [00:02<00:00, 34877.99it/s]


[Getting document vector for id] done in 2 s
[Creating w2v for doi] start


  4%|▍         | 3128/74201 [00:00<00:02, 31276.38it/s]

[Creating w2v for doi] done in 17 s
[Getting document vector for doi] start


100%|██████████| 74201/74201 [00:02<00:00, 30969.09it/s]


[Getting document vector for doi] done in 3 s
[Creating w2v for submitter] start


  3%|▎         | 2329/74201 [00:00<00:03, 23283.14it/s]

[Creating w2v for submitter] done in 48 s
[Getting document vector for submitter] start


100%|██████████| 74201/74201 [00:03<00:00, 21896.18it/s]


[Getting document vector for submitter] done in 4 s
[Creating w2v for categories] start


  3%|▎         | 2021/74201 [00:00<00:03, 20203.85it/s]

[Creating w2v for categories] done in 51 s
[Getting document vector for categories] start


100%|██████████| 74201/74201 [00:03<00:00, 20206.73it/s]


[Getting document vector for categories] done in 4 s
[Creating w2v for authors] start


  4%|▍         | 3029/74201 [00:00<00:02, 30282.81it/s]

[Creating w2v for authors] done in 226 s
[Getting document vector for authors] start


100%|██████████| 74201/74201 [00:02<00:00, 28795.65it/s]


[Getting document vector for authors] done in 3 s
[Creating w2v for id_doi] start


  7%|▋         | 4836/74201 [00:00<00:01, 48359.01it/s]

[Creating w2v for id_doi] done in 17 s
[Getting document vector for id_doi] start


100%|██████████| 74201/74201 [00:01<00:00, 42778.56it/s]


[Getting document vector for id_doi] done in 2 s
[Creating w2v for categories_space] start


  7%|▋         | 5437/74201 [00:00<00:01, 54364.61it/s]

[Creating w2v for categories_space] done in 13 s
[Getting document vector for categories_space] start


100%|██████████| 74201/74201 [00:02<00:00, 34828.91it/s]


[Getting document vector for categories_space] done in 2 s
[Creating w2v for categories_comma] start


  7%|▋         | 5056/74201 [00:00<00:01, 50551.13it/s]

[Creating w2v for categories_comma] done in 16 s
[Getting document vector for categories_comma] start


100%|██████████| 74201/74201 [00:01<00:00, 44789.80it/s]


[Getting document vector for categories_comma] done in 2 s
[Creating w2v for id_doi_categories_space] start


  6%|▌         | 4459/74201 [00:00<00:01, 44585.68it/s]

[Creating w2v for id_doi_categories_space] done in 34 s
[Getting document vector for id_doi_categories_space] start


100%|██████████| 74201/74201 [00:01<00:00, 42353.09it/s]


[Getting document vector for id_doi_categories_space] done in 2 s


In [None]:
w2v = df[['id']]
for i in range(9):
    w2v = pd.merge(w2v, w2v_dfs[i], on='id', how='left')

In [None]:
w2v.to_csv(OUTDIR / 'w2v.csv', index=False)

## fastText features

fastTextの学習済モデルを用いた特徴量作成

In [None]:
model_en = load_model("cc.en.300.bin")



In [None]:
def text_cleaning(raw_text):
    clean_text = hero.clean(raw_text, pipeline=[
        hero.preprocessing.fillna, # 欠損埋め
        hero.preprocessing.lowercase, # 小文字への統一
        hero.preprocessing.remove_digits, # 数字の削除
        hero.preprocessing.remove_punctuation, # 句読点の削除
        hero.preprocessing.remove_diacritics, # ダイアクリティカルマーク（発音区別符号。àやéなど）の削除
        hero.preprocessing.remove_stopwords, # ストップワードの除去
        hero.preprocessing.remove_whitespace, # スペースの削除
        hero.preprocessing.remove_brackets,
    ])
    return clean_text

In [None]:
df['abstract'] = text_cleaning(df['abstract'])
df['title'] = text_cleaning(df['title'])
df['comments'] = text_cleaning(df['comments'])

In [None]:
svd_fastText = pd.DataFrame(df["id"])
pca_fastText = pd.DataFrame(df["id"])
svd_fastText

Unnamed: 0,id
0,1403.7138
1,1405.5857
2,1807.01034
3,astro-ph/9908243
4,hep-ph/0103252
...,...
74196,1210.4112
74197,1701.03465
74198,1709.10428
74199,gr-qc/9803020


### TruncatedSVD

In [None]:
with timer("SVD abstract fastText"):
    X = df["abstract"].progress_apply(lambda x: model_en.get_sentence_vector(x.replace("\n", "")))
    X = np.stack(X.values)
    svd = TruncatedSVD(n_components=64, random_state=SEED)
    X = svd.fit_transform(X)
    X_df = pd.DataFrame(X, columns=[f"SVD_abstract_fastText_{i}" for i in range(X.shape[1])])
    X_df["id"] = df["id"]
    svd_fastText = svd_fastText.merge(X_df, on="id", how="left")

  1%|          | 572/74201 [00:00<00:26, 2770.34it/s]

[SVD abstract fastText] start


100%|██████████| 74201/74201 [00:23<00:00, 3120.89it/s]


[SVD abstract fastText] done in 26 s


In [None]:
with timer("SVD title fastText"):
    X = df["title"].progress_apply(lambda x: model_en.get_sentence_vector(x.replace("\n", "")))
    X = np.stack(X.values)
    svd = TruncatedSVD(n_components=32, random_state=SEED)
    X = svd.fit_transform(X)
    X_df = pd.DataFrame(X, columns=[f"SVD_title_fastText_{i}" for i in range(X.shape[1])])
    X_df["id"] = df["id"]
    svd_fastText = svd_fastText.merge(X_df, on="id", how="left")

  6%|▌         | 4468/74201 [00:00<00:03, 23236.84it/s]

[SVD title fastText] start


100%|██████████| 74201/74201 [00:03<00:00, 19086.81it/s]


[SVD title fastText] done in 5 s


In [None]:
with timer("SVD comments fastText"):
    X = df["comments"].progress_apply(lambda x: model_en.get_sentence_vector(x.replace("\n", "")))
    X = np.stack(X.values)
    svd = TruncatedSVD(n_components=16, random_state=SEED)
    X = svd.fit_transform(X)
    X_df = pd.DataFrame(X, columns=[f"SVD_comments_fastText_{i}" for i in range(X.shape[1])])
    X_df["id"] = df["id"]
    svd_fastText = svd_fastText.merge(X_df, on="id", how="left")

  3%|▎         | 2149/74201 [00:00<00:03, 21476.96it/s]

[SVD comments fastText] start


100%|██████████| 74201/74201 [00:02<00:00, 29335.04it/s]


[SVD comments fastText] done in 4 s


In [None]:
svd_fastText

Unnamed: 0,id,SVD_abstract_fastText_0,SVD_abstract_fastText_1,SVD_abstract_fastText_2,SVD_abstract_fastText_3,SVD_abstract_fastText_4,SVD_abstract_fastText_5,SVD_abstract_fastText_6,SVD_abstract_fastText_7,SVD_abstract_fastText_8,SVD_abstract_fastText_9,SVD_abstract_fastText_10,SVD_abstract_fastText_11,SVD_abstract_fastText_12,SVD_abstract_fastText_13,SVD_abstract_fastText_14,SVD_abstract_fastText_15,SVD_abstract_fastText_16,SVD_abstract_fastText_17,SVD_abstract_fastText_18,SVD_abstract_fastText_19,SVD_abstract_fastText_20,SVD_abstract_fastText_21,SVD_abstract_fastText_22,SVD_abstract_fastText_23,SVD_abstract_fastText_24,SVD_abstract_fastText_25,SVD_abstract_fastText_26,SVD_abstract_fastText_27,SVD_abstract_fastText_28,SVD_abstract_fastText_29,SVD_abstract_fastText_30,SVD_abstract_fastText_31,SVD_abstract_fastText_32,SVD_abstract_fastText_33,SVD_abstract_fastText_34,SVD_abstract_fastText_35,SVD_abstract_fastText_36,SVD_abstract_fastText_37,SVD_abstract_fastText_38,...,SVD_title_fastText_8,SVD_title_fastText_9,SVD_title_fastText_10,SVD_title_fastText_11,SVD_title_fastText_12,SVD_title_fastText_13,SVD_title_fastText_14,SVD_title_fastText_15,SVD_title_fastText_16,SVD_title_fastText_17,SVD_title_fastText_18,SVD_title_fastText_19,SVD_title_fastText_20,SVD_title_fastText_21,SVD_title_fastText_22,SVD_title_fastText_23,SVD_title_fastText_24,SVD_title_fastText_25,SVD_title_fastText_26,SVD_title_fastText_27,SVD_title_fastText_28,SVD_title_fastText_29,SVD_title_fastText_30,SVD_title_fastText_31,SVD_comments_fastText_0,SVD_comments_fastText_1,SVD_comments_fastText_2,SVD_comments_fastText_3,SVD_comments_fastText_4,SVD_comments_fastText_5,SVD_comments_fastText_6,SVD_comments_fastText_7,SVD_comments_fastText_8,SVD_comments_fastText_9,SVD_comments_fastText_10,SVD_comments_fastText_11,SVD_comments_fastText_12,SVD_comments_fastText_13,SVD_comments_fastText_14,SVD_comments_fastText_15
0,1403.7138,0.363249,0.036163,-0.057263,-0.004489,0.012805,-0.059323,-0.028795,-0.013458,-0.046762,0.015787,0.009254,0.057568,-0.016350,-0.032146,0.022353,-0.029111,0.002450,0.006109,0.008820,-0.045247,-0.000790,-0.007173,-0.000417,0.022783,0.010546,0.004842,-0.000400,0.005816,0.009108,0.001712,0.011019,0.002084,0.007094,-0.005201,0.001718,0.019941,0.034087,0.002584,0.000305,...,0.027592,0.000116,-0.025004,-0.075135,-0.034174,-0.007044,-0.016533,0.000785,-0.032443,-0.055686,0.039786,0.034201,-0.006243,0.001384,-0.039855,-0.053504,-0.008725,-0.026833,-0.001316,0.017851,-0.031920,-0.026893,0.023250,0.008686,0.463168,0.174357,0.018888,0.174098,-0.039492,-0.018320,-0.015919,0.026770,-0.016106,0.048573,-0.044592,0.006404,0.052530,0.079725,0.011259,0.018703
1,1405.5857,0.349707,0.036987,-0.007622,0.033312,-0.007483,0.033859,0.043349,-0.009825,0.014303,0.011604,0.010482,0.007153,-0.025723,0.018025,-0.008113,0.001240,-0.013418,-0.006776,0.010874,0.011695,0.012848,0.014915,-0.003451,0.003813,-0.005019,-0.009489,0.004690,-0.004948,-0.014124,0.005673,-0.000462,-0.001560,0.000566,-0.010549,0.017704,-0.000235,0.027379,-0.008552,-0.000534,...,-0.004573,0.009838,0.006643,-0.026310,0.015480,-0.004861,-0.015985,0.008354,0.011801,0.024794,0.003222,0.053409,0.002897,-0.038642,0.009655,-0.014963,-0.026050,0.032809,0.016913,-0.004278,-0.032901,0.017529,0.032499,0.010253,0.266957,0.204622,0.039873,-0.095879,-0.001073,0.060874,-0.052167,-0.006817,-0.016820,-0.010484,-0.013713,0.035832,0.012730,-0.027548,0.010707,-0.018787
2,1807.01034,0.362031,-0.002213,-0.035914,-0.016621,-0.020911,0.011678,0.011961,-0.008798,0.004393,-0.017695,-0.000504,-0.021193,0.026756,-0.012224,0.007477,0.020000,0.019287,0.025180,0.005475,0.000524,0.019232,-0.006324,0.000448,-0.022061,-0.009714,-0.006328,-0.039916,-0.003471,-0.000698,-0.006873,0.003015,0.000718,-0.009903,0.027684,0.005380,-0.005478,-0.004694,-0.007034,0.001672,...,0.044675,-0.014755,-0.014650,0.040130,0.053732,-0.035711,-0.016279,-0.026685,-0.029782,-0.046798,-0.042801,0.033379,-0.008264,0.083723,-0.079690,0.024421,-0.008615,0.093962,-0.012229,0.018197,-0.000186,-0.008751,0.022243,0.010486,0.660610,-0.140529,0.734568,0.011401,0.024995,-0.015010,0.018599,0.010156,-0.003934,-0.030926,-0.030800,-0.000739,0.020555,-0.002467,0.001499,-0.003411
3,astro-ph/9908243,0.339835,0.062082,-0.028124,0.055019,0.004731,-0.001190,-0.006085,-0.018830,0.027612,-0.002322,0.001375,0.011555,0.010688,-0.003520,-0.004322,-0.009801,-0.001265,0.027869,0.002879,-0.002030,-0.013345,0.009664,-0.001662,0.032070,-0.030826,0.005684,0.044412,0.008662,0.003435,-0.013374,0.001728,-0.049111,-0.005913,-0.000974,0.008223,-0.001653,-0.007390,0.013300,-0.006665,...,-0.044306,0.002913,-0.006779,-0.028796,0.026989,-0.024975,0.002021,0.004989,-0.018964,-0.008060,0.006756,0.040405,0.032710,-0.000859,0.061759,0.011202,0.003182,-0.001428,-0.005056,-0.051107,0.009577,-0.029309,-0.051310,0.023640,0.218959,0.232689,0.026076,-0.064775,-0.015977,0.022838,-0.037940,0.015046,0.023376,0.035361,-0.049240,-0.006938,-0.048712,-0.044078,-0.026303,0.022692
4,hep-ph/0103252,0.385218,0.010888,-0.013582,-0.017762,0.022672,0.015428,0.037803,-0.026076,-0.007963,-0.002703,-0.023248,0.007276,0.008429,0.026948,0.003904,0.012716,0.007105,-0.000488,0.032834,-0.011182,0.027872,0.007437,0.023449,0.030629,-0.003584,0.026192,0.013416,0.019120,0.000578,0.017370,0.010743,0.021854,0.021295,-0.020993,0.020946,0.011955,0.026938,0.008775,-0.010776,...,0.063160,0.130272,0.039636,-0.009426,-0.026935,0.019737,-0.084880,0.045116,-0.017734,0.037710,-0.019074,0.075588,-0.005836,0.026587,-0.001543,-0.009520,0.030540,0.006504,0.030035,-0.005267,0.025228,0.009073,0.026211,0.005693,0.337722,0.187405,0.001712,-0.020637,0.020054,0.057137,0.137323,-0.017663,-0.029259,0.010441,-0.031804,-0.026501,0.001716,0.002044,0.034518,0.024360
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
74196,1210.4112,0.332029,0.121272,0.055344,0.031685,-0.052997,0.060249,0.013624,0.000941,-0.019317,-0.020356,-0.012135,-0.011987,0.010982,0.010924,-0.008637,-0.038228,0.015540,0.019175,-0.006177,0.011759,-0.001497,0.035805,-0.030017,-0.026485,0.003092,0.003280,-0.033815,-0.012360,-0.001248,0.023245,0.012355,0.007253,0.015515,-0.008260,-0.008013,0.003343,-0.003819,0.002208,0.007274,...,0.054396,0.021934,0.028995,0.041716,0.011202,0.051186,-0.024808,-0.028759,0.089410,0.074944,-0.010017,0.030061,0.040673,0.025043,0.014938,-0.034774,-0.063904,0.033221,0.007855,0.031172,-0.012481,0.013914,0.012028,0.013282,0.745196,-0.182449,0.113219,0.018035,-0.001024,-0.000637,0.003108,-0.005618,-0.001643,-0.018248,-0.013345,-0.006522,0.003373,-0.005904,-0.000103,-0.003856
74197,1701.03465,0.283698,0.100649,-0.028159,0.100974,0.062808,-0.013002,-0.040827,-0.027231,0.019650,0.010576,0.005820,-0.011478,0.005533,-0.005890,0.041349,-0.002061,-0.019285,-0.023930,-0.026760,0.001512,0.015251,-0.011197,-0.010344,-0.024415,0.024906,-0.004109,-0.005032,-0.003520,0.013490,-0.004749,-0.017725,0.018083,0.002114,0.012739,0.000317,-0.021826,-0.007042,-0.009837,0.000651,...,-0.015571,-0.012364,-0.075102,-0.034222,0.012408,-0.012016,0.009450,-0.047186,-0.042244,-0.033538,-0.005990,-0.063990,0.005715,0.013026,-0.008306,-0.002378,-0.027336,-0.033787,-0.052031,0.030812,0.047150,-0.004337,0.013316,-0.000704,0.434565,0.162231,0.025861,0.134348,0.011715,-0.061259,-0.022461,-0.008251,-0.018422,0.014246,0.089114,0.027914,-0.025899,0.025244,-0.003881,0.072035
74198,1709.10428,0.345082,0.011044,0.058631,0.007143,-0.006866,-0.028062,0.025711,0.031323,-0.006512,-0.009441,-0.016158,-0.014673,-0.008513,-0.020917,0.018354,-0.015482,0.016254,-0.022182,-0.015883,-0.002990,0.016325,0.008899,0.004145,-0.025713,-0.006047,0.000892,-0.014928,-0.026211,-0.002191,-0.010457,-0.001126,-0.012107,-0.006804,-0.002103,0.002672,-0.005617,-0.015833,0.006143,-0.006326,...,0.015099,0.010050,-0.061346,0.000390,0.050218,0.029616,0.017998,-0.017461,0.006640,0.108778,0.050959,-0.027884,-0.060860,0.016973,-0.034339,0.001221,-0.036968,0.003406,0.051482,-0.020302,0.003457,0.010011,-0.008570,0.042196,0.829783,-0.224368,-0.508128,0.024676,-0.027046,0.013727,-0.012383,-0.021396,0.000630,-0.005592,0.004120,-0.012309,-0.013811,-0.009343,-0.001697,-0.004287
74199,gr-qc/9803020,0.360536,-0.063012,0.052926,-0.017128,0.095782,-0.011085,0.009809,0.019053,-0.008621,0.010012,0.009454,-0.038591,-0.041560,-0.021131,0.016149,0.002126,-0.011822,-0.007633,0.026703,0.010361,-0.043425,0.012003,-0.039361,0.030072,-0.001832,-0.001932,-0.023172,0.016034,-0.031370,0.018884,-0.002024,0.009447,-0.027012,-0.006684,0.011875,0.012547,-0.021343,0.026424,0.001380,...,-0.005170,-0.058724,-0.071851,-0.060252,-0.036436,-0.063686,0.016507,-0.099997,-0.020394,-0.034679,-0.013692,-0.014983,-0.012698,-0.003540,0.063365,-0.019162,-0.080282,-0.057334,0.027765,0.010988,-0.021594,-0.035033,0.005631,0.024332,0.337867,0.174960,0.071900,0.047363,0.009070,0.071994,-0.038819,0.037083,0.035533,-0.002086,-0.009043,-0.064786,0.011254,-0.058656,-0.014876,-0.033238


In [None]:
svd_fastText.to_csv(OUTDIR / 'svd_fastText.csv', index=False)

### PCA

In [None]:
with timer("PCA abstract fastText"):
    X = df["abstract"].progress_apply(lambda x: model_en.get_sentence_vector(x.replace("\n", "")))
    X = np.stack(X.values)
    pca = PCA(n_components=64, random_state=SEED)
    X = pca.fit_transform(X)
    X_df = pd.DataFrame(X, columns=[f"PCA_abstract_fastText_{i}" for i in range(X.shape[1])])
    X_df["id"] = df["id"]
    pca_fastText = pca_fastText.merge(X_df, on="id", how="left")

  1%|          | 606/74201 [00:00<00:24, 3019.93it/s]

[PCA abstract fastText] start


100%|██████████| 74201/74201 [00:24<00:00, 3052.77it/s]


[PCA abstract fastText] done in 26 s


In [None]:
with timer("PCA title fastText"):
    X = df["title"].progress_apply(lambda x: model_en.get_sentence_vector(x.replace("\n", "")))
    X = np.stack(X.values)
    pca = PCA(n_components=32, random_state=SEED)
    X = pca.fit_transform(X)
    X_df = pd.DataFrame(X, columns=[f"PCA_title_fastText_{i}" for i in range(X.shape[1])])
    X_df["id"] = df["id"]
    pca_fastText = pca_fastText.merge(X_df, on="id", how="left")

  4%|▍         | 3130/74201 [00:00<00:04, 15375.39it/s]

[PCA title fastText] start


100%|██████████| 74201/74201 [00:03<00:00, 20311.73it/s]


[PCA title fastText] done in 5 s


In [None]:
with timer("PCA comments fastText"):
    X = df["comments"].progress_apply(lambda x: model_en.get_sentence_vector(x.replace("\n", "")))
    X = np.stack(X.values)
    pca = PCA(n_components=16, random_state=SEED)
    X = pca.fit_transform(X)
    X_df = pd.DataFrame(X, columns=[f"PCA_comments_fastText_{i}" for i in range(X.shape[1])])
    X_df["id"] = df["id"]
    pca_fastText = pca_fastText.merge(X_df, on="id", how="left")

  4%|▍         | 3331/74201 [00:00<00:02, 33304.87it/s]

[PCA comments fastText] start


100%|██████████| 74201/74201 [00:02<00:00, 29120.12it/s]


[PCA comments fastText] done in 4 s


In [None]:
pca_fastText

Unnamed: 0,id,PCA_abstract_fastText_0,PCA_abstract_fastText_1,PCA_abstract_fastText_2,PCA_abstract_fastText_3,PCA_abstract_fastText_4,PCA_abstract_fastText_5,PCA_abstract_fastText_6,PCA_abstract_fastText_7,PCA_abstract_fastText_8,PCA_abstract_fastText_9,PCA_abstract_fastText_10,PCA_abstract_fastText_11,PCA_abstract_fastText_12,PCA_abstract_fastText_13,PCA_abstract_fastText_14,PCA_abstract_fastText_15,PCA_abstract_fastText_16,PCA_abstract_fastText_17,PCA_abstract_fastText_18,PCA_abstract_fastText_19,PCA_abstract_fastText_20,PCA_abstract_fastText_21,PCA_abstract_fastText_22,PCA_abstract_fastText_23,PCA_abstract_fastText_24,PCA_abstract_fastText_25,PCA_abstract_fastText_26,PCA_abstract_fastText_27,PCA_abstract_fastText_28,PCA_abstract_fastText_29,PCA_abstract_fastText_30,PCA_abstract_fastText_31,PCA_abstract_fastText_32,PCA_abstract_fastText_33,PCA_abstract_fastText_34,PCA_abstract_fastText_35,PCA_abstract_fastText_36,PCA_abstract_fastText_37,PCA_abstract_fastText_38,...,PCA_title_fastText_8,PCA_title_fastText_9,PCA_title_fastText_10,PCA_title_fastText_11,PCA_title_fastText_12,PCA_title_fastText_13,PCA_title_fastText_14,PCA_title_fastText_15,PCA_title_fastText_16,PCA_title_fastText_17,PCA_title_fastText_18,PCA_title_fastText_19,PCA_title_fastText_20,PCA_title_fastText_21,PCA_title_fastText_22,PCA_title_fastText_23,PCA_title_fastText_24,PCA_title_fastText_25,PCA_title_fastText_26,PCA_title_fastText_27,PCA_title_fastText_28,PCA_title_fastText_29,PCA_title_fastText_30,PCA_title_fastText_31,PCA_comments_fastText_0,PCA_comments_fastText_1,PCA_comments_fastText_2,PCA_comments_fastText_3,PCA_comments_fastText_4,PCA_comments_fastText_5,PCA_comments_fastText_6,PCA_comments_fastText_7,PCA_comments_fastText_8,PCA_comments_fastText_9,PCA_comments_fastText_10,PCA_comments_fastText_11,PCA_comments_fastText_12,PCA_comments_fastText_13,PCA_comments_fastText_14,PCA_comments_fastText_15
0,1403.7138,0.022213,-0.063580,-0.006930,0.012005,-0.060646,0.007012,-0.026887,-0.047010,0.013938,0.007185,0.059024,-0.014033,-0.032295,0.020724,0.016359,-0.007507,-0.018424,0.048894,-0.009012,0.007991,-0.002842,-0.006262,0.022443,0.005798,-0.002985,0.005018,0.006741,0.009631,0.000006,0.010181,0.003141,0.006641,-0.000299,0.006594,-0.019092,0.035545,0.001255,0.000645,0.009431,...,-0.015427,-0.056469,0.057619,-0.026314,0.027284,-0.026489,0.000808,0.032853,-0.054654,0.042696,-0.026532,0.008023,-0.001114,0.051280,-0.054159,-0.009021,-0.021708,-0.001483,0.020374,0.033468,-0.016812,0.027175,-0.000312,-0.008590,0.002440,0.005865,0.166830,-0.145452,-0.045220,-0.020820,-0.014590,0.026754,-0.015489,0.048165,-0.046970,0.005813,0.053057,0.078501,0.011781,0.018191
1,1405.5857,0.037169,-0.012214,0.024233,-0.002867,0.044837,-0.036380,0.003958,0.016378,0.013121,0.009504,0.008721,-0.025580,0.018166,-0.006840,0.001289,-0.011708,0.001809,-0.014148,0.015613,0.011453,0.013115,-0.001125,0.003983,-0.003748,-0.002661,-0.009839,-0.005478,-0.013711,0.006578,0.001089,-0.001652,0.000873,-0.007856,0.019372,0.000215,0.024720,-0.011748,-0.004984,-0.005072,...,0.007250,-0.004720,0.024331,0.017823,0.004571,-0.013848,0.013540,-0.011610,0.018792,-0.006821,-0.054149,0.016559,-0.037320,-0.005688,-0.022170,-0.018319,0.035826,0.017803,-0.000399,0.030720,0.019443,0.032302,0.008010,0.000793,-0.193191,0.032933,0.045226,0.093213,0.010066,0.060231,-0.054515,-0.006676,-0.017124,-0.009652,-0.014209,0.035692,0.012673,-0.028530,0.010638,-0.017865
2,1807.01034,-0.012415,-0.036621,-0.013441,-0.021336,0.014132,-0.009251,-0.003328,0.005234,-0.011307,0.002026,-0.023612,0.026777,-0.013517,0.008134,-0.015343,0.025941,0.026618,0.001923,0.006342,0.019214,-0.006920,-0.003280,-0.022798,-0.016680,-0.020302,0.030592,-0.004742,-0.002571,-0.008502,0.001463,0.000861,-0.010143,0.026225,-0.000043,0.004637,-0.006892,-0.007443,0.000335,-0.002657,...,-0.005449,0.008538,-0.047452,0.048375,-0.032910,-0.039609,-0.020476,0.027560,-0.048890,-0.039564,-0.015069,0.005120,0.084489,0.077087,0.029557,-0.003029,0.090141,-0.015419,0.022399,-0.002227,-0.003775,0.017506,0.011322,0.044233,0.314680,0.737771,0.002718,-0.017067,0.021739,-0.013355,0.018708,0.010159,-0.004158,-0.031874,-0.029754,-0.000814,0.020570,-0.002682,0.001592,-0.003387
3,astro-ph/9908243,0.063824,-0.034454,0.045905,0.008925,0.002876,-0.003132,-0.022353,0.027110,-0.006074,0.001075,0.011230,0.011012,-0.003354,-0.002625,0.010289,-0.000633,0.019072,0.020736,0.000325,-0.012990,0.010021,0.001100,0.032292,-0.019690,0.034485,-0.038456,0.008658,0.001907,-0.013874,-0.001113,-0.047985,-0.006718,0.001945,0.008487,0.001345,-0.007518,0.013335,-0.005626,0.008025,...,-0.002260,-0.016145,0.019994,0.028147,-0.010525,-0.014325,-0.004506,0.022570,-0.014562,0.001146,-0.040068,0.030250,-0.001089,-0.067057,0.007869,-0.002898,-0.002806,-0.005437,-0.054022,-0.000815,-0.034964,-0.047748,0.025213,0.018508,-0.247330,0.018235,0.055072,0.065864,-0.010199,0.019676,-0.037559,0.014991,0.024053,0.033058,-0.049399,-0.007251,-0.048892,-0.044344,-0.026545,0.023386
4,hep-ph/0103252,-0.005219,-0.019715,-0.030220,0.025976,0.027397,-0.039935,-0.015773,-0.006468,-0.001625,-0.022875,0.006812,0.008800,0.027447,0.000959,-0.013674,0.006386,-0.006198,0.017207,0.027181,0.030258,0.006336,0.023616,0.029941,-0.001981,0.024575,0.002844,0.020066,0.003694,0.016758,0.013131,0.022375,0.021391,-0.015081,0.025005,-0.011697,0.025768,0.004716,-0.019043,-0.014855,...,0.137253,0.027379,0.029302,-0.024237,0.029705,-0.041602,0.082604,0.016554,0.031526,-0.029509,-0.066031,0.012187,0.025301,0.002903,-0.010786,0.026716,-0.004158,0.020574,-0.008599,-0.022767,0.012156,0.020912,0.014718,-0.043028,-0.120902,-0.006444,0.071370,0.020597,0.028079,0.063888,0.134639,-0.017748,-0.029060,0.010061,-0.032260,-0.026775,0.001951,0.001447,0.034119,0.024854
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
74196,1210.4112,0.130983,0.036393,0.006982,-0.046848,0.067428,-0.006147,0.001885,-0.020630,-0.018876,-0.010033,-0.013455,0.011334,0.008832,0.005432,0.043775,0.016438,0.017253,-0.005475,-0.001198,-0.002055,0.034794,-0.025016,-0.025295,0.006973,-0.005768,0.028270,-0.013483,-0.002527,0.014461,0.021070,0.005181,0.016777,-0.015079,-0.009854,-0.003396,-0.003698,0.002014,0.004421,-0.005632,...,0.034882,0.048157,-0.031349,0.012993,0.029915,0.003565,0.004571,-0.098820,0.077429,-0.015224,-0.020312,0.052180,0.028111,-0.010470,-0.042808,-0.048209,0.043831,0.012214,0.036885,0.004971,0.008894,0.011622,0.015122,-0.024788,0.414773,0.118569,-0.036596,-0.019291,-0.006362,-0.001210,0.003351,-0.005634,-0.001755,-0.018884,-0.012550,-0.006527,0.003338,-0.005900,-0.000016,-0.003833
74197,1701.03465,0.121059,-0.034027,0.097275,0.065330,-0.019411,0.030518,-0.035418,0.021232,0.012663,0.005454,-0.011955,0.005183,-0.007439,0.039145,-0.009613,-0.022331,-0.015674,-0.023142,-0.023121,0.013745,-0.012321,-0.009189,-0.023802,0.022940,-0.012696,0.004299,-0.002881,0.015118,-0.000522,-0.018839,0.017515,0.002288,0.011430,-0.001818,0.021538,-0.008591,-0.009496,0.001036,-0.005098,...,-0.010966,-0.070571,-0.017619,0.018337,0.017078,-0.025670,-0.050636,0.039319,-0.025016,0.006861,0.067564,-0.006837,0.011767,0.007920,-0.000208,-0.028969,-0.025235,-0.047796,0.026990,-0.047666,-0.004877,0.014926,-0.007070,0.007740,-0.019066,0.015197,0.129274,-0.120491,0.009047,-0.060210,-0.021295,-0.008196,-0.018592,0.017335,0.088243,0.028354,-0.026159,0.027123,-0.004073,0.070981
74198,1709.10428,0.021074,0.056537,0.001111,-0.006540,-0.024038,-0.025086,0.033584,-0.008491,-0.011697,-0.015349,-0.014671,-0.008763,-0.022115,0.020111,0.009864,0.011072,-0.024216,-0.011496,-0.014870,0.016574,0.007645,0.004258,-0.025929,-0.006196,-0.002462,0.012158,-0.027001,-0.003570,-0.010692,-0.003091,-0.011867,-0.006871,-0.001610,0.002549,0.005043,-0.015465,0.006692,-0.004368,-0.003580,...,0.009611,-0.052942,-0.030488,0.051974,0.012558,0.025420,-0.017121,-0.012336,0.112897,0.041975,-0.001444,-0.066923,0.015337,0.038146,-0.002683,-0.039042,0.009157,0.053672,-0.022601,0.000820,0.010822,-0.007057,0.038942,0.012406,0.514865,-0.500632,-0.075906,-0.021514,-0.034459,0.010940,-0.012002,-0.021416,0.000636,-0.005886,0.004664,-0.012226,-0.013893,-0.009116,-0.001636,-0.004305
74199,gr-qc/9803020,-0.057975,0.060791,-0.014519,0.092623,-0.017144,0.000813,0.026219,-0.007001,0.015320,0.009363,-0.036752,-0.043098,-0.022121,0.015134,-0.005955,-0.011719,-0.002170,-0.006290,0.028658,-0.041225,0.016289,-0.042347,0.030529,-0.004305,-0.010995,0.018762,0.015142,-0.029431,0.022162,0.001625,0.008825,-0.026697,-0.006388,0.012009,-0.013180,-0.019662,0.027958,0.001491,-0.002639,...,-0.062238,-0.078414,0.007299,-0.029821,0.002190,-0.069073,-0.106558,0.017551,-0.031120,-0.008445,0.018239,-0.010587,-0.002988,-0.060018,-0.036412,-0.088684,-0.043994,0.033466,0.011699,0.020168,-0.040240,0.002969,0.023122,-0.003142,-0.113976,0.063982,0.077625,-0.045167,-0.002056,0.069951,-0.040821,0.037195,0.035424,-0.002563,-0.008381,-0.064806,0.011275,-0.059616,-0.014692,-0.031802


In [None]:
pca_fastText.to_csv(OUTDIR / 'pca_fastText.csv', index=False)

## SciBERT

In [None]:
class BertSequenceVectorizer:
    def __init__(self):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model_name = 'scibert_scivocab_uncased' # modelを保存しているpathを指定
        self.tokenizer = BertTokenizer.from_pretrained(self.model_name)
        self.bert_model = transformers.BertModel.from_pretrained(self.model_name)
        self.bert_model = self.bert_model.to(self.device)
        self.max_len = 128
        print(self.device)


    def vectorize(self, sentence : str) -> np.array:
        inp = self.tokenizer.encode(sentence)
        len_inp = len(inp)

        if len_inp >= self.max_len:
            inputs = inp[:self.max_len]
            masks = [1] * self.max_len
        else:
            inputs = inp + [0] * (self.max_len - len_inp)
            masks = [1] * len_inp + [0] * (self.max_len - len_inp)

        inputs_tensor = torch.tensor([inputs], dtype=torch.long).to(self.device)
        masks_tensor = torch.tensor([masks], dtype=torch.long).to(self.device)

        bert_out = self.bert_model(inputs_tensor, masks_tensor)
        seq_out, pooled_out = bert_out['last_hidden_state'], bert_out['pooler_output']

        if torch.cuda.is_available():    
            return seq_out[0][0].cpu().detach().numpy()
        else:
            return seq_out[0][0].detach().numpy()

BSV = BertSequenceVectorizer()

cuda


In [None]:
svd_SciBERT = pd.DataFrame(df["id"])
svd_SciBERT

Unnamed: 0,id
0,1403.7138
1,1405.5857
2,1807.01034
3,astro-ph/9908243
4,hep-ph/0103252
...,...
74196,1210.4112
74197,1701.03465
74198,1709.10428
74199,gr-qc/9803020


In [None]:
with timer("SVD abstract scibert"):
    bert = df['abstract'].fillna('nan').progress_apply(lambda x: BSV.vectorize(x))
    svd = TruncatedSVD(n_components=32, random_state=SEED+1)
    X = np.stack(bert.values)
    X = svd.fit_transform(X)
    X_df = pd.DataFrame(X, columns=[f"scibert_abstract_{i}" for i in range(X.shape[1])])
    X_df["id"] = df["id"]
    svd_SciBERT = svd_SciBERT.merge(X_df, on="id", how="left")

  0%|          | 7/74201 [00:00<17:41, 69.88it/s]

[SVD abstract scibert] start


100%|██████████| 74201/74201 [19:26<00:00, 63.63it/s]


[SVD abstract scibert] done in 1170 s


In [None]:
with timer("SVD title scibert"):
    bert = df['title'].fillna('nan').progress_apply(lambda x: BSV.vectorize(x))
    svd = TruncatedSVD(n_components=16, random_state=SEED+1)
    X = np.stack(bert.values)
    X = svd.fit_transform(X)
    X_df = pd.DataFrame(X, columns=[f"scibert_title_{i}" for i in range(X.shape[1])])
    X_df["id"] = df["id"]
    svd_SciBERT = svd_SciBERT.merge(X_df, on="id", how="left")

  0%|          | 10/74201 [00:00<13:28, 91.79it/s]

[SVD title scibert] start


100%|██████████| 74201/74201 [15:24<00:00, 80.28it/s]


[SVD title scibert] done in 927 s


In [None]:
svd_SciBERT.to_csv(OUTDIR / 'svd_SciBERT.csv', index=False)

## Tfidf features

精度上がらないので不使用

In [None]:
tfidf_dfs = []
VOCAB_SIZE = 10000
COMPONENT_SIZE = 50
text_columns = ['title', 'abstract', 'comments']
for col in text_columns:
    with timer(f'working on {col}'):
        docs = df[col]
        tv = TfidfVectorizer(max_features=VOCAB_SIZE, analyzer='word', ngram_range=(1, 3))
        X = tv.fit_transform(docs)

        svd = TruncatedSVD(n_components=COMPONENT_SIZE, random_state=SEED)
        X = svd.fit_transform(X)
        _df = pd.DataFrame(X, columns=[f'tfidf_{col}_{i}' for i in range(COMPONENT_SIZE)])
        _df.index = df['id']
        tfidf_dfs.append(_df)

[working on title] start
[working on title] done in 5 s
[working on abstract] start
[working on abstract] done in 57 s
[working on comments] start
[working on comments] done in 4 s


In [None]:
tfidf_dfs[0]

Unnamed: 0_level_0,tfidf_title_0,tfidf_title_1,tfidf_title_2,tfidf_title_3,tfidf_title_4,tfidf_title_5,tfidf_title_6,tfidf_title_7,tfidf_title_8,tfidf_title_9,tfidf_title_10,tfidf_title_11,tfidf_title_12,tfidf_title_13,tfidf_title_14,tfidf_title_15,tfidf_title_16,tfidf_title_17,tfidf_title_18,tfidf_title_19,tfidf_title_20,tfidf_title_21,tfidf_title_22,tfidf_title_23,tfidf_title_24,tfidf_title_25,tfidf_title_26,tfidf_title_27,tfidf_title_28,tfidf_title_29,tfidf_title_30,tfidf_title_31,tfidf_title_32,tfidf_title_33,tfidf_title_34,tfidf_title_35,tfidf_title_36,tfidf_title_37,tfidf_title_38,tfidf_title_39,tfidf_title_40,tfidf_title_41,tfidf_title_42,tfidf_title_43,tfidf_title_44,tfidf_title_45,tfidf_title_46,tfidf_title_47,tfidf_title_48,tfidf_title_49
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1
1403.7138,0.022092,-0.025197,0.001876,0.001340,-0.013368,-0.015695,0.010623,-0.003666,-0.008180,0.002093,-0.002974,-0.005010,-0.004291,-0.006800,0.008022,-0.003574,-0.006543,-0.006811,0.001579,-0.001033,0.000433,0.003499,-0.004929,0.005418,-0.001621,-0.014007,-0.007102,-0.006330,0.004639,-0.000343,-0.000698,-0.000189,-0.006186,0.006236,0.000338,-0.002472,0.006124,-0.009179,-0.014351,-0.004119,0.023604,0.000761,-0.010376,-0.006981,0.034924,0.030098,0.010955,-0.007343,0.000830,0.009619
1405.5857,0.039071,-0.015600,-0.011777,-0.006100,-0.021955,-0.012472,0.010274,-0.007363,-0.019584,0.019338,-0.013201,-0.007193,-0.019394,0.005822,-0.022268,-0.013668,-0.049980,0.000151,-0.003409,0.005417,-0.050395,-0.009596,-0.000275,-0.024614,-0.005751,-0.023067,-0.009088,0.019793,-0.051109,0.011580,-0.003654,0.003815,-0.014845,0.017690,-0.014772,0.001448,0.022097,-0.011567,0.024491,0.013943,-0.038399,-0.036524,0.072319,0.105217,-0.001384,0.055054,-0.024756,0.036734,-0.005273,-0.006620
1807.01034,0.064398,0.010299,-0.002961,-0.016136,-0.036859,-0.002454,-0.014653,0.005606,-0.046992,0.046362,0.010179,0.003106,-0.088866,0.147413,-0.000095,0.138788,0.090244,0.230015,-0.054889,0.006077,0.026439,0.012836,0.062635,0.067700,0.048007,-0.008884,-0.007613,-0.037485,0.000729,-0.000334,0.017643,-0.002840,-0.019006,-0.023163,0.001837,-0.014909,0.029141,-0.011661,-0.026074,-0.021488,-0.003926,-0.020361,-0.005310,-0.008551,-0.001246,-0.011700,0.014528,0.044758,0.003956,-0.000280
astro-ph/9908243,0.041271,-0.055143,-0.018558,0.002658,-0.089187,-0.087120,0.054102,-0.042897,0.111121,-0.034166,0.061036,0.017554,-0.037726,0.006297,0.008291,-0.004310,0.002430,-0.007016,-0.007481,-0.000015,0.013666,0.008857,-0.012984,0.016916,0.022556,-0.010352,-0.000571,0.006515,0.009366,-0.004062,0.006338,0.013250,0.013342,-0.002487,-0.014131,0.002107,0.000227,0.004397,0.006032,-0.012862,0.000117,0.015252,-0.000039,0.010219,-0.004081,0.002097,0.000914,0.005600,-0.009347,-0.001482
hep-ph/0103252,0.078460,-0.115212,0.097084,0.010201,0.025180,0.029005,-0.027669,-0.020705,-0.004142,-0.004497,0.010110,0.013843,-0.009440,-0.044549,-0.022563,0.007759,-0.006818,0.012014,0.004066,0.037335,0.003853,0.011135,0.012955,-0.022339,0.018941,-0.022067,-0.008790,-0.031957,0.018606,0.056578,0.026778,-0.117538,-0.005887,-0.009796,0.000734,-0.040185,-0.050068,0.054078,-0.070434,0.016771,-0.036722,-0.019203,0.076261,-0.032048,0.026092,-0.000374,-0.054850,0.052380,-0.014402,-0.011968
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1210.4112,0.040924,-0.023408,-0.000949,-0.009484,-0.031344,-0.006087,0.014426,0.002487,0.037970,0.066051,-0.072117,-0.022755,0.056696,-0.110857,-0.104536,0.073635,0.040779,0.076894,0.065806,0.007308,-0.065504,0.054936,-0.010204,-0.007389,-0.083515,-0.016871,-0.077354,0.097635,0.009339,-0.023994,0.032001,-0.000033,0.026592,0.028432,-0.026142,-0.024227,-0.016424,0.000985,0.027325,0.023699,0.018981,0.000121,0.032671,-0.010991,0.003375,-0.029798,-0.020322,-0.017541,0.047261,-0.023958
1701.03465,0.041688,-0.046463,-0.006669,0.002882,-0.049054,-0.055600,0.024967,-0.022833,-0.064195,-0.024989,0.025476,-0.014147,0.008259,0.011338,0.031057,-0.016756,0.006144,-0.018280,-0.031051,-0.019174,-0.026812,0.011386,-0.036347,0.045279,0.038001,-0.011437,-0.059120,0.033904,-0.002554,-0.032362,-0.012904,0.016632,-0.012508,0.052533,0.050060,-0.043404,0.006829,0.035399,0.007529,-0.008232,0.022882,-0.009863,-0.001706,-0.023466,0.011501,0.009434,0.046350,0.031306,0.000071,0.052191
1709.10428,0.086494,0.041244,-0.000504,-0.005493,-0.021754,0.024251,-0.090039,0.042658,0.013826,0.119960,0.101589,-0.058991,-0.003887,-0.004259,0.004367,-0.036508,0.007040,-0.049933,-0.018390,-0.033750,-0.032583,-0.001268,0.120197,0.112117,0.017791,0.088096,-0.069112,-0.045585,-0.084628,-0.057848,-0.004192,-0.044674,-0.019837,-0.015791,0.003991,0.031053,-0.017518,-0.049298,0.046528,0.001388,0.089797,0.026401,0.013819,-0.014683,-0.023339,0.085520,-0.049048,0.013170,-0.011741,-0.016854
gr-qc/9803020,0.034044,-0.000941,-0.008669,0.000490,0.009465,0.003445,0.009003,-0.006327,-0.018251,0.043154,-0.028326,-0.010443,-0.096139,-0.024816,0.036979,-0.006141,0.160831,-0.077672,-0.068816,-0.068734,-0.056514,-0.004276,-0.040833,-0.069732,-0.012580,-0.012090,0.013307,-0.006211,-0.010631,-0.008695,-0.015451,-0.004839,0.000647,-0.010384,-0.010034,-0.004942,0.006784,0.021750,-0.014884,0.010817,0.004975,0.013016,0.003313,0.001069,-0.010813,-0.009200,0.006365,-0.011989,0.011005,0.001927


In [None]:
tfidf_title_abstract_comments = df[['id']]
for i in range(3):
    tfidf_title_abstract_comments = pd.merge(tfidf_title_abstract_comments, tfidf_dfs[i], how='left', on='id')

In [None]:
tfidf_title_abstract_comments.to_csv(OUTDIR / 'tfidf.csv', index=False)