In [None]:
!pip -q install transformers hydra-core omegaconf fastBPE

[K     |████████████████████████████████| 133kB 19.3MB/s 
[K     |████████████████████████████████| 112kB 42.9MB/s 
[K     |████████████████████████████████| 645kB 43.9MB/s 
[?25h  Building wheel for fastBPE (setup.py) ... [?25l[?25hdone
  Building wheel for antlr4-python3-runtime (setup.py) ... [?25l[?25hdone


In [None]:
import torch
from transformers import FSMTForConditionalGeneration, FSMTTokenizer

class BaseTranslationModel:
    model = None
    tokenizer = None

    def get_inference(self, input: str) -> str:
        return input

    def translate(self, input: str) -> str:
        return self.get_inference(input)

# https://arxiv.org/abs/1907.06616
class FBWmt19(BaseTranslationModel):
    def __init__(self, mname = "facebook/wmt19-ru-en"):
        tokenizer = FSMTTokenizer.from_pretrained(mname)
        model = FSMTForConditionalGeneration.from_pretrained(mname)

        if torch.cuda.is_available:
            model = model.to('cuda')

        self.model = model
        self.tokenizer = tokenizer

    def get_inference(self, input: str) -> str:
        input_ids = self.tokenizer.encode(input, return_tensors="pt")
        
        if torch.cuda.is_available:
            input_ids = input_ids.to('cuda')

        outputs = self.model.generate(input_ids)
        decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return decoded
    

class FairseqWmt19(BaseTranslationModel):
    def __init__(self, mname = "transformer.wmt19.ru-en"):

        model = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.ru-en', checkpoint_file='model1.pt:model2.pt:model3.pt:model4.pt',
                       tokenizer='moses', bpe='fastbpe')
        
        if torch.cuda.is_available:
            model = model.to('cuda')

        self.model = model

    def get_inference(self, input: str) -> str:
        return self.model.translate(input)

In [None]:
from tqdm import tqdm_notebook

class Translater:
    def __init__(self, 
                 input_path='/content/eval-ru-100.txt', 
                 output_path='/content/answer.txt', 
                 model: BaseTranslationModel=FBWmt19):
        self.input_path = input_path
        self.output_path = output_path
        self.model = model()

    def translate(self):
        from tqdm import tqdm_notebook

        with open(self.input_path, 'r') as f1:
            with open(self.output_path, 'w') as f2:
                lines = [i for i in f1]
                for i, line in enumerate(tqdm_notebook(lines)):
                    outputs = self.model.translate(line)
                    f2.write(outputs)
                    
                    if i != (len(lines) - 1):
                        f2.write("\n")

In [None]:
tl = Translater()
tl.translate()

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0), HTML(value='')))


