In [None]:
import torch
import tensorflow as tf
from transformers import BertTokenizer, BertModel, BertForMaskedLM, TFBertModel

import warnings; warnings.filterwarnings('ignore')

In [None]:
pretrained_model_name = 'af-ai-center/bert-base-swedish-uncased'

In [None]:
# TODO: REMOVE THIS!!!
pretrained_model_name = './private/bert-base-swedish-uncased'
pretrained_model_name

# 1. Check SweBERT Model Accessibility

### a. Tokenizer

In [None]:
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name, do_lower_case=False)

### b. Model PyTorch

In [None]:
model = BertModel.from_pretrained(pretrained_model_name)

### c. Model TensorFlow

In [None]:
model_tf = TFBertModel.from_pretrained(pretrained_model_name, from_pt=True)

# 2. Simple Model Application (Masked Token Prediction)

In [None]:
example_sentence = 'Jag är ett barn, och det här är mitt hem. Alltså är det ett barnhem!'.lower()
example_sentence

### 1. Tokenize the example_sentence

In [None]:
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name, do_lower_case=False)

In [None]:
tokenized_text = tokenizer.tokenize(example_sentence)

print(f'{len(tokenized_text)} tokens')
print(tokenized_text)

### 2. Mask one of the tokens

In [None]:
masked_index = 16  # 'barn'
tokenized_text[masked_index] = '[MASK]'

print(tokenized_text)

### 3. Prepare the tokens for use with SweBERT

In [None]:
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)

In [None]:
indexed_tokens_tensor = torch.tensor([indexed_tokens])
print(indexed_tokens_tensor)

### 4. Use the SweBERT model to predict back the masked token

In [None]:
# instantiate model
model = BertForMaskedLM.from_pretrained(pretrained_model_name)
_ = model.eval()

In [None]:
# predict all tokens
with torch.no_grad():
    outputs = model(indexed_tokens_tensor)

predictions = outputs[0]
print(predictions.shape)  # 1 example, 19 tokens, 30522 possible token predictions

In [None]:
# show predicted index for masked token
predicted_index = torch.argmax(predictions[0, masked_index]).item()
print(predicted_index)

In [None]:
# show predicted masked token
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
print(predicted_token)

In [None]:
assert predicted_token == 'barn'

# Conclusions