Celda 1: Instalación de dependencias

In [1]:
!pip install transformers datasets sentencepiece sacrebleu torch accelerate
!pip install protobuf==3.20.* --force-reinstall
!apt-get install git-lfs

Collecting protobuf==3.20.*
  Using cached protobuf-3.20.3-py2.py3-none-any.whl.metadata (720 bytes)
Using cached protobuf-3.20.3-py2.py3-none-any.whl (162 kB)
Installing collected packages: protobuf
  Attempting uninstall: protobuf
    Found existing installation: protobuf 3.20.3
    Uninstalling protobuf-3.20.3:
      Successfully uninstalled protobuf-3.20.3
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-metadata 1.17.2 requires protobuf>=4.25.2; python_version >= "3.11", but you have protobuf 3.20.3 which is incompatible.
ydf 0.13.0 requires protobuf<7.0.0,>=5.29.1, but you have protobuf 3.20.3 which is incompatible.
grpcio-status 1.71.2 requires protobuf<6.0dev,>=5.26.1, but you have protobuf 3.20.3 which is incompatible.[0m[31m
[0mSuccessfully installed protobuf-3.20.3


Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
git-lfs is already the newest version (3.0.2-1ubuntu0.3).
0 upgraded, 0 newly installed, 0 to remove and 35 not upgraded.


Celda 2: Importaciones necesarias

In [1]:
  import torch
  import torch.nn as nn
  import torch.nn.functional as F
  from torch.utils.data import Dataset, DataLoader
  from transformers import (
      NllbTokenizer,
      AutoModelForSeq2SeqLM,
      ByT5Tokenizer,
      T5ForConditionalGeneration,
      get_linear_schedule_with_warmup,
      AutoTokenizer
  )
  import numpy as np
  import pandas as pd
  from typing import List, Dict, Tuple, Optional
  import json
  import os
  from tqdm.auto import tqdm
  import gc
  import warnings

  # Configurar warnings y determinismo
  warnings.filterwarnings('ignore')

  # Configuración para evitar NaN y problemas de estabilidad
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False
  torch.manual_seed(42)
  np.random.seed(42)

  # Configuración adicional para debugging CUDA
  os.environ["CUDA_LAUNCH_BLOCKING"] = "1"  # Para debug CUDA
  os.environ["TORCH_USE_CUDA_DSA"] = "1"    # Para device-side assertions

  # Verificar GPU y configurar device
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  print(f"🔧 Using device: {device}")

  if torch.cuda.is_available():
      print(f"🎮 GPU: {torch.cuda.get_device_name(0)}")
      print(f"💾 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

      # Limpiar caché CUDA al inicio
      torch.cuda.empty_cache()

      # Mostrar memoria disponible
      memory_free = torch.cuda.mem_get_info()[0] / 1e9
      memory_total = torch.cuda.mem_get_info()[1] / 1e9
      print(f"💾 GPU Memory Available: {memory_free:.2f} GB / {memory_total:.2f} GB")

      # Configurar para evitar fragmentación de memoria
      torch.cuda.set_per_process_memory_fraction(0.8)  # Usar solo 80% de la GPU

  else:
      print("⚠️ CUDA no disponible - usando CPU (será más lento)")

  # Configuración de PyTorch para estabilidad numérica
  torch.set_default_dtype(torch.float32)

  # Funciones de utilidad para debugging
  def check_tensor_health(tensor, name="tensor"):
      """Verifica si un tensor tiene valores problemáticos"""
      if torch.isnan(tensor).any():
          print(f"⚠️ NaN detectado en {name}")
          return False
      if torch.isinf(tensor).any():
          print(f"⚠️ Inf detectado en {name}")
          return False
      return True

  def print_memory_usage():
      """Imprime uso actual de memoria"""
      if torch.cuda.is_available():
          allocated = torch.cuda.memory_allocated() / 1e9
          cached = torch.cuda.memory_reserved() / 1e9
          print(f"💾 GPU Memory - Allocated: {allocated:.2f} GB, Cached: {cached:.2f} GB")

  def cleanup_memory():
      """Limpia memoria GPU"""
      if torch.cuda.is_available():
          torch.cuda.empty_cache()
      gc.collect()

  print("✅ Importaciones y configuración inicial completada")
  print("🔧 Configuración de estabilidad numérica activada")
  print("🧹 Limpieza automática de memoria configurada")


🔧 Using device: cpu
⚠️ CUDA no disponible - usando CPU (será más lento)
✅ Importaciones y configuración inicial completada
🔧 Configuración de estabilidad numérica activada
🧹 Limpieza automática de memoria configurada


Celda 3: Configuración del modelo híbrido con idiomas adicionales

In [2]:

  class HybridTranslationConfig:
      def __init__(self):
          # Modelos base
          self.nllb_model_name = "facebook/nllb-200-distilled-600M"
          self.byt5_model_name = "google/byt5-small"

          # CONFIGURACIÓN OPTIMIZADA: ÉPOCAS MÁS CORTAS
          self.batch_size = 8
          self.learning_rate = 1e-5
          self.num_epochs = 10           # ← MÁS ÉPOCAS
          self.max_length = 64
          self.warmup_steps = 100
          self.gradient_accumulation_steps = 4

          # LÍMITES POR ÉPOCA (CLAVE PARA ÉPOCAS CORTAS)
          self.max_samples_per_epoch = 50000    # ← SOLO 50K samples por época
          self.early_stopping_patience = 3     # ← Para en 3 épocas sin mejora
          self.save_every_n_epochs = 1         # ← Guardar cada época

          # Configuración del modelo híbrido
          self.hidden_size = 512
          self.fusion_dropout = 0.1
          self.temperature = 1.0

          # Learning rate scheduling mejorado
          self.lr_scheduler_type = "cosine"     # ← Cosine annealing
          self.min_lr = 1e-7                   # ← LR mínimo

          # IDIOMAS SOPORTADOS POR NLLB-200
          self.nllb_languages = [
              # Idiomas principales
              'eng_Latn', 'spa_Latn', 'fra_Latn', 'deu_Latn', 'ita_Latn',
              'por_Latn', 'rus_Cyrl', 'zho_Hans', 'jpn_Jpan', 'kor_Hang',
              'ara_Arab', 'hin_Deva', 'tur_Latn', 'pol_Latn', 'nld_Latn',

              # IDIOMAS AFRICANOS INCLUIDOS EN NLLB-200
              'wol_Latn',    # Wolof (Senegal, Gambia)
              'swh_Latn',    # Swahili (Tanzania, Kenya, Uganda)
              'amh_Ethi',    # Amhárico (Etiopía)
              'hau_Latn',    # Hausa (Nigeria, Níger)
              'ibo_Latn',    # Igbo (Nigeria)
              'yor_Latn',    # Yoruba (Nigeria, Benín)
              'sna_Latn',    # Shona (Zimbabwe)
              'som_Latn',    # Somalí (Somalia, Etiopía)
              'afr_Latn',    # Afrikáans (Sudáfrica)
              'xho_Latn',    # Xhosa (Sudáfrica)
              'zul_Latn',    # Zulu (Sudáfrica)
              'tsn_Latn',    # Tswana (Botsuana)
              'nso_Latn',    # Sotho del Norte (Sudáfrica)
              'ven_Latn',    # Venda (Sudáfrica)
              'tso_Latn',    # Tsonga (Sudáfrica)
              'ssw_Latn',    # Siswati (Esuatini)
              'lug_Latn',    # Luganda (Uganda)
              'kik_Latn',    # Kikuyu (Kenya)
              'luo_Latn',    # Luo (Kenya, Uganda)
              'rny_Latn',    # Runyanakore (Uganda)
              'lgg_Latn',    # Lugbara (Uganda, RDC)
              'fon_Latn',    # Fon (Benín)
              'twi_Latn',    # Twi (Ghana)
              'aka_Latn',    # Akan (Ghana)
              'bam_Latn',    # Bambara (Mali)
              'dyu_Latn',    # Dyula (Costa de Marfil)
              'mos_Latn',    # Mossi (Burkina Faso)
              'fuv_Latn',    # Fulfulde (Nigeria, otros países)

              # UZBEKO Y OTROS IDIOMAS ASIÁTICOS
              'uzn_Latn',    # Uzbeko (Uzbekistán)
              'kaz_Cyrl',    # Kazajo (Kazajistán)
              'kir_Cyrl',    # Kirguís (Kirguistán)
              'tgk_Cyrl',    # Tayiko (Tayikistán)
              'tuk_Latn',    # Turkmeno (Turkmenistán)
              'aze_Latn',    # Azerbaiyano (Azerbaiyán)

              # OTROS IDIOMAS ÚTILES
              'ben_Beng',    # Bengalí
              'urd_Arab',    # Urdu
              'fas_Arab',    # Persa/Farsi
              'mya_Mymr',    # Birmano
              'tha_Thai',    # Tailandés
              'vie_Latn',    # Vietnamita
              'ind_Latn',    # Indonesio
              'msa_Latn',    # Malayo
              'tgl_Latn',    # Tagalo (Filipinas)
              'ceb_Latn',    # Cebuano (Filipinas)
          ]

          # Nuevos idiomas para entrenar con ByT5
          self.new_languages = []

          # MAPEO DE CÓDIGOS PERSONALIZADOS A NLLB
          self.language_mapping = {
              # Códigos de tus datos → Códigos NLLB
              'EN': 'eng_Latn', 'ES': 'spa_Latn', 'FR': 'fra_Latn',
              'DE': 'deu_Latn', 'IT': 'ita_Latn', 'PT': 'por_Latn',
              'RU': 'rus_Cyrl', 'AR': 'ara_Arab', 'ZH': 'zho_Hans',
              'JA': 'jpn_Jpan', 'KO': 'kor_Hang',

              # Idiomas africanos
              'WO': 'wol_Latn',    # Wolof
              'SW': 'swh_Latn',    # Swahili
              'AM': 'amh_Ethi',    # Amhárico
              'HA': 'hau_Latn',    # Hausa
              'IG': 'ibo_Latn',    # Igbo
              'YO': 'yor_Latn',    # Yoruba
              'SN': 'sna_Latn',    # Shona
              'SO': 'som_Latn',    # Somalí
              'AF': 'afr_Latn',    # Afrikáans
              'XH': 'xho_Latn',    # Xhosa
              'ZU': 'zul_Latn',    # Zulu
              'TW': 'twi_Latn',    # Twi
              'AK': 'aka_Latn',    # Akan
              'BM': 'bam_Latn',    # Bambara

              # Uzbeko y otros asiáticos
              'UZ': 'uzn_Latn',    # Uzbeko
              'KK': 'kaz_Cyrl',    # Kazajo
              'KY': 'kir_Cyrl',    # Kirguís
              'TG': 'tgk_Cyrl',    # Tayiko
              'TK': 'tuk_Latn',    # Turkmeno
              'AZ': 'aze_Latn',    # Azerbaiyano

              # Otros útiles
              'BN': 'ben_Beng',    # Bengalí
              'UR': 'urd_Arab',    # Urdu
              'FA': 'fas_Arab',    # Persa
              'MY': 'mya_Mymr',    # Birmano
              'TH': 'tha_Thai',    # Tailandés
              'VI': 'vie_Latn',    # Vietnamita
              'ID': 'ind_Latn',    # Indonesio
              'MS': 'msa_Latn',    # Malayo
              'TL': 'tgl_Latn',    # Tagalo
          }

      def get_nllb_code(self, custom_code):
          """Convierte código personalizado a código NLLB"""
          return self.language_mapping.get(custom_code.upper(), custom_code)

      def is_supported_by_nllb(self, lang_code):
          """Verifica si un idioma está soportado por NLLB"""
          nllb_code = self.get_nllb_code(lang_code)
          return nllb_code in self.nllb_languages

  config = HybridTranslationConfig()

  print("🔧 CONFIGURACIÓN OPTIMIZADA:")
  print(f"  • Épocas: {config.num_epochs} (más épocas)")
  print(f"  • Samples por época: {config.max_samples_per_epoch:,} (épocas más cortas)")
  print(f"  • Tiempo estimado por época: ~45-60 minutos")
  print(f"  • Total estimado: ~8-10 horas")
  print(f"  • Guardado: Cada {config.save_every_n_epochs} época(s)")
  print(f"  • Early stopping: {config.early_stopping_patience} épocas sin mejora")


🔧 CONFIGURACIÓN OPTIMIZADA:
  • Épocas: 10 (más épocas)
  • Samples por época: 50,000 (épocas más cortas)
  • Tiempo estimado por época: ~45-60 minutos
  • Total estimado: ~8-10 horas
  • Guardado: Cada 1 época(s)
  • Early stopping: 3 épocas sin mejora


Celda 4: Clase del modelo híbrido NLLB + ByT5

In [3]:

  class HybridNLLBByT5Model(nn.Module):
      def __init__(self, config):
          super().__init__()
          self.config = config

          # Cargar modelos pre-entrenados
          print("Loading NLLB model...")
          self.nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
              config.nllb_model_name,
              torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
          )
          self.nllb_tokenizer = AutoTokenizer.from_pretrained(config.nllb_model_name)

          print("Loading ByT5 model...")
          self.byt5_model = T5ForConditionalGeneration.from_pretrained(config.byt5_model_name)
          self.byt5_tokenizer = ByT5Tokenizer.from_pretrained(config.byt5_model_name)

          # Obtener dimensiones de hidden states
          nllb_hidden_size = self.nllb_model.config.hidden_size
          byt5_hidden_size = self.byt5_model.config.d_model
          byt5_vocab_size = self.byt5_model.config.vocab_size

          print(f"NLLB hidden size: {nllb_hidden_size}")
          print(f"ByT5 hidden size: {byt5_hidden_size}")
          print(f"ByT5 vocab size: {byt5_vocab_size}")

          # CORRECCIÓN: Usar solo ByT5 para datos tokenizados con ByT5
          # El modelo híbrido se simplifica para evitar incompatibilidades

          # Capa de adaptación para mejorar ByT5
          self.adaptation_layer = nn.Sequential(
              nn.Linear(byt5_hidden_size, config.hidden_size),
              nn.LayerNorm(config.hidden_size),
              nn.ReLU(),
              nn.Dropout(config.fusion_dropout),
              nn.Linear(config.hidden_size, byt5_hidden_size)
          )

          # Proyección mejorada para salida
          self.enhanced_output = nn.Sequential(
              nn.Linear(byt5_hidden_size, byt5_hidden_size * 2),
              nn.GELU(),
              nn.Dropout(config.fusion_dropout),
              nn.Linear(byt5_hidden_size * 2, byt5_vocab_size)
          )

          # Inicializar pesos correctamente
          self._init_weights()

      def _init_weights(self):
          """Inicializa pesos con valores pequeños para evitar NaN"""
          for module in [self.adaptation_layer, self.enhanced_output]:
              for layer in module:
                  if isinstance(layer, nn.Linear):
                      # Inicialización Xavier con valores más pequeños
                      torch.nn.init.xavier_normal_(layer.weight, gain=0.1)
                      if layer.bias is not None:
                          torch.nn.init.zeros_(layer.bias)

      def forward(self, input_ids, attention_mask, labels=None, **kwargs):
          """Forward pass simplificado usando solo ByT5 con mejoras"""

          # Validar entrada
          if input_ids.max().item() >= self.byt5_model.config.vocab_size:
              print(f"⚠️ Token fuera de rango detectado: max={input_ids.max().item()}, vocab_size={self.byt5_model.config.vocab_size}")
              # Clamp tokens para evitar errores
              input_ids = torch.clamp(input_ids, 0, self.byt5_model.config.vocab_size - 1)

          if labels is not None and labels.max().item() >= self.byt5_model.config.vocab_size:
              # Ignorar tokens fuera de rango en labels
              labels = torch.where(
                  labels >= self.byt5_model.config.vocab_size,
                  torch.tensor(-100, device=labels.device),
                  labels
              )

          try:
              # Usar solo ByT5 (más estable para datos tokenizados con ByT5)
              outputs = self.byt5_model(
                  input_ids=input_ids,
                  attention_mask=attention_mask,
                  labels=labels,
                  output_hidden_states=True
              )

              # Obtener hidden states
              encoder_hidden_states = outputs.encoder_last_hidden_state

              # Aplicar capa de adaptación
              adapted_hidden = self.adaptation_layer(encoder_hidden_states)

              # Skip connection
              enhanced_hidden = encoder_hidden_states + adapted_hidden

              # Proyección mejorada para logits
              enhanced_logits = self.enhanced_output(enhanced_hidden)

              # Calcular pérdida si hay labels
              loss = None
              if labels is not None:
                  # Usar solo logits válidos
                  shift_logits = enhanced_logits[..., :-1, :].contiguous()
                  shift_labels = labels[..., 1:].contiguous()

                  loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction='mean')
                  loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

                  # Verificar que no sea NaN
                  if torch.isnan(loss):
                      print("⚠️ NaN loss detectado, usando loss por defecto")
                      loss = torch.tensor(0.0, device=loss.device, requires_grad=True)

              return {
                  'loss': loss,
                  'logits': enhanced_logits,
                  'hidden_states': enhanced_hidden
              }

          except Exception as e:
              print(f"❌ Error en forward pass: {e}")
              # Devolver loss válido para evitar crash
              dummy_loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
              dummy_logits = torch.zeros(
                  input_ids.shape[0], input_ids.shape[1], self.byt5_model.config.vocab_size,
                  device=input_ids.device
              )
              return {
                  'loss': dummy_loss,
                  'logits': dummy_logits,
                  'hidden_states': None
              }

      def generate_translation(self, text, src_lang, tgt_lang, max_length=256):
          """Genera traducción usando ByT5 mejorado"""
          self.eval()

          with torch.no_grad():
              # Usar solo ByT5 para generación
              inputs = self.byt5_tokenizer(text, return_tensors="pt",
                                          max_length=max_length, truncation=True)
              inputs = {k: v.to(next(self.parameters()).device) for k, v in inputs.items()}

              try:
                  generated_tokens = self.byt5_model.generate(
                      **inputs,
                      max_length=max_length,
                      num_beams=2,
                      temperature=0.8,
                      do_sample=True,
                      pad_token_id=self.byt5_tokenizer.pad_token_id
                  )

                  translation = self.byt5_tokenizer.decode(generated_tokens[0],
                                                         skip_special_tokens=True)
                  return translation

              except Exception as e:
                  print(f"Error en generación: {e}")
                  return f"Error generating translation for: {text}"


Celda 5: Dataset para datos PRE-TOKENIZADOS con ByT5

In [4]:

  from torch.utils.data import Dataset, DataLoader
  import torch

  class PreTokenizedTranslationDataset(Dataset):
      """Dataset para datos ya tokenizados con ByT5Tokenizer"""

      def __init__(self, tokenized_pairs, max_length=256):
          """
          tokenized_pairs: Lista de diccionarios con:
          - input_ids: tensor de tokens de entrada
          - attention_mask: tensor de máscaras de atención
          - labels: tensor de tokens objetivo
          - target_attention_mask: tensor de máscaras objetivo
          - src_lang: idioma fuente
          - tgt_lang: idioma destino
          """
          self.tokenized_pairs = tokenized_pairs
          self.max_length = max_length

      def __len__(self):
          return len(self.tokenized_pairs)

      def pad_or_truncate(self, tensor, max_length, pad_token_id=0):
          """Trunca o hace padding a tensor"""
          if len(tensor) > max_length:
              return tensor[:max_length]
          elif len(tensor) < max_length:
              padding = torch.full((max_length - len(tensor),), pad_token_id, dtype=tensor.dtype)
              return torch.cat([tensor, padding])
          else:
              return tensor

      def __getitem__(self, idx):
          item = self.tokenized_pairs[idx]

          # Obtener tensores
          input_ids = item['input_ids']
          attention_mask = item['attention_mask']
          labels = item['labels']
          target_attention_mask = item['target_attention_mask']

          # Aplicar padding/truncation
          input_ids = self.pad_or_truncate(input_ids, self.max_length, pad_token_id=0)
          attention_mask = self.pad_or_truncate(attention_mask, self.max_length, pad_token_id=0)
          labels = self.pad_or_truncate(labels, self.max_length, pad_token_id=-100)  # -100 para ignorar en loss
          target_attention_mask = self.pad_or_truncate(target_attention_mask, self.max_length, pad_token_id=0)

          return {
              'input_ids': input_ids,
              'attention_mask': attention_mask,
              'labels': labels,
              'target_attention_mask': target_attention_mask,
              'src_lang': item['src_lang'],
              'tgt_lang': item['tgt_lang']
          }

  def create_tokenized_data_collator():
      """Collator para datos pre-tokenizados"""
      def collate_fn(batch):
          # Apilar tensores
          input_ids = torch.stack([item['input_ids'] for item in batch])
          attention_mask = torch.stack([item['attention_mask'] for item in batch])
          labels = torch.stack([item['labels'] for item in batch])
          target_attention_mask = torch.stack([item['target_attention_mask'] for item in batch])

          # Idiomas
          src_langs = [item['src_lang'] for item in batch]
          tgt_langs = [item['tgt_lang'] for item in batch]

          return {
              'input_ids': input_ids,
              'attention_mask': attention_mask,
              'labels': labels,
              'target_attention_mask': target_attention_mask,
              'src_langs': src_langs,
              'tgt_langs': tgt_langs
          }

      return collate_fn

  # Función helper para crear dataloaders
  def create_tokenized_dataloaders(train_pairs, val_pairs, batch_size=4, max_length=256):
      """Crea dataloaders para datos pre-tokenizados"""

      # Crear datasets
      train_dataset = PreTokenizedTranslationDataset(train_pairs, max_length)
      val_dataset = PreTokenizedTranslationDataset(val_pairs, max_length)

      # Crear collator
      collator = create_tokenized_data_collator()

      # Crear dataloaders
      train_loader = DataLoader(
          train_dataset,
          batch_size=batch_size,
          shuffle=True,
          collate_fn=collator,
          num_workers=0,  # 0 para evitar problemas en Colab
          pin_memory=torch.cuda.is_available()
      )

      val_loader = DataLoader(
          val_dataset,
          batch_size=batch_size,
          shuffle=False,
          collate_fn=collator,
          num_workers=0,
          pin_memory=torch.cuda.is_available()
      )

      print(f"✅ Dataloaders creados:")
      print(f"  • Train batches: {len(train_loader)}")
      print(f"  • Val batches: {len(val_loader)}")
      print(f"  • Batch size: {batch_size}")
      print(f"  • Max length: {max_length}")

      return train_loader, val_loader


Celda 6: Carga de datos PRE-TOKENIZADOS desde Google Drive



In [5]:

  from google.colab import drive
  import pandas as pd
  import os
  import glob
  import gc
  import torch
  import ast
  import json
  import random

  def mount_drive_and_setup_paths():
      """Monta Google Drive y configura rutas"""
      print("🔌 Montando Google Drive...")
      drive.mount('/content/drive')
      print("✅ Google Drive montado exitosamente")

      base_path = "/content/drive/MyDrive/GlobalTranslator/NMT/"
      dataset_path = base_path + "Dataset/"
      models_path = base_path + "Models/"

      os.makedirs(dataset_path, exist_ok=True)
      os.makedirs(models_path, exist_ok=True)

      print(f"📂 Dataset path: {dataset_path}")
      print(f"💾 Models path: {models_path}")

      return dataset_path, models_path

  def parse_tokenized_data(data_string):
      """Convierte string de lista a tensor de PyTorch"""
      try:
          if isinstance(data_string, str):
              if data_string.startswith('"') and data_string.endswith('"'):
                  data_string = data_string[1:-1]
              token_list = ast.literal_eval(data_string)
          else:
              token_list = data_string

          return torch.tensor(token_list, dtype=torch.long)
      except Exception as e:
          print(f"Error parseando tokens: {e}")
          return torch.tensor([1], dtype=torch.long)

  def load_tokenized_csv(file_path, max_samples=None):
      """Carga CSV con datos pre-tokenizados"""

      print(f"📥 Cargando {os.path.basename(file_path)}...")

      try:
          df = pd.read_csv(file_path)

          expected_cols = ['input_ids', 'input_attention_mask', 'target_ids', 'target_attention_mask', 'input_label', 'target_label']
          missing_cols = [col for col in expected_cols if col not in df.columns]

          if missing_cols:
              print(f"  ❌ Columnas faltantes: {missing_cols}")
              return []

          print(f"  📊 Filas encontradas: {len(df)}")

          if max_samples and len(df) > max_samples:
              df = df.head(max_samples)
              print(f"  ✂️ Limitado a: {len(df)} samples")

          tokenized_pairs = []
          errors = 0

          for idx, row in df.iterrows():
              try:
                  input_ids = parse_tokenized_data(row['input_ids'])
                  input_attention_mask = parse_tokenized_data(row['input_attention_mask'])
                  target_ids = parse_tokenized_data(row['target_ids'])
                  target_attention_mask = parse_tokenized_data(row['target_attention_mask'])

                  src_lang = str(row['input_label']).strip()
                  tgt_lang = str(row['target_label']).strip()

                  if len(input_ids) < 2 or len(target_ids) < 2:
                      continue

                  if len(input_ids) != len(input_attention_mask):
                      continue

                  if len(target_ids) != len(target_attention_mask):
                      continue

                  tokenized_pairs.append({
                      'input_ids': input_ids,
                      'attention_mask': input_attention_mask,
                      'labels': target_ids,
                      'target_attention_mask': target_attention_mask,
                      'src_lang': src_lang,
                      'tgt_lang': tgt_lang
                  })

              except Exception as e:
                  errors += 1
                  if errors < 5:
                      print(f"  ⚠️ Error en fila {idx}: {str(e)[:50]}...")
                  continue

          if errors > 0:
              print(f"  ⚠️ Total errores: {errors}")

          print(f"  ✅ Pares válidos: {len(tokenized_pairs)}")
          return tokenized_pairs

      except Exception as e:
          print(f"  ❌ Error cargando archivo: {e}")
          return []

  def load_fragmented_tokenized_data(dataset_path):
      """Carga todos los archivos CSV tokenizados"""

      train_files = sorted(glob.glob(dataset_path + "NMT_train*.csv"))
      val_files = sorted(glob.glob(dataset_path + "NMT_val*.csv"))

      print(f"\n📁 ARCHIVOS TOKENIZADOS ENCONTRADOS:")
      print(f"  • Training files: {len(train_files)}")
      total_size_mb = 0
      for f in train_files:
          size_mb = os.path.getsize(f) / (1024*1024)
          total_size_mb += size_mb
          print(f"    - {os.path.basename(f)} ({size_mb:.1f} MB)")

      print(f"  • Validation files: {len(val_files)}")
      for f in val_files:
          size_mb = os.path.getsize(f) / (1024*1024)
          total_size_mb += size_mb
          print(f"    - {os.path.basename(f)} ({size_mb:.1f} MB)")

      print(f"  📊 Tamaño total: {total_size_mb:.1f} MB")

      if not train_files:
          raise FileNotFoundError(f"❌ No se encontraron archivos NMT_train*.csv en {dataset_path}")

      # Configurar límites para evitar colapso de memoria
      memory_limit_mb = 2000
      samples_per_mb = 100
      max_samples_per_file = max(1000, int(memory_limit_mb * samples_per_mb / len(train_files)))

      print(f"🧠 Configuración de memoria:")
      print(f"  • Límite total: {memory_limit_mb} MB")
      print(f"  • Max samples por archivo: {max_samples_per_file}")

      def load_files_batch(file_list, file_type="training"):
          """Carga archivos en batches para optimizar memoria"""
          all_pairs = []

          print(f"\n⚡ Cargando {len(file_list)} archivos de {file_type}...")

          for i, file_path in enumerate(file_list):
              print(f"  📥 [{i+1}/{len(file_list)}] {os.path.basename(file_path)}")

              file_pairs = load_tokenized_csv(file_path, max_samples_per_file)
              all_pairs.extend(file_pairs)

              del file_pairs
              gc.collect()

              if torch.cuda.is_available():
                  memory_used = torch.cuda.memory_allocated() / 1e9
                  print(f"    💾 GPU Memory: {memory_used:.2f} GB")

          print(f"  🎯 Total {file_type}: {len(all_pairs):,} pares tokenizados")
          return all_pairs

      train_pairs = load_files_batch(train_files, "entrenamiento")

      if val_files:
          val_pairs = load_files_batch(val_files, "validación")
      else:
          print("\n⚠️ No hay archivos de validación, creando subset...")
          split_size = min(2000, max(100, len(train_pairs) // 10))
          val_pairs = train_pairs[-split_size:]
          train_pairs = train_pairs[:-split_size]
          print(f"  🔀 {len(val_pairs):,} pares movidos a validación")

      return train_pairs, val_pairs

  def limit_samples_per_epoch(train_pairs, val_pairs, config):
      """Limita samples por época para épocas más cortas"""

      print(f"\n✂️ LIMITANDO SAMPLES POR ÉPOCA:")

      original_train = len(train_pairs)
      original_val = len(val_pairs)

      # Limitar training samples
      if len(train_pairs) > config.max_samples_per_epoch:
          # Usar subset random pero reproducible
          random.seed(42)

          # Mezclar y tomar subset
          train_indices = list(range(len(train_pairs)))
          random.shuffle(train_indices)

          limited_train = [train_pairs[i] for i in train_indices[:config.max_samples_per_epoch]]
      else:
          limited_train = train_pairs

      # Limitar validation samples (proporcionalmente)
      val_ratio = len(val_pairs) / len(train_pairs) if len(train_pairs) > 0 else 0.1
      max_val_samples = int(config.max_samples_per_epoch * val_ratio)
      max_val_samples = max(100, min(1000, max_val_samples))  # Entre 100 y 1000

      if len(val_pairs) > max_val_samples:
          random.seed(42)
          val_indices = list(range(len(val_pairs)))
          random.shuffle(val_indices)
          limited_val = [val_pairs[i] for i in val_indices[:max_val_samples]]
      else:
          limited_val = val_pairs

      print(f"  • Train: {original_train:,} → {len(limited_train):,}")
      print(f"  • Val: {original_val:,} → {len(limited_val):,}")

      if original_train > 0:
          reduction_pct = ((original_train - len(limited_train)) / original_train * 100)
          print(f"  • Reducción: {reduction_pct:.1f}%")

      # Estimar tiempo
      batches_per_epoch = len(limited_train) // config.batch_size
      estimated_minutes = (batches_per_epoch * 0.2) / 60  # ~0.2 seg por batch
      print(f"  • Batches por época: {batches_per_epoch:,}")
      print(f"  • Tiempo estimado por época: ~{estimated_minutes:.0f}-{estimated_minutes*1.5:.0f} minutos")

      return limited_train, limited_val

  def setup_model_saving(models_path):
      """Configura sistema de guardado"""
      global drive_models_path
      drive_models_path = models_path
      print(f"💾 Sistema de guardado configurado en: {models_path}")
      return models_path

  def analyze_tokenized_data(train_pairs, val_pairs):
      """Analiza las características de los datos tokenizados"""

      if len(train_pairs) == 0:
          print("❌ No hay datos para analizar")
          return

      print(f"\n🔍 ANÁLISIS DE DATOS TOKENIZADOS:")
      print(f"  • Training samples: {len(train_pairs):,}")
      print(f"  • Validation samples: {len(val_pairs):,}")

      # Analizar longitudes
      input_lengths = [len(pair['input_ids']) for pair in train_pairs[:1000]]
      target_lengths = [len(pair['labels']) for pair in train_pairs[:1000]]

      print(f"\n📏 ESTADÍSTICAS DE LONGITUD (muestra de 1000):")
      print(f"  • Input - Promedio: {sum(input_lengths)/len(input_lengths):.1f}, Max: {max(input_lengths)}")
      print(f"  • Target - Promedio: {sum(target_lengths)/len(target_lengths):.1f}, Max: {max(target_lengths)}")

      # Analizar idiomas
      src_langs = {}
      tgt_langs = {}
      sample_size = min(1000, len(train_pairs))

      for pair in train_pairs[:sample_size]:
          src_lang = pair['src_lang']
          tgt_lang = pair['tgt_lang']
          src_langs[src_lang] = src_langs.get(src_lang, 0) + 1
          tgt_langs[tgt_lang] = tgt_langs.get(tgt_lang, 0) + 1

      print(f"\n🌍 IDIOMAS DETECTADOS (muestra de {sample_size}):")
      print(f"  • Origen: {dict(sorted(src_langs.items()))}")
      print(f"  • Destino: {dict(sorted(tgt_langs.items()))}")

      # Mostrar ejemplo
      if len(train_pairs) > 0:
          example = train_pairs[0]
          print(f"\n📝 EJEMPLO DE DATOS:")
          print(f"  • Input tokens: {example['input_ids'][:20].tolist()}... (len: {len(example['input_ids'])})")
          print(f"  • Target tokens: {example['labels'][:20].tolist()}... (len: {len(example['labels'])})")
          print(f"  • Idiomas: {example['src_lang']} → {example['tgt_lang']}")

  def prepare_tokenized_training_data():
      """Función principal para datos pre-tokenizados"""

      try:
          print("🚀 INICIANDO CARGA DE DATOS PRE-TOKENIZADOS CON ÉPOCAS OPTIMIZADAS")
          print("=" * 80)

          # Configurar rutas
          dataset_path, models_path = mount_drive_and_setup_paths()
          setup_model_saving(models_path)

          # Optimizar garbage collection
          gc.set_threshold(500, 10, 10)
          gc.collect()

          # Cargar datos tokenizados
          train_pairs, val_pairs = load_fragmented_tokenized_data(dataset_path)

          if len(train_pairs) == 0:
              raise ValueError("No se pudieron cargar datos tokenizados válidos")

          # LIMITAR SAMPLES POR ÉPOCA (NUEVO)
          train_pairs, val_pairs = limit_samples_per_epoch(train_pairs, val_pairs, config)

          # Analizar datos
          analyze_tokenized_data(train_pairs, val_pairs)

          print(f"\n✅ DATOS TOKENIZADOS LISTOS PARA ENTRENAMIENTO OPTIMIZADO")
          print(f"📊 Dataset: ✓ | 💾 Guardado: ✓ | ✂️ Épocas cortas: ✓ | 🧹 Memoria optimizada: ✓")

          return train_pairs, val_pairs

      except FileNotFoundError as e:
          print(f"\n📂 {str(e)}")
          print("💡 Verifica que tus archivos CSV estén en la ruta correcta")
          return create_example_tokenized_data()

      except Exception as e:
          print(f"\n❌ ERROR: {str(e)}")
          print("💡 Usando datos de ejemplo...")
          return create_example_tokenized_data()

  def create_example_tokenized_data():
      """Crea datos tokenizados de ejemplo para testing"""
      print("🔧 Generando datos tokenizados de ejemplo...")

      example_pairs = []

      for i in range(5):
          input_ids = torch.tensor([119, 104, 118, 119, 35] + [100+i]*10 + [1])
          attention_mask = torch.ones_like(input_ids)
          labels = torch.tensor([104, 109, 104, 112, 115, 111, 114, 35] + [200+i]*5 + [1])
          target_attention_mask = torch.ones_like(labels)

          example_pairs.append({
              'input_ids': input_ids,
              'attention_mask': attention_mask,
              'labels': labels,
              'target_attention_mask': target_attention_mask,
              'src_lang': 'EN' if i % 2 == 0 else 'ES',
              'tgt_lang': 'ES' if i % 2 == 0 else 'SW'
          })

      # Limitar también los ejemplos
      limited_train = example_pairs[:4]
      limited_val = example_pairs[4:]

      # Simular limitación
      if hasattr(config, 'max_samples_per_epoch'):
          limited_train, limited_val = limit_samples_per_epoch(limited_train, limited_val, config)

      print(f"✅ Datos de ejemplo creados: {len(limited_train)} train, {len(limited_val)} val")
      return limited_train, limited_val

  # Variables globales
  drive_models_path = None

  # EJECUTAR: Cargar datos tokenizados
  print("Iniciando carga de datos pre-tokenizados con épocas optimizadas...")
  train_pairs, val_pairs = prepare_tokenized_training_data()


Iniciando carga de datos pre-tokenizados con épocas optimizadas...
🚀 INICIANDO CARGA DE DATOS PRE-TOKENIZADOS CON ÉPOCAS OPTIMIZADAS
🔌 Montando Google Drive...
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
✅ Google Drive montado exitosamente
📂 Dataset path: /content/drive/MyDrive/GlobalTranslator/NMT/Dataset/
💾 Models path: /content/drive/MyDrive/GlobalTranslator/NMT/Models/
💾 Sistema de guardado configurado en: /content/drive/MyDrive/GlobalTranslator/NMT/Models/

📁 ARCHIVOS TOKENIZADOS ENCONTRADOS:
  • Training files: 2
    - NMT_train17.csv (180.9 MB)
    - NMT_train18.csv (181.0 MB)
  • Validation files: 1
    - NMT_val3.csv (180.9 MB)
  📊 Tamaño total: 542.8 MB
🧠 Configuración de memoria:
  • Límite total: 2000 MB
  • Max samples por archivo: 100000

⚡ Cargando 2 archivos de entrenamiento...
  📥 [1/2] NMT_train17.csv
📥 Cargando NMT_train17.csv...
  📊 Filas encontradas: 100000
  ✅ Pares válidos: 10000

Celda 7: Función de entrenamiento

In [6]:
def train_epoch(model, dataloader, optimizer, scheduler, device, epoch, config):
    """Entrena una época"""
    model.train()
    total_loss = 0
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}")

    for batch_idx, batch in enumerate(progress_bar):
        # Mover datos a device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        src_langs = batch['src_langs']
        tgt_langs = batch['tgt_langs']

        # Determinar si usar NLLB basado en el primer elemento del batch
        use_nllb = src_langs[0] in config.nllb_languages

        # Forward pass
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            src_lang=src_langs[0],
            tgt_lang=tgt_langs[0],
            use_nllb=use_nllb
        )

        loss = outputs['loss']

        # Gradient accumulation
        loss = loss / config.gradient_accumulation_steps
        loss.backward()

        if (batch_idx + 1) % config.gradient_accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        total_loss += loss.item() * config.gradient_accumulation_steps

        # Actualizar progress bar
        progress_bar.set_postfix({
            'loss': loss.item() * config.gradient_accumulation_steps,
            'lr': scheduler.get_last_lr()[0]
        })

        # Liberar memoria
        del loss, outputs
        if batch_idx % 10 == 0:
            torch.cuda.empty_cache() if torch.cuda.is_available() else None

    return total_loss / len(dataloader)

def validate(model, dataloader, device, config):
    """Valida el modelo"""
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            src_langs = batch['src_langs']
            tgt_langs = batch['tgt_langs']

            use_nllb = src_langs[0] in config.nllb_languages

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
                src_lang=src_langs[0],
                tgt_lang=tgt_langs[0],
                use_nllb=use_nllb
            )

            total_loss += outputs['loss'].item()

    return total_loss / len(dataloader)

Celda 8: Loop principal de entrenamiento para datos PRE-TOKENIZADOS

In [7]:

  def train_epoch_tokenized(model, dataloader, optimizer, scheduler, device, epoch, config):
      """Entrena una época con datos pre-tokenizados"""
      model.train()
      total_loss = 0
      progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{config.num_epochs}")

      for batch_idx, batch in enumerate(progress_bar):
          input_ids = batch['input_ids'].to(device)
          attention_mask = batch['attention_mask'].to(device)
          labels = batch['labels'].to(device)
          src_langs = batch['src_langs']
          tgt_langs = batch['tgt_langs']

          # MAPEO DE IDIOMAS MEJORADO
          src_lang_code = src_langs[0]
          tgt_lang_code = tgt_langs[0]

          # Usar mapeo de la configuración
          lang_mapping = config.language_mapping
          mapped_src = lang_mapping.get(src_lang_code, src_lang_code)
          mapped_tgt = lang_mapping.get(tgt_lang_code, tgt_lang_code)

          # Verificar soporte NLLB usando la configuración
          use_nllb = config.is_supported_by_nllb(src_lang_code) and config.is_supported_by_nllb(tgt_lang_code)

          # Forward pass
          try:
              outputs = model(
                  input_ids=input_ids,
                  attention_mask=attention_mask,
                  labels=labels,
                  src_lang=mapped_src,
                  tgt_lang=mapped_tgt,
                  use_nllb=use_nllb
              )

              loss = outputs['loss']

              if loss is None or torch.isnan(loss):
                  print(f"⚠️ Loss is None or NaN at batch {batch_idx}")
                  continue

          except Exception as e:
              print(f"❌ Error en forward pass: {str(e)}")
              print(f"Input shape: {input_ids.shape}, Labels shape: {labels.shape}")
              continue

          # Gradient accumulation
          loss = loss / config.gradient_accumulation_steps
          loss.backward()

          if (batch_idx + 1) % config.gradient_accumulation_steps == 0:
              torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
              optimizer.step()
              scheduler.step()
              optimizer.zero_grad()

          total_loss += loss.item() * config.gradient_accumulation_steps

          # Actualizar progress bar
          progress_bar.set_postfix({
              'loss': loss.item() * config.gradient_accumulation_steps,
              'lr': scheduler.get_last_lr()[0],
              'use_nllb': use_nllb,
              'epoch': f"{epoch+1}/{config.num_epochs}"
          })

          # Liberar memoria periódicamente
          del loss, outputs
          if batch_idx % 20 == 0:
              torch.cuda.empty_cache() if torch.cuda.is_available() else None

      return total_loss / len(dataloader)

  def validate_tokenized(model, dataloader, device, config):
      """Valida el modelo con datos pre-tokenizados"""
      model.eval()
      total_loss = 0
      num_batches = 0

      with torch.no_grad():
          for batch in tqdm(dataloader, desc="Validation"):
              input_ids = batch['input_ids'].to(device)
              attention_mask = batch['attention_mask'].to(device)
              labels = batch['labels'].to(device)
              src_langs = batch['src_langs']
              tgt_langs = batch['tgt_langs']

              # Mapear idiomas usando configuración
              src_lang_code = src_langs[0]
              tgt_lang_code = tgt_langs[0]
              lang_mapping = config.language_mapping
              mapped_src = lang_mapping.get(src_lang_code, src_lang_code)
              mapped_tgt = lang_mapping.get(tgt_lang_code, tgt_lang_code)

              use_nllb = config.is_supported_by_nllb(src_lang_code) and config.is_supported_by_nllb(tgt_lang_code)

              try:
                  outputs = model(
                      input_ids=input_ids,
                      attention_mask=attention_mask,
                      labels=labels,
                      src_lang=mapped_src,
                      tgt_lang=mapped_tgt,
                      use_nllb=use_nllb
                  )

                  if outputs['loss'] is not None:
                      total_loss += outputs['loss'].item()
                      num_batches += 1

              except Exception as e:
                  print(f"⚠️ Error en validación: {str(e)}")
                  continue

      return total_loss / num_batches if num_batches > 0 else float('inf')

  def train_hybrid_model_tokenized(model, train_pairs, val_pairs, config, device):
      """Entrena el modelo híbrido con datos PRE-TOKENIZADOS y épocas optimizadas"""

      print(f"🏗️ CONFIGURANDO ENTRENAMIENTO OPTIMIZADO PARA ÉPOCAS CORTAS")
      print(f"=" * 70)

      # Crear dataloaders para datos tokenizados
      print("📦 Creando dataloaders...")
      train_loader, val_loader = create_tokenized_dataloaders(
          train_pairs, val_pairs,
          batch_size=config.batch_size,
          max_length=config.max_length
      )

      # Verificar un batch de ejemplo
      print("\n🔍 Verificando formato de datos...")
      sample_batch = next(iter(train_loader))
      print(f"  • Input shape: {sample_batch['input_ids'].shape}")
      print(f"  • Labels shape: {sample_batch['labels'].shape}")
      print(f"  • Idiomas ejemplo: {sample_batch['src_langs'][0]} → {sample_batch['tgt_langs'][0]}")
      print(f"  • Tokens ejemplo: {sample_batch['input_ids'][0][:10].tolist()}...")

      # Configurar optimizador
      print("\n⚙️ Configurando optimizador...")
      optimizer = torch.optim.AdamW(
          model.parameters(),
          lr=config.learning_rate,
          weight_decay=0.01,
          eps=1e-8
      )

      # Configurar scheduler
      total_steps = len(train_loader) * config.num_epochs

      if hasattr(config, 'lr_scheduler_type') and config.lr_scheduler_type == "cosine":
          from torch.optim.lr_scheduler import CosineAnnealingLR
          scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=config.min_lr)
      else:
          scheduler = get_linear_schedule_with_warmup(
              optimizer,
              num_warmup_steps=config.warmup_steps,
              num_training_steps=total_steps
          )

      print(f"  • Learning rate: {config.learning_rate}")
      print(f"  • Total steps: {total_steps}")
      print(f"  • Scheduler: {getattr(config, 'lr_scheduler_type', 'linear')}")

      # CONFIGURACIÓN DE ENTRENAMIENTO OPTIMIZADO
      history = {'train_loss': [], 'val_loss': []}
      best_val_loss = float('inf')
      patience_counter = 0
      epoch_losses = []

      print(f"\n💾 CONFIGURACIÓN DE GUARDADO Y EARLY STOPPING:")
      print(f"  • Guardar cada: {config.save_every_n_epochs} época(s)")
      print(f"  • Early stopping patience: {config.early_stopping_patience} épocas")
      print(f"  • Samples por época: {len(train_pairs):,}")
      print(f"  • Batches por época: {len(train_loader):,}")

      print(f"\n🚀 INICIANDO ENTRENAMIENTO OPTIMIZADO...")
      print(f"  • Épocas: {config.num_epochs}")
      print(f"  • Batch size: {config.batch_size}")
      print(f"  • Gradient accumulation: {config.gradient_accumulation_steps}")

      # Training loop optimizado
      for epoch in range(config.num_epochs):
          print(f"\n{'='*70}")
          print(f"ÉPOCA {epoch+1}/{config.num_epochs}")
          print(f"{'='*70}")

          # Entrenar
          try:
              train_loss = train_epoch_tokenized(
                  model, train_loader, optimizer, scheduler, device, epoch, config
              )
              history['train_loss'].append(train_loss)

              print(f"\n📊 Train Loss: {train_loss:.4f}")

          except Exception as e:
              print(f"❌ Error en entrenamiento: {e}")
              print("Intentando continuar...")
              continue

          # Validar
          try:
              val_loss = validate_tokenized(model, val_loader, device, config)
              history['val_loss'].append(val_loss)

              print(f"📊 Val Loss: {val_loss:.4f}")

          except Exception as e:
              print(f"❌ Error en validación: {e}")
              val_loss = float('inf')
              history['val_loss'].append(val_loss)

          # GUARDADO CADA ÉPOCA (NUEVO)
          if (epoch + 1) % config.save_every_n_epochs == 0:
              checkpoint_name = f"checkpoint_epoch_{epoch+1}.pt"
              checkpoint_data = {
                  'epoch': epoch,
                  'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'scheduler_state_dict': scheduler.state_dict(),
                  'train_loss': train_loss,
                  'val_loss': val_loss,
                  'config': config.__dict__,
                  'history': history,
                  'epoch_losses': epoch_losses
              }
              torch.save(checkpoint_data, checkpoint_name)
              print(f"💾 Checkpoint guardado: {checkpoint_name}")

          # Guardar mejor modelo
          if val_loss < best_val_loss:
              best_val_loss = val_loss
              patience_counter = 0

              # Guardar como mejor modelo
              best_checkpoint_data = {
                  'epoch': epoch,
                  'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'train_loss': train_loss,
                  'val_loss': val_loss,
                  'config': config.__dict__
              }

              torch.save(best_checkpoint_data, 'best_hybrid_model_tokenized.pt')
              print("🏆 ¡NUEVO MEJOR MODELO GUARDADO!")

          else:
              patience_counter += 1
              print(f"⏳ Paciencia: {patience_counter}/{config.early_stopping_patience}")

          # EARLY STOPPING (NUEVO)
          if patience_counter >= config.early_stopping_patience:
              print(f"\n🛑 EARLY STOPPING ACTIVADO!")
              print(f"   • {config.early_stopping_patience} épocas consecutivas sin mejora")
              print(f"   • Mejor val_loss: {best_val_loss:.4f} (época {epoch+1-patience_counter})")
              print(f"   • Entrenamiento finalizado en época {epoch+1}")
              break

          # Guardar estadísticas de la época
          epoch_losses.append({
              'epoch': epoch+1,
              'train_loss': train_loss,
              'val_loss': val_loss,
              'lr': scheduler.get_last_lr()[0],
              'patience': patience_counter
          })

          # Mostrar resumen de época
          print(f"\n📈 RESUMEN ÉPOCA {epoch+1}:")
          print(f"  • Train Loss: {train_loss:.4f}")
          print(f"  • Val Loss: {val_loss:.4f}")
          print(f"  • Best Val Loss: {best_val_loss:.4f}")
          print(f"  • Learning Rate: {scheduler.get_last_lr()[0]:.2e}")
          print(f"  • Paciencia: {patience_counter}/{config.early_stopping_patience}")

          # Estimación de tiempo restante
          remaining_epochs = config.num_epochs - (epoch + 1)
          estimated_time = remaining_epochs * 50  # ~50 min por época
          print(f"  • Épocas restantes: {remaining_epochs}")
          print(f"  • Tiempo estimado restante: ~{estimated_time} minutos")

          # Limpiar memoria
          torch.cuda.empty_cache() if torch.cuda.is_available() else None
          gc.collect()

      print(f"\n🎉 ENTRENAMIENTO COMPLETADO!")
      print(f"  • Mejor Val Loss: {best_val_loss:.4f}")
      print(f"  • Épocas completadas: {len(history['train_loss'])}")
      print(f"  • Modelo guardado en: best_hybrid_model_tokenized.pt")

      # Guardar historial final
      final_history = {
          'history': history,
          'epoch_details': epoch_losses,
          'best_val_loss': best_val_loss,
          'config': config.__dict__
      }
      torch.save(final_history, 'training_history.pt')
      print(f"  • Historial guardado en: training_history.pt")

      return history

  # Función wrapper para compatibilidad con el código existente
  def train_hybrid_model(model, train_pairs, val_pairs, config, device):
      """Función wrapper que detecta el tipo de datos y usa la función apropiada"""

      if len(train_pairs) > 0:
          if isinstance(train_pairs[0], dict) and 'input_ids' in train_pairs[0]:
              print("🔍 Detectados datos PRE-TOKENIZADOS")
              return train_hybrid_model_tokenized(model, train_pairs, val_pairs, config, device)
          else:
              print("🔍 Detectados datos de TEXTO RAW")
              raise NotImplementedError("Función para texto raw no implementada en esta versión")
      else:
          raise ValueError("No hay datos de entrenamiento")


Celda 9: Inicialización y entrenamiento

In [None]:
# Inicializar modelo
print("Initializing Hybrid Model...")
model = HybridNLLBByT5Model(config)
model = model.to(device)

# Contar parámetros
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Entrenar modelo
print("\nStarting training...")
history = train_hybrid_model(model, train_pairs, val_pairs, config, device)

Initializing Hybrid Model...
Loading NLLB model...


`torch_dtype` is deprecated! Use `dtype` instead!


Loading ByT5 model...
NLLB hidden size: 1024
ByT5 hidden size: 1472
ByT5 vocab size: 384
Total parameters: 921,689,280
Trainable parameters: 921,689,280

Starting training...
🔍 Detectados datos PRE-TOKENIZADOS
🏗️ CONFIGURANDO ENTRENAMIENTO OPTIMIZADO PARA ÉPOCAS CORTAS
📦 Creando dataloaders...
✅ Dataloaders creados:
  • Train batches: 6250
  • Val batches: 125
  • Batch size: 8
  • Max length: 64

🔍 Verificando formato de datos...
  • Input shape: torch.Size([8, 64])
  • Labels shape: torch.Size([8, 64])
  • Idiomas ejemplo: ES → SW
  • Tokens ejemplo: [119, 117, 100, 113, 118, 111, 100, 119, 104, 35]...

⚙️ Configurando optimizador...
  • Learning rate: 1e-05
  • Total steps: 62500
  • Scheduler: cosine

💾 CONFIGURACIÓN DE GUARDADO Y EARLY STOPPING:
  • Guardar cada: 1 época(s)
  • Early stopping patience: 3 épocas
  • Samples por época: 50,000
  • Batches por época: 6,250

🚀 INICIANDO ENTRENAMIENTO OPTIMIZADO...
  • Épocas: 10
  • Batch size: 8
  • Gradient accumulation: 4

ÉPOCA 1/1

Epoch 1/10:   0%|          | 0/6250 [00:00<?, ?it/s]

Celda 10: Función de inferencia y pruebas

In [None]:
def test_translation(model, test_pairs, device):
    """Prueba el modelo con pares de prueba"""
    model.eval()
    model = model.to(device)

    results = []

    for source_text, expected_target, src_lang, tgt_lang in test_pairs:
        print(f"\n{'='*50}")
        print(f"Source ({src_lang}): {source_text}")
        print(f"Expected ({tgt_lang}): {expected_target}")

        # Generar traducción
        translation = model.generate_translation(
            source_text, src_lang, tgt_lang
        )

        print(f"Generated: {translation}")

        results.append({
            'source': source_text,
            'expected': expected_target,
            'generated': translation,
            'src_lang': src_lang,
            'tgt_lang': tgt_lang
        })

    return results

# Pruebas
test_pairs = [
    ("Hello, how are you today?", "Hola, ¿cómo estás hoy?", "eng_Latn", "spa_Latn"),
    ("The weather is nice", "El clima está agradable", "eng_Latn", "spa_Latn"),
    ("I need help", "Necesito ayuda", "eng_Latn", "spa_Latn"),
]

print("Testing model...")
results = test_translation(model, test_pairs, device)

Celda 11: Guardar y cargar modelo

In [None]:
def save_hybrid_model(model, filepath='hybrid_translation_model'):
    """Guarda el modelo completo"""
    os.makedirs(filepath, exist_ok=True)

    # Guardar configuración
    with open(f'{filepath}/config.json', 'w') as f:
        json.dump({
            'nllb_model_name': model.config.nllb_model_name,
            'byt5_model_name': model.config.byt5_model_name,
            'hidden_size': model.config.hidden_size,
            'fusion_dropout': model.config.fusion_dropout,
            'nllb_languages': model.config.nllb_languages,
            'new_languages': model.config.new_languages
        }, f, indent=2)

    # Guardar pesos del modelo
    torch.save(model.state_dict(), f'{filepath}/model_weights.pt')

    print(f"Model saved to {filepath}")

def load_hybrid_model(filepath='hybrid_translation_model', device='cuda'):
    """Carga el modelo guardado"""
    # Cargar configuración
    with open(f'{filepath}/config.json', 'r') as f:
        config_dict = json.load(f)

    # Recrear configuración
    config = HybridTranslationConfig()
    for key, value in config_dict.items():
        setattr(config, key, value)

    # Recrear modelo
    model = HybridNLLBByT5Model(config)

    # Cargar pesos
    model.load_state_dict(torch.load(f'{filepath}/model_weights.pt', map_location=device))
    model = model.to(device)
    model.eval()

    print(f"Model loaded from {filepath}")
    return model

# Guardar modelo
save_hybrid_model(model, 'my_hybrid_translator')

Celda 12: Agregar soporte para nuevos idiomas

In [None]:
def add_new_language_support(model, new_language_pairs, config):
    """
    Agrega soporte para un nuevo idioma entrenando solo ByT5
    new_language_pairs: Lista de (source, target, src_lang, tgt_lang)
    """
    # Congelar NLLB para preservar conocimiento
    for param in model.nllb_model.parameters():
        param.requires_grad = False

    # Solo entrenar ByT5 y capas de fusión
    trainable_params = []
    for name, param in model.named_parameters():
        if 'nllb' not in name:
            param.requires_grad = True
            trainable_params.append(param)
        else:
            param.requires_grad = False

    # Configurar optimizer solo para parámetros entrenables
    optimizer = torch.optim.AdamW(trainable_params, lr=config.learning_rate * 0.5)

    # Crear dataset
    tokenizer = model.byt5_tokenizer
    dataset = MultilingualTranslationDataset(new_language_pairs, tokenizer, config.max_length)
    collator = create_data_collator(tokenizer)
    dataloader = DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=True,
        collate_fn=collator
    )

    # Fine-tuning rápido
    model.train()
    for epoch in range(2):  # Menos épocas para fine-tuning
        print(f"\nFine-tuning for new language - Epoch {epoch+1}")
        total_loss = 0

        for batch in tqdm(dataloader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
                src_lang=batch['src_langs'][0],
                tgt_lang=batch['tgt_langs'][0],
                use_nllb=False  # Forzar uso de ByT5
            )

            loss = outputs['loss']
            loss.backward()

            torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
            optimizer.step()
            optimizer.zero_grad()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Average Loss: {avg_loss:.4f}")

    # Descongelar todo para uso normal
    for param in model.parameters():
        param.requires_grad = True

    print("✓ New language support added successfully!")
    return model

# Ejemplo de uso para agregar un nuevo idioma
new_lang_pairs = [
    ("Hello", "Saluton", "eng", "esperanto"),
    ("Goodbye", "Ĝis revido", "eng", "esperanto"),
    # Agregar más pares de entrenamiento
]

# model = add_new_language_support(model, new_lang_pairs, config)

Celda 13: Evaluación con métricas BLEU

In [None]:
from sacrebleu import corpus_bleu

def evaluate_model_bleu(model, test_pairs, device):
    """Evalúa el modelo usando BLEU score"""
    model.eval()

    predictions = []
    references = []

    print("Generating translations for evaluation...")
    for source, target, src_lang, tgt_lang in tqdm(test_pairs):
        # Generar traducción
        translation = model.generate_translation(
            source, src_lang, tgt_lang
        )

        predictions.append(translation)
        references.append([target])  # BLEU espera lista de referencias

    # Calcular BLEU score
    bleu = corpus_bleu(predictions, references)

    print(f"\n{'='*50}")
    print(f"BLEU Score: {bleu.score:.2f}")
    print(f"{'='*50}")

    return bleu.score

# Evaluar modelo
# bleu_score = evaluate_model_bleu(model, test_pairs, device)

Celda 14: Interfaz interactiva para traducción

In [None]:
def interactive_translation(model, device):
    """Interfaz interactiva para probar traducciones"""
    model.eval()
    model = model.to(device)

    print("="*60)
    print("Interactive Translation Interface")
    print("="*60)
    print("\nSupported NLLB languages:", ", ".join(model.config.nllb_languages[:10]), "...")
    print("\nType 'quit' to exit\n")

    while True:
        # Input source text
        source_text = input("Enter text to translate (or 'quit'): ").strip()
        if source_text.lower() == 'quit':
            break

        # Input source language
        src_lang = input("Source language code (e.g., 'eng_Latn'): ").strip()

        # Input target language
        tgt_lang = input("Target language code (e.g., 'spa_Latn'): ").strip()

        try:
            # Generate translation
            print("\nTranslating...")
            translation = model.generate_translation(
                source_text, src_lang, tgt_lang
            )

            print(f"\n{'='*40}")
            print(f"Source ({src_lang}): {source_text}")
            print(f"Translation ({tgt_lang}): {translation}")
            print(f"{'='*40}\n")

        except Exception as e:
            print(f"Error: {e}\n")

print("\nThank you for using the translator!")

# Ejecutar interfaz interactiva
# interactive_translation(model, device)

Celda 15: Preparación de datos desde archivos

In [None]:
def load_data_from_file(filepath, file_format='tsv'):
    """
    Carga pares de traducción desde un archivo
    Formato esperado: source_text\ttarget_text\tsrc_lang\ttgt_lang
    """
    data_pairs = []

    if file_format == 'tsv':
        df = pd.read_csv(filepath, sep='\t', header=None,
                        names=['source', 'target', 'src_lang', 'tgt_lang'])
    elif file_format == 'csv':
        df = pd.read_csv(filepath)
    elif file_format == 'json':
        df = pd.read_json(filepath)
    else:
        raise ValueError(f"Unsupported file format: {file_format}")

    for _, row in df.iterrows():
        data_pairs.append((
            row['source'],
            row['target'],
            row['src_lang'],
            row['tgt_lang']
        ))

    return data_pairs

def prepare_parallel_corpus(source_file, target_file, src_lang, tgt_lang):
    """
    Prepara corpus paralelo desde archivos separados
    """
    with open(source_file, 'r', encoding='utf-8') as f:
        source_lines = f.readlines()

    with open(target_file, 'r', encoding='utf-8') as f:
        target_lines = f.readlines()

    assert len(source_lines) == len(target_lines), "Files must have same number of lines"

    data_pairs = []
    for src, tgt in zip(source_lines, target_lines):
        src = src.strip()
        tgt = tgt.strip()
        if src and tgt:  # Skip empty lines
            data_pairs.append((src, tgt, src_lang, tgt_lang))

    return data_pairs

# Ejemplo de uso
# train_pairs = load_data_from_file('path/to/training_data.tsv')
# val_pairs = load_data_from_file('path/to/validation_data.tsv')

Celda 16: Data augmentation para mejorar el entrenamiento

In [None]:
import random

class DataAugmentation:
    """Técnicas de augmentation para datos de traducción"""

    @staticmethod
    def back_translation(text, model, src_lang, tgt_lang, intermediate_lang='eng_Latn'):
        """Traducción ida y vuelta para generar variaciones"""
        # Traducir a idioma intermedio
        intermediate = model.generate_translation(text, src_lang, intermediate_lang)
        # Traducir de vuelta
        back_translated = model.generate_translation(intermediate, intermediate_lang, src_lang)
        return back_translated

    @staticmethod
    def token_dropout(text, dropout_prob=0.1):
        """Elimina tokens aleatoriamente"""
        tokens = text.split()
        kept_tokens = [token for token in tokens if random.random() > dropout_prob]
        return ' '.join(kept_tokens) if kept_tokens else text

    @staticmethod
    def token_shuffle(text, shuffle_distance=3):
        """Mezcla tokens dentro de una ventana"""
        tokens = text.split()
        for i in range(len(tokens)):
            j = min(len(tokens) - 1, i + random.randint(0, shuffle_distance))
            if i != j:
                tokens[i], tokens[j] = tokens[j], tokens[i]
        return ' '.join(tokens)

    @staticmethod
    def synonym_replacement(text, replacement_prob=0.1):
        """Reemplaza palabras con sinónimos (simplificado)"""
        # Diccionario simple de sinónimos
        synonyms = {
            'good': ['great', 'excellent', 'nice'],
            'bad': ['poor', 'terrible', 'awful'],
            'big': ['large', 'huge', 'enormous'],
            'small': ['tiny', 'little', 'minute'],
            # Agregar más sinónimos según necesidad
        }

        tokens = text.split()
        augmented_tokens = []

        for token in tokens:
            token_lower = token.lower()
            if token_lower in synonyms and random.random() < replacement_prob:
                augmented_tokens.append(random.choice(synonyms[token_lower]))
            else:
                augmented_tokens.append(token)

        return ' '.join(augmented_tokens)

def augment_training_data(data_pairs, augmentation_factor=2):
    """Aumenta los datos de entrenamiento"""
    augmenter = DataAugmentation()
    augmented_pairs = []

    for source, target, src_lang, tgt_lang in data_pairs:
        # Agregar par original
        augmented_pairs.append((source, target, src_lang, tgt_lang))

        # Agregar variaciones
        for _ in range(augmentation_factor - 1):
            aug_method = random.choice([
                augmenter.token_dropout,
                augmenter.token_shuffle,
                augmenter.synonym_replacement
            ])

            aug_source = aug_method(source)
            augmented_pairs.append((aug_source, target, src_lang, tgt_lang))

    random.shuffle(augmented_pairs)
    return augmented_pairs

# Aumentar datos de entrenamiento
# augmented_train_pairs = augment_training_data(train_pairs, augmentation_factor=2)
# print(f"Original pairs: {len(train_pairs)}")
# print(f"Augmented pairs: {len(augmented_train_pairs)}")

Celda 17: Optimización de memoria y velocidad

In [None]:
class OptimizedHybridModel(HybridNLLBByT5Model):
    """Versión optimizada del modelo híbrido con mixed precision y gradient checkpointing"""

    def __init__(self, config):
        super().__init__(config)

        # Habilitar gradient checkpointing para ahorrar memoria
        if hasattr(self.nllb_model, 'gradient_checkpointing_enable'):
            self.nllb_model.gradient_checkpointing_enable()
        if hasattr(self.byt5_model, 'gradient_checkpointing_enable'):
            self.byt5_model.gradient_checkpointing_enable()

        # Configurar mixed precision
        self.use_amp = torch.cuda.is_available()

    def forward_with_amp(self, *args, **kwargs):
        """Forward pass con automatic mixed precision"""
        if self.use_amp:
            with torch.cuda.amp.autocast():
                return self.forward(*args, **kwargs)
        else:
            return self.forward(*args, **kwargs)

def train_with_optimization(model, train_pairs, val_pairs, config, device):
    """Entrenamiento optimizado con mixed precision y gradient accumulation"""

    # Preparar datos
    tokenizer = model.byt5_tokenizer
    train_dataset = MultilingualTranslationDataset(train_pairs, tokenizer, config.max_length)
    val_dataset = MultilingualTranslationDataset(val_pairs, tokenizer, config.max_length)

    collator = create_data_collator(tokenizer)
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        collate_fn=collator,
        num_workers=2,
        pin_memory=True if torch.cuda.is_available() else False
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size * 2,  # Batch más grande para validación
        shuffle=False,
        collate_fn=collator
    )

    # Optimizador y scheduler
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,
        weight_decay=0.01
    )

    total_steps = len(train_loader) * config.num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=config.warmup_steps,
        num_training_steps=total_steps
    )

    # GradScaler para mixed precision
    scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None

    # Training loop optimizado
    best_val_loss = float('inf')

    for epoch in range(config.num_epochs):
        model.train()
        train_loss = 0
        optimizer.zero_grad()

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.num_epochs}")

        for batch_idx, batch in enumerate(progress_bar):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Mixed precision forward pass
            if scaler is not None:
                with torch.cuda.amp.autocast():
                    outputs = model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        labels=labels,
                        src_lang=batch['src_langs'][0],
                        tgt_lang=batch['tgt_langs'][0]
                    )
                    loss = outputs['loss'] / config.gradient_accumulation_steps

                # Backward pass con scaler
                scaler.scale(loss).backward()

                if (batch_idx + 1) % config.gradient_accumulation_steps == 0:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    scaler.step(optimizer)
                    scaler.update()
                    scheduler.step()
                    optimizer.zero_grad()
            else:
                # Sin mixed precision
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels,
                    src_lang=batch['src_langs'][0],
                    tgt_lang=batch['tgt_langs'][0]
                )
                loss = outputs['loss'] / config.gradient_accumulation_steps
                loss.backward()

                if (batch_idx + 1) % config.gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()

            train_loss += loss.item() * config.gradient_accumulation_steps

            # Actualizar progress bar
            progress_bar.set_postfix({
                'loss': loss.item() * config.gradient_accumulation_steps,
                'lr': scheduler.get_last_lr()[0]
            })

            # Limpiar cache periódicamente
            if batch_idx % 50 == 0 and torch.cuda.is_available():
                torch.cuda.empty_cache()

        # Validación
        model.eval()
        val_loss = validate(model, val_loader, device, config)

        avg_train_loss = train_loss / len(train_loader)
        print(f"\nEpoch {epoch+1} - Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}")

        # Guardar mejor modelo
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_optimized_model.pt')
            print("✓ Saved best model")

    return model

# Usar modelo optimizado
# optimized_config = HybridTranslationConfig()
# optimized_config.batch_size = 8  # Podemos usar batch más grande con optimizaciones
# optimized_model = OptimizedHybridModel(optimized_config).to(device)
# optimized_model = train_with_optimization(optimized_model, train_pairs, val_pairs, optimized_config, device)

Celda 18: Exportar modelo para producción

In [None]:
def export_model_for_production(model, export_path='production_model'):
    """Exporta el modelo para uso en producción"""

    os.makedirs(export_path, exist_ok=True)

    # 1. Guardar el modelo completo en formato PyTorch
    torch.save(model, f'{export_path}/complete_model.pt')

    # 2. Guardar solo los pesos (más eficiente)
    torch.save(model.state_dict(), f'{export_path}/model_weights.pt')

    # 3. Exportar a ONNX para interoperabilidad
    dummy_input = torch.randint(0, 256, (1, 128)).to(device)
    dummy_attention = torch.ones(1, 128).to(device)

    try:
        torch.onnx.export(
            model,
            (dummy_input, dummy_attention),
            f'{export_path}/model.onnx',
            input_names=['input_ids', 'attention_mask'],
            output_names=['output'],
            dynamic_axes={
                'input_ids': {0: 'batch_size', 1: 'sequence'},
                'attention_mask': {0: 'batch_size', 1: 'sequence'},
                'output': {0: 'batch_size', 1: 'sequence'}
            }
        )
        print("✓ ONNX export successful")
    except Exception as e:
        print(f"ONNX export failed: {e}")

    # 4. Guardar configuración
    config_dict = {
        'model_type': 'HybridNLLBByT5',
        'nllb_model': model.config.nllb_model_name,
        'byt5_model': model.config.byt5_model_name,
        'supported_languages': {
            'nllb': model.config.nllb_languages,
            'custom': model.config.new_languages
        },
        'max_length': model.config.max_length,
        'device_requirements': 'CUDA recommended, CPU supported'
    }

    with open(f'{export_path}/config.json', 'w') as f:
        json.dump(config_dict, f, indent=2)

    # 5. Crear script de inferencia
    inference_script = '''
import torch
import json
from transformers import AutoTokenizer

def load_production_model(model_path):
    """Carga el modelo para producción"""
    model = torch.load(f'{model_path}/complete_model.pt', map_location='cpu')
    model.eval()
    return model

def translate(text, src_lang, tgt_lang, model, max_length=256):
    """Función de traducción simplificada"""
    with torch.no_grad():
        translation = model.generate_translation(text, src_lang, tgt_lang, max_length)
    return translation

# Ejemplo de uso
if __name__ == "__main__":
    model = load_production_model(".")
    result = translate("Hello world", "eng_Latn", "spa_Latn", model)
    print(result)
'''

    with open(f'{export_path}/inference.py', 'w') as f:
        f.write(inference_script)

    print(f"\n✅ Model exported successfully to '{export_path}/'")
    print(f"Files created:")
    print(f"  - complete_model.pt (full model)")
    print(f"  - model_weights.pt (weights only)")
    print(f"  - config.json (configuration)")
    print(f"  - inference.py (inference script)")

    return export_path

# Exportar modelo
# export_path = export_model_for_production(model)

Celda 19: Monitoreo y métricas avanzadas

In [None]:
import matplotlib.pyplot as plt
from collections import defaultdict
import time

class TranslationMetrics:
    """Clase para calcular y monitorear métricas de traducción"""

    def __init__(self):
        self.metrics_history = defaultdict(list)

    def calculate_bleu(self, predictions, references):
        """Calcula BLEU score"""
        from sacrebleu import corpus_bleu
        bleu = corpus_bleu(predictions, references)
        return bleu.score

    def calculate_ter(self, predictions, references):
        """Calcula Translation Edit Rate (TER)"""
        # Implementación simplificada de TER
        total_edits = 0
        total_words = 0

        for pred, ref_list in zip(predictions, references):
            ref = ref_list[0] if isinstance(ref_list, list) else ref_list
            pred_tokens = pred.split()
            ref_tokens = ref.split()

            # Distancia de Levenshtein normalizada
            edits = self._levenshtein_distance(pred_tokens, ref_tokens)
            total_edits += edits
            total_words += len(ref_tokens)

        ter = (total_edits / total_words) * 100 if total_words > 0 else 0
        return ter

    def _levenshtein_distance(self, s1, s2):
        """Calcula distancia de Levenshtein"""
        if len(s1) < len(s2):
            return self._levenshtein_distance(s2, s1)

        if len(s2) == 0:
            return len(s1)

        previous_row = range(len(s2) + 1)
        for i, c1 in enumerate(s1):
            current_row = [i + 1]
            for j, c2 in enumerate(s2):
                insertions = previous_row[j + 1] + 1
                deletions = current_row[j] + 1
                substitutions = previous_row[j] + (c1 != c2)
                current_row.append(min(insertions, deletions, substitutions))
            previous_row = current_row

        return previous_row[-1]

    def calculate_inference_speed(self, model, test_texts, device):
        """Mide velocidad de inferencia"""
        model.eval()
        total_time = 0
        total_tokens = 0

        with torch.no_grad():
            for text in test_texts:
                start_time = time.time()
                _ = model.generate_translation(text, "eng_Latn", "spa_Latn")
                end_time = time.time()

                total_time += (end_time - start_time)
                total_tokens += len(text.split())

        tokens_per_second = total_tokens / total_time if total_time > 0 else 0
        return tokens_per_second

    def plot_metrics(self):
        """Visualiza métricas"""
        if not self.metrics_history:
            print("No metrics to plot")
            return

        fig, axes = plt.subplots(2, 2, figsize=(12, 8))

        # Plot 1: Loss history
        if 'train_loss' in self.metrics_history:
            axes[0, 0].plot(self.metrics_history['train_loss'], label='Train Loss')
            if 'val_loss' in self.metrics_history:
                axes[0, 0].plot(self.metrics_history['val_loss'], label='Val Loss')
            axes[0, 0].set_xlabel('Epoch')
            axes[0, 0].set_ylabel('Loss')
            axes[0, 0].set_title('Training Progress')
            axes[0, 0].legend()

        # Plot 2: BLEU scores
        if 'bleu' in self.metrics_history:
            axes[0, 1].plot(self.metrics_history['bleu'])
            axes[0, 1].set_xlabel('Evaluation')
            axes[0, 1].set_ylabel('BLEU Score')
            axes[0, 1].set_title('BLEU Score Progress')

        # Plot 3: TER scores
        if 'ter' in self.metrics_history:
            axes[1, 0].plot(self.metrics_history['ter'])
            axes[1, 0].set_xlabel('Evaluation')
            axes[1, 0].set_ylabel('TER (%)')
            axes[1, 0].set_title('Translation Edit Rate')

        # Plot 4: Inference speed
        if 'speed' in self.metrics_history:
            axes[1, 1].plot(self.metrics_history['speed'])
            axes[1, 1].set_xlabel('Evaluation')
            axes[1, 1].set_ylabel('Tokens/Second')
            axes[1, 1].set_title('Inference Speed')

        plt.tight_layout()
        plt.savefig('translation_metrics.png')
        plt.show()

    def evaluate_model_complete(self, model, test_pairs, device):
        """Evaluación completa del modelo"""
        print("\n" + "="*60)
        print("COMPLETE MODEL EVALUATION")
        print("="*60)

        predictions = []
        references = []
        source_texts = []

        # Generar traducciones
        print("\nGenerating translations...")
        for source, target, src_lang, tgt_lang in tqdm(test_pairs):
            translation = model.generate_translation(source, src_lang, tgt_lang)
            predictions.append(translation)
            references.append([target])
            source_texts.append(source)

        # Calcular métricas
        bleu_score = self.calculate_bleu(predictions, references)
        ter_score = self.calculate_ter(predictions, references)
        speed = self.calculate_inference_speed(model, source_texts[:10], device)

        # Guardar en historial
        self.metrics_history['bleu'].append(bleu_score)
        self.metrics_history['ter'].append(ter_score)
        self.metrics_history['speed'].append(speed)

        # Imprimir resultados
        print(f"\n📊 EVALUATION RESULTS:")
        print(f"  • BLEU Score: {bleu_score:.2f}")
        print(f"  • TER Score: {ter_score:.2f}%")
        print(f"  • Inference Speed: {speed:.1f} tokens/second")

        # Ejemplos de traducción
        print(f"\n📝 SAMPLE TRANSLATIONS:")
        for i in range(min(3, len(predictions))):
            print(f"\n  Example {i+1}:")
            print(f"    Source: {source_texts[i]}")
            print(f"    Reference: {references[i][0]}")
            print(f"    Generated: {predictions[i]}")

        print("="*60)

        return {
            'bleu': bleu_score,
            'ter': ter_score,
            'speed': speed,
            'predictions': predictions,
            'references': references
        }

# Usar métricas
metrics = TranslationMetrics()
# results = metrics.evaluate_model_complete(model, test_pairs, device)
# metrics.plot_metrics()

Celda 20: Script principal de ejecución

In [None]:
def main_training_pipeline():
    """Pipeline completo de entrenamiento"""

    print("🚀 Starting Hybrid NLLB-ByT5 Translation Model Training Pipeline")
    print("="*70)

    # 1. Configuración
    print("\n📋 Step 1: Setting up configuration...")
    config = HybridTranslationConfig()
    config.batch_size = 4
    config.num_epochs = 3
    config.learning_rate = 5e-5
    print(f"  • Batch size: {config.batch_size}")
    print(f"  • Epochs: {config.num_epochs}")
    print(f"  • Learning rate: {config.learning_rate}")

    # 2. Preparar datos
    print("\n📊 Step 2: Preparing data...")
    train_pairs, val_pairs = prepare_training_data()
    print(f"  • Training samples: {len(train_pairs)}")
    print(f"  • Validation samples: {len(val_pairs)}")

    # 3. Inicializar modelo
    print("\n🤖 Step 3: Initializing hybrid model...")
    model = OptimizedHybridModel(config)
    model = model.to(device)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"  • Total parameters: {total_params:,}")
    print(f"  • Device: {device}")

    # 4. Entrenar modelo
    print("\n🏋️ Step 4: Training model...")
    model = train_with_optimization(model, train_pairs, val_pairs, config, device)

    # 5. Evaluar modelo
    print("\n📈 Step 5: Evaluating model...")
    metrics = TranslationMetrics()
    test_pairs = val_pairs[:5]  # Usar subset para evaluación rápida
    results = metrics.evaluate_model_complete(model, test_pairs, device)

    # 6. Guardar modelo
    print("\n💾 Step 6: Saving model...")
    export_path = export_model_for_production(model, 'final_hybrid_model')

    print("\n✅ Pipeline completed successfully!")
    print(f"  • Model saved to: {export_path}")
    print(f"  • Final BLEU score: {results['bleu']:.2f}")

    return model, results

# Ejecutar pipeline completo
# trained_model, evaluation_results = main_training_pipeline()

Celda 21: Instrucciones de uso y próximos pasos

In [None]:
print("""
╔══════════════════════════════════════════════════════════════════╗
║          HYBRID NLLB-ByT5 TRANSLATION MODEL - USER GUIDE         ║
╚══════════════════════════════════════════════════════════════════╝

📚 CÓMO USAR ESTE NOTEBOOK:

1️⃣ PREPARACIÓN:
   - Asegúrate de tener GPU disponible (recomendado)
   - Instala todas las dependencias (Celda 1)

2️⃣ DATOS DE ENTRENAMIENTO:
   - Modifica la función prepare_training_data() en Celda 6
   - Formato: (texto_origen, texto_destino, idioma_origen, idioma_destino)
   - Para NLLB, usa códigos como: 'eng_Latn', 'spa_Latn', etc.

3️⃣ CONFIGURACIÓN:
   - Ajusta hiperparámetros en HybridTranslationConfig (Celda 3)
   - Batch size, learning rate, epochs, etc.

4️⃣ ENTRENAMIENTO:
   - Ejecuta main_training_pipeline() (Celda 20) para proceso completo
   - O entrena paso a paso ejecutando celdas individuales

5️⃣ AGREGAR NUEVOS IDIOMAS:
   - Usa add_new_language_support() (Celda 12)
   - Proporciona pares de entrenamiento para el nuevo idioma

6️⃣ INFERENCIA:
   - Usa model.generate_translation() para traducir
   - O usa interactive_translation() para interfaz interactiva

📊 IDIOMAS SOPORTADOS POR NLLB-200:
   - 200+ idiomas incluyendo:
     • Principales: Inglés, Español, Francés, Alemán, Chino, etc.
     • Regionales: Catalán, Gallego, Euskera, etc.
     • Minoritarios: Muchos idiomas con pocos recursos

   Lista completa: https://github.com/facebookresearch/fairseq/tree/nllb

⚙️ PERSONALIZACIÓN:
   - Puedes cambiar el modelo NLLB base por versiones más grandes
   - Ajusta el modelo ByT5 (small, base, large)
   - Modifica las capas de fusión según tus necesidades

⚠️ CONSIDERACIONES:
   - El modelo distilled NLLB-200 usa ~2.4GB de VRAM
   - ByT5-small usa ~1.2GB adicionales
   - Con batch_size=4, necesitas ~8GB de VRAM mínimo
   - Para modelos más grandes, ajusta batch_size y gradient_accumulation

💡 PRÓXIMOS PASOS:
   1. Entrenar con más datos (mínimo 10k pares por idioma)
   2. Fine-tuning específico por dominio
   3. Implementar beam search para mejor calidad
   4. Agregar post-procesamiento específico por idioma
   5. Crear API REST para servir el modelo

📧 SOPORTE:
   Si tienes problemas, verifica:
   - Versiones de librerías
   - Memoria GPU disponible
   - Formato correcto de datos
   - Códigos de idioma válidos

¡Buena suerte con tu modelo de traducción multilingüe! 🌍🚀
""")

# Verificar estado final
print("\n" + "="*60)
print("ESTADO DEL SISTEMA:")
print(f"  • GPU disponible: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"  • GPU: {torch.cuda.get_device_name(0)}")
    print(f"  • VRAM libre: {torch.cuda.mem_get_info()[0]/1e9:.2f} GB")
print(f"  • Modelos cargados: ✓")
print(f"  • Listo para entrenamiento: ✓")
print("="*60)