In [None]:
import torch
import requests

from transformers import FSMTForConditionalGeneration, FSMTTokenizer
from typing import List

In [None]:
!pip3 install transformers

In [None]:
# Download cleaned Alpaca Dataset from: https://github.com/gururise/AlpacaDataCleaned
# Use specific commit (current latest main) for reproducability)
r = requests.get("https://raw.githubusercontent.com/gururise/AlpacaDataCleaned/2ee9f5ca1d4dc2df3777a765bab88ad061e83378/alpaca_data_cleaned.json")

In [None]:
if not r:
    print("Error downloading dataset!")

In [None]:
data = r.json()

In [None]:
model_name = "facebook/wmt19-en-de"
tokenizer = FSMTTokenizer.from_pretrained(model_name)
model = FSMTForConditionalGeneration.from_pretrained(model_name)
model.to(device="cuda:0")

In [None]:
source_instructions = [ example["instruction"].replace("\n", "<br>") for example in data]
source_inputs = [ example["input"].replace("\n", "<br>") for example in data]
source_outputs = [ example["output"].replace("\n", "<br>") for example in data]

def generate_batches(input_list: List[str], batch_size: int):
    for i in range(0, len(input_list), batch_size):
        yield input_list[i:i + batch_size]

In [None]:
source_instruction_batches = generate_batches(source_instructions, 128)
source_input_batches = generate_batches(source_inputs, 128)
source_output_batches = generate_batches(source_outputs, 64)

In [None]:
def translate(batch: List[str]):
    tokenized_batch = tokenizer(batch, return_tensors="pt", padding=True)
    generate_kwargs = {"num_beams": 1, "do_sample": True, "num_return_sequences": 1, "max_length": 512}
    translated_texts = model.generate(tokenized_batch["input_ids"].to(device="cuda:0"),
                                      attention_mask=tokenized_batch["attention_mask"].to(device="cuda:0"),
                                      top_p=0.8, **generate_kwargs)
    
    return [tokenizer.decode(t, skip_special_tokens=True).replace("< br > ", "\n") for t in translated_texts]

In [None]:
translated_instructions = []
translated_inputs = []
translated_outputs = []

for index, batch in enumerate(source_instruction_batches):
    print(f"Translating Instruction Batch {index+1}")
    translated_batch = translate(batch)
    translated_instructions.extend(translated_batch)

for index, batch in enumerate(source_input_batches):
    print(f"Translating Input Batch {index+1}")
    translated_batch = translate(batch)
    translated_inputs.extend(translated_batch)
   
for index, batch in enumerate(source_output_batches):
    print(f"Translating Output Batch {index+1}")
    translated_batch = translate(batch)
    translated_outputs.extend(translated_batch)

In [None]:
translated_data = []

for source_input, translated_input, translated_instruction, translated_output in zip(source_inputs,
                                                                                     translated_inputs,
                                                                                     translated_instructions,
                                                                                     translated_outputs):
    current_example = {
        "instruction": translated_instruction,
        "input": translated_input if source_input else "",
        "output": translated_output
    }
    translated_data.append(current_example)

In [None]:
with open("translated_german_alpaca.json", "wt") as f_p:
    json.dump(translated_data, f_p)