In [None]:
# Dependencies
!pip install datasets --upgrade transformers pandas psutil polars protobuf tiktoken blobfile sentencepiece accelerate==0.26.1

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
import pandas as pd
from tqdm import tqdm
import gc
import psutil
import os
from datetime import datetime
import polars as pl

# Define output paths
OUTPUT_DIR = "/content/batches"
FINAL_OUTPUT_PATH = "/content/final_translations.csv"

def get_memory_usage():
    """Monitor memory usage"""
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024  # MB

def clean_memory():
    """Aggressive memory cleanup"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

class BatchTranslator:
    def __init__(self, model_name="vinai/vinai-translate-en2vi-v2", device=None, batch_size=8):
        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        self.batch_size = batch_size
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, src_lang="en_XX")
        self.load_model(model_name)

    def load_model(self, model_name):
        """Load the model with memory efficiency settings"""
        print(f"Loading model: {model_name} to {self.device}")
        self.model = AutoModelForSeq2SeqLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
            low_cpu_mem_usage=True
        ).to(self.device)
        print(f"Model loaded. Memory usage: {get_memory_usage():.2f} MB")

    def unload_model(self):
        """Unload model from GPU memory"""
        print("Unloading model...")
        if hasattr(self, 'model'):
            del self.model
        clean_memory()
        print(f"Model unloaded. Memory usage after cleanup: {get_memory_usage():.2f} MB")

    def translate_batch(self, batch_en_passages):
        """Translates a batch of English passages to Vietnamese using the vinai model."""
        if not hasattr(self, 'model'):
            self.load_model(self.model_name)

        try:
            inputs = self.tokenizer(
                batch_en_passages,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512,
            ).to(self.device)

            with torch.no_grad():
                output_ids = self.model.generate(
                    inputs.input_ids,
                    do_sample=True,
                    top_k=100,
                    top_p=0.8,
                    decoder_start_token_id=self.tokenizer.lang_code_to_id["vi_VN"],
                    max_length=512,
                    num_return_sequences=1,
                )

            translations = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
            cleaned_translations = [t.strip() for t in translations]

            del inputs, output_ids
            clean_memory()

            return cleaned_translations

        except RuntimeError as e:
            print(f"Runtime Error during translation: {e}")
            clean_memory()
            return [f"Translation Error: {e}"] * len(batch_en_passages)
        except Exception as e:
            print(f"Unexpected Error during translation: {e}")
            clean_memory()
            return [f"Translation Error: {e}"] * len(batch_en_passages)

def process_dataset_in_batches(df, batch_size, st, e):
    """Processes the dataset in batches and translates each batch"""
    print(f"Processing rows {st} to {e} with batch size {batch_size}...")
    translator = BatchTranslator()
    
    for start in range(st, e, batch_size):
        end = min(start + batch_size, e)
        batch_data = df[start:end]['en_passages'].to_list()
        translations = translator.translate_batch(batch_data)
        print(f"Translated rows {start} to {end}")
    
    translator.unload_model()

def combine_batch_files():
    """Combines batch results into the final output file"""
    print("Combining batch files into final output...")
    # Example logic: Append all batch results into one file

if __name__ == "__main__":
    print(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA device count: {torch.cuda.device_count()}")
        print(f"Current CUDA device: {torch.cuda.current_device()}")
        print(f"Device name: {torch.cuda.get_device_name(torch.cuda.current_device())}")

    print("Loading datasets...")
    try:
        jp_path_pattern = 'hf://datasets/hotchpotch/ms_marco_japanese/v2.1-madlad400-3b/train-*.parquet'
        en_path = 'hf://datasets/microsoft/ms_marco/v2.1/train-*.parquet'

        jp_df = pl.read_parquet(jp_path_pattern)
        en_df = pl.read_parquet(en_path)

        print(f"Loaded Japanese data shape: {jp_df.shape}")
        print(f"Loaded English data shape: {en_df.shape}")

        jp_df = jp_df.select(['query_id', pl.col('passages').struct.field('passage_text').alias('jp_passages')])
        en_df = en_df.select(['query_id', pl.col('passages').struct.field('passage_text').alias('en_passages')])

        print("Joining datasets on query_id...")
        merged = jp_df.join(en_df, on="query_id", how="inner")

        print(f"Merged data shape: {merged.shape}")
        if merged.is_empty():
            raise ValueError("Merged DataFrame is empty. Check join key 'query_id' and input data.")

        merged = merged.explode(['en_passages'])
        print(f"Exploded data shape (one row per passage): {merged.shape}")
        merged = merged.filter(pl.col('en_passages') != "")

        print(f"Shape after filtering empty passages: {merged.shape}")
        print("Sample of merged data:")
        print(merged.head(5))

        START_ROW = 10000
        END_ROW = 10000 + 1024
        BATCH_SIZE = 128

        process_dataset_in_batches(merged, batch_size=BATCH_SIZE, st=START_ROW, e=END_ROW)
        combine_batch_files()

    except Exception as main_e:
        print(f"An error occurred in the main execution block: {main_e}")
        import traceback
        traceback.print_exc()

    print("Script finished.")
