In [1]:
"""
SUBMISSION BLENDING - MAXIMUM SCORE OPTIMIZED
==============================================
KEY OPTIMIZATIONS:
1. ‚úÖ Fixed ALL regex double-escaping bugs
2. ‚úÖ Improved generation: more beams, better length penalty
3. ‚úÖ Smarter blending with multiple quality signals
4. ‚úÖ Better post-processing (preserves more content)
5. ‚úÖ Ensemble diversity via temperature sampling
"""

import os
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import re
from collections import Counter

# ============================================================
# CONFIG - OPTIMIZED FOR MAXIMUM SCORE
# ============================================================
CONFIG = {
    "data_path": "/kaggle/input/deep-past-initiative-machine-translation/test.csv",
    "external_submission_path": "/kaggle/input/akkadian2eng-v1/submission.csv",
    "models": [
        "/kaggle/input/byt5-base-big-data2",
        "/kaggle/input/byt5-akkadian-model",
        "/kaggle/input/train-gap-all-2/byt5-base-akkadian_gap_setence2"
    ],
    "model_weights": [0.995, 0.98, 0.395],
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "max_len": 512,
    "batch_size": 12,  # Increased for H100
    "gen_params": {
        "num_beams": 12,           # More beams = better search
        "max_new_tokens": 512,
        "length_penalty": 1.05,    # Slightly favor longer (more complete)
        "early_stopping": True,
        "no_repeat_ngram_size": 3, # Prevent repetition
        "repetition_penalty": 1.1, # Additional repetition control
    },
    "blend_weights": [0.75, 0.25]  # Tune based on validation
}

print(f"üî¨ OPTIMIZED: {CONFIG['blend_weights'][0]*100:.0f}% ours + {CONFIG['blend_weights'][1]*100:.0f}% external")
print(f"üñ•Ô∏è Device: {CONFIG['device']}")

# ============================================================
# PREPROCESSING - FIXED REGEX
# ============================================================
def preprocess_transliteration(text):
    if pd.isna(text): 
        return ""
    processed_text = str(text)
    # FIXED: Single backslash for regex
    processed_text = re.sub(r'(\.{3,}|‚Ä¶+|‚Ä¶‚Ä¶)', '<big_gap>', processed_text)
    processed_text = re.sub(r'(xx+|\s+x\s+)', '<gap>', processed_text)
    return processed_text

# ============================================================
# POSTPROCESSING - FIXED & OPTIMIZED
# ============================================================
# Pre-compile regex patterns for speed
_PATTERNS = {
    'gap_markers': re.compile(r'(\[x\]|\(x\)|\bx\b)', re.IGNORECASE),
    'ellipsis': re.compile(r'(\.{3,}|‚Ä¶|\[\.+\])'),
    'double_gap': re.compile(r'<gap>\s*<gap>'),
    'double_big_gap': re.compile(r'<big_gap>\s*<big_gap>'),
    'annotations': re.compile(r'\((fem|plur|pl|sing|singular|plural|\?|!)\.?\s*\w*\)', re.IGNORECASE),
    'repeated_words': re.compile(r'\b(\w+)(?:\s+\1\b)+'),
    'whitespace': re.compile(r'\s+'),
}

def postprocess_translation(text):
    if not isinstance(text, str) or not text.strip(): 
        return "The tablet contains fragmentary text."
    
    processed = text
    
    # Character replacements
    processed = processed.replace('·∏´', 'h').replace('·∏™', 'H')
    sub_map = str.maketrans("‚ÇÄ‚ÇÅ‚ÇÇ‚ÇÉ‚ÇÑ‚ÇÖ‚ÇÜ‚Çá‚Çà‚Çâ", "0123456789")
    processed = processed.translate(sub_map)
    
    # Gap normalization
    processed = _PATTERNS['gap_markers'].sub('<gap>', processed)
    processed = _PATTERNS['ellipsis'].sub('<big_gap>', processed)
    processed = _PATTERNS['double_gap'].sub(' <big_gap> ', processed)
    processed = _PATTERNS['double_big_gap'].sub(' <big_gap> ', processed)
    
    # Remove annotations
    processed = _PATTERNS['annotations'].sub('', processed)
    
    # Protect gaps during character removal
    processed = processed.replace('<gap>', '\x00GAP\x00').replace('<big_gap>', '\x00BIG\x00')
    
    # Remove problematic characters (but keep more punctuation for readability)
    bad_chars = '!?()"‚Äî‚Äì<>‚åà‚åã‚åä[]+ æ/'
    processed = processed.translate(str.maketrans('', '', bad_chars))
    
    # Restore gaps
    processed = processed.replace('\x00GAP\x00', ' <gap> ').replace('\x00BIG\x00', ' <big_gap> ')
    
    # Fraction conversion
    frac_patterns = [
        (r'(\d+)\.5\b', r'\1 ¬Ω'),
        (r'(\d+)\.25\b', r'\1 ¬º'),
        (r'(\d+)\.75\b', r'\1 ¬æ'),
        (r'(\d+)\.33\d*\b', r'\1 ‚Öì'),
        (r'(\d+)\.66\d*\b', r'\1 ‚Öî'),
        (r'\b0\.5\b', '¬Ω'),
        (r'\b0\.25\b', '¬º'),
        (r'\b0\.75\b', '¬æ'),
    ]
    for pat, rep in frac_patterns:
        processed = re.sub(pat, rep, processed)
    
    # Remove repeated words/phrases
    processed = _PATTERNS['repeated_words'].sub(r'\1', processed)
    
    # Repeated phrases (2-4 word sequences)
    for n in range(4, 1, -1):
        pat = r'\b((?:\w+\s+){' + str(n-1) + r'}\w+)(?:\s+\1\b)+'
        processed = re.sub(pat, r'\1', processed)
    
    # Capitalize first letter
    if processed and processed[0].islower():
        processed = processed[0].upper() + processed[1:]
    
    # Ensure ending punctuation
    if processed and processed[-1] not in '.!?':
        processed += '.'
    
    # Final cleanup
    processed = _PATTERNS['whitespace'].sub(' ', processed).strip().strip('-')
    
    return processed

# ============================================================
# IMPROVED BLENDING - MULTI-SIGNAL SCORING
# ============================================================
def score_translation(text):
    """
    Multi-factor quality scoring for translations.
    Higher score = better quality.
    """
    if not text or not isinstance(text, str):
        return -100
    
    score = 0.0
    words = text.split()
    word_count = len(words)
    
    # 1. Length scoring (prefer medium-length translations)
    if 8 <= word_count <= 50:
        score += 3.0
    elif 5 <= word_count <= 80:
        score += 1.5
    elif word_count < 3:
        score -= 5.0  # Penalize very short
    
    # 2. Structural quality
    if text and text[0].isupper():
        score += 1.0
    if text and text[-1] in '.!?':
        score += 1.0
    
    # 3. Content quality - domain-specific keywords
    domain_keywords = {
        'high_value': ['tablet', 'king', 'god', 'temple', 'city', 'year', 'month', 'silver', 
                       'barley', 'field', 'house', 'son', 'daughter', 'servant', 'lord'],
        'medium_value': ['said', 'wrote', 'gave', 'received', 'sent', 'took', 'made',
                         'brought', 'placed', 'sealed', 'witnessed'],
        'low_value': ['the', 'of', 'to', 'and', 'in', 'for', 'from', 'with']
    }
    
    text_lower = text.lower()
    for kw in domain_keywords['high_value']:
        if kw in text_lower:
            score += 0.8
    for kw in domain_keywords['medium_value']:
        if kw in text_lower:
            score += 0.4
    
    # 4. Penalize problematic patterns
    if '???' in text or 'xxx' in text_lower:
        score -= 3.0
    if 'fragmentary' in text_lower:
        score -= 2.0
    if text.count('<gap>') > 5:
        score -= 1.0
    if text.count('<big_gap>') > 3:
        score -= 1.0
    
    # 5. Repetition penalty
    word_freq = Counter(words)
    most_common_count = word_freq.most_common(1)[0][1] if word_freq else 0
    if most_common_count > 4 and word_count > 10:
        score -= (most_common_count - 4) * 0.5
    
    # 6. Coherence bonus (has both subject and verb indicators)
    has_noun = any(kw in text_lower for kw in ['king', 'god', 'man', 'city', 'tablet', 'field'])
    has_verb = any(kw in text_lower for kw in ['said', 'gave', 'took', 'made', 'is', 'was', 'has'])
    if has_noun and has_verb:
        score += 2.0
    
    return score


def blend_translations(text1, text2, weight1=0.75, weight2=0.25):
    """
    Intelligent blending using weighted quality scores.
    """
    # Handle edge cases
    if not text1 or not text1.strip():
        return text2 if text2 and text2.strip() else "The tablet contains fragmentary text."
    if not text2 or not text2.strip():
        return text1
    
    # Score both translations
    score1 = score_translation(text1)
    score2 = score_translation(text2)
    
    # Apply confidence weights
    weighted1 = score1 * weight1
    weighted2 = score2 * weight2
    
    # If scores are very close, prefer our model (text1)
    if abs(weighted1 - weighted2) < 0.5:
        return text1
    
    return text1 if weighted1 >= weighted2 else text2


def smart_ensemble_blend(our_text, external_text, our_weight=0.75):
    """
    Advanced blending that can combine parts of translations.
    """
    # First, do quality-based selection
    selected = blend_translations(our_text, external_text, our_weight, 1 - our_weight)
    
    # If our translation is too short but external has content, use external
    if len(our_text.split()) < 5 and len(external_text.split()) >= 10:
        return external_text
    
    # If external is garbage but ours is decent, use ours
    if score_translation(our_text) > 0 and score_translation(external_text) < -3:
        return our_text
    
    return selected

# ============================================================
# MODEL SOUP - OPTIMIZED
# ============================================================
def create_model_soup():
    """Memory-efficient model averaging with proper normalization."""
    total_score = sum(CONFIG['model_weights'])
    WEIGHTS = [w / total_score for w in CONFIG['model_weights']]
    
    print(f"Loading ensemble with normalized weights: {[f'{w:.3f}' for w in WEIGHTS]}")
    
    # Load base model
    print(f"  Loading model 1: {CONFIG['models'][1].split('/')[-1]}")
    template_model = AutoModelForSeq2SeqLM.from_pretrained(CONFIG['models'][1])
    soup_sd = template_model.state_dict()
    
    # Track which keys each model contributes to
    norm_factors = {key: WEIGHTS[1] for key in soup_sd}
    
    for key in soup_sd:
        soup_sd[key] = WEIGHTS[1] * soup_sd[key].float()
    
    # Accumulate other models
    for idx, model_path in enumerate([CONFIG['models'][0], CONFIG['models'][2]]):
        weight_idx = 0 if idx == 0 else 2
        print(f"  Loading model {weight_idx}: {model_path.split('/')[-1]}")
        temp_sd = AutoModelForSeq2SeqLM.from_pretrained(model_path).state_dict()
        
        for key in soup_sd:
            if key in temp_sd:
                soup_sd[key] += WEIGHTS[weight_idx] * temp_sd[key].float()
                norm_factors[key] += WEIGHTS[weight_idx]
        
        del temp_sd
        torch.cuda.empty_cache()
    
    # Normalize
    for key in soup_sd:
        soup_sd[key] = soup_sd[key] / norm_factors[key]
    
    template_model.load_state_dict(soup_sd)
    
    # Use BF16 on H100 for speed, FP32 on older GPUs for quality
    if torch.cuda.is_available() and 'H100' in torch.cuda.get_device_name(0):
        return template_model.to(CONFIG['device']).eval().bfloat16()
    return template_model.to(CONFIG['device']).eval().float()

# ============================================================
# DATASET
# ============================================================
class AkkadianTranslationDataset(Dataset):
    def __init__(self, dataframe):
        self.ids = dataframe['id'].tolist()
        self.texts = [
            "translate Akkadian to English: " + str(t) 
            for t in dataframe['transliteration']
        ]
    
    def __len__(self): 
        return len(self.ids)
    
    def __getitem__(self, idx): 
        return self.ids[idx], self.texts[idx]

# ============================================================
# MAIN EXECUTION
# ============================================================
print("\n" + "="*60)
print("üìÇ LOADING DATA")
print("="*60)

# Load external submission
print("\nLoading external submission...")
external_submissions = pd.read_csv(CONFIG['external_submission_path'])
external_dict = dict(zip(external_submissions['id'], external_submissions['translation']))
print(f"‚úÖ Loaded {len(external_dict)} external translations")

# Load test data
print("\nLoading test data...")
dataframe = pd.read_csv(CONFIG['data_path'])
dataframe['transliteration'] = dataframe['transliteration'].apply(preprocess_transliteration)
print(f"‚úÖ Loaded {len(dataframe)} test samples")

print("\n" + "="*60)
print("üß† CREATING MODEL ENSEMBLE")
print("="*60)
model = create_model_soup()
tokenizer = AutoTokenizer.from_pretrained(CONFIG['models'][1])
print(f"‚úÖ Model loaded")

# DataLoader
data_loader = DataLoader(
    AkkadianTranslationDataset(dataframe),
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=0,
    pin_memory=True,
    collate_fn=lambda batch: (
        [item[0] for item in batch],
        tokenizer(
            [item[1] for item in batch], 
            max_length=CONFIG['max_len'], 
            padding=True, 
            truncation=True, 
            return_tensors="pt"
        )
    )
)

print("\n" + "="*60)
print("üîÆ GENERATING PREDICTIONS")
print("="*60)
our_predictions = {}
with torch.inference_mode():
    for batch_idx, (ids, inputs) in enumerate(data_loader):
        outputs = model.generate(
            input_ids=inputs.input_ids.to(CONFIG['device']),
            attention_mask=inputs.attention_mask.to(CONFIG['device']),
            **CONFIG['gen_params']
        )
        
        decoded_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        cleaned_translations = [postprocess_translation(text) for text in decoded_texts]
        
        for id_, translation in zip(ids, cleaned_translations):
            our_predictions[id_] = translation
        
        if (batch_idx + 1) % 20 == 0:
            print(f"  Processed {batch_idx + 1}/{len(data_loader)} batches")

print(f"‚úÖ Generated {len(our_predictions)} predictions")

print("\n" + "="*60)
print("üîÄ SMART BLENDING")
print("="*60)
blended_results = []
blend_stats = {"ours": 0, "external": 0}

for id_ in sorted(our_predictions.keys()):
    our_translation = our_predictions[id_]
    external_translation = external_dict.get(id_, "")
    
    blended = smart_ensemble_blend(
        our_translation, 
        external_translation,
        our_weight=CONFIG['blend_weights'][0]
    )
    
    if blended == our_translation:
        blend_stats["ours"] += 1
    else:
        blend_stats["external"] += 1
    
    blended_results.append((id_, blended))

print(f"üìä Selection: {blend_stats['ours']} ours / {blend_stats['external']} external")

# Create submission
submission_df = pd.DataFrame(blended_results, columns=['id', 'translation'])

# Final quality check
submission_df['translation'] = submission_df['translation'].apply(
    lambda x: "The tablet contains an incomplete inscription." 
    if not x or len(x.split()) < 3 else x
)

submission_df.to_csv("submission.csv", index=False)
print(f"\n‚úÖ Saved submission.csv with {len(submission_df)} rows")

print("\n" + "="*60)
print("üìã SAMPLE OUTPUT")
print("="*60)
for i in range(min(3, len(submission_df))):
    id_ = submission_df.iloc[i]['id']
    print(f"\nID: {id_}")
    print(f"  Final: {submission_df.iloc[i]['translation'][:100]}...")

print("\n" + "="*60)
print("üèÅ OPTIMIZATION COMPLETE")
print("="*60)


üî¨ OPTIMIZED: 75% ours + 25% external
üñ•Ô∏è Device: cuda

üìÇ LOADING DATA

Loading external submission...
‚úÖ Loaded 4 external translations

Loading test data...
‚úÖ Loaded 4 test samples

üß† CREATING MODEL ENSEMBLE
Loading ensemble with normalized weights: ['0.420', '0.414', '0.167']
  Loading model 1: byt5-akkadian-model


2026-01-29 08:16:05.714431: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1769674565.903564      23 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1769674565.960631      23 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1769674566.416541      23 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769674566.416589      23 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769674566.416592      23 computation_placer.cc:177] computation placer alr

  Loading model 0: byt5-base-big-data2
  Loading model 2: byt5-base-akkadian_gap_setence2
‚úÖ Model loaded

üîÆ GENERATING PREDICTIONS
‚úÖ Generated 4 predictions

üîÄ SMART BLENDING
üìä Selection: 4 ours / 0 external

‚úÖ Saved submission.csv with 4 rows

üìã SAMPLE OUTPUT

ID: 0
  Final: From the Kanesh colony to Aqil <big_gap> datum, our messengers, every single one and two of us: A ta...

ID: 1
  Final: In a tablet from the City you wrote to me as follows: This day whoever receives my gold, will Daur o...

ID: 2
  Final: In accordance with our letter, he has given me for an investment to a palace,....

üèÅ OPTIMIZATION COMPLETE
