In [1]:
import torch
from datasets import load_dataset

raw_dataset1 = load_dataset("kde4", lang1="ja", lang2="pl")
raw_dataset2 = load_dataset("json", data_files="./ted_multi_jp-pl.json")

print(f"DATASET 1 ({len(raw_dataset1['train'])})")
print(raw_dataset1['train'][5]['translation'])
print(f"DATASET 2 ({len(raw_dataset2['train'])})")
print(raw_dataset2['train'][5]['translation'])

  from .autonotebook import tqdm as notebook_tqdm


DATASET 1 (125653)
{'ja': 'kinfocenter', 'pl': 'kinfocenter'}
DATASET 2 (165758)
{'ja': '我々は彼らの対処に苦慮しています', 'pl': 'Zastanawiamy się , jak sobie z nimi radzić'}


## Hyperparameters

In [40]:
BATCH_SIZE = 32
LERANING_RATE = 0.1
MAX_LENGTH = 128

## Podzielenie danych na treningowe i testowe

In [3]:
split_dataset1 = raw_dataset1['train'].train_test_split(train_size=0.9)
split_dataset2 = raw_dataset2['train'].train_test_split(train_size=0.9)

In [4]:
torch.cuda.empty_cache()

### Przygotowanie modelu

In [5]:
from transformers import M2M100Config, M2M100ForConditionalGeneration, M2M100Tokenizer

model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="ja", tgt_lang="pl")

src_text = "我々は彼らの対処に苦慮しています"
tgt_text = "Zastanawiamy się , jak sobie z nimi radzić"

optimizer = torch.optim.Adam(params= model.parameters(), lr=LERANING_RATE)

model_inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt")

loss = model(**model_inputs).loss
loss

tensor(1.9513, grad_fn=<NllLossBackward0>)

### Testowe tłumaczenie

In [10]:
src_text = "我々は彼らの対処に苦慮しています"
tgt_text = "Zastanawiamy się , jak sobie z nimi radzić"
model.eval()
encoded_string = tokenizer(src_text, return_tensors="pt")
generated_tokens = model.generate(**encoded_string, forced_bos_token_id=tokenizer.get_lang_id("pl"))
tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

['Mamy kłopoty z ich traktowaniem.']

### Data Loadery

In [27]:
train_dataLoader = torch.utils.data.DataLoader(dataset=split_dataset1['train']["translation"],
                                              batch_size=BATCH_SIZE,
                                              shuffle=True)
test_dataLoader = torch.utils.data.DataLoader(dataset=split_dataset1['test']["translation"],
                                              batch_size=BATCH_SIZE,
                                              shuffle=False)

print(f"Train dataLoader: {len(train_dataLoader)} batches of {BATCH_SIZE}")
print(f"Test dataLoader: {len(test_dataLoader)} batches of {BATCH_SIZE}")

Train dataLoader: 3534 batches of 32
Test dataLoader: 393 batches of 32


### Podział na batche

In [43]:
def preprocess_function(examples):
    model_inputs = tokenizer(
        examples['ja'], text_target=examples['pl'], max_length=MAX_LENGTH, truncation=True
    )
    return model_inputs

# FINE TUNING!!!!!

In [46]:
from timeit import default_timer as timer
from tqdm.auto import tqdm

epochs = 1

for epoch in tqdm(range(0, epochs)):
    model.train()
    
    for id, batch in enumerate(train_dataLoader):
        model_inputs = preprocess_function(batch)
        #___Pętla treningowa___
        #1. Loss
        loss = model(**model_inputs).loss
        #2. optimizer zero grad
        optimizer.zero_grad()
        #3. Backward
        loss.backward()
        #5. Step
        optimizer.step()
        print(loss)
    

  0%|                                                                                            | 0/1 [00:00<?, ?it/s]


In [None]:
model.eval()
with torch.inference_mode():
    encoded_string = tokenizer(src_text, return_tensors="pt")
    generated_tokens = model.generate(**encoded_string, forced_bos_token_id=tokenizer.get_lang_id("pl"))
    test = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
test

### Zapisywanie modelu

In [None]:
from pathlib import Path
MODEL_NAME = "Translator_v0.2.pt"

MODEL_PATH = Path("Models")
MODEL_PATH.mkdir(parents=True,
                exist_ok=True)

MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

torch.save(modelV2.state_dict(), MODEL_SAVE_PATH)