<a href="https://colab.research.google.com/github/agungfirdaus717-ux/torentotgd/blob/main/SubTranslator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title Translate Multiple SRT & ASS to Indonesian (No API Key)
# ⚙️ Setup & Install
!pip -q install transformers sentencepiece sacremoses langdetect regex tqdm

# --- Imports
import os, re, io
from langdetect import detect
from tqdm import tqdm
import torch
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer

# --- GPU check
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using device:', device)

# ============================
# 🔧 CONFIG
# ============================
SRC_LANG = 'auto'
TGT_LANG = 'id'
MODEL_NAME = 'facebook/m2m100_418M'
MAX_CHARS_PER_LINE = 200
BATCH_SIZE = 16
PRESERVE_CASE = True
GLOSSARY = {}

# ============================
# 📂 Upload Banyak file (SRT/ASS)
# ============================
from google.colab import files
uploads = files.upload()
input_files = [fn for fn in uploads.keys() if fn.lower().endswith(('.srt','.ass'))]
print('Files loaded:', input_files)

# ============================
# 🧩 Utilities
# ============================
TIMECODE_RE = re.compile(r"""
    ^\s*(\d+)\s*\n
    (\d{2}:\d{2}:\d{2},\d{3})\s*-->\s*(\d{2}:\d{2}:\d{2},\d{3})\s*\n
    (.*?)\n{1,}(?=\s*\d+\s*\n\d{2}:|\Z)
""", re.VERBOSE | re.DOTALL | re.MULTILINE)

PROTECT_PATTERNS = [
    (re.compile(r"<[^>]+>"), 'TAG'),
    (re.compile(r"{[^}]+}"), 'BRACE'),
    (re.compile(r"\\N"), 'NEWLINE'),
    (re.compile(r"\u266A|\u266B|\u266C|\u2669|\u266F|\u266D|\u266E|♪"), 'NOTE'),
]

PLACEHOLDER_PREFIX = "[[PH-"

def protect_text(s):
    placeholders = []
    def repl_factory(label):
        def _repl(m):
            placeholders.append(m.group(0))
            return f"{PLACEHOLDER_PREFIX}{label}-{len(placeholders)-1}]]"
        return _repl
    for pat, label in PROTECT_PATTERNS:
        s = re.sub(pat, repl_factory(label), s)
    return s, placeholders

def restore_text(s, placeholders):
    def _repl(m):
        i = int(re.findall(r"-(\d+)]]$", m.group(0))[0])
        return placeholders[i]
    return re.sub(re.escape(PLACEHOLDER_PREFIX) + r"[A-Z]+-\d+]]", _repl, s)

def apply_glossary_pre(s):
    for k in sorted(GLOSSARY.keys(), key=len, reverse=True):
        s = re.sub(rf"(?i)\b{re.escape(k)}\b", lambda _: f"[[GLS-{k}]]", s)
    return s

def apply_glossary_post(s):
    for k, v in GLOSSARY.items():
        s = s.replace(f"[[GLS-{k}]]", v)
    return s

def preserve_caps(src, tgt):
    if src.strip() and src.upper() == src and not any(c.islower() for c in src):
        return tgt.upper()
    return tgt

def split_long_line(line, max_chars=MAX_CHARS_PER_LINE):
    if len(line) <= max_chars:
        return [line]
    parts = re.split(r"([.!?])", line)
    chunks, cur = [], ''
    for chunk in parts:
        if len(cur) + len(chunk) <= max_chars:
            cur += chunk
        else:
            if cur:
                chunks.append(cur.strip())
            cur = chunk
    if cur:
        chunks.append(cur.strip())
    return [c for c in chunks if c]

# ============================
# 🚀 Load model
# ============================
print('Loading model:', MODEL_NAME)
model = M2M100ForConditionalGeneration.from_pretrained(MODEL_NAME)
tokenizer = M2M100Tokenizer.from_pretrained(MODEL_NAME)
model = model.to(device).eval()

# ============================
# 🔄 Proses setiap file (SRT/ASS)
# ============================
for filepath in input_files:
    print("\n=== Processing:", filepath, "===")
    base, ext = os.path.splitext(filepath)

    with io.open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
        content = f.read()

    # ============================
    # Jika format SRT
    # ============================
    if ext.lower() == '.srt':
        content = content.strip() + '\n\n'
        blocks = []
        for m in re.finditer(TIMECODE_RE, content):
            idx = int(m.group(1))
            start, end = m.group(2), m.group(3)
            text = m.group(4).strip('\n')
            lines = text.split('\n')
            norm_lines = []
            for ln in lines:
                norm_lines.extend(split_long_line(ln))
            blocks.append({'idx': idx, 'start': start, 'end': end, 'lines': norm_lines})

        joined_sample = '\n'.join(['\n'.join(b['lines']) for b in blocks[:50]])
        src_lang = SRC_LANG
        if SRC_LANG == 'auto':
            try:
                src_lang = detect(joined_sample)
            except Exception:
                src_lang = 'en'
        print('Source lang:', src_lang, '-> Target:', TGT_LANG)

        try:
            tokenizer.src_lang = src_lang
        except Exception:
            tokenizer.src_lang = 'en'
        forced_id = tokenizer.get_lang_id(TGT_LANG)

        source_lines, meta = [], []
        for bi, b in enumerate(blocks):
            for li, line in enumerate(b['lines']):
                orig = line
                line = apply_glossary_pre(line)
                prot, ph = protect_text(line)
                source_lines.append(prot)
                meta.append((bi, li, orig, ph))

        translated_lines = [''] * len(source_lines)
        for i in tqdm(range(0, len(source_lines), BATCH_SIZE)):
            batch = source_lines[i:i+BATCH_SIZE]
            enc = tokenizer(batch, return_tensors='pt', padding=True, truncation=True).to(device)
            with torch.no_grad():
                gen = model.generate(**enc, forced_bos_token_id=forced_id, max_length=256)
            outs = tokenizer.batch_decode(gen, skip_special_tokens=True)
            for j, out in enumerate(outs):
                translated_lines[i+j] = out

        for k, (bi, li, orig, ph) in enumerate(meta):
            text = translated_lines[k]
            text = restore_text(text, ph)
            text = apply_glossary_post(text)
            if PRESERVE_CASE:
                text = preserve_caps(orig, text)
            blocks[bi]['lines'][li] = text

        out_path = f"{base}.id.srt"
        with io.open(out_path, 'w', encoding='utf-8', errors='ignore') as f:
            for b in blocks:
                f.write(str(b['idx']) + '\n')
                f.write(f"{b['start']} --> {b['end']}\n")
                f.write('\n'.join(b['lines']).strip() + '\n\n')

        print('Saved:', out_path)

    # ============================
    # Jika format ASS
    # ============================
    elif ext.lower() == '.ass':
        lines = content.splitlines()
        out_lines = []

        # deteksi bahasa dari beberapa dialog awal
        sample_dialogues = []
        for l in lines:
            if l.strip().startswith("Dialogue:"):
                parts = l.split(",", 9)
                if len(parts) >= 10:
                    sample_dialogues.append(parts[9])
                if len(sample_dialogues) >= 20:
                    break
        joined_sample = " ".join(sample_dialogues)
        src_lang = SRC_LANG
        if SRC_LANG == 'auto':
            try:
                src_lang = detect(joined_sample)
            except Exception:
                src_lang = 'en'
        print('Source lang:', src_lang, '-> Target:', TGT_LANG)

        try:
            tokenizer.src_lang = src_lang
        except Exception:
            tokenizer.src_lang = 'en'
        forced_id = tokenizer.get_lang_id(TGT_LANG)

        # siapkan teks
        source_lines, meta = [], []
        for idx, l in enumerate(lines):
            if l.strip().startswith("Dialogue:"):
                parts = l.split(",", 9)
                if len(parts) >= 10:
                    orig_text = parts[9]
                    text = apply_glossary_pre(orig_text)
                    prot, ph = protect_text(text)
                    source_lines.append(prot)
                    meta.append((idx, orig_text, ph))

        translated_lines = [''] * len(source_lines)
        for i in tqdm(range(0, len(source_lines), BATCH_SIZE)):
            batch = source_lines[i:i+BATCH_SIZE]
            enc = tokenizer(batch, return_tensors='pt', padding=True, truncation=True).to(device)
            with torch.no_grad():
                gen = model.generate(**enc, forced_bos_token_id=forced_id, max_length=512)
            outs = tokenizer.batch_decode(gen, skip_special_tokens=True)
            for j, out in enumerate(outs):
                translated_lines[i+j] = out

        for k, (idx, orig_text, ph) in enumerate(meta):
            text = translated_lines[k]
            text = restore_text(text, ph)
            text = apply_glossary_post(text)
            if PRESERVE_CASE:
                text = preserve_caps(orig_text, text)
            parts = lines[idx].split(",", 9)
            parts[9] = text
            lines[idx] = ",".join(parts)

        out_path = f"{base}.id.ass"
        with io.open(out_path, 'w', encoding='utf-8', errors='ignore') as f:
            f.write("\n".join(lines))

        print('Saved:', out_path)

# ============================
# ⬇️ Download semua hasil
# ============================
from google.colab import files
for fn in os.listdir():
    if fn.endswith('.id.srt') or fn.endswith('.id.ass'):
        files.download(fn)
