In [1]:
import os
import sacremoses
import transformers
import torch


## Create a combined `en-mul` and `mul-en` model

We create here a single model which can do translation in both directions. This can be useful for endpoint hosting.

In [3]:
# Change paths accordingly
mul_en_checkpoint_path = "savedmodels/mul-en"
mul_en_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(
    mul_en_checkpoint_path)
mul_en_tokenizer = transformers.AutoTokenizer.from_pretrained(
    mul_en_checkpoint_path)

In [4]:
en_mul_checkpoint_path = "savedmodels/en-mul"
en_mul_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(
    en_mul_checkpoint_path)
en_mul_tokenizer = transformers.AutoTokenizer.from_pretrained(
    en_mul_checkpoint_path)

In [6]:
class CombinedTranslationModel(torch.nn.Module):
    def __init__(self, en_mul_model, mul_en_model):
        super(CombinedTranslationModel, self).__init__()
        self.en_mul_model = en_mul_model
        self.mul_en_model = mul_en_model  
        
    def generate(self, target_language, **inputs):
        if target_language == 'en':
            x = self.mul_en_model.generate(**inputs)
        else:
            x = self.en_mul_model.generate(**inputs)
        return x

In [52]:
combined_model = CombinedTranslationModel(en_mul_model, mul_en_model)

## Example of using the combined model for translation

This is the same as usual, but notice the extra parameter `target_language` passed to `model.generate()`.

In [9]:
def translate(source_text, model, tokenizer, target_language):
    device = torch.device('cpu')
    model = model.eval()
    model = model.to(device) 
    inputs = tokenizer(source_text, return_tensors="pt").to(device)
    tokens = model.generate(target_language, **inputs)
    result = tokenizer.decode(tokens.squeeze(), skip_special_tokens=True)
    return result

Watch out that the right tokenizer is used (`en-mul` or `mul-en`), otherwise the output will be unintelligible.

In [67]:
translate('Nina amagulu ana', combined_model,
          mul_en_tokenizer, target_language='en')

'I have four legs.'

In [66]:
translate('>>lug<< My name is Bainomugisha', combined_model,
          en_mul_tokenizer, target_language='mul')

'Erinnya lyange Bainomugisha.'

## Test loading and saving

In [74]:
torch.save(combined_model, 'savedmodels/combined.pt')

In [7]:
class CombinedTranslationModel(torch.nn.Module):
    def __init__(self, en_mul_model, mul_en_model):
        super(CombinedTranslationModel, self).__init__()
        self.en_mul_model = en_mul_model
        self.mul_en_model = mul_en_model  
        
    def generate(self, target_language, **inputs):
        if target_language == 'en':
            x = self.mul_en_model.generate(**inputs)
        else:
            x = self.en_mul_model.generate(**inputs)
        return x
    
loaded_model = torch.load('savedmodels/combined.pt')

In [10]:
translate('Okola ki?', loaded_model,
          mul_en_tokenizer, target_language='en')

'What are you doing?'

In [11]:
translate('>>lug<< I am a computer', loaded_model,
          en_mul_tokenizer, target_language='mul')

'Nze kompyuta.'