In [2]:
import torch
import pandas as pd
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration

In [3]:
# Check if MPS is available, otherwise use CPU
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

Using device: mps


In [4]:
#Load model
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M").to(device)
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")

In [5]:
data = pd.read_csv("sampled_data.csv")

In [6]:
#Define the columns to be translated
translate_column = ['tweet_text']

#Define translation function
def translate_function(index, row):
    if row['language'] != 'en':
        for column in translate_column:
            if pd.notnull(row[column]) and isinstance(row[column], str):  # 检查当前列的内容是否不是NaN
                try:
                    #Translate using M2M100 model
                    if row['language'] != 'und':
                        tokenizer.src_lang = row['language'] 
                    model_inputs = tokenizer(row[column], return_tensors="pt").to(device)
                    # print(row['language'], row[column])
                    
                    #Translate to English
                    gen_tokens = model.generate(**model_inputs, forced_bos_token_id=tokenizer.get_lang_id("en"))
                    translated_text = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)[0]
                    
                    #Update to translated text
                    row[column] = translated_text
                    # print(translated_text)
                    
                    print(f"Translated row {index}")
                except Exception as e:
                    print(f"Error translating row {index}: {e}")
    return row

#Traverse the DataFrame with line numbers using iterros()
data_new = data.apply(lambda row: translate_function(row.name, row), axis=1)

Translated row 2
Translated row 3
Translated row 4
Translated row 7
Translated row 9
Translated row 10
Translated row 11
Translated row 12
Translated row 14
Translated row 16
Translated row 17
Translated row 18
Translated row 20
Translated row 24
Translated row 27
Translated row 31
Translated row 32
Translated row 33
Translated row 37
Translated row 38
Translated row 41
Translated row 42
Translated row 43
Translated row 49
Translated row 52
Translated row 54
Translated row 56
Translated row 57
Translated row 59
Translated row 61
Translated row 63
Translated row 64
Translated row 65
Translated row 66
Translated row 68
Translated row 70
Translated row 71
Translated row 73
Translated row 75
Translated row 76
Translated row 77
Translated row 79
Translated row 80
Translated row 83
Translated row 84
Translated row 85
Translated row 86
Translated row 87
Translated row 88
Translated row 89
Translated row 90
Translated row 91
Translated row 93
Translated row 94
Translated row 96
Translated row 

In [7]:
data_new.to_csv('sampled_data_translated.csv', index=False)