In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/ml-study-meetup-osaka/OsakaWinter_station.csv
/kaggle/input/ml-study-meetup-osaka/OsakaWinter_test.csv
/kaggle/input/ml-study-meetup-osaka/OsakaWinter_train.csv
/kaggle/input/ml-study-meetup-osaka/OsakaWinter_city.csv
/kaggle/input/ml-study-meetup-osaka/OsakaWinter_data_dictionary.csv
/kaggle/input/ml-study-meetup-osaka/OsakaWinter_sample_submission.csv
/kaggle/input/ml-study-meetup-osaka/glove.840B.300d/glove.840B.300d.txt


### BERTを使って、textデータをベクトル化します
https://huggingface.co/transformers/model_doc/bert.html <br>

データ分析のコンペではテキスト処理によくBERTは使われます <br>
下記URLを参考にしています<br>
https://www.guruguru.science/competitions/16/discussions/fb792c87-6bad-445d-aa34-b4118fc378c1/

Settingsで下記設定に変更してください
- AcceleratorをGPU
- InternetをON

In [2]:
# ==================
# Library
# ==================
import pandas as pd
import numpy as np
import torch
import transformers
from transformers import BertTokenizer
from tqdm import tqdm
from sklearn.decomposition import TruncatedSVD
tqdm.pandas()

In [3]:
# ==================
# Constant
# ==================
CITY_PATH = '/kaggle/input/ml-study-meetup-osaka/OsakaWinter_city.csv'

In [4]:
city = pd.read_csv(CITY_PATH)

In [5]:
city.head()

Unnamed: 0,Prefecture,Municipality,Latitude,Longitude,wiki_description
0,Hyogo Prefecture,"Fukusaki Town,Kanzaki County",34.950238,134.760182,"Fukusaki (福崎町, Fukusaki-chō) is a town in Kanz..."
1,Hyogo Prefecture,Kasai City,34.928023,134.841609,"Kasaï-Oriental (French for ""East Kasai"") is on..."
2,Hyogo Prefecture,Tamba Sasayama City,35.075729,135.219196,"Tamba-Sasayama (丹波篠山市, Tanba-Sasayama-shi), fo..."
3,Hyogo Prefecture,Yabu City,35.404612,134.767632,"Yabu (養父市, Yabu-shi) is a city located in Hyōg..."
4,Hyogo Prefecture,Tanba City,35.177132,135.035842,


In [6]:
class BertSequenceVectorizer:
    def __init__(self):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model_name = 'bert-base-uncased'
        self.tokenizer = BertTokenizer.from_pretrained(self.model_name)
        self.bert_model = transformers.BertModel.from_pretrained(self.model_name)
        self.bert_model = self.bert_model.to(self.device)
        self.max_len = 128


    def vectorize(self, sentence : str) -> np.array:
        inp = self.tokenizer.encode(sentence)
        len_inp = len(inp)

        if len_inp >= self.max_len:
            inputs = inp[:self.max_len]
            masks = [1] * self.max_len
        else:
            inputs = inp + [0] * (self.max_len - len_inp)
            masks = [1] * len_inp + [0] * (self.max_len - len_inp)

        inputs_tensor = torch.tensor([inputs], dtype=torch.long).to(self.device)
        masks_tensor = torch.tensor([masks], dtype=torch.long).to(self.device)

        bert_out = self.bert_model(inputs_tensor, masks_tensor)
        seq_out, pooled_out = bert_out['last_hidden_state'], bert_out['pooler_output']

        if torch.cuda.is_available():    
            return seq_out[0][0].cpu().detach().numpy() # 0番目は [CLS] token, 768 dim の文章特徴量
        else:
            return seq_out[0][0].detach().numpy()

In [7]:
BSV = BertSequenceVectorizer() # インスタンス化します
city['wiki_description'] = city['wiki_description'].fillna("NaN") # null は代わりのもので埋めます
city['wiki_description_feature'] = city['wiki_description'].progress_apply(lambda x: BSV.vectorize(x))

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/433 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]

  0%|          | 0/274 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (988 > 512). Running this sequence through the model will result in indexing errors
100%|██████████| 274/274 [01:18<00:00,  3.51it/s]


In [8]:
bert_array = np.zeros([len(city),768])
for n,i in enumerate(city['wiki_description_feature']):
    bert_array[n,:] = i

In [10]:
svd = TruncatedSVD(n_components=50)
X = svd.fit_transform(bert_array)
df = pd.DataFrame(X, columns=[f"wiki_description_bert_svd_{i}" for i in range(50)])

In [None]:
Prefecture	Municipality

In [12]:
df["Prefecture"] = city["Prefecture"]
df["Municipality"] = city["Municipality"]

In [13]:
df.head()

Unnamed: 0,wiki_description_bert_svd_0,wiki_description_bert_svd_1,wiki_description_bert_svd_2,wiki_description_bert_svd_3,wiki_description_bert_svd_4,wiki_description_bert_svd_5,wiki_description_bert_svd_6,wiki_description_bert_svd_7,wiki_description_bert_svd_8,wiki_description_bert_svd_9,...,wiki_description_bert_svd_42,wiki_description_bert_svd_43,wiki_description_bert_svd_44,wiki_description_bert_svd_45,wiki_description_bert_svd_46,wiki_description_bert_svd_47,wiki_description_bert_svd_48,wiki_description_bert_svd_49,Prefecture,Municipality
0,7.19585,6.540291,-1.202089,-0.861769,0.11923,2.853026,-2.235604,-0.182519,0.972919,-0.123232,...,0.114649,-0.575004,-0.559227,0.113885,0.286264,0.579359,0.498586,0.120311,Hyogo Prefecture,"Fukusaki Town,Kanzaki County"
1,7.393426,5.078303,-0.418622,2.075598,1.727283,-3.374451,0.144032,-0.900871,-2.064772,-0.465511,...,-0.873664,-0.838587,-1.539198,-0.236962,-0.635799,-0.266622,-0.37894,0.717455,Hyogo Prefecture,Kasai City
2,8.321034,6.666421,-0.225001,-0.807661,-0.470637,-0.795802,3.042514,-0.933491,-0.232067,0.275798,...,0.124066,-0.461417,-0.372746,-0.275443,0.783521,-0.846106,0.106188,-0.155596,Hyogo Prefecture,Tamba Sasayama City
3,7.236345,8.13092,-0.133214,-1.833246,-1.677707,-2.200563,2.409288,0.206769,-0.694948,0.031241,...,0.551018,-0.048959,0.203427,0.783047,0.170565,-0.13276,0.707002,-0.055784,Hyogo Prefecture,Yabu City
4,11.585023,-8.557754,-0.37303,-0.128543,0.000157,-0.149383,0.008602,0.039303,0.151931,0.057495,...,0.000738,0.01009,-0.004297,0.011787,-0.00536,-0.015314,-0.002867,0.001425,Hyogo Prefecture,Tanba City


In [14]:
df.to_csv("city_wiki_description_bert.csv",index=False)