# Subtasks 1 & 2: Translation with MarianTranslate

In [None]:
import torch
from transformers import MarianMTModel, MarianTokenizer
import os
import csv
from tqdm import tqdm
from pathlib import Path

class Translator:
    def __init__(self, source_lang: str, device: str):
        self.device = device
        self.source_lang = source_lang
        model_name = f"Helsinki-NLP/opus-mt-{source_lang}-en"

        try:
            self.tokenizer = MarianTokenizer.from_pretrained(model_name)
            self.model = MarianMTModel.from_pretrained(model_name).to(self.device)
        except Exception as e:
            raise(f"Error during the download of the model for the language {source_lang}: {e}")

    def translate_to_eng(self, text: str) -> str:
        inputs = self.tokenizer(text, return_tensors="pt", padding=True)
        translated = self.model.generate(**inputs)
        tgt_text = self.tokenizer.decode(translated[0], skip_special_tokens=True)

        return tgt_text

    def translate_batch_to_eng(self, texts: list[str]) -> list[str]:
        inputs = self.tokenizer(
            texts,
            return_tensors="pt",
            padding=True
        )
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        with torch.no_grad():
            translated = self.model.generate(**inputs)

        return [
            self.tokenizer.decode(t, skip_special_tokens=True)
            for t in translated
        ]


    def translate_directory(self, input_dir, output_dir, text_column="text"):
        os.makedirs(output_dir, exist_ok=True)

        for filename in os.listdir(input_dir):
            if not filename.endswith(".csv"):
                continue

            in_path = os.path.join(input_dir, filename)
            out_path = os.path.join(output_dir, f"en_{filename}")

            print(f"Translating {filename}")

            with open(in_path, encoding="utf-8") as f:
                reader = csv.DictReader(f)
                rows = list(reader)
                fieldnames = reader.fieldnames + ["text_en"]

            for row in tqdm(rows):
                row["text_en"] = self.translate_to_eng(row[text_column])

            with open(out_path, "w", encoding="utf-8", newline="") as f:
                writer = csv.DictWriter(f, fieldnames=fieldnames)
                writer.writeheader()
                writer.writerows(rows)

    def translate_file_by_batch(self, input_file_path, output_dir, batch_size=8, text_column="text"):
        os.makedirs(output_dir, exist_ok=True)

        input_path = Path(input_file_path)

        if not input_path.exists():
            raise ("The file doesn't exist")

        if not input_path.suffix == ".csv":
            raise ("The extension file is not correct")

        filename = input_path.name

        out_path = os.path.join(output_dir, f"en_{filename}")

        print(f"Translating {filename} from {self.source_lang} into eng...")

        with open(input_file_path, encoding="utf-8") as f:
            reader = csv.DictReader(f)
            rows = list(reader)
            fieldnames = reader.fieldnames + ["text_en"]

        with open(out_path, "w", encoding="utf-8", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=fieldnames)
            writer.writeheader()

            for i in tqdm(range(0, len(rows), batch_size)):
                batch_rows = rows[i : i + batch_size]

                texts = [row.get(text_column, "") for row in batch_rows]

                translations = self.translate_batch_to_eng(texts)

                for row, text_en in zip(batch_rows, translations):
                    row["text_en"] = text_en if text_en else ""

                writer.writerows(batch_rows)
                f.flush()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

source_langs = ["ar", "bn", "de", "hi", "it", "pl", "ru", "es", "tr", "ur", "zh"]
file_prefixes = ["arb", "ben", "deu", "hin", "ita", "pol", "rus", "spa", "tur", "urd", "zho"]
output_dir = "./translation"

for lang, prefix in zip(source_langs, file_prefixes):
    print(f"Processing language: {lang} with file: {prefix}.csv")

    try:
        translator = Translator(lang, device)

        input_file = f"./train/{prefix}.csv"

        translator.translate_file_by_batch(input_file, output_dir, batch_size=32)

    except Exception as e:
        print(f"An error occurred while processing {lang} ({prefix}.csv): {e}")