In [1]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
model_name = "facebook/nllb-200-distilled-600M"
tokenizer = AutoTokenizer.from_pretrained(model_name, src_lang="run_Latn", tgt_lang="eng_Latn")
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

  return self.fget.__get__(instance, owner)()


In [2]:
import torch
import torch.quantization

In [3]:
# Assuming model is your loaded NLLB model
model.eval()  # Set the model to evaluation mode

# Specify quantization configuration
quantization_config = torch.quantization.get_default_qconfig('fbgemm')

In [4]:
total_params = sum(p.numel() for p in model.parameters())
total_params

615073792

In [5]:
# Apply the configuration to the model
model.qconfig = quantization_config
torch.quantization.prepare(model, inplace=True)




M2M100ForConditionalGeneration(
  (model): M2M100Model(
    (shared): M2M100ScaledWordEmbedding(256206, 1024, padding_idx=1)
    (encoder): M2M100Encoder(
      (embed_tokens): M2M100ScaledWordEmbedding(256206, 1024, padding_idx=1)
      (embed_positions): M2M100SinusoidalPositionalEmbedding()
      (layers): ModuleList(
        (0-11): 12 x M2M100EncoderLayer(
          (self_attn): M2M100SdpaAttention(
            (k_proj): Linear(
              in_features=1024, out_features=1024, bias=True
              (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
            )
            (v_proj): Linear(
              in_features=1024, out_features=1024, bias=True
              (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
            )
            (q_proj): Linear(
              in_features=1024, out_features=1024, bias=True
              (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
            )
            (out_pr

In [6]:
total_params_quantized = sum(p.numel() for p in model.parameters())
total_params_quantized

615073792

In [7]:
tokenizer.src_lang = "run_Latn"

In [8]:
import re

def word_tokenize(text):
    """
    Split a text into words, numbers, and punctuation marks
    (for languages where words are separated by spaces)
    """
    return re.findall('(\w+|[^\w\s])', text)

In [9]:
from datasets import load_dataset
dataset = load_dataset("Muennighoff/flores200", 'eng_Latn-run_Latn')

In [10]:
import pandas as pd
dataset = dataset['dev']

In [11]:
flores_train = pd.DataFrame(([dataset['sentence_eng_Latn'], dataset['sentence_run_Latn']]))
flores_train = flores_train.T
flores_train.columns = ['eng', 'run']

In [12]:
from tqdm.auto import tqdm, trange
import random
texts_with_unk = [
    text for text in tqdm(flores_train.run) 
    if tokenizer.unk_token_id in tokenizer(text).input_ids
]
print(len(texts_with_unk))
# 163
s = random.sample(texts_with_unk, 5)
print(s)

  0%|          | 0/997 [00:00<?, ?it/s]

151
['Itangazo ry’uyu munsi ryahaye inguvu irindi Leta yashizeho mu kwa gatatu kw’uno mwaka ryo kwongereza izindi modoka.', "Mu 1977, Dogoteri Damadian yararangije kwubaka sikaneri ya mbere ya IRM “y'umubiri-wose”, ayita ”Mutananirwa”.", "Ibibanza vyo mu bice bizwi cane nka Bright Angel Campgound iri hafi ya Phantom Ranch, mu bisanzwe bifatwa vyose n'ababisaba kw’itariki ya mbere batangurirako kubikisha ibibanza.", 'Iyo raporo yerekanye ingene amanota y’ibibazo yaduze cane ku rugero rutangaje ko kandi ishure ryabibonye ntirikore na kimwe.', 'Umushikiranganji w’amagara y’abantu yavuze ko atewe impungenge n’abantu bariko bakoresha mategeko y’agateganyo kubijanye n’ukubaho kwabantu ku gatwe kabo, ndetse n’ibihano bifitanye isano n’ibiyovyabwenge vyatanzwe kuva habaye ihinduka rishingiye ku mategekomashasha.']


In [13]:
import re
import sys
import unicodedata
from sacremoses import MosesPunctNormalizer

mpn = MosesPunctNormalizer(lang="en")
mpn.substitutions = [
    (re.compile(r), sub) for r, sub in mpn.substitutions
]

def get_non_printing_char_replacer(replace_by: str = " "):
    non_printable_map = {
        ord(c): replace_by
        for c in (chr(i) for i in range(sys.maxunicode + 1))
        # same as \p{C} in perl
        # see https://www.unicode.org/reports/tr44/#General_Category_Values
        if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"}
    }

    def replace_non_printing_char(line) -> str:
        return line.translate(non_printable_map)

    return replace_non_printing_char

replace_nonprint = get_non_printing_char_replacer(" ")

def preproc(text):
    clean = mpn.normalize(text)
    clean = replace_nonprint(clean)
    # replace 𝓕𝔯𝔞𝔫𝔠𝔢𝔰𝔠𝔞 by Francesca
    clean = unicodedata.normalize("NFKC", clean)
    return clean

In [14]:
from tqdm import tqdm
texts_with_unk_normed = [
    text for text in tqdm(texts_with_unk) 
    if tokenizer.unk_token_id in tokenizer(preproc(text)).input_ids
]
print(len(texts_with_unk_normed))  # 0

100%|██████████| 151/151 [00:00<00:00, 6139.63it/s]

0





In [36]:
from transformers.optimization import Adafactor
from transformers import get_constant_schedule_with_warmup
model.cuda();
def get_optimizer():
    optimizer = Adafactor(
        [p for p in model.parameters() if p.requires_grad],
        scale_parameter=False,
        relative_step=False,
        lr=1e-5,
        clip_threshold=1.0,
        weight_decay=1e-3,
    )
    scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=1000)

    return optimizer, scheduler

In [37]:
flores_train

Unnamed: 0,eng,run
0,"On Monday, scientists from the Stanford Univer...","Ku wa mbere, abahinga bo kuri kaminuza yitwa S..."
1,Lead researchers say this may bring early dete...,Abashakashatsi nyamukuru bavuga ko ako gakores...
2,The JAS 39C Gripen crashed onto a runway at ar...,"Isaha 9:30 zo mu gitondo (0230 UTC), iyo ndege..."
3,The pilot was identified as Squadron Leader Di...,"Basanze umudereva yari Dilokrit Pattavee, umuk..."
4,Local media reports an airport fire vehicle ro...,Ibimenyeshamakuru vyaho bivuga ko hari kizimya...
...,...,...
992,The tourist season for the hill stations gener...,Igihe c'ingenzi mu mahuriro yo mu misozi mu bi...
993,"However, they have a different kind of beauty ...","N'aho biri ukwo, bifise ubwoko butandukanye bw..."
994,Only a few airlines still offer bereavement fa...,Amashirahamwe y'ivyindege amwe gusa niyo azota...
995,"Airlines that offer these include Air Canada, ...",Amashirahamwe y'ivyindege atanga ivyo harimwo ...


In [38]:
import random
LANGS = [('eng', 'eng_Latn'), ('run', 'run_Latn')]

def get_batch_pairs(batch_size, data=flores_train, langs=LANGS):
    (l1, long1), (l2, long2) = random.sample(langs, 2)
    xx, yy = [], []
    for _ in range(batch_size):
        item = data.iloc[random.randint(0, len(data)-1)]
        xx.append(preproc(item[l1]))
        yy.append(preproc(item[l2]))
    return xx, yy, long1, long2

print(get_batch_pairs(1))

(["Ishirahamwe Virgin Group, rya Richard Branson, ryaciye ryankirwa, imbere y'uko iyo banki ishirwa mu minwe ya reta."], ["Sir Richard Branson's Virgin Group had a bid for the bank rejected prior to the bank's nationalisation."], 'run_Latn', 'eng_Latn')


In [39]:
from datasets import load_dataset

def get_dataset(lang):
    dataset = load_dataset("Muennighoff/flores200", f'eng_Latn-{lang}_Latn')
    dataset = dataset['dev']
    dataset = pd.DataFrame(([dataset['sentence_eng_Latn'], dataset[f'sentence_{lang}_Latn']]))
    dataset = dataset.T
    dataset.columns = ['eng', lang]
    return dataset

In [40]:
import gc
import torch

def cleanup():
    """Try to free GPU memory"""
    gc.collect()
    torch.cuda.empty_cache()

In [41]:
batch_size = 16  # 32 already doesn't fit well to 15GB of GPU memory
max_length = 128  # token sequences will be truncated
training_steps = 100  # Usually, I set a large number of steps,
# and then just interrupt the training manually
losses = []  # with this list, I do very simple tracking of average loss
MODEL_SAVE_PATH = './NLLB/nllb-eng-kir-v1'  # on my Google drive

In [42]:
optimizers = []
schedulers = []
for i in range(3):
    optimizer, scheduler = get_optimizer()
    optimizers.append(optimizer)
    schedulers.append(scheduler)

In [43]:
import copy
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
models = [copy.deepcopy(model) for i in range(3)]

tokenizer1 = AutoTokenizer.from_pretrained(model_name, src_lang="bem_Latn", tgt_lang="eng_Latn")
tokenizer2 = AutoTokenizer.from_pretrained(model_name, src_lang="kin_Latn", tgt_lang="eng_Latn")
tokenizer3 = AutoTokenizer.from_pretrained(model_name, src_lang="lug_Latn", tgt_lang="eng_Latn")

tokenizers = [tokenizer1, tokenizer2, tokenizer3]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [44]:
names = ['bem', 'kin', 'lug']
model_langs = [[('eng', 'eng_Latn'), ('bem', 'bem_Latn')], [('eng', 'eng_Latn'), ('kin', 'kin_Latn')], [('eng', 'eng_Latn'), ('lug', 'lug_Latn')]]

In [45]:
model_data = [get_dataset(x) for x in ['bem', 'kin', 'lug']]

In [47]:
import numpy as np

for i in range(3):
    models[i].train()
x, y, loss = None, None, None
cleanup()

global_weights = models[0].state_dict()

tq = trange(len(losses), training_steps)
for n in tq:
    w, local_loss = [], 0
    for i in range(3):
        models[i].load_state_dict(global_weights)
        xx, yy, lang1, lang2 = get_batch_pairs(batch_size, data=model_data[i], langs=model_langs[i])
        try:
            tokenizers[i].src_lang = lang1
            x = tokenizers[i](xx, return_tensors='pt', padding=True, truncation=True, max_length=max_length).to(models[i].device)
            tokenizers[i].src_lang = lang2
            y = tokenizers[i](yy, return_tensors='pt', padding=True, truncation=True, max_length=max_length).to(models[i].device)
            # -100 is a magic value ignored in the loss function
            # because we don't want the model to learn to predict padding ids
            y.input_ids[y.input_ids == tokenizers[i].pad_token_id] = -100

            loss = models[i](**x, labels=y.input_ids).loss
            loss.backward()
            local_loss += loss.item()

            torch.nn.utils.clip_grad_norm_(models[i].parameters(), max_norm=1.0)

            optimizers[i].step()
            optimizers[i].zero_grad(set_to_none=True)
            schedulers[i].step()

        except RuntimeError as e:  # usually, it is out-of-memory
            optimizers[i].zero_grad(set_to_none=True)
            x, y, loss = None, None, None
            cleanup()
            print('error', max(len(s) for s in xx + yy), e)
            continue

        w.append(copy.deepcopy(models[i].state_dict()))
        
        if n % 10 == 0 and n > 0:
            model.save_pretrained(f'./NLLB/{names[i]}')
            tokenizers[i].save_pretrained(f'./NLLB/{names[i]}')

    weights_avg = copy.deepcopy(w[0])
    for k in weights_avg.keys():
        for i in range(1, len(w)):
            weights_avg[k] += w[i][k]

        weights_avg[k] = torch.div(weights_avg[k], len(w))
    
    global_weights = weights_avg

    losses = [(local_loss / 3)]

    if n % 1 == 0:
        # each 1000 steps, I report average loss at these steps
        print(i, losses)

   

  0%|          | 0/100 [00:00<?, ?it/s]

2 [2.282008330027262]




2 [2.282008330027262, 9.868823687235514, 167.23800659179688, 1172.4419962565105, 6177.994791666667, 20783.942057291668, 63603.928385416664, 193023.86979166666, 579949.75, 1728642.6666666667, 5183393.5]
2 [2.282008330027262, 9.868823687235514, 167.23800659179688, 1172.4419962565105, 6177.994791666667, 20783.942057291668, 63603.928385416664, 193023.86979166666, 579949.75, 1728642.6666666667, 5183393.5, 15443793.666666666, 46623834.666666664, 138769333.33333334, 415407658.6666667, 1259567744.0, 3767256405.3333335, 11337098240.0, 34228019882.666668, 102652581205.33333, 305526876842.6667]
2 [2.282008330027262, 9.868823687235514, 167.23800659179688, 1172.4419962565105, 6177.994791666667, 20783.942057291668, 63603.928385416664, 193023.86979166666, 579949.75, 1728642.6666666667, 5183393.5, 15443793.666666666, 46623834.666666664, 138769333.33333334, 415407658.6666667, 1259567744.0, 3767256405.3333335, 11337098240.0, 34228019882.666668, 102652581205.33333, 305526876842.6667, 916294270976.0, 2773

KeyboardInterrupt: 

In [26]:
models[2].state_dict()

OrderedDict([('model.shared.weight',
              tensor([[-inf, inf, inf,  ..., inf, -inf, -inf],
                      [-inf, inf, -inf,  ..., inf, -inf, -inf],
                      [-inf, -inf, -inf,  ..., inf, -inf, -inf],
                      ...,
                      [-inf, -inf, -inf,  ..., inf, -inf, -inf],
                      [inf, -inf, -inf,  ..., inf, -inf, -inf],
                      [-inf, -inf, -inf,  ..., inf, -inf, -inf]])),
             ('model.encoder.embed_tokens.weight',
              tensor([[-inf, inf, inf,  ..., inf, -inf, -inf],
                      [-inf, inf, -inf,  ..., inf, -inf, -inf],
                      [-inf, -inf, -inf,  ..., inf, -inf, -inf],
                      ...,
                      [-inf, -inf, -inf,  ..., inf, -inf, -inf],
                      [inf, -inf, -inf,  ..., inf, -inf, -inf],
                      [-inf, -inf, -inf,  ..., inf, -inf, -inf]])),
             ('model.encoder.layers.0.self_attn.k_proj.weight',
              te

In [27]:
batch_size = 16  # 32 already doesn't fit well to 15GB of GPU memory
max_length = 128  # token sequences will be truncated
training_steps = 100  # Usually, I set a large number of steps,
# and then just interrupt the training manually
losses = []  # with this list, I do very simple tracking of average loss
MODEL_SAVE_PATH = './NLLB/nllb-eng-kir-v1'  # on my Google drive

In [28]:
import numpy as np

models[0].train()
x, y, loss = None, None, None
cleanup()

tq = trange(len(losses), training_steps)
for i in tq:
    xx, yy, lang1, lang2 = get_batch_pairs(batch_size)
    try:
        tokenizer.src_lang = lang1
        x = tokenizer(xx, return_tensors='pt', padding=True, truncation=True, max_length=max_length).to(model.device)
        tokenizer.src_lang = lang2
        y = tokenizer(yy, return_tensors='pt', padding=True, truncation=True, max_length=max_length).to(model.device)
        # -100 is a magic value ignored in the loss function
        # because we don't want the model to learn to predict padding ids
        y.input_ids[y.input_ids == tokenizer.pad_token_id] = -100

        loss = models[0](**x, labels=y.input_ids).loss
        loss.backward()
        losses.append(loss.item())

        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        scheduler.step()

    except RuntimeError as e:  # usually, it is out-of-memory
        optimizer.zero_grad(set_to_none=True)
        x, y, loss = None, None, None
        cleanup()
        print('error', max(len(s) for s in xx + yy), e)
        continue

    if i % 100 == 0:
        # each 1000 steps, I report average loss at these steps
        print(i, np.mean(losses[-1000:]))

    if i % 100 == 0 and i > 0:
        model.save_pretrained(MODEL_SAVE_PATH)
        tokenizer.save_pretrained(MODEL_SAVE_PATH)

  0%|          | 0/100 [00:00<?, ?it/s]

0 nan


KeyboardInterrupt: 

In [None]:

# Convert the model to a quantized version
torch.quantization.convert(model, inplace=True)

In [29]:
def translate(
    text, model, src_lang='run_Latn', tgt_lang='eng_Latn', 
    a=32, b=3, max_input_length=1024, num_beams=4, **kwargs
):
    """Turn a text or a list of texts into a list of translations"""
    tokenizer.src_lang = src_lang
    tokenizer.tgt_lang = tgt_lang
    inputs = tokenizer(
        text, return_tensors='pt', padding=True, truncation=True, 
        max_length=max_input_length
    )
    model.eval() # turn off training mode
    result = model.generate(
        **inputs.to(model.device),
        forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang),
        max_new_tokens=int(a + b * inputs.input_ids.shape[1]),
        num_beams=num_beams, **kwargs
    )
    return tokenizer.batch_decode(result, skip_special_tokens=True)

In [30]:
def batched_translate(texts, model, batch_size=16, **kwargs):
    """Translate texts in batches of similar length"""
    idxs, texts2 = zip(*sorted(enumerate(texts), key=lambda p: len(p[1]), reverse=True))
    results = []
    for i in trange(0, len(texts2), batch_size):
        results.extend(translate(texts2[i: i+batch_size], model, **kwargs))
    return [p for i, p in sorted(zip(idxs, results))]

In [32]:
from datasets import load_dataset
dataset_test = load_dataset("Muennighoff/flores200", 'eng_Latn-run_Latn')

In [33]:
dataset_test = dataset_test['devtest']
flores_test = pd.DataFrame(([dataset_test['sentence_eng_Latn'], dataset_test['sentence_run_Latn']]))
flores_test = flores_test.T
flores_test.columns = ['eng', 'run']

In [34]:
flores_test

Unnamed: 0,eng,run
0,"""We now have 4-month-old mice that are non-dia...","Yongeyeko ati: ""Ubu turafise imbeba y'amezi 4 ..."
1,"Dr. Ehud Ur, professor of medicine at Dalhousi...","Umuhinga Ehud Ur, umwigisha w'ivy'ubuganga kur..."
2,"Like some other experts, he is skeptical about...","Cokimwe n'abandi bahinga, arafise amakenga ku ..."
3,"On Monday, Sara Danius, permanent secretary of...","Ku wa mbere, Sara Danius, umunyamabanga ntayeg..."
4,"Danius said, ""Right now we are doing nothing. ...","Danius yavuze ati: ""Ubu nta co turiko turakora..."
...,...,...
1007,"As the areas are sparsely populated, and light...","Kuko ivyo bice bibamwo abantu inkehwa, kandi n..."
1008,Japanese work culture is more hierarchical and...,Akaranga mu kazi k'Abayapani karasumbasumbana ...
1009,"Suits are standard business attire, and cowork...","Ikositimu niwo mwambaro w'akazi umenyerewe, ka..."
1010,"Workplace harmony is crucial, emphasizing grou...",Itunganywa ryiza ry'ikibanza c'akazi ni ngombw...


In [35]:
translations = batched_translate(flores_test['run'].tolist()[:100], models[0])

  0%|          | 0/7 [00:00<?, ?it/s]

In [36]:
import sacrebleu
bleu_calc = sacrebleu.BLEU()
chrf_calc = sacrebleu.CHRF(word_order=2)  # this metric is called ChrF++

print(bleu_calc.corpus_score(translations, [flores_test['eng'].tolist()]))
print(chrf_calc.corpus_score(translations, [flores_test['eng'].tolist()]))

BLEU = 0.00 0.0/0.0/0.0/0.0 (BP = 1.000 ratio = 8.328 hyp_len = 21420 ref_len = 2572)
chrF2++ = 0.97


23.453369494797222 in 14 minutes inference

In [179]:
flores_train[['eng', 'run']].to_csv('flores-eng-kir.csv', sep='\t', header=False, index=False)

In [170]:
flores

Unnamed: 0,lat,run,translated
0,"""We now have 4-month-old mice that are non-dia...","Yongeyeko ati: ""Ubu turafise imbeba y'amezi 4 ...","He added: ""We now have four-month-old mice who..."
1,"Dr. Ehud Ur, professor of medicine at Dalhousi...","Umuhinga Ehud Ur, umwigisha w'ivy'ubuganga kur...","Professor Ehud Ur, a professor of medicine at ..."
2,"Like some other experts, he is skeptical about...","Cokimwe n'abandi bahinga, arafise amakenga ku ...","Like other scientists, he is skeptical of the ..."
3,"On Monday, Sara Danius, permanent secretary of...","Ku wa mbere, Sara Danius, umunyamabanga ntayeg...","On Monday, Sara Danius, permanent secretary of..."
4,"Danius said, ""Right now we are doing nothing. ...","Danius yavuze ati: ""Ubu nta co turiko turakora...","Danius says: ""Now that we're doing nothing, I'..."
...,...,...,...
1007,"As the areas are sparsely populated, and light...","Kuko ivyo bice bibamwo abantu inkehwa, kandi n...","For this is a small part of the human family, ..."
1008,Japanese work culture is more hierarchical and...,Akaranga mu kazi k'Abayapani karasumbasumbana ...,Japanese craftsmanship is more sophisticated a...
1009,"Suits are standard business attire, and cowork...","Ikositimu niwo mwambaro w'akazi umenyerewe, ka...","Costumes are the most common work clothes, and..."
1010,"Workplace harmony is crucial, emphasizing grou...",Itunganywa ryiza ry'ikibanza c'akazi ni ngombw...,"Good workplace planning is essential, celebrat..."
