In [3]:
from datasets import load_dataset, Dataset
from easynmt import EasyNMT
import json
import torch
# Check if GPU is available
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# Initialize EasyNMT model on the available device
model = EasyNMT('opus-mt', device=device)

# Load dataset
dataset = load_dataset("hungnm/multilingual-amazon-review-sentiment-processed")
desired_features = ['stars', 'text', 'language', 'label']
reduced_dataset = dataset.select_columns(desired_features)
train= reduced_dataset['validation']

# Filter datasets by language
de_train=train.filter(lambda example: example['language']=='de')
fr_train=train.filter(lambda example: example['language']=='fr')
es_train=train.filter(lambda example: example['language']=='es')
ja_train=train.filter(lambda example: example['language']=='ja')
zh_train=train.filter(lambda example: example['language']=='zh')

# Group datasets by language
ds={"de":de_train,"fr":fr_train,"es":es_train, "ja":ja_train,"zh":zh_train}
language_dataset={"de":[],"es":[],"fr":[],"ja":[],"zh":[]}




Using device: cuda:0


In [4]:
ds

{'de': Dataset({
     features: ['stars', 'text', 'language', 'label'],
     num_rows: 7848
 }),
 'fr': Dataset({
     features: ['stars', 'text', 'language', 'label'],
     num_rows: 7929
 }),
 'es': Dataset({
     features: ['stars', 'text', 'language', 'label'],
     num_rows: 7921
 }),
 'ja': Dataset({
     features: ['stars', 'text', 'language', 'label'],
     num_rows: 8000
 }),
 'zh': Dataset({
     features: ['stars', 'text', 'language', 'label'],
     num_rows: 8000
 })}

In [5]:
# Translate and extend language datasets
for language,dset in ds.items():
    input_sentences = dset['text']
    translated_texts = model.translate(input_sentences, source_lang=language, target_lang='en', show_progress_bar=True)
    com_data = Dataset.from_dict({
        'text':dset['text'],
        'en': translated_texts,
        'label': dset['label'],
        'stars': dset['stars']
    })
    language_dataset[language].extend(com_data)

  full_bar = Bar(frac,
100%|██████████| 15024/15023.0 [10:21<00:00, 24.19it/s]
12272it [10:05, 20.28it/s]                             
100%|██████████| 11776/11771.0 [05:50<00:00, 33.56it/s]
100%|██████████| 8048/8041.0 [07:28<00:00, 17.93it/s]
100%|██████████| 10736/10733.0 [06:24<00:00, 27.92it/s]


In [6]:
for i,j in language_dataset.items():
    file_name=f'{i}_val_en.json'
    with open(file_name, "w") as file:
        json.dump(j, file)
