In [None]:
import torch
import tensorflow as tf
from transformers import BertTokenizer, BertModel, TFBertModel, BertForMaskedLM 

import warnings; warnings.filterwarnings('ignore')

# 0. Choose SweBERT model

In [None]:
pretrained_model_name = 'af-ai-center/bert-base-swedish-uncased'

# 1. Check SweBERT Model Accessibility

### a. Tokenizer

In [None]:
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name, do_lower_case=False)

### b. Model PyTorch

In [None]:
model = BertModel.from_pretrained(pretrained_model_name)

### c. Model TensorFlow

In [None]:
model = TFBertModel.from_pretrained(pretrained_model_name)

# 2. Simple Model Application (Masked Token Prediction)

In [None]:
example = 'Jag är ett barn, och det här är mitt hem. Alltså är det ett barnhem!'
example

### 1. Preprocess the example

#### a. lowercase 

In [None]:
example_uncased = example.lower()
example_uncased

#### b. special tokens 

In [None]:
example_preprocessed = f'[CLS] {example_uncased} [SEP]'
example_preprocessed

### 2. Tokenize the preprocessed example

In [None]:
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name, do_lower_case=False)  # see

In [None]:
tokens = tokenizer.tokenize(example_preprocessed)

print(f'{len(tokens)} tokens:')
print(tokens)

### 3. Mask one of the tokens

In [None]:
masked_index = 17  # 'barn'
tokens[masked_index] = '[MASK]'

print(tokens)

### 4. Prepare the tokens for use with SweBERT

In [None]:
indexed_tokens = tokenizer.convert_tokens_to_ids(tokens)

In [None]:
indexed_tokens_tensor = torch.tensor([indexed_tokens])
print(indexed_tokens_tensor)

### 5. Use SweBERT to predict back the masked token

In [None]:
# instantiate model
model = BertForMaskedLM.from_pretrained(pretrained_model_name)
_ = model.eval()

In [None]:
# predict all tokens
with torch.no_grad():
    outputs = model(indexed_tokens_tensor)

predictions = outputs[0]
print(predictions.shape)  # 1 example, 21 tokens, 30522 vocab size

In [None]:
# show predicted index for masked token
predicted_index = torch.argmax(predictions[0, masked_index]).item()
print(predicted_index)

In [None]:
# show predicted masked token
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
print(predicted_token)

In [None]:
assert predicted_token == 'barn'

# Conclusions