  Celda 1: Importaciones y Configuración Inicial

In [1]:

  # ============================================================================
  # SISTEMA DE TRADUCCIÓN NEURONAL HÍBRIDO MEJORADO - NLLB200 + ByT5
  # Versión Optimizada y Corregida
  # ============================================================================

  import torch
  import torch.nn as nn
  import torch.nn.functional as F
  from torch.utils.data import Dataset, DataLoader
  from transformers import (
      AutoTokenizer, AutoModelForSeq2SeqLM,
      T5ForConditionalGeneration, ByT5Tokenizer,
      get_linear_schedule_with_warmup
  )
  import numpy as np
  import pandas as pd
  from typing import List, Dict, Tuple, Optional, Union, Any
  import json
  import os
  from tqdm.auto import tqdm
  import gc
  import warnings
  import time
  from dataclasses import dataclass
  from datetime import datetime
  import matplotlib.pyplot as plt
  import seaborn as sns
  from pathlib import Path

  # Suprimir warnings innecesarios
  warnings.filterwarnings("ignore", category=UserWarning)
  warnings.filterwarnings("ignore", category=FutureWarning)

  print("✅ Importaciones completadas")
  print(f"🔥 PyTorch: {torch.__version__}")
  print(f"💾 CUDA disponible: {torch.cuda.is_available()}")
  if torch.cuda.is_available():
      print(f"🚀 GPU: {torch.cuda.get_device_name()}")
      print(f"💽 Memoria GPU: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")


✅ Importaciones completadas
🔥 PyTorch: 2.8.0+cu126
💾 CUDA disponible: True
🚀 GPU: Tesla T4
💽 Memoria GPU: 14.7 GB


  Celda 2: Configuración Optimizada del Sistema

In [2]:
  # ============================================================================
  # CONFIGURACIÓN OPTIMIZADA - SIN RALENTIZADORES
  # ============================================================================

  print("⚙️ Configurando entorno optimizado...")

  # ❌ ELIMINADO: Configuraciones que ralentizan
  # os.environ["CUDA_LAUNCH_BLOCKING"] = "1"  # ❌ RALENTIZA MUCHÍSIMO
  # os.environ["TORCH_USE_CUDA_DSA"] = "1"    # ❌ SOLO PARA DEBUG
  # torch.cuda.set_per_process_memory_fraction(0.8)  # ❌ FRAGMENTA MEMORIA

  # ✅ Configuración optimizada para velocidad
  if torch.cuda.is_available():
      device = torch.device("cuda")

      # Optimizaciones que SÍ mejoran rendimiento
      torch.backends.cudnn.benchmark = True  # ✅ Optimiza para tamaños fijos
      torch.backends.cuda.matmul.allow_tf32 = True  # ✅ Más rápido en GPUs modernas
      torch.backends.cudnn.allow_tf32 = True

      print(f"✅ GPU optimizada: {torch.cuda.get_device_name()}")
      print(f"💾 Memoria GPU total: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

      # Solo limpiar memoria al inicio, NO en loops
      torch.cuda.empty_cache()

  else:
      device = torch.device("cpu")
      print("⚠️ Usando CPU - entrenamiento será lento")

  # Configuración de precisión mixta optimizada según GPU
  def setup_mixed_precision():
      """Configurar precisión mixta según la GPU disponible"""
      if torch.cuda.is_available():
          gpu_name = torch.cuda.get_device_name().lower()

          if 't4' in gpu_name:
              print("🎯 GPU T4 detectada: usando FP16")
              return True, torch.float16
          elif 'a100' in gpu_name or 'v100' in gpu_name:
              print("🚀 GPU moderna detectada: usando BF16")
              return True, torch.bfloat16
          else:
              print("🔧 GPU genérica: usando FP16")
              return True, torch.float16

      return False, torch.float32

  USE_AMP, AMP_DTYPE = setup_mixed_precision()

  print(f"✅ Configuración completada:")
  print(f"  🔥 Device: {device}")
  print(f"  ⚡ Mixed precision: {USE_AMP}")
  print(f"  📊 AMP dtype: {AMP_DTYPE}")

  # Detectar entorno
  try:
      import google.colab
      IN_COLAB = True
      print("📱 Entorno: Google Colab")
  except ImportError:
      IN_COLAB = False
      print("💻 Entorno: Local/Servidor")


⚙️ Configurando entorno optimizado...
✅ GPU optimizada: Tesla T4
💾 Memoria GPU total: 14.7 GB
🎯 GPU T4 detectada: usando FP16
✅ Configuración completada:
  🔥 Device: cuda
  ⚡ Mixed precision: True
  📊 AMP dtype: torch.float16
📱 Entorno: Google Colab


Celda 2.5 Conexion con Google Drive:

In [3]:

  # ============================================================================
  # MONTAJE DE GOOGLE DRIVE - CELDA QUE FALTABA
  # ============================================================================

  print("📁 Configurando acceso a Google Drive...")

  try:
      from google.colab import drive

      # Montar Google Drive
      print("🔗 Montando Google Drive...")
      drive.mount('/content/drive')

      # Verificar acceso
      drive_path = "/content/drive/MyDrive"
      if os.path.exists(drive_path):
          print("✅ Google Drive montado exitosamente")
          print(f"📁 Ruta base: {drive_path}")

          # Crear directorios necesarios
          directories_to_create = [
              "/content/drive/MyDrive/model_checkpoints",
              "/content/drive/MyDrive/translation_data",
              "/content/drive/MyDrive/final_models",
              "/content/drive/MyDrive/training_logs"
          ]

          for directory in directories_to_create:
              os.makedirs(directory, exist_ok=True)
              print(f"📂 Directorio asegurado: {directory}")

          DRIVE_MOUNTED = True
          DRIVE_BASE_PATH = drive_path

      else:
          print("❌ Error: No se puede acceder a Google Drive")
          DRIVE_MOUNTED = False
          DRIVE_BASE_PATH = None

  except ImportError:
      print("⚠️ No está en Google Colab")
      DRIVE_MOUNTED = False
      DRIVE_BASE_PATH = None
  except Exception as e:
      print(f"❌ Error montando Google Drive: {e}")
      DRIVE_MOUNTED = False
      DRIVE_BASE_PATH = None

  print(f"✅ Configuración Drive completada:")
  print(f"  DRIVE_MOUNTED = {DRIVE_MOUNTED}")
  print(f"  DRIVE_BASE_PATH = {DRIVE_BASE_PATH}")


📁 Configurando acceso a Google Drive...
🔗 Montando Google Drive...
Mounted at /content/drive
✅ Google Drive montado exitosamente
📁 Ruta base: /content/drive/MyDrive
📂 Directorio asegurado: /content/drive/MyDrive/model_checkpoints
📂 Directorio asegurado: /content/drive/MyDrive/translation_data
📂 Directorio asegurado: /content/drive/MyDrive/final_models
📂 Directorio asegurado: /content/drive/MyDrive/training_logs
✅ Configuración Drive completada:
  DRIVE_MOUNTED = True
  DRIVE_BASE_PATH = /content/drive/MyDrive


  Celda 3: Configuración del Modelo Híbrido Corregida

In [4]:

  # ============================================================================
  # CONFIGURACIÓN HÍBRIDA REALISTA Y OPTIMIZADA
  # ============================================================================

  class HybridTranslationConfig:
      def __init__(self):
          # Modelos base
          self.nllb_model_name = "facebook/nllb-200-distilled-600M"
          self.byt5_model_name = "google/byt5-small"

          # 🔧 CORRECCIÓN CRÍTICA: Uso realista de modelos
          self.use_nllb_in_training = False  # ❌ DESACTIVADO para entrenamiento (muy costoso)
          self.use_nllb_for_inference = True  # ✅ Solo para inferencia/fallback
          self.hybrid_mode = "byt5_primary"  # Modo: ByT5 principal, NLLB auxiliar

          # Configuración de entrenamiento
          self.num_epochs = 10
          self.batch_size = 8 if not IN_COLAB else 4  # Menor batch en Colab
          self.learning_rate = 1e-5
          self.weight_decay = 0.01
          self.max_samples_per_epoch = 50000
          self.max_length = 128

          # Optimizaciones críticas
          self.gradient_accumulation_steps = 4
          self.use_mixed_precision = USE_AMP
          self.max_grad_norm = 1.0

          # Early stopping
          self.early_stopping_patience = 3
          self.early_stopping_min_delta = 0.001

          # Warmup y scheduler
          self.warmup_steps = 1000
          self.scheduler_type = "linear"

          print("⚠️ CONFIGURACIÓN HÍBRIDA:")
          print(f"  🎯 NLLB en entrenamiento: {'❌ NO' if not self.use_nllb_in_training else '✅ SÍ'}")
          print(f"  🌍 NLLB para inferencia: {'✅ SÍ' if self.use_nllb_for_inference else '❌ NO'}")
          print(f"  🔤 Modelo principal: ByT5 (optimizado)")
          print(f"  📊 Batch size: {self.batch_size}")
          print(f"  ⚡ Mixed precision: {'✅ SÍ' if self.use_mixed_precision else '❌ NO'}")

  # Crear configuración
  config = HybridTranslationConfig()

  # Idiomas soportados (mantener tu lista original)

  SUPPORTED_LANGUAGES = {
      # ===== IDIOMAS PRINCIPALES =====
      'es': 'spa_Latn',    # Español (500M hablantes)
      'en': 'eng_Latn',    # Inglés (1.5B hablantes)
      'fr': 'fra_Latn',    # Francés (280M hablantes)
      'pt': 'por_Latn',    # Portugués (260M hablantes)
      'ar': 'ara_Arab',    # Árabe (422M hablantes)
      'ru': 'rus_Cyrl',    # Ruso (258M hablantes)
      'zh': 'zho_Hans',    # Chino simplificado (918M)
      'hi': 'hin_Deva',    # Hindi (602M hablantes)

      # ===== EUROPA (países en desarrollo del este) =====
      'uk': 'ukr_Cyrl',    # Ucraniano (37M) - país en reconstrucción
      'bg': 'bul_Cyrl',    # Búlgaro (7M)
      'hr': 'hrv_Latn',    # Croata (5M)
      'sr': 'srp_Cyrl',    # Serbio (12M)
      'mk': 'mkd_Cyrl',    # Macedonio (2M)
      'sq': 'sqi_Latn',    # Albanés (6M)
      'ro': 'ron_Latn',    # Rumano (22M)
      'hu': 'hun_Latn',    # Húngaro (13M)
      'pl': 'pol_Latn',    # Polaco (45M)
      'cs': 'ces_Latn',    # Checo (10M)
      'sk': 'slk_Latn',    # Eslovaco (5M)
      'et': 'est_Latn',    # Estonio (1M)
      'lv': 'lav_Latn',    # Letón (2M)
      'lt': 'lit_Latn',    # Lituano (3M)

      # ===== ÁFRICA (AMPLIADO) =====
      # África Occidental
      'sw': 'swa_Latn',    # Swahili (200M) - lingua franca África Oriental
      'ha': 'hau_Latn',    # Hausa (80M) - Nigeria, Níger, Ghana
      'yo': 'yor_Latn',    # Yoruba (45M) - Nigeria, Benín
      'ig': 'ibo_Latn',    # Igbo (45M) - Nigeria
      'wo': 'wol_Latn',    # Wolof (12M) - Senegal, Gambia
      'ff': 'fuv_Latn',    # Fulah (65M) - Sahel africano
      'bm': 'bam_Latn',    # Bambara (15M) - Mali
      'tw': 'twi_Latn',    # Twi (17M) - Ghana
      'ak': 'aka_Latn',    # Akan (11M) - Ghana, Costa de Marfil
      'ee': 'ewe_Latn',    # Ewe (6M) - Ghana, Togo
      'gaa': 'gaa_Latn',   # Ga (3M) - Ghana
      'kr': 'kau_Latn',    # Kanuri (10M) - Nigeria, Chad

      # África Oriental y Meridional
      'am': 'amh_Ethi',    # Amhárico (57M) - Etiopía
      'om': 'orm_Latn',    # Oromo (37M) - Etiopía (más hablado)
      'ti': 'tir_Ethi',    # Tigriña (9M) - Etiopía, Eritrea
      'so': 'som_Latn',    # Somalí (21M) - Somalia, Etiopía, Kenia
      'zu': 'zul_Latn',    # Zulu (27M) - Sudáfrica
      'xh': 'xho_Latn',    # Xhosa (19M) - Sudáfrica
      'af': 'afr_Latn',    # Afrikáans (16M) - Sudáfrica
      'st': 'sot_Latn',    # Sesotho (7M) - Lesoto, Sudáfrica
      'tn': 'tsn_Latn',    # Tswana (8M) - Botswana, Sudáfrica
      'ss': 'ssw_Latn',    # Siswati (2M) - Esuatini, Sudáfrica
      've': 'ven_Latn',    # Venda (1M) - Sudáfrica
      'ts': 'tso_Latn',    # Tsonga (7M) - Sudáfrica, Mozambique
      'nr': 'nbl_Latn',    # Ndebele (2M) - Sudáfrica
      'ny': 'nya_Latn',    # Chichewa (17M) - Malawi, Zambia
      'sn': 'sna_Latn',    # Shona (15M) - Zimbabue
      'rw': 'kin_Latn',    # Kinyarwanda (25M) - Ruanda
      'rn': 'run_Latn',    # Kirundi (13M) - Burundi
      'kg': 'kon_Latn',    # Kikongo (10M) - RD Congo, Angola
      'ln': 'lin_Latn',    # Lingala (15M) - RD Congo
      'lua': 'luo_Latn',   # Luo (4M) - Kenia, Tanzania
      'mg': 'mlg_Latn',    # Malgache (25M) - Madagascar

      # África del Norte (Bereber)
      'ber': 'ber_Tfng',   # Bereber/Amazigh (30M) - Marruecos, Argelia
      'kab': 'kab_Latn',   # Cabila (7M) - Argelia

      # ===== ASIA (PAÍSES EN DESARROLLO) =====
      # Asia Meridional
      'bn': 'ben_Beng',    # Bengalí (300M) - Bangladesh, India
      'ur': 'urd_Arab',    # Urdu (230M) - Pakistán, India
      'pa': 'pan_Guru',    # Panyabí (130M) - India, Pakistán
      'gu': 'guj_Gujr',    # Gujarati (60M) - India
      'or': 'ory_Orya',    # Oriya (45M) - India
      'as': 'asm_Beng',    # Asamés (15M) - India
      'ml': 'mal_Mlym',    # Malayalam (38M) - India
      'kn': 'kan_Knda',    # Canarés (44M) - India
      'te': 'tel_Telu',    # Telugu (95M) - India
      'ta': 'tam_Taml',    # Tamil (78M) - India, Sri Lanka
      'si': 'sin_Sinh',    # Cingalés (16M) - Sri Lanka
      'ne': 'nep_Deva',    # Nepalí (16M) - Nepal
      'my': 'mya_Mymr',    # Birmano (33M) - Myanmar
      'km': 'khm_Khmr',    # Jemer (16M) - Camboya
      'lo': 'lao_Laoo',    # Lao (30M) - Laos

      # Asia Central (AÑADIDO Uzbeko y otros)
      'uz': 'uzn_Latn',    # Uzbeko (34M) - Uzbekistán ✅ TU SOLICITADO
      'kk': 'kaz_Cyrl',    # Kazajo (15M) - Kazajstán
      'ky': 'kir_Cyrl',    # Kirguís (5M) - Kirguistán
      'tg': 'tgk_Cyrl',    # Tayiko (9M) - Tayikistán
      'tk': 'tuk_Latn',    # Turkmeno (6M) - Turkmenistán
      'az': 'aze_Latn',    # Azerbaiyano (23M) - Azerbaiyán
      'hy': 'hye_Armn',    # Armenio (7M) - Armenia
      'ka': 'kat_Geor',    # Georgiano (4M) - Georgia

      # Asia Oriental ✅ YA TIENES JAPONÉS
      'ja': 'jpn_Jpan',    # Japonés (125M)
      'ko': 'kor_Hang',    # Coreano (77M)
      'mn': 'khk_Cyrl',    # Mongol (6M) - Mongolia

      # Asia Sudoriental
      'th': 'tha_Thai',    # Tailandés (60M) - Tailandia
      'vi': 'vie_Latn',    # Vietnamita (95M) - Vietnam
      'id': 'ind_Latn',    # Indonesio (280M) - Indonesia
      'ms': 'msa_Latn',    # Malayo (60M) - Malasia, Brunéi
      'tl': 'tgl_Latn',    # Filipino/Tagalo (45M) - Filipinas
      'ceb': 'ceb_Latn',   # Cebuano (25M) - Filipinas
      'war': 'war_Latn',   # Waray (3M) - Filipinas
      'ilo': 'ilo_Latn',   # Ilocano (10M) - Filipinas

      # ===== AMÉRICA (PAÍSES EN DESARROLLO) =====
      # América Central y Caribe
      'ht': 'hat_Latn',    # Criollo haitiano (12M) - Haití
      'gn': 'grn_Latn',    # Guaraní (12M) - Paraguay
      'qu': 'quy_Latn',    # Quechua (10M) - Perú, Bolivia, Ecuador
      'ay': 'aym_Latn',    # Aimara (3M) - Bolivia, Perú

      # Idiomas indígenas importantes
      'nah': 'nah_Latn',   # Náhuatl (1.5M) - México
      'yua': 'yua_Latn',   # Maya yucateco (800K) - México
      'bzd': 'bzd_Latn',   # Belize Kriol (200K) - Belice

      # ===== OCEANÍA =====
      'fj': 'fij_Latn',    # Fiyiano (350K) - Fiyi
      'sm': 'smo_Latn',    # Samoano (510K) - Samoa
      'to': 'ton_Latn',    # Tongano (200K) - Tonga
      'bi': 'bis_Latn',    # Bislama (200K) - Vanuatu
      'ty': 'tah_Latn',    # Tahitiano (280K) - Polinesia Francesa

      # ===== OTROS IDIOMAS RELEVANTES =====
      'mt': 'mlt_Latn',    # Maltés (520K) - Malta
      'is': 'isl_Latn',    # Islandés (350K) - Islandia
      'ga': 'gle_Latn',    # Irlandés (170K) - Irlanda
      'cy': 'cym_Latn',    # Galés (880K) - Gales
      'eu': 'eus_Latn',    # Euskera (750K) - España, Francia
      'ca': 'cat_Latn',    # Catalán (10M) - España
      'gl': 'glg_Latn',    # Gallego (2.4M) - España

      # Europa menos desarrollada
      'be': 'bel_Cyrl',    # Bielorruso (5M) - Bielorrusia
      'lv': 'lav_Latn',    # Letón (1.9M) - Letonia
      'sl': 'slv_Latn',    # Esloveno (2.1M) - Eslovenia
      'fi': 'fin_Latn',    # Finés (5.5M) - Finlandia
      'da': 'dan_Latn',    # Danés (6M) - Dinamarca
      'sv': 'swe_Latn',    # Sueco (10M) - Suecia
      'no': 'nor_Latn',    # Noruego (5M) - Noruega
      'nl': 'nld_Latn',    # Neerlandés (24M) - Países Bajos
      'de': 'deu_Latn',    # Alemán (95M) - Alemania
      'it': 'ita_Latn',    # Italiano (65M) - Italia
      'tr': 'tur_Latn',    # Turco (88M) - Turquía
      'he': 'heb_Hebr',    # Hebreo (9M) - Israel
      'fa': 'pes_Arab',    # Persa/Farsi (70M) - Irán
      'ps': 'pbt_Arab',    # Pastún (60M) - Afganistán, Pakistán
  }

  print(f"🌍 TOTAL IDIOMAS SOPORTADOS: {len(SUPPORTED_LANGUAGES)}")
  print(f"📊 Cobertura estimada: ~4.5 mil millones de hablantes")
  print(f"🎯 Enfoque: Países en desarrollo y idiomas con pocos recursos")
  print(f"🌍 Idiomas soportados: {len(SUPPORTED_LANGUAGES)}")


⚠️ CONFIGURACIÓN HÍBRIDA:
  🎯 NLLB en entrenamiento: ❌ NO
  🌍 NLLB para inferencia: ✅ SÍ
  🔤 Modelo principal: ByT5 (optimizado)
  📊 Batch size: 4
  ⚡ Mixed precision: ✅ SÍ
🌍 TOTAL IDIOMAS SOPORTADOS: 123
📊 Cobertura estimada: ~4.5 mil millones de hablantes
🎯 Enfoque: Países en desarrollo y idiomas con pocos recursos
🌍 Idiomas soportados: 123


  Celda 4: Modelo Híbrido Corregido NLLB + ByT5

In [5]:

  # ============================================================================
  # MODELO HÍBRIDO CORREGIDO - SIN ERRORES DE TOKENIZACIÓN
  # ============================================================================

  class HybridNLLBByT5Model(nn.Module):
      def __init__(self, config):
          super().__init__()
          self.config = config

          # 🔧 CORRECCIÓN CRÍTICA: Solo cargar NLLB si realmente se va a usar
          if config.use_nllb_for_inference:
              print("🌍 Cargando NLLB para inferencia...")
              try:
                  self.nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
                      config.nllb_model_name,
                      torch_dtype=torch.float32,
                      low_cpu_mem_usage=True
                  )
                  self.nllb_tokenizer = AutoTokenizer.from_pretrained(config.nllb_model_name)
                  print(f"✅ NLLB cargado para inferencia")
              except Exception as e:
                  print(f"⚠️ Error cargando NLLB: {e}")
                  self.nllb_model = None
                  self.nllb_tokenizer = None
          else:
              print("🚫 NLLB completamente desactivado")
              self.nllb_model = None
              self.nllb_tokenizer = None

          # ByT5: modelo principal y ÚNICO para entrenamiento
          print("🔤 Cargando ByT5 (modelo principal)...")
          self.byt5_model = T5ForConditionalGeneration.from_pretrained(
              config.byt5_model_name,
              low_cpu_mem_usage=True
          )
          self.byt5_tokenizer = ByT5Tokenizer.from_pretrained(config.byt5_model_name)

          # Verificar vocabulario
          print(f"  📊 ByT5 vocab size: {self.byt5_tokenizer.vocab_size}")
          print(f"  🔤 ByT5 pad token: {self.byt5_tokenizer.pad_token_id}")
          print(f"  🔤 ByT5 eos token: {self.byt5_tokenizer.eos_token_id}")

          # 🔧 CORRECCIÓN: Capas de mejora simplificadas
          byt5_hidden_size = self.byt5_model.config.d_model

          self.enhancement_layer = nn.Sequential(
              nn.Linear(byt5_hidden_size, byt5_hidden_size),
              nn.LayerNorm(byt5_hidden_size),
              nn.Dropout(0.1)
          )

          print(f"✅ Modelo híbrido inicializado")
          print(f"  🔤 ByT5 hidden size: {byt5_hidden_size}")
          print(f"  ⚙️ Solo ByT5 será entrenado")

      def forward(self, input_ids, attention_mask, labels=None, use_nllb=False, **kwargs):
          """Forward pass CORREGIDO - solo ByT5 durante entrenamiento"""

          # 🔧 FORZAR ByT5 durante entrenamiento
          if self.training:
              use_nllb = False

          # 🔧 CORRECCIÓN CRÍTICA: Solo usar ByT5 (vocabulario consistente)
          try:
              # Validación de entrada
              batch_size, seq_len = input_ids.shape
              device = input_ids.device

              # 🔧 CRÍTICO: Asegurar que input_ids está en rango válido para ByT5
              vocab_size = self.byt5_model.config.vocab_size
              input_ids = torch.clamp(input_ids, 0, vocab_size - 1)

              if labels is not None:
                  # 🔧 CRÍTICO: Manejar labels fuera de rango
                  labels = torch.clamp(labels, -100, vocab_size - 1)
                  # Reemplazar tokens fuera de rango con -100 (ignorar en loss)
                  labels = torch.where(
                      (labels >= vocab_size) | (labels < 0),
                      torch.tensor(-100, device=device, dtype=labels.dtype),
                      labels
                  )

              # Forward pass SOLO con ByT5
              outputs = self.byt5_model(
                  input_ids=input_ids,
                  attention_mask=attention_mask,
                  labels=labels,
                  **kwargs
              )

              return outputs

          except Exception as e:
              print(f"❌ Error en forward: {e}")

              # 🔧 FALLBACK seguro
              batch_size, seq_len = input_ids.shape
              device = input_ids.device
              vocab_size = self.byt5_model.config.vocab_size

              from transformers.modeling_outputs import Seq2SeqLMOutput
              return Seq2SeqLMOutput(
                  loss=torch.tensor(1.0, device=device, requires_grad=True) if labels is not None else None,
                  logits=torch.randn((batch_size, seq_len, vocab_size), device=device, requires_grad=True)
              )

      def generate(self, input_ids, attention_mask=None, **kwargs):
          """Generación simplificada - solo ByT5"""
          try:
              return self.byt5_model.generate(
                  input_ids=input_ids,
                  attention_mask=attention_mask,
                  pad_token_id=self.byt5_tokenizer.pad_token_id,
                  eos_token_id=self.byt5_tokenizer.eos_token_id,
                  **kwargs
              )
          except Exception as e:
              print(f"⚠️ Error en generación: {e}")
              # Fallback: devolver input_ids
              return input_ids

  print("✅ Clase HybridNLLBByT5Model CORREGIDA definida")


✅ Clase HybridNLLBByT5Model CORREGIDA definida


  Celda 5: Dataset y DataLoader Optimizados

In [6]:

  # ============================================================================
  # DATASET CORREGIDO - SIN ERRORES DE TOKENIZACIÓN
  # ============================================================================

  class MultilingualTranslationDataset(Dataset):
      def __init__(self, translation_pairs, tokenizer, max_length=128):
          self.data = translation_pairs
          self.tokenizer = tokenizer
          self.max_length = max_length

          # 🔧 CORRECCIÓN: Verificar tokenizer
          print(f"  📊 Dataset con {len(translation_pairs)} pares")
          print(f"  🔤 Tokenizer: {type(tokenizer).__name__}")
          print(f"  📏 Max length: {max_length}")

      def __len__(self):
          return len(self.data)

      def __getitem__(self, idx):
          item = self.data[idx]

          # Obtener textos con validación
          source_text = str(item.get('source', ''))[:200]  # Limitar longitud
          target_text = str(item.get('target', ''))[:200]

          # 🔧 CORRECCIÓN CRÍTICA: Validar que el texto no esté vacío
          if not source_text.strip():
              source_text = "Hello"  # Fallback
          if not target_text.strip():
              target_text = "Hola"   # Fallback

          try:
              # 🔧 TOKENIZACIÓN SEGURA
              source_encoding = self.tokenizer(
                  source_text,
                  truncation=True,
                  padding=False,
                  max_length=self.max_length,
                  return_tensors=None,
                  add_special_tokens=True  # 🔧 IMPORTANTE: Añadir tokens especiales
              )

              target_encoding = self.tokenizer(
                  target_text,
                  truncation=True,
                  padding=False,
                  max_length=self.max_length,
                  return_tensors=None,
                  add_special_tokens=True  # 🔧 IMPORTANTE
              )

              # 🔧 VALIDACIÓN: Asegurar que tenemos datos válidos
              if not source_encoding['input_ids'] or not target_encoding['input_ids']:
                  raise ValueError("Tokenización vacía")

              return {
                  'input_ids': source_encoding['input_ids'],
                  'target_ids': target_encoding['input_ids'],
                  'source_lang': item.get('source_lang', 'en'),
                  'target_lang': item.get('target_lang', 'es')
              }

          except Exception as e:
              print(f"⚠️ Error tokenizando item {idx}: {e}")
              # 🔧 FALLBACK SEGURO
              fallback_ids = [self.tokenizer.pad_token_id] * 10
              return {
                  'input_ids': fallback_ids,
                  'target_ids': fallback_ids,
                  'source_lang': 'en',
                  'target_lang': 'es'
              }

  def optimized_collate_fn(batch):
      """Función de collation CORREGIDA"""

      # Extraer datos
      input_ids = [item['input_ids'] for item in batch]
      target_ids = [item['target_ids'] for item in batch]

      # 🔧 CORRECCIÓN: pad_token_id correcto para ByT5
      pad_token_id = 0  # ByT5 usa 0 como pad_token_id

      # Verificar que tenemos datos válidos
      if not input_ids or not target_ids:
          print("⚠️ Batch vacío detectado")
          # Crear batch de fallback
          input_ids = [[pad_token_id] * 10 for _ in range(len(batch))]
          target_ids = [[pad_token_id] * 10 for _ in range(len(batch))]

      # Calcular longitudes máximas
      max_input_len = min(128, max(len(seq) for seq in input_ids))  # 🔧 Limitar a 128
      max_target_len = min(128, max(len(seq) for seq in target_ids))

      # Pad sequences y crear máscaras
      padded_inputs = []
      padded_targets = []
      input_masks = []

      for inp, tgt in zip(input_ids, target_ids):
          # 🔧 CORRECCIÓN: Validar longitudes
          inp = inp[:max_input_len]  # Truncar si es necesario
          tgt = tgt[:max_target_len]

          # Input padding
          inp_len = len(inp)
          inp_padded = inp + [pad_token_id] * (max_input_len - inp_len)
          inp_mask = [1] * inp_len + [0] * (max_input_len - inp_len)

          # Target padding
          tgt_len = len(tgt)
          tgt_padded = tgt + [pad_token_id] * (max_target_len - tgt_len)

          padded_inputs.append(inp_padded)
          padded_targets.append(tgt_padded)
          input_masks.append(inp_mask)

      # 🔧 CONVERSIÓN SEGURA A TENSORES
      try:
          return {
              'input_ids': torch.tensor(padded_inputs, dtype=torch.long),
              'attention_mask': torch.tensor(input_masks, dtype=torch.long),
              'labels': torch.tensor(padded_targets, dtype=torch.long),
          }
      except Exception as e:
          print(f"❌ Error creando tensores: {e}")
          # Fallback batch
          batch_size = len(batch)
          return {
              'input_ids': torch.zeros((batch_size, 10), dtype=torch.long),
              'attention_mask': torch.ones((batch_size, 10), dtype=torch.long),
              'labels': torch.zeros((batch_size, 10), dtype=torch.long),
          }

  def create_optimized_dataloader(dataset, batch_size, is_training=True):
      """Crear DataLoader CORREGIDO"""

      if IN_COLAB:
          dataloader_config = {
              'batch_size': batch_size,
              'shuffle': is_training,
              'num_workers': 0,
              'pin_memory': True,
              'drop_last': is_training,
              'collate_fn': optimized_collate_fn
          }
          print("🔧 DataLoader configurado para Colab")
      else:
          dataloader_config = {
              'batch_size': batch_size,
              'shuffle': is_training,
              'num_workers': 2,
              'pin_memory': True,
              'persistent_workers': True,
              'prefetch_factor': 2,
              'drop_last': is_training,
              'collate_fn': optimized_collate_fn
          }
          print("🚀 DataLoader optimizado para entorno local")

      return DataLoader(dataset, **dataloader_config)

  print("✅ Dataset y DataLoader CORREGIDOS definidos")



✅ Dataset y DataLoader CORREGIDOS definidos


In [7]:
  # ============================================================================
  # DEBUGGING - VERIFICAR RUTAS Y ARCHIVOS EN GOOGLE DRIVE
  # ============================================================================

  import os
  from pathlib import Path

  print("🔍 DEBUGGING: Verificando rutas en Google Drive...")

  # Rutas a verificar
  potential_paths = [
      "/content/drive/My Drive/GlobalTranslator/NMT/Dataset",
      "/content/drive/MyDrive/GlobalTranslator/NMT/Dataset",
      "/content/drive/My Drive/GlobalTranslatorApp/Codigo/NMT/Dataset",
      "/content/drive/MyDrive/GlobalTranslatorApp/Codigo/NMT/Dataset"
  ]

  for path in potential_paths:
      print(f"\n📂 Verificando: {path}")
      if os.path.exists(path):
          print(f"   ✅ Existe")
          try:
              files = os.listdir(path)
              csv_files = [f for f in files if f.endswith('.csv')]
              nmt_files = [f for f in csv_files if f.startswith('NMT_')]

              print(f"   📄 Total archivos: {len(files)}")
              print(f"   📊 Archivos CSV: {len(csv_files)}")
              print(f"   🎯 Archivos NMT_*: {len(nmt_files)}")

              if nmt_files:
                  print(f"   📋 Primeros archivos NMT:")
                  for f in nmt_files[:5]:
                      print(f"      - {f}")

              if csv_files:
                  print(f"   📋 Todos los CSV:")
                  for f in csv_files[:10]:
                      print(f"      - {f}")

          except Exception as e:
              print(f"   ❌ Error listando: {e}")
      else:
          print(f"   ❌ No existe")

  # Verificar estructura desde raíz
  print(f"\n🌳 Explorando estructura desde /content/drive/:")
  try:
      for root in ["/content/drive/My Drive", "/content/drive/MyDrive"]:
          if os.path.exists(root):
              print(f"\n📂 {root}/")
              for item in os.listdir(root)[:10]:
                  item_path = os.path.join(root, item)
                  if os.path.isdir(item_path):
                      print(f"   📁 {item}/")
                      # Buscar GlobalTranslator
                      if 'global' in item.lower():
                          print(f"      🎯 Posible match: {item}")
                          try:
                              sub_items = os.listdir(item_path)[:5]
                              for sub in sub_items:
                                  print(f"         📁 {sub}")
                          except:
                              pass
  except Exception as e:
      print(f"❌ Error explorando: {e}")


🔍 DEBUGGING: Verificando rutas en Google Drive...

📂 Verificando: /content/drive/My Drive/GlobalTranslator/NMT/Dataset
   ✅ Existe
   📄 Total archivos: 3
   📊 Archivos CSV: 3
   🎯 Archivos NMT_*: 3
   📋 Primeros archivos NMT:
      - NMT_train17.csv
      - NMT_train18.csv
      - NMT_val3.csv
   📋 Todos los CSV:
      - NMT_train17.csv
      - NMT_train18.csv
      - NMT_val3.csv

📂 Verificando: /content/drive/MyDrive/GlobalTranslator/NMT/Dataset
   ✅ Existe
   📄 Total archivos: 3
   📊 Archivos CSV: 3
   🎯 Archivos NMT_*: 3
   📋 Primeros archivos NMT:
      - NMT_train17.csv
      - NMT_train18.csv
      - NMT_val3.csv
   📋 Todos los CSV:
      - NMT_train17.csv
      - NMT_train18.csv
      - NMT_val3.csv

📂 Verificando: /content/drive/My Drive/GlobalTranslatorApp/Codigo/NMT/Dataset
   ❌ No existe

📂 Verificando: /content/drive/MyDrive/GlobalTranslatorApp/Codigo/NMT/Dataset
   ❌ No existe

🌳 Explorando estructura desde /content/drive/:

📂 /content/drive/My Drive/

📂 /content/drive/My

  Celda 6: Carga de Datos Robusta



In [8]:
# ============================================================================
# Celda 6: Drive + Config + Carga robusta de datasets pre-tokenizados (ByT5)
#           con filtrado de ejemplos inválidos (previene NaN)
# ============================================================================
import os, re, ast, gc
from pathlib import Path
from typing import List, Dict, Optional
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset

# ----------------------------- Drive ----------------------------------------
try:
    DRIVE_MOUNTED
    DRIVE_BASE_PATH
    print(f"✅ Variables de Drive ya definidas: DRIVE_MOUNTED={DRIVE_MOUNTED}, BASE={DRIVE_BASE_PATH}")
except NameError:
    print("⚠️ Variables de Drive no definidas, configurando...")
    try:
        from google.colab import drive
        print("🔗 Montando Google Drive...")
        drive.mount('/content/drive')
        DRIVE_BASE_PATH = "/content/drive/My Drive"
        if os.path.exists(DRIVE_BASE_PATH):
            DRIVE_MOUNTED = True
            print(f"✅ Google Drive montado: {DRIVE_BASE_PATH}")
            for d in [f"{DRIVE_BASE_PATH}/GlobalTranslator/NMT/Models",
                      f"{DRIVE_BASE_PATH}/GlobalTranslator/NMT/Dataset"]:
                os.makedirs(d, exist_ok=True)
        else:
            DRIVE_MOUNTED = False
            DRIVE_BASE_PATH = None
            print("❌ No se puede acceder a Google Drive")
    except Exception as e:
        print(f"❌ Error montando Drive: {e}")
        DRIVE_MOUNTED = False
        DRIVE_BASE_PATH = None

print(f"🎛️ Estado final: DRIVE_MOUNTED={DRIVE_MOUNTED}, DRIVE_BASE_PATH={DRIVE_BASE_PATH}")

# ----------------------------- Config ---------------------------------------
try:
    current_config = enhanced_config
    print("✅ Usando enhanced_config")
except NameError:
    try:
        current_config = config
        print("✅ Usando config")
    except NameError:
        print("⚠️ Creando configuración básica (fallback)...")
        from dataclasses import dataclass
        @dataclass
        class BasicConfig:
            # longitudes y padding
            max_length: int = 128
            pad_token_id: int = 0
            label_pad_id: int = -100
            # entrenamiento (por si los usa el trainer)
            batch_size: int = 4
            learning_rate: float = 1e-4
            num_epochs: int = 3
            gradient_accumulation_steps: int = 4
            clip_norm: float = 1.0
            patience: int = 3
            warmup_steps: int = 100
            lr_scheduler_type: str = "linear"
            min_lr: float = 1e-6
            # modelo
            byt5_model_name: str = "google/byt5-small"
            # rutas
            model_save_path: str = "/content/drive/My Drive/GlobalTranslator/NMT/Models"
            checkpoint_dir: str = "/content/drive/My Drive/GlobalTranslator/NMT/Models"
        current_config = BasicConfig()
        print("✅ Configuración básica creada")

MODEL_SAVE_PATH = f"{DRIVE_BASE_PATH}/GlobalTranslator/NMT/Models" if DRIVE_BASE_PATH else "/content"
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)
current_config.model_save_path = MODEL_SAVE_PATH
current_config.checkpoint_dir = MODEL_SAVE_PATH
if 'enhanced_config' not in globals():
    enhanced_config = current_config
    print("✅ enhanced_config configurado globalmente")

# ---------------------- Parseo y normalización ------------------------------
def _safe_list_eval(val) -> List[int]:
    if isinstance(val, (list, tuple, np.ndarray)):
        return [int(x) for x in val]
    if isinstance(val, str):
        s = val.strip()
        try:
            out = ast.literal_eval(s)
            if isinstance(out, (list, tuple, np.ndarray)):
                return [int(x) for x in out]
        except Exception:
            pass
        s = s.strip("[]")
        if not s:
            return []
        return [int(tok) for tok in s.split(",") if tok.strip() != ""]
    return []

def _pad_or_trim(seq: List[int], L: int, pad_value: int) -> List[int]:
    if len(seq) > L:
        return seq[:L]
    if len(seq) < L:
        return seq + [pad_value] * (L - len(seq))
    return seq

def _normalize_example(ex: Dict, L: int, pad_id: int, label_pad_id: int) -> Dict:
    inp = _pad_or_trim(ex['input_ids'], L, pad_id)
    msk = _pad_or_trim(ex['attention_mask'], L, 0)
    tgt = _pad_or_trim(ex['labels'], L, pad_id)
    tgt_m = _pad_or_trim(ex['target_attention_mask'], L, 0)
    lbl = [label_pad_id if tok == pad_id else tok for tok in tgt]
    return {
        'input_ids': inp,
        'attention_mask': msk,
        'labels': lbl,
        'target_attention_mask': tgt_m,
        'src_lang': ex.get('src_lang'),
        'tgt_lang': ex.get('tgt_lang'),
    }

# ------------------- Búsqueda y carga streaming -----------------------------
def _find_files(base_path: str, prefix: str) -> List[Path]:
    base = Path(base_path)
    if not base.exists():
        print(f"❌ Ruta no encontrada: {base}")
        return []
    files = list(base.glob(f"{prefix}*.csv"))
    def extract_num(p: Path):
        nums = re.findall(r'\d+', p.stem)
        return int(nums[-1]) if nums else -1
    files.sort(key=extract_num)
    if not files:
        print(f"❌ No se encontraron archivos con patrón {prefix}*.csv")
    else:
        print(f"📁 {prefix}: {len(files)} archivos -> {[f.name for f in files]}")
    return files

def load_pretokenized_byT5(
    base_path: str,
    prefix: str,
    max_files: Optional[int] = None,
    max_rows_per_file: Optional[int] = None,
    L: Optional[int] = None,
    pad_id: int = 0,
    label_pad_id: int = -100,
    min_valid_target_tokens: int = 1
) -> List[Dict]:
    """
    Carga datos con columnas exactas:
      input_ids, input_attention_mask, target_ids, target_attention_mask, input_label, target_label
    Filtra ejemplos cuyo target tenga < min_valid_target_tokens tokens válidos (evita NaN).
    """
    files = _find_files(base_path, prefix)
    if not files:
        return []

    if max_files is not None:
        files = files[:max_files]
        print(f"📊 Limitando a {max_files} archivos para {prefix}")

    cols = [
        'input_ids',
        'input_attention_mask',
        'target_ids',
        'target_attention_mask',
        'input_label',
        'target_label'
    ]
    chunksize = 50_000
    out: List[Dict] = []
    kept = dropped_empty = dropped_short = 0

    for fp in files:
        try:
            print(f"📥 Leyendo {fp.name} ...")
            n_rows = 0
            for chunk in pd.read_csv(fp, usecols=cols, chunksize=chunksize):
                for _, row in chunk.iterrows():
                    n_rows += 1
                    try:
                        inp_ids = _safe_list_eval(row['input_ids'])
                        inp_msk = _safe_list_eval(row['input_attention_mask'])
                        tgt_ids = _safe_list_eval(row['target_ids'])
                        tgt_msk = _safe_list_eval(row['target_attention_mask'])

                        # Corrige máscaras si no coinciden
                        if len(inp_msk) != len(inp_ids):
                            inp_msk = [0 if tok == pad_id else 1 for tok in inp_ids]
                        if len(tgt_msk) != len(tgt_ids):
                            tgt_msk = [0 if tok == pad_id else 1 for tok in tgt_ids]

                        # Filtro clave: al menos N tokens target válidos (!= pad_id)
                        valid_target_tokens = sum(1 for t in tgt_ids if t != pad_id)
                        if valid_target_tokens == 0:
                            dropped_empty += 1
                            continue
                        if valid_target_tokens < min_valid_target_tokens:
                            dropped_short += 1
                            continue

                        ex = {
                            'input_ids': inp_ids,
                            'attention_mask': inp_msk,
                            'labels': tgt_ids,
                            'target_attention_mask': tgt_msk,
                            'src_lang': str(row['input_label']).strip(),
                            'tgt_lang': str(row['target_label']).strip()
                        }

                        if L is not None:
                            ex = _normalize_example(ex, L=L, pad_id=pad_id, label_pad_id=label_pad_id)

                        out.append(ex)
                        kept += 1

                        if max_rows_per_file and (kept % max_rows_per_file == 0):
                            break
                    except Exception:
                        continue  # ignora fila corrupta

                if max_rows_per_file and (kept % max_rows_per_file == 0):
                    break
                gc.collect()

            print(f"   ✅ {fp.name}: leídas={n_rows} | guardadas(acum)={kept} | vacías={dropped_empty} | cortas={dropped_short}")
        except Exception as e:
            print(f"❌ Error en {fp.name}: {e}")

    print(f"📊 TOTAL {prefix}: guardadas={kept} | descartadas(vacías)={dropped_empty} | descartadas(cortas)={dropped_short}")
    return out

# ------------------------ Dataset Torch -------------------------------------
class PreTokenizedByT5Dataset(Dataset):
    def __init__(self, data: List[Dict]):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        it = self.data[idx]
        return {
            'input_ids': torch.tensor(it['input_ids'], dtype=torch.long),
            'attention_mask': torch.tensor(it['attention_mask'], dtype=torch.long),
            'labels': torch.tensor(it['labels'], dtype=torch.long),
            'target_attention_mask': torch.tensor(it['target_attention_mask'], dtype=torch.long),
            'src_lang': it.get('src_lang'),
            'tgt_lang': it.get('tgt_lang'),
        }

# ----------------------- Carga principal ------------------------------------
def load_translation_data_byT5():
    base = f"{DRIVE_BASE_PATH}/GlobalTranslator/NMT/Dataset" if DRIVE_BASE_PATH else "/content"
    print("🔄 Cargando datos pre-tokenizados ByT5...")
    print("📂 Carpeta:", base)

    L = int(getattr(current_config, "max_length", 128))
    pad_id = int(getattr(current_config, "pad_token_id", 0))
    label_pad_id = int(getattr(current_config, "label_pad_id", -100))

    train_pairs = load_pretokenized_byT5(base, "NMT_train", L=L, pad_id=pad_id, label_pad_id=label_pad_id)
    val_pairs   = load_pretokenized_byT5(base, "NMT_val",   L=L, pad_id=pad_id, label_pad_id=label_pad_id)
    return train_pairs, val_pairs

print("📊 Iniciando carga de datos pre-tokenizados ByT5...")
train_pairs, val_pairs = load_translation_data_byT5()

print("🔄 Creando datasets PyTorch...")
TRAIN_DATASET = PreTokenizedByT5Dataset(train_pairs)
VAL_DATASET   = PreTokenizedByT5Dataset(val_pairs)
print(f"✅ Datasets: train={len(TRAIN_DATASET)} | val={len(VAL_DATASET)}")

# Muestra de sanity-check
if len(TRAIN_DATASET) > 0:
    s = TRAIN_DATASET[0]
    print("📋 Sample:")
    print("   input_ids[:10]:", s['input_ids'][:10].tolist())
    print("   labels[:10]:   ", s['labels'][:10].tolist())

# Crea DataLoaders si existe tu utilidad de Celda 5
if 'create_tokenized_dataloaders' in globals():
    try:
        TRAIN_LOADER, VAL_LOADER = create_tokenized_dataloaders(
            train_pairs, val_pairs,
            batch_size=int(getattr(current_config, "batch_size", 4)),
            max_length=int(getattr(current_config, "max_length", 128)),
            pad_token_id=int(getattr(current_config, "pad_token_id", 0)),
            label_pad_id=int(getattr(current_config, "label_pad_id", -100)),
        )
        print("✅ DataLoaders creados (TRAIN_LOADER, VAL_LOADER)")
    except Exception as e:
        print(f"⚠️ create_tokenized_dataloaders falló: {e}")
        TRAIN_LOADER = VAL_LOADER = None
else:
    TRAIN_LOADER = VAL_LOADER = None
    print("ℹ️ No hay 'create_tokenized_dataloaders' (Celda 5).")


✅ Variables de Drive ya definidas: DRIVE_MOUNTED=True, BASE=/content/drive/MyDrive
🎛️ Estado final: DRIVE_MOUNTED=True, DRIVE_BASE_PATH=/content/drive/MyDrive
✅ Usando config
✅ enhanced_config configurado globalmente
📊 Iniciando carga de datos pre-tokenizados ByT5...
🔄 Cargando datos pre-tokenizados ByT5...
📂 Buscando en: /content/drive/MyDrive/GlobalTranslator/NMT/Dataset
📁 Encontrados 2 archivos NMT_train*.csv: ['NMT_train17.csv', 'NMT_train18.csv']
📥 Cargando NMT_train17.csv ...
   ✅ 100000 filas válidas de NMT_train17.csv
📥 Cargando NMT_train18.csv ...
   ✅ 100000 filas válidas de NMT_train18.csv
📊 TOTAL cargado para 'NMT_train': 200000 ejemplos
📁 Encontrados 1 archivos NMT_val*.csv: ['NMT_val3.csv']
📥 Cargando NMT_val3.csv ...
   ✅ 100000 filas válidas de NMT_val3.csv
📊 TOTAL cargado para 'NMT_val': 100000 ejemplos
🔄 Creando datasets PyTorch...
✅ Datasets creados:
   📈 Entrenamiento: 200000 muestras
   📊 Validación:    100000 muestras
📋 Ejemplo de muestra:
   input_ids shape: (128

  Celda 7: Inicialización del Modelo

In [9]:

  # ============================================================================
  # INICIALIZACIÓN DEL MODELO HÍBRIDO
  # ============================================================================

  print("🚀 Inicializando modelo híbrido...")

  # Crear modelo
  model = HybridNLLBByT5Model(enhanced_config)
  model.to(device)

  print(f"✅ Modelo cargado en: {device}")

  # Contar parámetros
  total_params = sum(p.numel() for p in model.parameters())
  trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

  print(f"📊 Parámetros del modelo:")
  print(f"  🔢 Total: {total_params:,}")
  print(f"  🎯 Entrenables: {trainable_params:,}")
  print(f"  📈 Porcentaje entrenable: {100 * trainable_params / total_params:.1f}%")

  # Verificar memoria GPU
  if torch.cuda.is_available():
      torch.cuda.empty_cache()  # Limpiar antes de medir
      memory_allocated = torch.cuda.memory_allocated() / 1024**3
      memory_reserved = torch.cuda.memory_reserved() / 1024**3

      print(f"💾 Memoria GPU:")
      print(f"  📊 Asignada: {memory_allocated:.2f} GB")
      print(f"  🔒 Reservada: {memory_reserved:.2f} GB")

  # Configurar rutas de guardado específicas
  enhanced_config.best_model_path = os.path.join(MODEL_SAVE_PATH, "best_hybrid_model.pt")
  enhanced_config.final_model_path = os.path.join(MODEL_SAVE_PATH, "final_hybrid_model.pt")

  print(f"📁 Configuración de guardado:")
  print(f"  🏆 Mejor modelo: {enhanced_config.best_model_path}")
  print(f"  🎯 Modelo final: {enhanced_config.final_model_path}")

  print("✅ Modelo híbrido inicializado correctamente")


🚀 Inicializando modelo híbrido...
🌍 Cargando NLLB para inferencia...


config.json:   0%|          | 0.00/846 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


pytorch_model.bin:   0%|          | 0.00/2.46G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.46G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/564 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/4.85M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.3M [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

✅ NLLB cargado para inferencia
🔤 Cargando ByT5 (modelo principal)...


config.json:   0%|          | 0.00/698 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.20G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

  📊 ByT5 vocab size: 256
  🔤 ByT5 pad token: 0
  🔤 ByT5 eos token: 1
✅ Modelo híbrido inicializado
  🔤 ByT5 hidden size: 1472
  ⚙️ Solo ByT5 será entrenado
✅ Modelo cargado en: cuda
📊 Parámetros del modelo:
  🔢 Total: 916,882,752
  🎯 Entrenables: 916,882,752
  📈 Porcentaje entrenable: 100.0%
💾 Memoria GPU:
  📊 Asignada: 3.43 GB
  🔒 Reservada: 3.54 GB
📁 Configuración de guardado:
  🏆 Mejor modelo: /content/drive/MyDrive/GlobalTranslator/NMT/Models/best_hybrid_model.pt
  🎯 Modelo final: /content/drive/MyDrive/GlobalTranslator/NMT/Models/final_hybrid_model.pt
✅ Modelo híbrido inicializado correctamente


In [None]:
from tqdm.auto import tqdm
import torch
from torch.cuda.amp import autocast, GradScaler

def _count_valid_targets(labels):
    # cuenta tokens != -100
    return int((labels != -100).sum().item())

def train_epoch(model, dataloader, optimizer, scheduler, device, epoch, config):
    model.train()
    scaler = GradScaler(enabled=torch.cuda.is_available())
    total_loss = 0.0
    skipped = 0
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}")

    for batch_idx, batch in enumerate(progress_bar):
        input_ids = batch['input_ids'].to(device, non_blocking=True)
        attention_mask = batch['attention_mask'].to(device, non_blocking=True)
        labels = batch['labels'].to(device, non_blocking=True)

        # ⚠️ SALTO si no hay ni un token válido
        if _count_valid_targets(labels) == 0:
            skipped += 1
            continue

        with autocast(enabled=torch.cuda.is_available()):
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs['loss'] / config.gradient_accumulation_steps

        if torch.isnan(loss):
            skipped += 1
            continue

        scaler.scale(loss).backward()

        if (batch_idx + 1) % config.gradient_accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), getattr(config, "clip_norm", 1.0))
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            scheduler.step()

        total_loss += loss.item() * config.gradient_accumulation_steps

        progress_bar.set_postfix({
            'loss': f"{(total_loss / max(1, (batch_idx + 1))):.4f}",
            'lr': scheduler.get_last_lr()[0] if hasattr(scheduler, "get_last_lr") else None,
            'skip': skipped
        })

        if torch.cuda.is_available() and (batch_idx % 200 == 0):
            torch.cuda.empty_cache()

    if skipped > 0:
        print(f"⚠️ Batches saltados por labels vacíos/NaN: {skipped}")
    return total_loss / max(1, len(dataloader))

def validate(model, dataloader, device, config):
    model.eval()
    total_loss = 0.0
    skipped = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation"):
            input_ids = batch['input_ids'].to(device, non_blocking=True)
            attention_mask = batch['attention_mask'].to(device, non_blocking=True)
            labels = batch['labels'].to(device, non_blocking=True)

            if _count_valid_targets(labels) == 0:
                skipped += 1
                continue

            with autocast(enabled=torch.cuda.is_available()):
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs['loss']

            if torch.isnan(loss):
                skipped += 1
                continue

            total_loss += loss.item()

    if skipped > 0:
        print(f"⚠️ Val batches saltados por labels vacíos/NaN: {skipped}")
    return total_loss / max(1, len(dataloader))


  Celda 8: Configuración de Optimizador y Scheduler

In [10]:

  # ============================================================================
  # CONFIGURACIÓN DE OPTIMIZADOR Y SCHEDULER
  # ============================================================================

  print("⚙️ Configurando optimizador y scheduler...")

  # Optimizador
  optimizer = torch.optim.AdamW(
      model.parameters(),
      lr=config.learning_rate,
      weight_decay=config.weight_decay,
      eps=1e-8,
      betas=(0.9, 0.999)
  )

  # Calcular pasos totales para el scheduler
  total_steps = (len(train_pairs) // config.batch_size) * config.num_epochs
  warmup_steps = min(config.warmup_steps, total_steps // 10)  # Máximo 10% de warmup

  # Scheduler
  if config.scheduler_type == "linear":
      scheduler = get_linear_schedule_with_warmup(
          optimizer,
          num_warmup_steps=warmup_steps,
          num_training_steps=total_steps
      )
  else:
      scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
          optimizer,
          T_max=total_steps,
          eta_min=config.learning_rate * 0.01
      )

  # Configurar AMP scaler si está habilitado
  if USE_AMP:
      scaler = torch.cuda.amp.GradScaler()
      print("✅ AMP Scaler configurado")
  else:
      scaler = None

  print(f"✅ Configuración completada:")
  print(f"  🎯 Optimizador: AdamW")
  print(f"  📈 Learning rate: {config.learning_rate}")
  print(f"  🔥 Scheduler: {config.scheduler_type}")
  print(f"  📊 Total steps: {total_steps:,}")
  print(f"  🔄 Warmup steps: {warmup_steps:,}")
  print(f"  ⚡ Mixed precision: {'✅ SÍ' if USE_AMP else '❌ NO'}")


⚙️ Configurando optimizador y scheduler...
✅ AMP Scaler configurado
✅ Configuración completada:
  🎯 Optimizador: AdamW
  📈 Learning rate: 1e-05
  🔥 Scheduler: linear
  📊 Total steps: 500,000
  🔄 Warmup steps: 1,000
  ⚡ Mixed precision: ✅ SÍ


  Celda 9: Creación de Datasets y DataLoaders


In [11]:

  # ============================================================================
  # CREACIÓN DE DATASETS Y DATALOADERS
  # ============================================================================

  print("📊 Creando datasets y dataloaders...")

  # Obtener tokenizer principal (ByT5)
  tokenizer = model.byt5_tokenizer

  # Crear datasets
  train_dataset = MultilingualTranslationDataset(
      train_pairs,
      tokenizer,
      max_length=config.max_length
  )

  val_dataset = MultilingualTranslationDataset(
      val_pairs,
      tokenizer,
      max_length=config.max_length
  )

  # Crear dataloaders optimizados
  train_loader = create_optimized_dataloader(
      train_dataset,
      config.batch_size,
      is_training=True
  )

  val_loader = create_optimized_dataloader(
      val_dataset,
      config.batch_size,
      is_training=False
  )

  print(f"✅ Datasets y DataLoaders creados:")
  print(f"  🏋️ Train dataset: {len(train_dataset)} samples")
  print(f"  📊 Val dataset: {len(val_dataset)} samples")
  print(f"  📦 Train batches: {len(train_loader)}")
  print(f"  📦 Val batches: {len(val_loader)}")

  # Probar un batch
  print("\n🧪 Probando primer batch...")
  try:
      sample_batch = next(iter(train_loader))
      print(f"  ✅ Batch shape: {sample_batch['input_ids'].shape}")
      print(f"  ✅ Labels shape: {sample_batch['labels'].shape}")
      print(f"  ✅ Attention mask shape: {sample_batch['attention_mask'].shape}")
  except Exception as e:
      print(f"  ❌ Error en batch de prueba: {e}")


📊 Creando datasets y dataloaders...
  📊 Dataset con 200000 pares
  🔤 Tokenizer: ByT5Tokenizer
  📏 Max length: 128
  📊 Dataset con 100000 pares
  🔤 Tokenizer: ByT5Tokenizer
  📏 Max length: 128
🔧 DataLoader configurado para Colab
🔧 DataLoader configurado para Colab
✅ Datasets y DataLoaders creados:
  🏋️ Train dataset: 200000 samples
  📊 Val dataset: 100000 samples
  📦 Train batches: 50000
  📦 Val batches: 25000

🧪 Probando primer batch...
  ✅ Batch shape: torch.Size([4, 6])
  ✅ Labels shape: torch.Size([4, 5])
  ✅ Attention mask shape: torch.Size([4, 6])


  Celda 10: Funciones de Entrenamiento Optimizadas

In [12]:

  # ============================================================================
  # FUNCIONES DE ENTRENAMIENTO COMPLETAMENTE OPTIMIZADAS
  # ============================================================================

  def train_epoch_optimized(model, dataloader, optimizer, scheduler, device, epoch, config, scaler=None):
      """Época de entrenamiento completamente optimizada"""

      model.train()
      total_loss = 0.0
      num_batches = 0

      # Progress bar
      progress_bar = tqdm(dataloader, desc=f"Época {epoch+1}")

      accumulation_steps = config.gradient_accumulation_steps

      for batch_idx, batch in enumerate(progress_bar):
          # Mover a device
          input_ids = batch['input_ids'].to(device, non_blocking=True)
          attention_mask = batch['attention_mask'].to(device, non_blocking=True)
          labels = batch['labels'].to(device, non_blocking=True)

          # Zero gradients solo al inicio del accumulation
          if batch_idx % accumulation_steps == 0:
              optimizer.zero_grad()

          # Forward pass con AMP - SOLO BYT5 DURANTE ENTRENAMIENTO
          if USE_AMP and scaler is not None:
              with torch.cuda.amp.autocast(dtype=AMP_DTYPE):
                  outputs = model(
                      input_ids=input_ids,
                      attention_mask=attention_mask,
                      labels=labels,
                      use_nllb=False  # 🔧 FORZAR ByT5 durante entrenamiento
                  )
                  loss = outputs.loss / accumulation_steps
          else:
              outputs = model(
                  input_ids=input_ids,
                  attention_mask=attention_mask,
                  labels=labels,
                  use_nllb=False  # 🔧 FORZAR ByT5 durante entrenamiento
              )
              loss = outputs.loss / accumulation_steps

          # Backward pass
          if USE_AMP and scaler is not None:
              scaler.scale(loss).backward()
          else:
              loss.backward()

          # Optimization step cada accumulation_steps
          if (batch_idx + 1) % accumulation_steps == 0:
              if USE_AMP and scaler is not None:
                  scaler.unscale_(optimizer)
                  torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
                  scaler.step(optimizer)
                  scaler.update()
              else:
                  torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
                  optimizer.step()

              if scheduler:
                  scheduler.step()

          # Acumular loss
          total_loss += loss.item() * accumulation_steps
          num_batches += 1

          # Actualizar progress bar
          current_lr = optimizer.param_groups[0]['lr']
          progress_bar.set_postfix({
              'Loss': f"{loss.item() * accumulation_steps:.4f}",
              'Avg': f"{total_loss/num_batches:.4f}",
              'LR': f"{current_lr:.2e}",
              'Model': 'ByT5'  # Indicar que usa ByT5
          })

          # Limpieza de memoria periódica (cada 100 batches, no cada batch)
          if batch_idx > 0 and batch_idx % 100 == 0:
              torch.cuda.empty_cache()

      progress_bar.close()

      # Limpiar memoria al final de la época
      if torch.cuda.is_available():
          torch.cuda.empty_cache()
      gc.collect()

      avg_loss = total_loss / max(num_batches, 1)
      return avg_loss

  def validate_model(model, dataloader, device):
      """Validación del modelo - TAMBIÉN SOLO ByT5"""

      model.eval()
      total_loss = 0.0
      num_batches = 0

      with torch.no_grad():
          progress_bar = tqdm(dataloader, desc="Validando")

          for batch in progress_bar:
              input_ids = batch['input_ids'].to(device, non_blocking=True)
              attention_mask = batch['attention_mask'].to(device, non_blocking=True)
              labels = batch['labels'].to(device, non_blocking=True)

              # Forward pass - SOLO ByT5 durante validación también
              if USE_AMP:
                  with torch.cuda.amp.autocast(dtype=AMP_DTYPE):
                      outputs = model(
                          input_ids=input_ids,
                          attention_mask=attention_mask,
                          labels=labels,
                          use_nllb=False  # 🔧 ByT5 también en validación
                      )
              else:
                  outputs = model(
                      input_ids=input_ids,
                      attention_mask=attention_mask,
                      labels=labels,
                      use_nllb=False  # 🔧 ByT5 también en validación
                  )

              total_loss += outputs.loss.item()
              num_batches += 1

              progress_bar.set_postfix({
                  'Val Loss': f"{outputs.loss.item():.4f}",
                  'Model': 'ByT5'
              })

      progress_bar.close()
      avg_loss = total_loss / max(num_batches, 1)
      return avg_loss

  def save_checkpoint(model, optimizer, scheduler, epoch, loss, checkpoint_dir="./checkpoints"):
      """Guardar checkpoint - MEJORADO con Google Drive automático"""

      # Crear carpetas
      os.makedirs(checkpoint_dir, exist_ok=True)

      # 📁 GUARDAR TAMBIÉN EN GOOGLE DRIVE automáticamente
      drive_checkpoint_dir = "/content/drive/MyDrive/model_checkpoints"
      try:
          os.makedirs(drive_checkpoint_dir, exist_ok=True)
          save_to_drive = True
      except:
          save_to_drive = False
          print("⚠️ No se puede acceder a Google Drive")

      checkpoint = {
          'epoch': epoch,
          'model_state_dict': model.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
          'loss': loss,
          'config': config.__dict__,
          'timestamp': datetime.now().isoformat(),
          'model_type': 'HybridNLLBByT5Model',
          'training_mode': 'byt5_only',  # Indicar que se entrenó solo ByT5
          'pytorch_version': torch.__version__
      }

      # Guardar LOCAL
      checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt")
      latest_path = os.path.join(checkpoint_dir, "latest_model.pt")

      torch.save(checkpoint, checkpoint_path)
      torch.save(checkpoint, latest_path)

      print(f"💾 Checkpoint guardado LOCAL: {checkpoint_path}")

      # ✅ GUARDAR EN GOOGLE DRIVE si es posible
      if save_to_drive:
          try:
              drive_checkpoint_path = os.path.join(drive_checkpoint_dir, f"checkpoint_epoch_{epoch}.pt")
              drive_latest_path = os.path.join(drive_checkpoint_dir, "latest_model.pt")

              torch.save(checkpoint, drive_checkpoint_path)
              torch.save(checkpoint, drive_latest_path)

              print(f"☁️ Checkpoint guardado DRIVE: {drive_checkpoint_path}")

          except Exception as e:
              print(f"⚠️ Error guardando en Drive: {e}")

      return checkpoint_path

  def test_hybrid_inference(model, tokenizer, test_text="Hello, how are you?"):
      """
      Probar INFERENCIA HÍBRIDA - aquí sí se usa la selección inteligente
      """
      print(f"\n🧪 Probando inferencia híbrida con: '{test_text}'")

      model.eval()

      with torch.no_grad():
          # Tokenizar
          inputs = tokenizer(
              test_text,
              return_tensors="pt",
              padding=True,
              truncation=True,
              max_length=128
          ).to(device)

          # 🌍 PRUEBA 1: Con NLLB (si está disponible)
          if hasattr(model, 'generate_hybrid') and model.nllb_model is not None:
              print("🌍 Probando con NLLB (idiomas bien soportados)...")
              try:
                  nllb_output = model.generate_hybrid(
                      input_ids=inputs['input_ids'],
                      attention_mask=inputs['attention_mask'],
                      source_lang='en',
                      target_lang='es',  # Idioma bien soportado por NLLB
                      max_length=128,
                      num_beams=4,
                      early_stopping=True
                  )

                  nllb_translation = tokenizer.decode(nllb_output[0], skip_special_tokens=True)
                  print(f"  🌍 NLLB: {nllb_translation}")

              except Exception as e:
                  print(f"  ❌ NLLB falló: {e}")

          # 🔤 PRUEBA 2: Con ByT5 (modelo entrenado)
          print("🔤 Probando con ByT5 (modelo entrenado)...")
          try:
              byt5_output = model.byt5_model.generate(
                  input_ids=inputs['input_ids'],
                  attention_mask=inputs['attention_mask'],
                  max_length=128,
                  num_beams=4,
                  early_stopping=True
              )

              byt5_translation = tokenizer.decode(byt5_output[0], skip_special_tokens=True)
              print(f"  🔤 ByT5: {byt5_translation}")

          except Exception as e:
              print(f"  ❌ ByT5 falló: {e}")

  print("✅ Funciones de entrenamiento optimizadas definidas")
  print("🔧 Modo: ByT5 para entrenamiento, híbrido para inferencia")


✅ Funciones de entrenamiento optimizadas definidas
🔧 Modo: ByT5 para entrenamiento, híbrido para inferencia


  Celda 11: Función Principal de Entrenamiento

*   Elemento de lista
*   Elemento de lista



In [13]:

  # ============================================================================
  # FUNCIÓN PRINCIPAL DE ENTRENAMIENTO - CORREGIDA
  # ============================================================================

  def train_with_optimization(model, train_pairs, val_pairs, config, device):
      """Entrenamiento optimizado principal - FUNCIÓN CORREGIDA"""

      print("🎯 Iniciando entrenamiento optimizado...")

      # Crear datasets y loaders (ya creados anteriormente, pero por consistencia)
      tokenizer = model.byt5_tokenizer
      train_dataset = MultilingualTranslationDataset(train_pairs, tokenizer, config.max_length)
      val_dataset = MultilingualTranslationDataset(val_pairs, tokenizer, config.max_length)

      train_loader = create_optimized_dataloader(train_dataset, config.batch_size, True)
      val_loader = create_optimized_dataloader(val_dataset, config.batch_size, False)

      # Variables de seguimiento
      best_val_loss = float('inf')
      patience_counter = 0
      training_history = []

      print(f"📊 Configuración de entrenamiento:")
      print(f"  📈 Épocas: {config.num_epochs}")
      print(f"  🎯 Batch size: {config.batch_size}")
      print(f"  📊 Gradient accumulation: {config.gradient_accumulation_steps}")
      print(f"  ⚡ Mixed precision: {'✅' if USE_AMP else '❌'}")

      # Loop de entrenamiento
      for epoch in range(config.num_epochs):
          print(f"\n{'='*60}")
          print(f"🔄 ÉPOCA {epoch + 1}/{config.num_epochs}")
          print(f"{'='*60}")

          # Entrenar época
          epoch_start_time = time.time()
          train_loss = train_epoch_optimized(
              model, train_loader, optimizer, scheduler, device, epoch, config, scaler
          )
          epoch_time = time.time() - epoch_start_time

          # Validar
          print("\n📊 Ejecutando validación...")
          val_loss = validate_model(model, val_loader, device)

          # Guardar métricas
          epoch_metrics = {
              'epoch': epoch + 1,
              'train_loss': train_loss,
              'val_loss': val_loss,
              'epoch_time': epoch_time,
              'learning_rate': optimizer.param_groups[0]['lr']
          }
          training_history.append(epoch_metrics)

          # Mostrar resultados
          print(f"\n📈 RESULTADOS ÉPOCA {epoch + 1}:")
          print(f"  🏋️ Train Loss: {train_loss:.4f}")
          print(f"  📊 Val Loss: {val_loss:.4f}")
          print(f"  ⏱️ Tiempo: {epoch_time/60:.2f} min")
          print(f"  📈 Learning Rate: {optimizer.param_groups[0]['lr']:.2e}")

          # Early stopping
          if val_loss < best_val_loss:
              best_val_loss = val_loss
              patience_counter = 0
              print(f"  🏆 ¡Nuevo mejor modelo! (Val Loss: {best_val_loss:.4f})")

              # Guardar mejor modelo
              save_checkpoint(model, optimizer, scheduler, epoch, val_loss, "./checkpoints")

          else:
              patience_counter += 1
              print(f"  ⏳ Paciencia: {patience_counter}/{config.early_stopping_patience}")

          if patience_counter >= config.early_stopping_patience:
              print(f"\n🛑 Early stopping activado (sin mejora en {patience_counter} épocas)")
              break

          # Mostrar memoria GPU
          if torch.cuda.is_available():
              memory_used = torch.cuda.memory_allocated() / 1024**3
              memory_reserved = torch.cuda.memory_reserved() / 1024**3
              print(f"  💾 GPU Memory: {memory_used:.2f}GB usado, {memory_reserved:.2f}GB reservada")

      print(f"\n🎉 ¡ENTRENAMIENTO COMPLETADO!")
      print(f"🏆 Mejor Val Loss: {best_val_loss:.4f}")
      print(f"📊 Épocas entrenadas: {len(training_history)}")

      return training_history

  # 🔧 CORRECCIÓN CRÍTICA: Función que estaba faltante
  def train_hybrid_model(model, train_pairs, val_pairs, config, device):
      """
      Función principal que estaba faltante - CORRECCIÓN CRÍTICA
      """
      print("🚀 Ejecutando train_hybrid_model...")
      return train_with_optimization(model, train_pairs, val_pairs, config, device)

  print("✅ Función de entrenamiento principal definida y CORREGIDA")


✅ Función de entrenamiento principal definida y CORREGIDA


  Celda 12: Reanudación de Entrenamiento (Eliminar celdas duplicadas 15 y 17)


In [14]:

  # ============================================================================
  # REANUDACIÓN DE ENTRENAMIENTO - VERSIÓN MEJORADA CON GOOGLE DRIVE
  # ============================================================================

  def setup_training_resumption_enhanced(model, optimizer, scheduler):
      """
      Configurar reanudación con búsqueda inteligente en múltiples ubicaciones
      """

      print("🔍 Buscando checkpoints existentes...")

      # Buscar checkpoints en orden de prioridad
      checkpoint_paths = [
          "/content/drive/MyDrive/model_checkpoints/latest_model.pt",  # Drive - más reciente
          "./checkpoints/latest_model.pt",  # Local - backup
          "/content/drive/MyDrive/model_checkpoints/checkpoint_epoch_9.pt",  # Específicos recientes
          "/content/drive/MyDrive/model_checkpoints/checkpoint_epoch_8.pt",
          "/content/drive/MyDrive/model_checkpoints/checkpoint_epoch_7.pt",
          "/content/drive/MyDrive/model_checkpoints/checkpoint_epoch_6.pt",
          "./checkpoints/checkpoint_epoch_9.pt",  # Local específicos
          "./checkpoints/checkpoint_epoch_8.pt",
      ]

      for i, checkpoint_path in enumerate(checkpoint_paths):
          if os.path.exists(checkpoint_path):
              try:
                  print(f"📂 ENCONTRADO checkpoint ({i+1}): {checkpoint_path}")

                  # Cargar checkpoint
                  checkpoint = torch.load(checkpoint_path, map_location=device)

                  # Validar que el checkpoint es compatible
                  required_keys = ['model_state_dict', 'optimizer_state_dict', 'epoch']
                  if not all(key in checkpoint for key in required_keys):
                      print(f"  ⚠️ Checkpoint incompleto, probando siguiente...")
                      continue

                  # Cargar estados
                  model.load_state_dict(checkpoint['model_state_dict'])
                  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

                  if scheduler and 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict']:
                      scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
                      print("  ✅ Scheduler cargado")

                  # Extraer información
                  start_epoch = checkpoint.get('epoch', 0) + 1
                  best_loss = checkpoint.get('loss', float('inf'))
                  training_history = checkpoint.get('history', [])

                  # Información del checkpoint
                  timestamp = checkpoint.get('timestamp', 'unknown')
                  training_mode = checkpoint.get('training_mode', 'unknown')

                  print(f"  ✅ REANUDACIÓN EXITOSA:")
                  print(f"    🔄 Desde época: {start_epoch}")
                  print(f"    🎯 Mejor loss: {best_loss:.4f}")
                  print(f"    📊 Historial: {len(training_history)} épocas")
                  print(f"    🕐 Guardado: {timestamp}")
                  print(f"    🤖 Modo entrenamiento: {training_mode}")

                  return start_epoch, best_loss, training_history

              except Exception as e:
                  print(f"  ❌ Error cargando {checkpoint_path}: {e}")
                  print(f"  🔄 Probando siguiente ubicación...")
                  continue

      print("📝 No se encontraron checkpoints válidos")
      print("🆕 Iniciando entrenamiento desde cero")
      return 0, float('inf'), []

  def verify_model_compatibility(model, checkpoint_path):
      """Verificar que el modelo cargado es compatible"""

      try:
          checkpoint = torch.load(checkpoint_path, map_location='cpu')

          # Verificar configuración
          if 'config' in checkpoint:
              saved_config = checkpoint['config']
              current_config = config.__dict__

              # Verificar compatibilidad de modelos base
              if (saved_config.get('byt5_model_name') != current_config.get('byt5_model_name') or
                  saved_config.get('nllb_model_name') != current_config.get('nllb_model_name')):

                  print("⚠️ ADVERTENCIA: Configuración de modelos base diferente")
                  print(f"  Guardado: ByT5={saved_config.get('byt5_model_name')}, NLLB={saved_config.get('nllb_model_name')}")
                  print(f"  Actual:   ByT5={current_config.get('byt5_model_name')}, NLLB={current_config.get('nllb_model_name')}")

                  response = input("¿Continuar de todos modos? (y/n): ")
                  if response.lower() != 'y':
                      return False

          return True

      except Exception as e:
          print(f"❌ Error verificando compatibilidad: {e}")
          return False

  def cleanup_old_checkpoints(max_checkpoints=5):
      """Limpiar checkpoints antiguos para ahorrar espacio"""

      directories = [
          "./checkpoints",
          "/content/drive/MyDrive/model_checkpoints"
      ]

      for checkpoint_dir in directories:
          if not os.path.exists(checkpoint_dir):
              continue

          try:
              # Obtener todos los checkpoints de época
              epoch_files = glob.glob(os.path.join(checkpoint_dir, "checkpoint_epoch_*.pt"))

              if len(epoch_files) > max_checkpoints:
                  # Ordenar por número de época (extraer del nombre)
                  def extract_epoch_num(filepath):
                      try:
                          filename = os.path.basename(filepath)
                          return int(filename.split('epoch_')[1].split('.')[0])
                      except:
                          return 0

                  epoch_files.sort(key=extract_epoch_num)

                  # Eliminar los más antiguos
                  files_to_delete = epoch_files[:-max_checkpoints]

                  for file_path in files_to_delete:
                      try:
                          os.remove(file_path)
                          epoch_num = extract_epoch_num(file_path)
                          print(f"🗑️ Eliminado checkpoint antiguo: época {epoch_num}")
                      except Exception as e:
                          print(f"⚠️ Error eliminando {file_path}: {e}")

          except Exception as e:
              print(f"⚠️ Error limpiando {checkpoint_dir}: {e}")

  # ============================================================================
  # EJECUTAR CONFIGURACIÓN DE REANUDACIÓN
  # ============================================================================

  print("🚀 Configurando reanudación de entrenamiento...")

  # Buscar y cargar checkpoint automáticamente
  start_epoch, best_loss, training_history = setup_training_resumption_enhanced(
      model, optimizer, scheduler
  )

  # Limpiar checkpoints antiguos para ahorrar espacio
  cleanup_old_checkpoints(max_checkpoints=3)

  # Mostrar estado actual
  print(f"\n📊 ESTADO ACTUAL DEL ENTRENAMIENTO:")
  print(f"  🎯 Época inicial: {start_epoch}")
  print(f"  🏆 Mejor loss hasta ahora: {best_loss:.4f}")
  print(f"  📈 Épocas en historial: {len(training_history)}")

  # Configurar variables globales
  current_epoch = start_epoch
  current_best_loss = best_loss
  current_history = training_history

  # Probar inferencia híbrida si el modelo fue cargado
  if start_epoch > 0:
      print(f"\n🧪 Probando modelo cargado...")
      test_hybrid_inference(model, tokenizer, "Hello, how are you today?")

  print("✅ Configuración de reanudación completada")



🚀 Configurando reanudación de entrenamiento...
🔍 Buscando checkpoints existentes...
📝 No se encontraron checkpoints válidos
🆕 Iniciando entrenamiento desde cero
⚠️ Error limpiando /content/drive/MyDrive/model_checkpoints: name 'glob' is not defined

📊 ESTADO ACTUAL DEL ENTRENAMIENTO:
  🎯 Época inicial: 0
  🏆 Mejor loss hasta ahora: inf
  📈 Épocas en historial: 0
✅ Configuración de reanudación completada


In [None]:
# Auditoría rápida de targets
def audit_pairs(pairs, sample=10000, pad_id=0):
    n = min(sample, len(pairs))
    ok, empty, oob = 0, 0, 0
    max_lab = 0
    for i in range(n):
        lbl = pairs[i]['labels']
        max_lab = max(max_lab, max(lbl) if len(lbl)>0 else 0)
        valid = sum(1 for t in lbl if t != pad_id)
        if valid == 0:
            empty += 1
        else:
            ok += 1
    print(f"Audit: n={n} | ok={ok} | empty={empty} | max_label_id={max_lab}")

audit_pairs(train_pairs, pad_id=int(getattr(current_config, "pad_token_id", 0)))
audit_pairs(val_pairs,   pad_id=int(getattr(current_config, "pad_token_id", 0)))


  Celda 13: EJECUCIÓN DEL ENTRENAMIENTO

In [15]:

  # ============================================================================
  # 🎯 EJECUCIÓN PRINCIPAL DEL ENTRENAMIENTO
  # ============================================================================

  print("🚀 INICIANDO ENTRENAMIENTO DEL MODELO HÍBRIDO OPTIMIZADO")
  print("=" * 70)

  try:
      # Verificar que todo está listo
      print("🔍 Verificación pre-entrenamiento:")
      print(f"  ✅ Modelo: {type(model).__name__}")
      print(f"  ✅ Device: {device}")
      print(f"  ✅ Datos train: {len(train_pairs)}")
      print(f"  ✅ Datos val: {len(val_pairs)}")
      print(f"  ✅ Optimizador: {type(optimizer).__name__}")
      print(f"  ✅ Scheduler: {type(scheduler).__name__}")

      # 🔧 EJECUTAR FUNCIÓN CORREGIDA
      print(f"\n🎯 Iniciando entrenamiento...")
      history = train_hybrid_model(model, train_pairs, val_pairs, config, device)

      print("\n" + "=" * 70)
      print("🎉 ¡ENTRENAMIENTO COMPLETADO EXITOSAMENTE!")
      print("=" * 70)

      # Mostrar resumen final
      if history:
          final_metrics = history[-1]
          print(f"\n📊 MÉTRICAS FINALES:")
          print(f"  🏋️ Train Loss final: {final_metrics['train_loss']:.4f}")
          print(f"  📊 Val Loss final: {final_metrics['val_loss']:.4f}")
          print(f"  📈 Épocas completadas: {len(history)}")
          print(f"  ⏱️ Tiempo total: {sum(h['epoch_time'] for h in history)/60:.2f} min")

          # Encontrar mejor época
          best_epoch = min(history, key=lambda x: x['val_loss'])
          print(f"  🏆 Mejor época: {best_epoch['epoch']} (Val Loss: {best_epoch['val_loss']:.4f})")

      # Guardar historial
      with open('./training_history.json', 'w') as f:
          json.dump(history, f, indent=2)
      print(f"📄 Historial guardado en: ./training_history.json")

  except KeyboardInterrupt:
      print("\n⏹️ Entrenamiento interrumpido por el usuario")

  except Exception as e:
      print(f"\n❌ Error durante el entrenamiento: {e}")
      import traceback
      traceback.print_exc()

  finally:
      # Limpieza final
      if torch.cuda.is_available():
          torch.cuda.empty_cache()
      gc.collect()
      print("🧹 Limpieza de memoria completada")


🚀 INICIANDO ENTRENAMIENTO DEL MODELO HÍBRIDO OPTIMIZADO
🔍 Verificación pre-entrenamiento:
  ✅ Modelo: HybridNLLBByT5Model
  ✅ Device: cuda
  ✅ Datos train: 200000
  ✅ Datos val: 100000
  ✅ Optimizador: AdamW
  ✅ Scheduler: LambdaLR

🎯 Iniciando entrenamiento...
🚀 Ejecutando train_hybrid_model...
🎯 Iniciando entrenamiento optimizado...
  📊 Dataset con 200000 pares
  🔤 Tokenizer: ByT5Tokenizer
  📏 Max length: 128
  📊 Dataset con 100000 pares
  🔤 Tokenizer: ByT5Tokenizer
  📏 Max length: 128
🔧 DataLoader configurado para Colab
🔧 DataLoader configurado para Colab
📊 Configuración de entrenamiento:
  📈 Épocas: 10
  🎯 Batch size: 4
  📊 Gradient accumulation: 4
  ⚡ Mixed precision: ✅

🔄 ÉPOCA 1/10


Época 1:   0%|          | 0/50000 [00:00<?, ?it/s]


📊 Ejecutando validación...


Validando:   0%|          | 0/25000 [00:00<?, ?it/s]


📈 RESULTADOS ÉPOCA 1:
  🏋️ Train Loss: nan
  📊 Val Loss: nan
  ⏱️ Tiempo: 91.07 min
  📈 Learning Rate: 9.77e-06
  ⏳ Paciencia: 1/3
  💾 GPU Memory: 6.79GB usado, 7.82GB reservada

🔄 ÉPOCA 2/10


Época 2:   0%|          | 0/50000 [00:00<?, ?it/s]


📊 Ejecutando validación...


Validando:   0%|          | 0/25000 [00:00<?, ?it/s]


📈 RESULTADOS ÉPOCA 2:
  🏋️ Train Loss: nan
  📊 Val Loss: nan
  ⏱️ Tiempo: 88.75 min
  📈 Learning Rate: 9.52e-06
  ⏳ Paciencia: 2/3
  💾 GPU Memory: 6.79GB usado, 7.82GB reservada

🔄 ÉPOCA 3/10


Época 3:   0%|          | 0/50000 [00:00<?, ?it/s]


📊 Ejecutando validación...


Validando:   0%|          | 0/25000 [00:00<?, ?it/s]


📈 RESULTADOS ÉPOCA 3:
  🏋️ Train Loss: nan
  📊 Val Loss: nan
  ⏱️ Tiempo: 88.63 min
  📈 Learning Rate: 9.27e-06
  ⏳ Paciencia: 3/3

🛑 Early stopping activado (sin mejora en 3 épocas)

🎉 ¡ENTRENAMIENTO COMPLETADO!
🏆 Mejor Val Loss: inf
📊 Épocas entrenadas: 3

🎉 ¡ENTRENAMIENTO COMPLETADO EXITOSAMENTE!

📊 MÉTRICAS FINALES:
  🏋️ Train Loss final: nan
  📊 Val Loss final: nan
  📈 Épocas completadas: 3
  ⏱️ Tiempo total: 268.45 min
  🏆 Mejor época: 1 (Val Loss: nan)
📄 Historial guardado en: ./training_history.json
🧹 Limpieza de memoria completada


 Celda 14: Pruebas y Evaluación del Modelo

In [15]:

  # ============================================================================
  # PRUEBAS Y EVALUACIÓN DEL MODELO ENTRENADO
  # ============================================================================

  def test_model_translations(model, tokenizer, test_sentences=None):
      """Probar el modelo con oraciones de ejemplo"""

      if test_sentences is None:
          test_sentences = [
              "Hello, how are you today?",
              "The weather is beautiful.",
              "I love machine learning.",
              "Good morning!",
              "Thank you very much.",
              "Where is the library?",
              "I need help with this problem.",
              "What time is it?",
              "How much does this cost?",
              "I don't understand."
          ]

      model.eval()
      print("🧪 Probando traducciones del modelo...")
      print("-" * 60)

      with torch.no_grad():
          for i, source_text in enumerate(test_sentences, 1):
              try:
                  # Tokenizar entrada
                  inputs = tokenizer(
                      source_text,
                      return_tensors="pt",
                      padding=True,
                      truncation=True,
                      max_length=config.max_length
                  ).to(device)

                  # Generar traducción
                  if hasattr(model, 'generate_hybrid'):
                      # Usar generación híbrida si está disponible
                      outputs = model.generate_hybrid(
                          input_ids=inputs['input_ids'],
                          attention_mask=inputs['attention_mask'],
                          max_length=config.max_length,
                          num_beams=4,
                          temperature=0.7,
                          do_sample=True,
                          early_stopping=True,
                          source_lang='en',
                          target_lang='es'
                      )
                  else:
                      # Usar generación estándar
                      outputs = model.byt5_model.generate(
                          input_ids=inputs['input_ids'],
                          attention_mask=inputs['attention_mask'],
                          max_length=config.max_length,
                          num_beams=2,
                          early_stopping=True
                      )

                  # Decodificar resultado
                  translation = tokenizer.decode(outputs[0], skip_special_tokens=True)

                  print(f"{i:2d}. EN: {source_text}")
                  print(f"    ES: {translation}")
                  print()

              except Exception as e:
                  print(f"{i:2d}. ❌ Error traduciendo '{source_text}': {e}")
                  print()

  def calculate_model_size(model):
      """Calcular tamaño del modelo en MB"""
      param_size = 0
      buffer_size = 0

      for param in model.parameters():
          param_size += param.nelement() * param.element_size()

      for buffer in model.buffers():
          buffer_size += buffer.nelement() * buffer.element_size()

      size_mb = (param_size + buffer_size) / 1024**2
      return size_mb

  # Ejecutar pruebas si el entrenamiento fue exitoso
  if 'history' in locals() and history:
      print("🎯 Ejecutando pruebas del modelo entrenado...")

      # Información del modelo
      model_size_mb = calculate_model_size(model)
      print(f"📊 Tamaño del modelo: {model_size_mb:.2f} MB")

      # Pruebas de traducción
      test_model_translations(model, tokenizer)

      # Cargar mejor checkpoint si existe
      best_checkpoint_path = "./checkpoints/latest_model.pt"
      if os.path.exists(best_checkpoint_path):
          print("📂 Cargando mejor checkpoint para pruebas...")
          try:
              checkpoint = torch.load(best_checkpoint_path, map_location=device)
              model.load_state_dict(checkpoint['model_state_dict'])
              print("✅ Mejor checkpoint cargado")

              print("\n🏆 Traducciones con el MEJOR modelo:")
              test_model_translations(model, tokenizer)

          except Exception as e:
              print(f"⚠️ Error cargando checkpoint: {e}")

  else:
      print("⚠️ No se puede probar el modelo - entrenamiento no completado")


🎯 Ejecutando pruebas del modelo entrenado...
📊 Tamaño del modelo: 3532.61 MB
🧪 Probando traducciones del modelo...
------------------------------------------------------------
🌍 Generando con NLLB (en->es)
 1. ❌ Error traduciendo 'Hello, how are you today?': bytes must be in range(0, 256)

🌍 Generando con NLLB (en->es)
 2. ❌ Error traduciendo 'The weather is beautiful.': bytes must be in range(0, 256)

🌍 Generando con NLLB (en->es)
 3. EN: I love machine learning.
    ES:     ll  lloo  llooov    lloooooooooooooooovvv                           lll      lllloooooooooooooooooooooooooooooooooooooooovv

🌍 Generando con NLLB (en->es)
 4. ❌ Error traduciendo 'Good morning!': bytes must be in range(0, 256)

🌍 Generando con NLLB (en->es)
 5. ❌ Error traduciendo 'Thank you very much.': bytes must be in range(0, 256)

🌍 Generando con NLLB (en->es)
 6. ❌ Error traduciendo 'Where is the library?': bytes must be in range(0, 256)

🌍 Generando con NLLB (en->es)
 7. ❌ Error traduciendo 'I need help wit

  Celda 15: Guardar Modelo Final

In [None]:

  # ============================================================================
  # GUARDAR MODELO FINAL PARA PRODUCCIÓN
  # ============================================================================

  def save_final_model(model, tokenizer, config, save_dir="./final_model"):
      """Guardar modelo final optimizado para producción"""

      print(f"💾 Guardando modelo final en: {save_dir}")
      os.makedirs(save_dir, exist_ok=True)

      try:
          # Guardar estado del modelo
          model_path = os.path.join(save_dir, "model.pt")
          torch.save({
              'model_state_dict': model.state_dict(),
              'config': config.__dict__,
              'model_type': 'HybridNLLBByT5Model',
              'pytorch_version': torch.__version__,
              'save_timestamp': datetime.now().isoformat()
          }, model_path)

          # Guardar tokenizer
          tokenizer.save_pretrained(os.path.join(save_dir, "tokenizer"))

          # Guardar configuración como JSON
          config_path = os.path.join(save_dir, "config.json")
          with open(config_path, 'w') as f:
              json.dump(config.__dict__, f, indent=2)

          # Guardar información del modelo
          model_info = {
              'model_name': 'HybridNLLBByT5Model',
              'base_models': {
                  'primary': config.byt5_model_name,
                  'secondary': config.nllb_model_name if config.use_nllb_for_inference else None
              },
              'training_info': {
                  'epochs_trained': len(history) if 'history' in locals() else 0,
                  'final_train_loss': history[-1]['train_loss'] if 'history' in locals() and history else None,
                  'final_val_loss': history[-1]['val_loss'] if 'history' in locals() and history else None,
                  'best_val_loss': min(h['val_loss'] for h in history) if 'history' in locals() and history else None
              },
              'model_size_mb': calculate_model_size(model),
              'supported_languages': list(SUPPORTED_LANGUAGES.keys()),
              'optimizations': {
                  'mixed_precision': config.use_mixed_precision,
                  'gradient_accumulation': config.gradient_accumulation_steps,
                  'nllb_in_training': config.use_nllb_in_training,
                  'nllb_for_inference': config.use_nllb_for_inference
              }
          }

          info_path = os.path.join(save_dir, "model_info.json")
          with open(info_path, 'w') as f:
              json.dump(model_info, f, indent=2)

          # Listar archivos guardados
          saved_files = os.listdir(save_dir)
          print("📄 Archivos guardados:")
          for file in saved_files:
              file_path = os.path.join(save_dir, file)
              if os.path.isfile(file_path):
                  size_mb = os.path.getsize(file_path) / (1024**2)
                  print(f"  📄 {file}: {size_mb:.2f} MB")

          print(f"✅ Modelo final guardado exitosamente en: {save_dir}")
          return save_dir

      except Exception as e:
          print(f"❌ Error guardando modelo final: {e}")
          return None

  # Guardar modelo final si el entrenamiento fue exitoso
  if 'model' in locals() and model is not None:
      final_model_path = save_final_model(model, tokenizer, config)

      if final_model_path:
          print(f"\n🎉 ¡MODELO GUARDADO EXITOSAMENTE!")
          print(f"📁 Ubicación: {final_model_path}")
          print("\n📋 Para usar el modelo:")
          print("1. Carga el modelo: torch.load('final_model/model.pt')")
          print("2. Carga el tokenizer desde: 'final_model/tokenizer'")
          print("3. Revisa la configuración en: 'final_model/config.json'")

  else:
      print("⚠️ No hay modelo para guardar")


💾 Guardando modelo final en: ./final_model


## **USER GUIDE**


In [None]:
print("""
╔══════════════════════════════════════════════════════════════════╗
║          HYBRID NLLB-ByT5 TRANSLATION MODEL - USER GUIDE         ║
╚══════════════════════════════════════════════════════════════════╝

📚 CÓMO USAR ESTE NOTEBOOK:

1️⃣ PREPARACIÓN:
   - Asegúrate de tener GPU disponible (recomendado)
   - Instala todas las dependencias (Celda 1)

2️⃣ DATOS DE ENTRENAMIENTO:
   - Modifica la función prepare_training_data() en Celda 6
   - Formato: (texto_origen, texto_destino, idioma_origen, idioma_destino)
   - Para NLLB, usa códigos como: 'eng_Latn', 'spa_Latn', etc.

3️⃣ CONFIGURACIÓN:
   - Ajusta hiperparámetros en HybridTranslationConfig (Celda 3)
   - Batch size, learning rate, epochs, etc.

4️⃣ ENTRENAMIENTO:
   - Ejecuta main_training_pipeline() (Celda 20) para proceso completo
   - O entrena paso a paso ejecutando celdas individuales

5️⃣ AGREGAR NUEVOS IDIOMAS:
   - Usa add_new_language_support() (Celda 12)
   - Proporciona pares de entrenamiento para el nuevo idioma

6️⃣ INFERENCIA:
   - Usa model.generate_translation() para traducir
   - O usa interactive_translation() para interfaz interactiva

📊 IDIOMAS SOPORTADOS POR NLLB-200:
   - 200+ idiomas incluyendo:
     • Principales: Inglés, Español, Francés, Alemán, Chino, etc.
     • Regionales: Catalán, Gallego, Euskera, etc.
     • Minoritarios: Muchos idiomas con pocos recursos

   Lista completa: https://github.com/facebookresearch/fairseq/tree/nllb

⚙️ PERSONALIZACIÓN:
   - Puedes cambiar el modelo NLLB base por versiones más grandes
   - Ajusta el modelo ByT5 (small, base, large)
   - Modifica las capas de fusión según tus necesidades

⚠️ CONSIDERACIONES:
   - El modelo distilled NLLB-200 usa ~2.4GB de VRAM
   - ByT5-small usa ~1.2GB adicionales
   - Con batch_size=4, necesitas ~8GB de VRAM mínimo
   - Para modelos más grandes, ajusta batch_size y gradient_accumulation

💡 PRÓXIMOS PASOS:
   1. Entrenar con más datos (mínimo 10k pares por idioma)
   2. Fine-tuning específico por dominio
   3. Implementar beam search para mejor calidad
   4. Agregar post-procesamiento específico por idioma
   5. Crear API REST para servir el modelo

📧 SOPORTE:
   Si tienes problemas, verifica:
   - Versiones de librerías
   - Memoria GPU disponible
   - Formato correcto de datos
   - Códigos de idioma válidos

¡Buena suerte con tu modelo de traducción multilingüe! 🌍🚀
""")

# Verificar estado final
print("\n" + "="*60)
print("ESTADO DEL SISTEMA:")
print(f"  • GPU disponible: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"  • GPU: {torch.cuda.get_device_name(0)}")
    print(f"  • VRAM libre: {torch.cuda.mem_get_info()[0]/1e9:.2f} GB")
print(f"  • Modelos cargados: ✓")
print(f"  • Listo para entrenamiento: ✓")
print("="*60)