In [1]:
import numpy as np
import os
import pandas as pd
import sacremoses
import tqdm
import transformers
from transformers import AutoTokenizer
import torch

Change the paths here accordingly, to the locations of saved checkpoints.

In [2]:
mul_en_checkpoint_path = "savedmodels/mul-en"
mul_en_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(
    mul_en_checkpoint_path)
mul_en_tokenizer = AutoTokenizer.from_pretrained(
    mul_en_checkpoint_path)

In [144]:
en_mul_checkpoint_path = "savedmodels/en-mul"
en_mul_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(
    en_mul_checkpoint_path)
en_mul_tokenizer = AutoTokenizer.from_pretrained(
    en_mul_checkpoint_path)

We run the model on CPU, so that this can be done while the model is training (no extra GPU memory taken).

In [145]:
def sentence_format(input, first_word_title_case = True): 
    '''Ensure capital letter at the start and full stop at the end.'''
    input = input[0].capitalize() + input[1:]
    if input[-1] not in ['.', '!', '?']:
        input = input + '.'
    return input

def translate(source_text, model, tokenizer):
    device = torch.device('cpu')
    model = model.eval()
    model = model.to(device) 
    inputs = tokenizer(source_text, return_tensors="pt").to(device)
    tokens = model.generate(**inputs)
    result = tokenizer.decode(tokens.squeeze(), skip_special_tokens=True)
    return result

def translate_to_multiple(query,
                          languages = {'ach': 'Acholi', 'lgg': 'Lugbara',
                                       'lug': 'Luganda','nyn': 'Runyankore',
                                       'teo': 'Ateso'}):
    normalizer = sacremoses.MosesPunctNormalizer()
    normalized_query = sentence_format(normalizer.normalize(query))
    
    translations = []
    for lang in ['lug', 'ach', 'nyn', 'teo', 'lgg']:
        translation = translate(f">>{lang}<< {normalized_query}",
                                en_mul_model, en_mul_tokenizer)
        translations.append({'target': languages[lang],
                             'translation': translation})

    df = pd.DataFrame(translations)
    df = df.set_index('target')   
    return df

def translate_to_english(query):
    normalizer = sacremoses.MosesPunctNormalizer()
    normalized_query = sentence_format(normalizer.normalize(query))
    return translate(normalized_query, mul_en_model, mul_en_tokenizer)

pd.set_option('display.max_colwidth', None)

In [146]:
query = '''
It may rain tomorrow.
'''

translate_to_multiple(query)

Unnamed: 0_level_0,translation
target,Unnamed: 1_level_1
Luganda,Enkuba eyinza okutonnya enkya.
Acholi,Kot romo cwer diki.
Runyankore,Enjura nebaasa kugwa nyenkyakare.
Ateso,Aticepak edou moi.
Lugbara,Ozoo eco drusi ra.


In [142]:
query = '''
Omukyala agamba akooye okubonabona.
'''

translate_to_english(query)

'The lady says she is tired of suffering.'