<a href="https://colab.research.google.com/github/alexgaskell10/NLP_Translation/blob/master/notebooks/stuff.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [24]:
! pip install transformers



In [0]:
import torch
from transformers import BertModel, BertTokenizer

**Loading BERT**

In [26]:
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
model = BertModel.from_pretrained('bert-base-multilingual-cased')
input_ids = torch.tensor([tokenizer.encode("Here is some text to encode", add_special_tokens=True)])  # Add special tokens takes care of adding [CLS], [SEP], <s>... tokens in the right way for each model.
with torch.no_grad():
    hidden_states = model(input_ids)  # Models outputs are now tuples
    last_hidden_states = model(input_ids)[0]  # Models outputs are now tuples

print(len(hidden_states), last_hidden_states.shape)

2 torch.Size([1, 9, 768])


In [0]:
# Download and unzip the data
from os.path import exists
if not exists('ende_data.zip'):
    !wget -O ende_data.zip https://competitions.codalab.org/my/datasets/download/c748d2c0-d6be-4e36-9f12-ca0e88819c4d
    !unzip ende_data.zip

In [28]:
# Check the files
import io

#English-German
print("---EN-DE---")
print()

with open("./train.ende.src", "r") as ende_src:
    print("Source: ",ende_src.readline())
with open("./train.ende.mt", "r") as ende_mt:
    print("Translation: ",ende_mt.readline())
with open("./train.ende.scores", "r") as ende_scores:
    print("Score: ",ende_scores.readline())


---EN-DE---

Source:  José Ortega y Gasset visited Husserl at Freiburg in 1934.

Translation:  1934 besuchte José Ortega y Gasset Husserl in Freiburg.

Score:  1.1016968715664406



In [0]:
# Load data into variables
with open("./train.ende.src", "r") as ende_src:
    en_train = ende_src.read().split('\n')
with open("./train.ende.mt", "r") as ende_src:
    de_train = ende_src.read().split('\n')
with open("./train.ende.scores", "r") as ende_src:
    train_scores = ende_src.read().split('\n')

In [0]:
# Convert input sequences to correct format

# Tokenize English
inputs_en = []
max_len_en = 0
    
for i in range(len(en_train)-1):
    seq = en_train[i][:-1]
    input_ids = torch.tensor([tokenizer.encode(seq, add_special_tokens=True)])  # Add special tokens takes care of adding [CLS], [SEP], <s>... tokens in the right way for each model.
    inputs_en.append(input_ids)
    if input_ids.shape[-1] > max_len_en:
        max_len_en = input_ids.shape[-1]

# Tokenize German
inputs_de = []
max_len_de = 0

for i in range(len(en_train)-1):
    seq = de_train[i][:-1]
    input_ids = torch.tensor([tokenizer.encode(seq, add_special_tokens=True)])  # Add special tokens takes care of adding [CLS], [SEP], <s>... tokens in the right way for each model.
    inputs_de.append(input_ids)
    if input_ids.shape[-1] > max_len_de:
        max_len_de = input_ids.shape[-1]

# Combine tokens into single
inp_tensor = torch.zeros((len(inputs_en), max_len_en + max_len_de - 2))      # <-- -2 because special tokens are not necessary at beginning of German sequence

for i in range(len(inputs_en)):
    # Add English tokens
    en_tokens = inputs_en[i].squeeze()
    inp_tensor[i, : len(en_tokens)] = en_tokens

    # Add German tokens
    de_tokens = inputs_de[i][:,2:].squeeze()      # <-- ignore first 2 tokens as these are special tokens and unnecessary in this case
    inp_tensor[i, max_len_en : max_len_en + len(de_tokens)] = de_tokens

In [37]:
USE_GPU = True
dtype = torch.float32 

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

model = model.to(device=device)
inp_tensor = inp_tensor.to(device=device, dtype=torch.long)

batches = torch.split(inp_tensor, 1000, dim=0)
list_bert_embs = []
for X in batches:
    with torch.no_grad():
        last_hidden_states = model(X)[0]    # <-- take word embeddings ([1] gives sentence embeddings)

    list_bert_embs.append(last_hidden_states)

    print(last_hidden_states.shape)

    torch.cuda.empty_cache()

bert_embs = torch.cat(list_bert_embs, dim=0)

print(bert_embs.shape)

torch.Size([1000, 132, 768])
torch.Size([1000, 132, 768])
torch.Size([1000, 132, 768])
torch.Size([1000, 132, 768])
torch.Size([1000, 132, 768])
torch.Size([1000, 132, 768])
torch.Size([1000, 132, 768])
torch.Size([7000, 132, 768])
