# BERT - Multilingual Model

## 01. BertTokenizer

In [5]:
import dill

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

from transformers import BertModel, BertTokenizer

In [6]:
class SentDataset(Dataset):
    
    def __init__(self, path, prefix="train", is_sample=False):
        
        with open(path, 'rb') as f:
            self.corpus = dill.load(f)
            if is_sample:
                self.corpus = self.corpus[:1000]
    
    def __len__(self):
        """Returns the number of corpus."""
        return len(self.corpus)
    
    def __getitem__(self, idx):
        return self.corpus[idx]

In [9]:
# Dataset
corpus_path = '../../data/corpus/kowiki_corpus.pkl'
dataset = SentDataset(corpus_path, prefix='train', is_sample=False)

# tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

In [10]:
tokenizer.vocab_size

119547

In [11]:
tmp_text = "지미 카터는 조지아주 섬터 카운티 플레인스 마을에서 태어났다."

encoded = tokenizer.encode(tmp_text)
print(encoded)

[101, 9706, 22458, 9786, 21876, 11018, 9678, 12508, 16985, 16323, 9430, 21876, 9786, 21614, 45725, 9944, 56645, 12030, 12605, 9246, 10622, 11489, 88921, 119, 102]


## 02. BERT Multi-lingual Model

In [12]:
from sentence_transformers import models
from sentence_transformers import SentenceTransformer

In [13]:
# # Use BERT for mapping tokens to embeddings
# word_embedding_model = models.BERT('bert-base-multilingual-cased')

# # Apply mean pooling to get one fixed sized sentence vector
# pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
#                                pooling_mode_mean_tokens=True,
#                                pooling_mode_cls_token=False,
#                                pooling_mode_max_tokens=False)

# model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

In [14]:
embedder = SentenceTransformer('distiluse-base-multilingual-cased')

In [9]:
# embedder.tokenize(tmp_text)

In [10]:
embeddings = embedder.encode([tmp_text])

In [13]:
embeddings[0].shape

(512,)

In [15]:
import torch

In [16]:
model = nn.Sequential(
    embedder,
    nn.Softmax(1)
)