In [None]:
!pip install transformers==4.16.0
!pip install torch
!pip install sentencepiece

In [None]:
import random
import json
import urllib.request
import torch
from tqdm import tqdm
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer, AdamW
from torch.utils.data import Dataset, DataLoader

### Load Dataset

In [None]:
url = 'https://raw.githubusercontent.com/kt2k01/petci/main/data/json/filtered.json'
response = urllib.request.urlopen(url)
data = json.loads(response.read())

In [None]:
format_data = []
i=0
for idiom in data:
  entry_1 = {}
  entry_2 = {}

  chinese = idiom['chinese']

  if 'gold' in idiom:
    gold = idiom['gold']
    entry_1['id'] = i
    i += 1
    entry_1['prompt'] = chinese + "</s>"
    entry_1['completion'] = gold + "</s>"
    format_data.append(entry_1)

  if idiom['human'] != []:
    human = idiom['human'][0]
    entry_2['id'] = i
    i += 1
    entry_2['prompt'] = chinese + "</s>"
    entry_2['completion'] = human + "</s>"
    format_data.append(entry_2)


In [None]:
    random.seed(10)

    random.shuffle(format_data)

    len_train = int(len(format_data) * 0.8)
    train_data = format_data[:len_train]
    test_data = [x for x in format_data if x not in train_data]
    validation_data = train_data[:int(len_train * 0.2)]
    train_data = [x for x in train_data if x not in validation_data]

    src = []
    tgt = []

    for item in train_data:
        src.append("translate " + item["prompt"])
        tgt.append(item["completion"])

### Load the pretrained model

In [None]:
from transformers import M2M100Config, M2M100ForConditionalGeneration, M2M100Tokenizer

model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="zh", tgt_lang="en")

### Tokenize data and set up training

In [None]:
from torch.utils.data import Dataset, DataLoader


class TranslationDataset(Dataset):
    def __init__(self, src_texts, tgt_texts, tokenizer, max_length):
        self.src_texts = src_texts
        self.tgt_texts = tgt_texts
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.src_texts)

    def __getitem__(self, index):
        src_text = str(self.src_texts[index])
        tgt_text = str(self.tgt_texts[index])

        # Tokenize the source and target texts
        src_tokens = self.tokenizer(src_text, padding='max_length', truncation=True, max_length=self.max_length,
                                    return_tensors='pt')
        tgt_tokens = self.tokenizer(tgt_text, padding='max_length', truncation=True, max_length=self.max_length,
                                    return_tensors='pt')

        # Create dictionary with source and target token IDs
        input_ids = src_tokens['input_ids'].squeeze()
        attention_mask = src_tokens['attention_mask'].squeeze()
        target_ids = tgt_tokens['input_ids'].squeeze()

        return {'input_ids': input_ids, 'attention_mask': attention_mask, 'target_ids': target_ids}


In [None]:
translated_data = TranslationDataset(src, tgt, tokenizer, max_length=20)
train_loader = DataLoader(translated_data, batch_size=64, shuffle=True)

optimizer = AdamW(model.parameters(), lr=5e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

### Train the model

In [None]:
epochs = 3
for epoch in range(epochs):
    print(f'Epoch {epoch + 1}/{epochs}')
    for batch in tqdm(train_loader):
        input_ids = batch['input_ids'].to(model.device)
        attention_mask = batch['attention_mask'].to(model.device)
        target_ids = batch['target_ids'].to(model.device)

        # Generate output sequence from model
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=target_ids[:, :-1],
                        use_cache=False)
        logits = outputs.logits

        # Compute loss and backpropagate
        loss = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)(logits.view(-1, logits.shape[-1]),
                                                                              target_ids[:, 1:].contiguous().view(
                                                                                  -1))
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

    print('Loss:', loss.item())

### Save the model locally

In [None]:
model.save_pretrained('m2m100_chinese_to_english')
tokenizer.save_pretrained('m2m100_chinese_to_english')