In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/msum-sum/es/val.txt.src
/kaggle/input/msum-sum/es/test.txt.urls
/kaggle/input/msum-sum/es/test.txt.src
/kaggle/input/msum-sum/es/test.txt.tgt
/kaggle/input/msum-sum/es/val.txt.urls
/kaggle/input/msum-sum/es/val.txt.tgt
/kaggle/input/msum-sum/es/train.txt.src
/kaggle/input/msum-sum/es/train.txt.tgt
/kaggle/input/msum-sum/es/train.txt.urls


In [2]:
!ls

kaggle	     test.txt.src.tokenized   val.txt.src.tokenized
outputs.zip  test.txt.tgt.tokenized   val.txt.tgt.tokenized
saved	     train.txt.src.tokenized
state.db     train.txt.tgt.tokenized


In [3]:
import os
import re 
import json
from collections import Counter
from tqdm import tqdm
from typing import List, Tuple, Dict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import re
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

# Constantes

In [4]:

import torch
import time
import os

# =================================================
# DATA / VOCABULARY
# =================================================
# Rutas relativas al directorio del proyecto
BASE_DIR = '/kaggle/input/msum-sum/'
DATA_DIR = os.path.join(BASE_DIR, "es")
VOCAB_NAME = "Vocabulary.json"
CHECKPOINT_VOCABULARY_DIR = os.path.join("saved", "working")

# Ruta a embeddings pre-entrenados 
EMBEDDING_PATH = os.path.join('kaggle/working', "wiki.es.vec")

MAX_VOCAB_SIZE = 50000
MAX_LEN_SRC = 500
MAX_LEN_TGT = 50
BATCH_SIZE = 16

PAD_TOKEN = "[PAD]"
UNK_TOKEN = "[UNK]"
START_DECODING = "[START]"
END_DECODING = "[END]"

CREATE_VOCABULARY = not os.path.exists(
    os.path.join(CHECKPOINT_VOCABULARY_DIR, VOCAB_NAME)
)

# =================================================
# MODEL ARCHITECTURE
# =================================================
EMBEDDING_SIZE = 300 
HIDDEN_SIZE = 256

NUM_ENC_LAYERS = 1
NUM_DEC_LAYERS = 1
BIDIRECTIONAL = True
 

IS_ATTENTION = True
IS_PGEN = True
IS_COVERAGE = True
COV_LOSS_LAMBDA = 1.0 

# =================================================
# DECODING
# =================================================
DECODING_STRATEGY = "beam_search"
BEAM_SIZE = 5


EPOCHS = 15
WARMUP_EPOCHS = 0
ITERS_PER_EPOCH = None  



LEARNER = "adam"

LEARNING_RATE = 0.001 
GRAD_CLIP = 2.0  
TRAIN_BATCH_SIZE = 64  
EVAL_BATCH_SIZE = 64
SAVE_HISTORY = True
SAVE_MODEL_EPOCH = True
DROPOUT_RATIO = 0.3


# =================================================
# GPU / REPRODUCIBILITY
# =================================================
USE_GPU = True
GPU_ID = 0
DEVICE = torch.device(f"cuda:{GPU_ID}" if USE_GPU and torch.cuda.is_available() else "cpu")

SEED = 42
REPRODUCIBILITY = False

# =================================================
# PATHS / LOGGING
# =================================================
CHECKPOINT_DIR = os.path.join('kaggle/working', "saved")
GENERATED_TEXT_DIR = os.path.join('kaggle/working/', "generated")
PLOT = False


# Configuración

In [5]:
class Config:
    def __init__(
        self,
        **kwargs
    ):
        self.config_dict = {
            **kwargs
        }

        self._init_device()
        
    def _init_device(self):
        if self.config_dict["use_gpu"] and torch.cuda.is_available():
            self.config_dict["device"] = torch.device(
                f"cuda:{self.config_dict['gpu_id']}"
            )
        else:
            self.config_dict["device"] = torch.device("cpu")

    def __getitem__(self, item):
        return self.config_dict.get(item, None)
    
    def get_config_dict(self):
        return self.config_dict
        
    def __str__(self):
        args_info = "\nHyper Parameters:\n"
        for key, value in self.config_dict.items():
            args_info += f"{key}={value}\n"
        return args_info

    def __repr__(self):
        return self.__str__()


In [6]:
import os 
from collections import Counter
import json
import re
import spacy

class Vocabulary:
    def __init__(self, CREATE_VOCABULARY,
                 PAD_TOKEN, UNK_TOKEN,
                  END_DECODING, START_DECODING,
                 MAX_VOCAB_SIZE, CHECKPOINT_VOCABULARY_DIR, DATA_DIR,VOCAB_NAME):
        
        self.vocab_name = VOCAB_NAME
        self.create_vocabulary = CREATE_VOCABULARY
        self.checkpoint_vocab_dir = CHECKPOINT_VOCABULARY_DIR
        self.data_dir = DATA_DIR
        self.max_vocab_size = MAX_VOCAB_SIZE
        self._c = 0
        try:
            spacy.prefer_gpu()
            self.nlp = spacy.load("es_core_news_sm", disable=['parser','ner','lemmatizer','morphologizer','attribute_ruler'])
            self.nlp.add_pipe('sentencizer')
        except OSError:
            print("Descargando modelo de spaCy español (Large)...")
            spacy.cli.download("es_core_news_sm")
            spacy.prefer_gpu()
            self.nlp = spacy.load("es_core_news_sm", disable=['parser','ner','lemmatizer','attribute_ruler','morphologizer'])
            self.nlp.add_pipe('sentencizer')
        # Token de Relleno (Padding) - Usado para igualar longitudes de secuencias.
        self.pad_token = PAD_TOKEN
        # Token Desconocido (Unknown) - Usado para palabras no vistas en el vocabulario.
        self.unk_token = UNK_TOKEN
        
        # Tokens para Delimitación de Sentencias/Secuencias
        self.start_decoding = START_DECODING
        self.end_decoding = END_DECODING
        
   
        # Diccionario para mapear palabras a sus IDs (índices)     
        self.word_to_id = {}
        # Lista para mapear IDs a sus palabras
        self.id_to_word = []
        # Contador de frecuencia de palabras
        self.word_count = {}
        
        self._add_special_tokens()
        
    def total_size(self):
        return len(self.word_to_id)

    def word2id(self, word):
        """Retorna el id de la palabra o [UNK] id si es OOV."""
        if word not in self.word_to_id:
          return self.word_to_id[self.unk_token]
        return self.word_to_id[word]

    def id2word(self, word_id):
        """Retorna la palabra dado el id si existe en el vocabulario"""
        if 0 <= word_id < len(self.id_to_word):
            return self.id_to_word[word_id]
        
        raise ValueError('Id no esta en el vocab: %d' % word_id)
        
    def _add_special_tokens(self):
        """Añade los tokens especiales al vocabulario."""
        # Se añaden en un orden específico para que sus IDs sean fijos.
        special_tokens = [
            self.pad_token, self.unk_token, 
            self.start_decoding, self.end_decoding
        ]
        
        for token in special_tokens: #{'[PAD]':0,'[UNK]':1,'[START]':2,'[END]':3}
            if token not in self.word_to_id:
                self.word_to_id[token] = len(self.id_to_word)
                self.id_to_word.append(token)
                self.word_count[token] = 0 # Frecuencia inicial 0
        
        self.num_special_tokens = len(self.id_to_word)

    def _load_vocabulary(self):
        """
        Carga el vocabulario completo desde disco y restaura el estado interno.
        """
        try:
            vocab_path = os.path.join(self.checkpoint_vocab_dir, self.vocab_name)
    
            if not os.path.exists(vocab_path):
                raise FileNotFoundError(f"Vocabulario no encontrado en {vocab_path}")
            print(vocab_path)
            with open(vocab_path, 'r', encoding='utf-8') as f:
                saved_data = json.load(f)
            # -------------------------
            # Restaurar vocabulario base
            # -------------------------
                self.word_to_id = saved_data['word_to_id']
                self.id_to_word = saved_data['id_to_word']
                self.word_count = saved_data['word_count']
                self._c = saved_data['size']
        
                # -------------------------
                # Restaurar tokens especiales
                # -------------------------
                special_tokens = saved_data['special_tokens']
    
                self.pad_token = special_tokens['PAD']
                self.unk_token = special_tokens['UNK']
                self.start_decoding = special_tokens.get('START')
                self.end_decoding = special_tokens.get('END_DECODING')
        
                self.num_special_tokens = saved_data['metadata']['num_special_tokens']
        
                # -------------------------
                # Restaurar metadata
                # -------------------------
                metadata = saved_data['metadata']
        
                self.max_vocab_size = metadata['max_vocab_size']
                self.data_dir = metadata['data_dir']
                self.create_vocabulary = metadata['create_vocabulary']
                self.vocab_name = metadata['vocab_name']
                self.checkpoint_vocab_dir = metadata['checkpoint_dir']
    
            
            if len(self.word_to_id) != len(self.id_to_word):
                raise ValueError("Inconsistencia: word_to_id e id_to_word tienen tamaños distintos")
    
            if self.pad_token not in self.word_to_id:
                raise ValueError("Token PAD no encontrado en el vocabulario")
    
            print(f" Vocabulario cargado desde: {vocab_path}")
            print(f" Tamaño total: {len(self.word_to_id)}")
            print(f" Tokens especiales: {self.num_special_tokens}")
            print(f" Tokens regulares: {self._c}")
    
            return True
    
        except Exception as e:
            print(f"✗ Error cargando vocabulario: {e}")
            return False

            
    def size(self):
        """Retorna el tamaño real de el vocabulario"""
        return self._c    
        
    def _clean_text(self, text,for_vocab=False):
        """Limpieza inicial de texto antes de pasar por spaCy."""
        # 1. Quitar HTML
        text = re.sub(r'<[^>]+>', ' ', text)
        
        # Atrapa http, https, www y los que empiezan con //
        url_pattern = r'(http[s]?://|www\.|//)[^\s/$.?#].[^\s]*'
        text = re.sub(url_pattern, ' ', text)
        # 3. Limpieza de caracteres especiales y ruido
        text = text.replace('\xa0', ' ')
        # Caracteres decorativos repetidos
        text = re.sub(r'[~*\-_=]{2,}', ' ', text)

        if for_vocab:
            # Quita números aislados: "1", "2025", "10.5", "50%"
            # Y también combinaciones numéricas con guión: "1-0", "24-7", "2023-2024"
            text = re.sub(r'\b\d+([.,-]\d+)*%?\b', ' ', text)
        
        text = text.replace('...', ' ')
        
        # 4. Normalizar espacios 
        return re.sub(r'\s+', ' ', text).strip()

    def _tokens_from_doc(self, doc, for_vocab=False):
        """Extrae y filtra tokens de un doc de spaCy."""
        tokens = []
        for token in doc:
            # Si estamos filtrando para el VOCABULARIO
            if for_vocab:
                # Omitimos Números y Fechas en el vocabulario fijo
                if token.like_num or token.pos_ == "NUM":
                    continue
                # Intento de detectar fechas por forma básica
                if re.match(r'\d+[/-]\d+', token.text):
                    continue
            
            # Filtros comunes (Puntuación ruidosa, brackets, quotes)
            if token.is_punct and token.text not in ['.', ',', '!', '?','¿']:
                continue
            if token.is_bracket or token.is_quote:
                continue
            
            t = token.text
            t = t.replace('``', '"').replace("''", '"')
            if t:
                tokens.append(t)
        return tokens

    def process_text(self, text):
        """Procesa un único texto para el modelo (mantiene fechas/números)."""
        text = self._clean_text(text,for_vocab=False)
        doc = self.nlp(text)
        return self._tokens_from_doc(doc, for_vocab=False)

    
    def _save_vocabulary(self):
        """Guarda el vocabulario completo en el disco."""
        try:
            # Crear directorio si no existe
            os.makedirs(self.checkpoint_vocab_dir, exist_ok=True)
            
            path = os.path.join(self.checkpoint_vocab_dir, self.vocab_name)
            
            # Preparar datos para guardar
            save_data = {
                'word_to_id': self.word_to_id,
                'id_to_word': self.id_to_word,
                'word_count': self.word_count,
                'size': self._c,
                'special_tokens': {
                    'PAD': self.pad_token,
                    'UNK': self.unk_token,
                    'START': self.start_decoding,
                    'END_DECODING': self.end_decoding
                },
                'metadata': {
                    'max_vocab_size': self.max_vocab_size,
                    'data_dir': self.data_dir,
                    'create_vocabulary': self.create_vocabulary,
                    'vocab_name': self.vocab_name,
                    'checkpoint_dir': self.checkpoint_vocab_dir,
                    'num_special_tokens': self.num_special_tokens,
                    'total_size': len(self.word_to_id)
                }
            }
            
            # Guardar como JSON
            with open(path, 'w', encoding='utf-8') as f:
                json.dump(save_data, f, ensure_ascii=False, indent=4)
            
            print(f"  Vocabulario guardado en: {path}")
            print(f"  Tamaño total: {len(self.word_to_id)} palabras")
            print(f"  Tokens especiales: {self.num_special_tokens}")
            print(f"  Tokens regulares: {self._c}")
                       
            return True
            
        except Exception as e:
            print(f"✗ Error al guardar el vocabulario: {e}")
            raise
            
    def _create_vocabulary(self):
        import multiprocessing
        num_cores = max(1, multiprocessing.cpu_count() - 2)
        print(f"Construyendo vocabulario usando {num_cores} núcleos...")
        print(f"Construyendo vocabulario a partir de los datos en: {self.data_dir}")
        src_files = [os.path.join(self.data_dir, f"{split}.txt.src") for split in ["train"]]
        tgt_files = [os.path.join(self.data_dir, f"{split}.txt.tgt") for split in ["train"]]
        all_files = src_files + tgt_files
        all_words = []
       
        word_counts = Counter()
       
        for file_path in all_files:
            
            def line_generator(path):
                with open(path, "r", encoding="utf-8") as f:
                    for line in f:
                        # Aplicamos la limpieza básica de strings antes de spaCy
                        yield self._clean_text(line, for_vocab=True)
                        
            doc_stream = self.nlp.pipe(
                line_generator(file_path), 
                batch_size=500,
                num_workers=num_cores
            )
            for doc in tqdm(doc_stream, desc=f"Procesando {os.path.basename(file_path)}"):
                tokens = self._tokens_from_doc(doc, for_vocab=True)
                word_counts.update(tokens)
      
        # Calcular cuántas palabras regulares podemos añadir:
        if self.max_vocab_size <= self.num_special_tokens:
            raise ValueError(
                f"ERROR: MAX_VOCAB_SIZE ({self.max_vocab_size}) debe ser mayor que "
                f"el número de tokens especiales ({self.num_special_tokens}). "
                "Vocabulario muy pequeño."
            )
        limit = self.max_vocab_size - self.num_special_tokens
        # Seleccionar las 'limit' palabras más comunes, excluyendo las que ya son tokens especiales
        for word, count in word_counts.most_common(limit):
            if word not in self.word_to_id and len(self.word_to_id) < self.max_vocab_size:
                self.word_to_id[word] = len(self.id_to_word)
                self.id_to_word.append(word)
                self.word_count[word] = count
                self._c+=1
                
        # Guardar el vocabulario 
        self._save_vocabulary()

        print(f"Vocabulario construido. Tamaño final: {len(self.word_to_id)}")
        return True
        
    def build_vocabulary(self):
        if not self.create_vocabulary:
            return self._load_vocabulary()
        return self._create_vocabulary()

    def load_pretrained_embeddings(self, embedding_path, embedding_dim):
        """
        Carga embeddings pre-entrenados y los alinea con el vocabulario actual.
        Solo realiza coincidencias EXACTAS. Las palabras no encontradas (incluyendo
        variaciones de mayúsculas/minúsculas no presentes en el archivo) serán
        aprendidas por el modelo durante el entrenamiento.
        """
        import torch
        import numpy as np
        from tqdm import tqdm

        # 1. Verificar si el archivo existe, si no, descargar SBW (News) por defecto
        if embedding_path is not None and not os.path.exists(embedding_path):
            print(f"⚠ Archivo {embedding_path} no encontrado.")
            if "sbw_news.vec" in embedding_path:
                print("Iniciando descarga automática de SBW News Embeddings (Noticias en español)...")
                self.download_spanish_embeddings(os.path.dirname(embedding_path), type='sbw')
            elif "wiki.es.vec" in embedding_path:
                print("Iniciando descarga automática de FastText Spanish...")
                self.download_spanish_embeddings(os.path.dirname(embedding_path), type='fasttext')
            else:
                print("Usando inicialización aleatoria.")
                return torch.randn(len(self.id_to_word), embedding_dim) * 0.1

        vocab_size = len(self.id_to_word)
        # Inicialización aleatoria para que el modelo "aprenda" lo que no esté en los embeddings
        weights = torch.randn(vocab_size, embedding_dim) * 0.1
        
        if embedding_path is None:
            return weights

        print(f"Cargando embeddings (Solo Coincidencias Exactas) desde {embedding_path}...")
        
        found_indices = set()
        
        try:
            with open(embedding_path, 'r', encoding='utf-8', errors='ignore') as f:
                header = f.readline().split()
                if len(header) != 2:
                    f.seek(0)
                
                for line in tqdm(f, desc="Alineando embeddings"):
                    parts = line.rstrip().split(' ')
                    if len(parts) < embedding_dim + 1:
                        continue
                        
                    emb_word = parts[0]
                    
                    # Búsqueda Exacta ÚNICAMENTE
                    if emb_word in self.word_to_id:
                        idx = self.word_to_id[emb_word]
                        if idx not in found_indices:
                            vec = np.array([float(x) for x in parts[1:embedding_dim+1]])
                            weights[idx] = torch.from_numpy(vec)
                            found_indices.add(idx)
                                
            print(f"✓ Cobertura Exacta: {len(found_indices)} / {vocab_size} ({len(found_indices)/vocab_size*100:.1f}%)")
            print(f"  - Las {vocab_size - len(found_indices)} palabras restantes serán aprendidas desde cero.")

            if self.pad_token in self.word_to_id:
                weights[self.word_to_id[self.pad_token]] = torch.zeros(embedding_dim)
                
        except Exception as e:
            print(f"✗ Error cargando embeddings: {e}")
            
        return weights

    def download_spanish_embeddings(self, target_dir, type='sbw'):
        """
        Descarga y extrae vectores pre-entrenados.
        type: 'fasttext' (Wikipedia/CC) o 'sbw' (Spanish Billion Words - Noticias/Libros)
        """
        import urllib.request
        import os
        import gzip
        import shutil

        os.makedirs(target_dir, exist_ok=True)
        
        if type == 'sbw':
            # Spanish Billion Words (Más orientado a Noticias/Libros)
            url = "http://dcc.uchile.cl/~jperez/word-embeddings/glove-sbwc.i2e.300d.vec.gz"
            target_name = "sbw_news.vec"
        else:
            # FastText CC Spanish
            url = "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.es.300.vec.gz"
            target_name = "wiki.es.vec"

        gz_path = os.path.join(target_dir, f"{target_name}.gz")
        vec_path = os.path.join(target_dir, target_name)
        
        print(f"Descargando embeddings de tipo '{type}' desde {url}...")
        try:
            urllib.request.urlretrieve(url, gz_path)
            print(f"✓ Descarga completada. Descomprimiendo en {vec_path}...")
            
            with gzip.open(gz_path, 'rb') as f_in:
                with open(vec_path, 'wb') as f_out:
                    shutil.copyfileobj(f_in, f_out)
            
            print(f"✓ Extracción completada.")
            os.remove(gz_path) 
            return vec_path
        except Exception as e:
            print(f"✗ Error descargando: {e}")
            return None
        

In [7]:
"""import os
import sys
from tqdm import tqdm
import multiprocessing



def preprocess_files(vocab, files, is_source=True):
    num_cores = max(1, multiprocessing.cpu_count() - 2)
    
    for file_path, original_name in files:
        output_path = original_name + ".tokenized"
                    
        print(f"Procesando {os.path.basename(file_path)} -> {os.path.basename(output_path)}")
        
        def line_generator(path):
            with open(path, "r", encoding="utf-8") as f:
                for line in f:
                    yield vocab._clean_text(line, for_vocab=False)
        
        doc_stream = vocab.nlp.pipe(
            line_generator(file_path), 
            batch_size=1000, 
        )
        
        with open(output_path, "w", encoding="utf-8") as f_out:
            for doc in tqdm(doc_stream, desc=f"  Tokenizando..."):
                # Procesamos por oraciones para mantener la estructura
                sentences_tokens = []
                for sent in doc.sents:
                    tokens = vocab._tokens_from_doc(sent, for_vocab=False)
                    if tokens:
                        sentences_tokens.append(" ".join(tokens))
                
                # Unimos las oraciones con "[.]" para que dataset.py las recupere de forma segura
                f_out.write(" [.] ".join(sentences_tokens) + "\n")

def main():
    print("=== INICIANDO PREPROCESAMIENTO ===")
    
    # 1. Cargar Vocabulario
    print("Cargando instancia de Vocabulary...")
    vocab = Vocabulary(
        CREATE_VOCABULARY=False,
        PAD_TOKEN=PAD_TOKEN,
        UNK_TOKEN=UNK_TOKEN,
        START_DECODING=START_DECODING,
        END_DECODING=END_DECODING,
        MAX_VOCAB_SIZE=MAX_VOCAB_SIZE,
        CHECKPOINT_VOCABULARY_DIR=CHECKPOINT_VOCABULARY_DIR,
        DATA_DIR=DATA_DIR,
        VOCAB_NAME=VOCAB_NAME
    )

    # 2. Definir archivos
    splits = ['train', 'val', 'test']
    src_files = []
    tgt_files = []
    
    for split in splits:
        s_name = f"{split}.txt.src"
        t_name = f"{split}.txt.tgt"
        s_path = os.path.join(DATA_DIR, s_name)
        t_path = os.path.join(DATA_DIR, t_name)
        
        if os.path.exists(s_path):
            src_files.append((s_path, s_name))
        if os.path.exists(t_path):
            tgt_files.append((t_path, t_name))
            
    # 3. Procesar
    if src_files:
        print(f"\nArchivos Source encontrados: {len(src_files)}")
        preprocess_files(vocab, src_files, is_source=True)
    
    if tgt_files:
        print(f"\nArchivos Target encontrados: {len(tgt_files)}")
        preprocess_files(vocab, tgt_files, is_source=False)
    
    print("\n=== PREPROCESAMIENTO COMPLETADO ===")

if __name__ == "__main__":
    main()"""


'import os\nimport sys\nfrom tqdm import tqdm\nimport multiprocessing\n\n\n\ndef preprocess_files(vocab, files, is_source=True):\n    num_cores = max(1, multiprocessing.cpu_count() - 2)\n    \n    for file_path, original_name in files:\n        output_path = original_name + ".tokenized"\n                    \n        print(f"Procesando {os.path.basename(file_path)} -> {os.path.basename(output_path)}")\n        \n        def line_generator(path):\n            with open(path, "r", encoding="utf-8") as f:\n                for line in f:\n                    yield vocab._clean_text(line, for_vocab=False)\n        \n        doc_stream = vocab.nlp.pipe(\n            line_generator(file_path), \n            batch_size=1000, \n        )\n        \n        with open(output_path, "w", encoding="utf-8") as f_out:\n            for doc in tqdm(doc_stream, desc=f"  Tokenizando..."):\n                # Procesamos por oraciones para mantener la estructura\n                sentences_tokens = []\n      

In [8]:
import os
from typing import List, Tuple, Dict
import torch
from torch.utils.data import Dataset
class PGNDataset(Dataset):
    """
    Dataset para PGN con OOVs dinámicos y head truncation
    """
    
    def __init__(self, vocab, MAX_LEN_SRC: int, MAX_LEN_TGT: int, data_dir: str, split: str):
        self.vocab = vocab
        self.MAX_LEN_SRC = MAX_LEN_SRC
        self.MAX_LEN_TGT = MAX_LEN_TGT
        self.data_dir = data_dir
        self.split = split

        self.PAD_ID = self.vocab.word2id(self.vocab.pad_token)
        self.SOS_ID = self.vocab.word2id(self.vocab.start_decoding)
        self.EOS_ID = self.vocab.word2id(self.vocab.end_decoding)
        self.UNK_ID = self.vocab.word2id(self.vocab.unk_token)

        src_path = os.path.join(data_dir, f"{split}.txt.src")
        tgt_path = os.path.join(data_dir, f"{split}.txt.tgt")
        
        # Verificar si existen versiones tokenizadas
        src_tokenized_path = f"{split}.txt.src" + ".tokenized"
        tgt_tokenized_path = f"{split}.txt.tgt" + ".tokenized"
        
        self.is_tokenized = False
        
        if os.path.exists(src_tokenized_path) and os.path.exists(tgt_tokenized_path):
            print(f"✓ Usando archivos TOKENIZADOS para {split} (Carga optimizada en Kaggle)")
            src_path = src_tokenized_path
            tgt_path = tgt_tokenized_path
            self.is_tokenized = True
        else:
            print(f"⚠ Usando archivos RAW para {split} (Tokenización en tiempo real -> LENTO)")

        if not os.path.exists(src_path) or not os.path.exists(tgt_path):
            raise FileNotFoundError(f"Split '{split}' no encontrado")

        with open(src_path, encoding="utf-8") as f:
            self.src_lines = f.readlines()

        with open(tgt_path, encoding="utf-8") as f:
            self.tgt_lines = f.readlines()

        assert len(self.src_lines) == len(self.tgt_lines)
    
    def _get_extended_src_ids(
        self, src_tokens_raw: List[str]
    ) -> Tuple[List[int], int, Dict[str, int], List[str]]:
        """Obtener IDs extendidos para fuente con OOVs"""
        extended_src_ids = []
        temp_oov_map = {}
        oov_words = []
        
        vocab_size = len(self.vocab.word_to_id)
        oov_id_counter = vocab_size  # Empezar después del vocabulario base
        
        for token in src_tokens_raw:
            base_id = self.vocab.word2id(token)
            
            if base_id == self.UNK_ID:
                if token not in temp_oov_map:
                    temp_oov_map[token] = oov_id_counter
                    oov_words.append(token)
                    oov_id_counter += 1
                extended_src_ids.append(temp_oov_map[token])
            else:
                extended_src_ids.append(base_id)
        
        extended_vocab_size = oov_id_counter
        return extended_src_ids, extended_vocab_size, temp_oov_map, oov_words
    
    def _map_target_to_extended_ids(self, tgt_tokens, oov_map):
        """Mapear tokens objetivo a IDs extendidos"""
        mapped_ids = []
        for token in tgt_tokens:
            base_id = self.vocab.word2id(token)
            if base_id == self.UNK_ID and token in oov_map:
                mapped_ids.append(oov_map[token])
            else:
                mapped_ids.append(base_id)
        return mapped_ids
    
    def _pad_sequence(self, ids, max_len):
        """Rellenar secuencia con PAD_ID"""
        if len(ids) < max_len:
            ids.extend([self.PAD_ID] * (max_len - len(ids)))
        return ids[:max_len]
    
    def __len__(self):
        return len(self.src_lines)
    
    def __getitem__(self, idx):
        src_line = self.src_lines[idx].strip()
        tgt_line = self.tgt_lines[idx].strip()
                         
        # --- 1. Head truncation por oraciones ---
        if self.is_tokenized:
            # Si ya está tokenizado, recuperamos las oraciones separando por "[.]"
            # ya que preprocess_data.py las guardó así expresamente.
            raw_sentences = src_line.split(" [.] ")
        else:
             # Usamos spaCy/NLTK para dividir oraciones (más lento pero robusto)
             # Nota: Se usa NLTK según la última configuración aceptada
             raw_sentences = nltk.sent_tokenize(src_line, language='spanish')
             
        trimmed_src_tokens = []
        
        for sentence in raw_sentences:
            if self.is_tokenized:
                # Fast path: Ya está tokenizado por palabras
                sentence_tokens = sentence.split()
                # El split por " [.] " quitó el punto, lo añadimos para el modelo
                tokens_to_add = sentence_tokens 
            else:
                # Slow path: Tokenización en tiempo real
                sentence_tokens = self.vocab.process_text(sentence.strip())
                if not sentence_tokens:
                    continue
                tokens_to_add = sentence_tokens
            
            if len(trimmed_src_tokens) + len(tokens_to_add) > self.MAX_LEN_SRC:
                break
            
            trimmed_src_tokens.extend(tokens_to_add)
        
        # Eliminar posible punto final si la última oración ya lo tenía y lo añadimos doble (opcional)
        # O simplemente dejar la lógica fluir.
        
        # --- 2. Encoder con OOVs ---
        ext_src_ids, ext_vocab_size, oov_map, oov_words = \
            self._get_extended_src_ids(trimmed_src_tokens)
        
        max_oov_len = ext_vocab_size - len(self.vocab.word_to_id)
        
        # Extended encoder input (con IDs extendidos para pointer-generator)
        # Dynamic Batching: NO rellenamos aquí, solo truncamos
        # extended_encoder_input = self._pad_sequence(ext_src_ids.copy(), self.MAX_LEN_SRC)
        extended_encoder_input = ext_src_ids[:self.MAX_LEN_SRC]
        
        # Encoder input regular (convertir OOVs a UNK para embeddings)
        encoder_input = [
            i if i < len(self.vocab.word_to_id) else self.UNK_ID
            for i in extended_encoder_input
        ]
        
        # --- 3. Decoder ---
        if self.is_tokenized:
            tgt_tokens = tgt_line.strip().split()
        else:
            tgt_tokens = self.vocab.process_text(tgt_line)
            
        tgt_ext_ids = self._map_target_to_extended_ids(tgt_tokens, oov_map)
        
        MAX_RAW_TGT_LEN = self.MAX_LEN_TGT - 1
        tgt_ext_ids = tgt_ext_ids[:MAX_RAW_TGT_LEN]
        
        # Decoder input (convertir OOVs a UNK para embeddings)
        decoder_input_ids = [self.SOS_ID]
        for token_id in tgt_ext_ids:
            if token_id < len(self.vocab.word_to_id):
                decoder_input_ids.append(token_id)
            else:
                decoder_input_ids.append(self.UNK_ID)
        
        # Decoder target (mantener extended IDs para loss)
        decoder_output_ids = tgt_ext_ids + [self.EOS_ID]
        
        # Dynamic Batching: NO rellenamos aquí
        # decoder_input = self._pad_sequence(decoder_input_ids, self.MAX_LEN_TGT)
        # decoder_output = self._pad_sequence(decoder_output_ids, self.MAX_LEN_TGT)
        decoder_input = decoder_input_ids[:self.MAX_LEN_TGT]
        decoder_output = decoder_output_ids[:self.MAX_LEN_TGT]
        
        # --- 4. Información adicional ---
        encoder_length = len(trimmed_src_tokens)
        # Mask será dinámica en collate, aquí solo devolvemos el largo real
        # encoder_mask = [1] * encoder_length + [0] * (self.MAX_LEN_SRC - encoder_length)
        encoder_mask = [1] * encoder_length
        
        return {
            "encoder_input": torch.tensor(encoder_input, dtype=torch.long),
            "extended_encoder_input": torch.tensor(extended_encoder_input, dtype=torch.long),
            "encoder_length": torch.tensor(encoder_length, dtype=torch.long),
            "encoder_mask": torch.tensor(encoder_mask, dtype=torch.bool),
            "decoder_input": torch.tensor(decoder_input, dtype=torch.long),
            "decoder_target": torch.tensor(decoder_output, dtype=torch.long),
            "max_oov_len": max_oov_len,
            "oov_words": oov_words,
            "pad_id": self.PAD_ID  # Pasamos PAD_ID para el collate
        }

def pgn_collate_fn(batch):
    """Función para combinar muestras en batches"""
    # Filtrar ejemplos con encoder_length <= 0
    filter_batch = []
    for x in batch:
        if x['encoder_length'].item() > 0:
            filter_batch.append(x)
    
    # Si todos los ejemplos fueron filtrados, retornar None
    if len(filter_batch) == 0:
        return None
    
    batch = filter_batch
    max_oov = max(x["max_oov_len"] for x in batch)
    
    # 1. Obtener lengths máximos del batch actual (Dynamic Batching)
    max_enc_len = max(x["encoder_length"].item() for x in batch)
    max_dec_len = max(len(x["decoder_input"]) for x in batch)
    
    pad_id = batch[0]["pad_id"]
    
    def pad_tensor(t, length, val):
        """Pad tensor to length with val"""
        return torch.cat([t, torch.full((length - len(t),), val, dtype=t.dtype)])

    def pad_oov(words):
        """Rellenar lista de OOVs con strings vacíos"""
        return words + [""] * (max_oov - len(words))
    
    return {
        # Encoders: Pad a max_enc_len
        "encoder_input": torch.stack([pad_tensor(x["encoder_input"], max_enc_len, pad_id) for x in batch]),
        "extended_encoder_input": torch.stack([pad_tensor(x["extended_encoder_input"], max_enc_len, pad_id) for x in batch]),
        "encoder_length": torch.stack([x["encoder_length"] for x in batch]),
        "encoder_mask": torch.stack([pad_tensor(x["encoder_mask"], max_enc_len, 0) for x in batch]), # 0 es False/Pad
        
        # Decoders: Pad a max_dec_len
        "decoder_input": torch.stack([pad_tensor(x["decoder_input"], max_dec_len, pad_id) for x in batch]),
        "decoder_target": torch.stack([pad_tensor(x["decoder_target"], max_dec_len, pad_id) for x in batch]),
        
        "max_oov_len": torch.tensor([x["max_oov_len"] for x in batch], dtype=torch.long),
        "oov_words": [pad_oov(x["oov_words"]) for x in batch]
    }

In [9]:
import torch
import torch.nn.functional as F
from typing import List, Tuple
import numpy as np

class Hypothesis:
    """
    Representa una hipótesis durante beam search.
    """
    def __init__(self, tokens, log_probs, decoder_state, context_vector, coverage, p_gens=None):
        """
        Args:
            tokens: List[int] - Secuencia de tokens generados
            log_probs: List[float] - Log probabilities de cada token
            decoder_state: Tuple (h, c) - Estado del decoder
            context_vector: Tensor - Context vector actual
            context_vector: Tensor - Context vector actual
            coverage: Tensor - Coverage acumulado
            p_gens: List[Tensor] - Lista de p_gen para cada paso
        """
        self.tokens = tokens
        self.log_probs = log_probs
        self.decoder_state = decoder_state
        self.context_vector = context_vector
        self.coverage = coverage
        self.p_gens = p_gens if p_gens is not None else []
    
    def extend(self, token, log_prob, decoder_state, context_vector, coverage, p_gen):
        """
        Extiende la hipótesis con un nuevo token.
        
        Returns:
            Nueva Hypothesis
        """
        return Hypothesis(
            tokens=self.tokens + [token],
            log_probs=self.log_probs + [log_prob],
            decoder_state=decoder_state,
            context_vector=context_vector,
            coverage=coverage,
            p_gens=self.p_gens + [p_gen]
        )
    
    @property
    def avg_log_prob(self):
        """Promedio de log probabilities (para ranking)."""
        return sum(self.log_probs) / len(self.tokens)
    
    @property
    def latest_token(self):
        """Último token generado."""
        return self.tokens[-1]


class BeamSearch:
    """
    Implementa Beam Search para decodificación.
    """
    
    def __init__(self, model, vocab, beam_size=4, max_len=50, min_len=10):
        """
        Args:
            model: PointerGeneratorNetwork
            vocab: Vocabulary object
            beam_size: Tamaño del beam
            max_len: Longitud máxima de generación
            min_len: Longitud mínima (penaliza secuencias cortas)
        """
        self.model = model
        self.vocab = vocab
        self.beam_size = beam_size
        self.max_len = max_len
        self.min_len = min_len
        
        self.start_id = vocab.word2id(vocab.start_decoding)
        self.end_id = vocab.word2id(vocab.end_decoding)
        self.unk_id = vocab.word2id(vocab.unk_token)
        self.vocab_size = len(vocab.word_to_id)
        
        self.device = model.device
    
    def search(self, batch):
        """
        Realiza beam search para un batch (asume batch_size=1).
        
        Args:
            batch: Dict con encoder inputs
            
        Returns:
            best_hypothesis: Hypothesis con la mejor secuencia
        """
        # Asegurar batch_size = 1
        encoder_input = batch['encoder_input'].to(self.device)  # (1, src_len)
        extended_encoder_input = batch['extended_encoder_input'].to(self.device)
        encoder_length = batch['encoder_length'].to(self.device)
        encoder_mask = batch['encoder_mask'].to(self.device)
        
        batch_size = encoder_input.size(0)
        assert batch_size == 1, "Beam search solo soporta batch_size=1"
        
        src_len = encoder_input.size(1)
        
        # 1. Encoder
        encoder_outputs, decoder_state = self.model.encoder(encoder_input, encoder_length)
        # encoder_outputs: (1, src_len, hidden_size * 2)
        
        # 2. Inicializar beam
        initial_context = torch.zeros(1, self.model.config['hidden_size'] * 2, device=self.device)
        initial_coverage = torch.zeros(1, src_len, device=self.device) if self.model.is_coverage else None
        
        # Hipótesis inicial con START token
        initial_hyp = Hypothesis(
            tokens=[self.start_id],
            log_probs=[0.0],
            decoder_state=decoder_state,
            context_vector=initial_context,
            coverage=initial_coverage,
            p_gens=[]
        )
        
        hypotheses = [initial_hyp]  # Beam actual
        completed = []  # Hipótesis completas (con END)
        
        # 3. Beam search loop
        for step in range(self.max_len):
            if len(completed) >= self.beam_size:
                break
            
            all_candidates = []
            
            # Expandir cada hipótesis en el beam
            for hyp in hypotheses:
                # Si ya terminó, mover a completed
                if hyp.latest_token == self.end_id:
                    completed.append(hyp)
                    continue
                
                # Preparar input para el decoder
                decoder_input = torch.tensor(
                    [hyp.latest_token],
                    dtype=torch.long,
                    device=self.device
                )
                
                # Convertir OOV a UNK para embeddings
                if decoder_input.item() >= self.vocab_size:
                    decoder_input = torch.tensor(
                        [self.unk_id],
                        dtype=torch.long,
                        device=self.device
                    )
                
                # Decoder step
                final_dist, new_decoder_state, new_context, attention_dist, p_gen, new_coverage = self.model.decoder(
                    decoder_input=decoder_input,
                    decoder_state=hyp.decoder_state,
                    encoder_outputs=encoder_outputs,
                    encoder_mask=encoder_mask,
                    extended_encoder_input=extended_encoder_input,
                    context_vector=hyp.context_vector,
                    coverage=hyp.coverage
                )
                
                # Log probabilities
                log_probs = torch.log(final_dist + 1e-12)  # (1, extended_vocab_size)
                log_probs = log_probs.squeeze(0)  # (extended_vocab_size,)
                
                # Penalizar UNK si queremos
                log_probs[self.unk_id] -= 100.0
                
                # Penalizar END si estamos antes de min_len
                if step < self.min_len:
                    log_probs[self.end_id] = -1e20
                
                # Top-k candidatos
                top_k_log_probs, top_k_ids = torch.topk(log_probs, self.beam_size * 2)
                
                # Crear nuevas hipótesis
                for i in range(self.beam_size * 2):
                    token_id = top_k_ids[i].item()
                    token_log_prob = top_k_log_probs[i].item()
                    
                    new_hyp = hyp.extend(
                        token=token_id,
                        log_prob=token_log_prob,
                        decoder_state=new_decoder_state,
                        context_vector=new_context,
                        coverage=new_coverage,
                        p_gen=p_gen if p_gen is not None else torch.tensor([[1.0]], device=self.device)
                    )
                    
                    all_candidates.append(new_hyp)
            
            # Ordenar candidatos por avg_log_prob
            all_candidates.sort(key=lambda h: h.avg_log_prob, reverse=True)
            
            # Seleccionar top beam_size hipótesis
            hypotheses = all_candidates[:self.beam_size]
        
        # 4. Si no hay hipótesis completas, tomar las mejores actuales
        if len(completed) == 0:
            completed = hypotheses
        
        # Ordenar por avg_log_prob y retornar la mejor
        completed.sort(key=lambda h: h.avg_log_prob, reverse=True)
        best_hypothesis = completed[0]
        
        return best_hypothesis
    
    def decode_batch(self, data_loader, num_examples=None):
        """
        Decodifica múltiples ejemplos usando beam search.
        
        Args:
            data_loader: DataLoader
            num_examples: Número de ejemplos a decodificar (None = todos)
            
        Returns:
            List[List[int]]: Secuencias generadas
        """
        self.model.eval()
        generated_sequences = []
        
        with torch.no_grad():
            for i, batch in enumerate(data_loader):
                if num_examples is not None and i >= num_examples:
                    break
                
                # Beam search espera batch_size=1
                # Si el batch tiene más de 1, procesar uno por uno
                batch_size = batch['encoder_input'].size(0)
                
                for b in range(batch_size):
                    # Extraer ejemplo individual
                    single_batch = {
                        'encoder_input': batch['encoder_input'][b:b+1],
                        'extended_encoder_input': batch['extended_encoder_input'][b:b+1],
                        'encoder_length': batch['encoder_length'][b:b+1],
                        'encoder_mask': batch['encoder_mask'][b:b+1]
                    }
                    
                    # Beam search
                    best_hyp = self.search(single_batch)
                    generated_sequences.append(best_hyp.tokens)
                    
                    if num_examples is not None and len(generated_sequences) >= num_examples:
                        break
        
        return generated_sequences


# Model

In [10]:
import torch
import torch.optim as optim
import math

class ScheduledOptimizer:
    """
    Optimizer con learning rate scheduling para Pointer-Generator Network.
    Implementa:
    - Warm-up lineal
    - Decaimiento (opcional)
    - Gradient clipping
    """
    
    def __init__(self, optimizer, config):
        """
        Args:
            optimizer: PyTorch optimizer (Adam, SGD, etc.)
            config: Config object con hiperparámetros
        """
        self.optimizer = optimizer
        self.config = config
        
        self.initial_lr = config['learning_rate']
        self.current_lr = config['learning_rate']
        self.warmup_epochs = config['warmup_epochs'] if config['warmup_epochs'] else 0
        self.grad_clip = config['grad_clip']
        
        self.current_epoch = 0
        self.current_step = 0
    
    def step(self):
        """Realiza un paso de optimización con gradient clipping."""
        # Gradient clipping
        if self.grad_clip > 0:
            torch.nn.utils.clip_grad_norm_(
                self._get_parameters(),
                self.grad_clip
            )
        
        self.optimizer.step()
        self.current_step += 1
    
    def zero_grad(self):
        """Resetea los gradientes."""
        self.optimizer.zero_grad()
    
    def update_learning_rate(self, epoch):
        """
        Actualiza el learning rate basado en el epoch actual.
        
        Args:
            epoch: Epoch actual (0-indexed)
        """
        self.current_epoch = epoch
        
        # Warm-up: incremento lineal del learning rate
        if epoch < self.warmup_epochs:
            # LR crece linealmente de 0 a initial_lr
            self.current_lr = self.initial_lr * (epoch + 1) / self.warmup_epochs
        else:
            # Después del warm-up, mantener o decrementar
            # Opción 1: Mantener constante
            self.current_lr = self.initial_lr
            
            # Opción 2: Decaimiento exponencial (comentado por defecto)
            # decay_rate = 0.5
            # decay_epochs = 5
            # epochs_after_warmup = epoch - self.warmup_epochs
            # self.current_lr = self.initial_lr * (decay_rate ** (epochs_after_warmup / decay_epochs))
        
        # Aplicar nuevo learning rate
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.current_lr
    
    def get_learning_rate(self):
        """Retorna el learning rate actual."""
        return self.current_lr
    
    def _get_parameters(self):
        """Obtiene todos los parámetros del optimizer."""
        params = []
        for param_group in self.optimizer.param_groups:
            params.extend(param_group['params'])
        return params
    
    def state_dict(self):
        """Guarda el estado del optimizer."""
        return {
            'optimizer': self.optimizer.state_dict(),
            'current_epoch': self.current_epoch,
            'current_step': self.current_step,
            'current_lr': self.current_lr
        }
    
    def load_state_dict(self, state_dict):
        """Carga el estado del optimizer."""
        self.optimizer.load_state_dict(state_dict['optimizer'])
        self.current_epoch = state_dict['current_epoch']
        self.current_step = state_dict['current_step']
        self.current_lr = state_dict['current_lr']


def build_optimizer(model, config):
    """
    Construye el optimizer basado en la configuración.
    
    Args:
        model: Modelo PyTorch
        config: Config object
        
    Returns:
        ScheduledOptimizer
    """
    learner_type = (config['learner'] if config['learner'] else 'adam').lower()
    learning_rate = config['learning_rate']

    
   
    if learner_type =='adagrad':
        base_optimizer = optim.Adagrad(
            model.parameters(),
            lr=learning_rate,
            initial_accumulator_value=0.1
        )
    elif learner_type == 'adam':
        base_optimizer = optim.Adam(
            model.parameters(),
            lr=learning_rate,
            betas=(0.9, 0.999),
            eps=1e-6
        )
    elif learner_type == 'sgd':
        base_optimizer = optim.SGD(
            model.parameters(),
            lr=learning_rate,
            momentum=0.9
        )
    else:
        raise ValueError(f"Optimizer desconocido: {learner_type}")
    
    # Envolver en ScheduledOptimizer
    scheduled_optimizer = ScheduledOptimizer(base_optimizer, config)
    
    return scheduled_optimizer


In [11]:
class Attention(nn.Module):
    """
    Mecanismo de atención de Bahdanau para el modelo Pointer-Generator.
    Calcula la distribución de atención sobre los encoder outputs.
    """
    
    def __init__(self, hidden_size, is_coverage=False):
        """
        Args:
            hidden_size: Dimensión del hidden state
            is_coverage: Si se usa coverage mechanism
        """
        super(Attention, self).__init__()
        
        self.hidden_size = hidden_size
        self.is_coverage = is_coverage
        
        # Proyecciones para calcular attention scores
        # encoder_outputs: (batch, src_len, hidden_size * 2) si bidireccional
        self.W_h = nn.Linear(hidden_size * 2, hidden_size, bias=False)  # Para encoder outputs
        self.W_s = nn.Linear(hidden_size, hidden_size, bias=True)       # Para decoder state
        
        # Coverage feature
        if is_coverage:
            self.W_c = nn.Linear(1, hidden_size, bias=False)
        
        # Proyección final
        self.v = nn.Linear(hidden_size, 1, bias=False)
    
    def forward(self, decoder_state, encoder_outputs, encoder_mask, coverage=None):
        """
        Args:
            decoder_state: (batch_size, hidden_size) - Estado actual del decoder
            encoder_outputs: (batch_size, src_len, hidden_size * 2) - Outputs del encoder
            encoder_mask: (batch_size, src_len) - Máscara de padding (1 = válido, 0 = padding)
            coverage: (batch_size, src_len) - Vector de coverage acumulado (opcional)
            
        Returns:
            context_vector: (batch_size, hidden_size * 2) - Vector de contexto
            attention_dist: (batch_size, src_len) - Distribución de atención
            coverage: (batch_size, src_len) - Coverage actualizado (si is_coverage=True)
        """
        batch_size, src_len, _ = encoder_outputs.size()
        
        # 1. Proyectar encoder outputs
        encoder_features = self.W_h(encoder_outputs)  # (batch_size, src_len, hidden_size)
        
        # 2. Proyectar decoder state y expandir
        decoder_features = self.W_s(decoder_state)  # (batch_size, hidden_size)
        decoder_features = decoder_features.unsqueeze(1)  # (batch_size, 1, hidden_size)
        decoder_features = decoder_features.expand(-1, src_len, -1)  # (batch_size, src_len, hidden_size)
        
        # 3. Calcular attention scores
        attention_features = encoder_features + decoder_features  # (batch_size, src_len, hidden_size)
        
        # 4. Añadir coverage si está activado
        if self.is_coverage and coverage is not None:
            coverage_features = self.W_c(coverage.unsqueeze(2))  # (batch_size, src_len, hidden_size)
            attention_features = attention_features + coverage_features
        
        # 5. Calcular scores
        e = self.v(torch.tanh(attention_features))  # (batch_size, src_len, 1)
        e = e.squeeze(2)  # (batch_size, src_len)
        
        # 6. Aplicar máscara (hacer -inf los padding para que softmax → 0)
        e = e.masked_fill(encoder_mask == 0, -1e4)
        
        # 7. Softmax para obtener distribución de atención
        attention_dist = F.softmax(e, dim=1)  # (batch_size, src_len)
        
        # 8. Calcular context vector
        attention_dist_expanded = attention_dist.unsqueeze(1)  # (batch_size, 1, src_len)
        context_vector = torch.bmm(attention_dist_expanded, encoder_outputs)  # (batch_size, 1, hidden_size * 2)
        context_vector = context_vector.squeeze(1)  # (batch_size, hidden_size * 2)
        
        # 9. Actualizar coverage
        if self.is_coverage:
            if coverage is None:
                coverage = attention_dist
            else:
                coverage = coverage + attention_dist
        
        return context_vector, attention_dist, coverage


In [12]:
import torch
import torch.nn as nn

class Encoder(nn.Module):
    """
    Encoder bidireccional LSTM para el modelo Pointer-Generator.
    """
    
    def __init__(self, vocab_size, embedding_size, hidden_size, num_enc_layers, dropout_ratio, bidirectional, pretrained_weights=None):
        """
        Args:
            vocab_size: Tamaño del vocabulario base (sin OOVs)
            embedding_size: Dimensión de los embeddings
            hidden_size: Dimensión del hidden state del LSTM
            num_enc_layers: Número de capas del LSTM
            dropout_ratio: Probabilidad de dropout
            bidirectional: Si el LSTM es bidireccional
            pretrained_weights: Tensor con pesos pre-entrenados (opcional)
        """
        super(Encoder, self).__init__()
        
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.num_enc_layers = num_enc_layers
        self.bidirectional = bidirectional
        
        # Embedding layer (solo para vocabulario base)
        self.embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=0)
        
        if pretrained_weights is not None:
            self.embedding.weight.data.copy_(pretrained_weights)
            self.embedding.weight.requires_grad = False 
            print("✓ Encoder: Pesos de embedding inicializados (no Entrenables).")
        
        # LSTM bidireccional
        self.lstm = nn.LSTM(
            input_size=embedding_size,
            hidden_size=hidden_size,
            num_layers=num_enc_layers,
            batch_first=True,
            bidirectional=bidirectional,
            dropout=dropout_ratio if num_enc_layers > 1 else 0
        )
        self.dropout = nn.Dropout(dropout_ratio)
        # Proyección para reducir hidden state bidireccional
        if bidirectional:
            self.reduce_h = nn.Linear(hidden_size * 2, hidden_size)
            self.reduce_c = nn.Linear(hidden_size * 2, hidden_size)
    
    def forward(self, encoder_input, encoder_length):
        """
        Args:
            encoder_input: (batch_size, src_len) - IDs del vocabulario base (OOVs → UNK)
            encoder_length: (batch_size,) - Longitudes reales de cada secuencia
            
        Returns:
            encoder_outputs: (batch_size, src_len, hidden_size * 2) si bidirectional
            hidden: Tuple (h_n, c_n) reducidos a (batch_size, hidden_size)
        """
        batch_size, src_len = encoder_input.size()
        
        # 1. Embeddings
        embedded = self.embedding(encoder_input)  # (batch_size, src_len, embedding_size)
        embedded = self.dropout(embedded)
        # 2. Pack padded sequence para eficiencia
        packed = nn.utils.rnn.pack_padded_sequence(
            embedded, 
            encoder_length.cpu(), 
            batch_first=True, 
            enforce_sorted=False
        )
        
        # 3. LSTM
        packed_outputs, (h_n, c_n) = self.lstm(packed)
        
        # 4. Unpack
        encoder_outputs, _ = nn.utils.rnn.pad_packed_sequence(
            packed_outputs, 
            batch_first=True,
            total_length=src_len
        )
        # encoder_outputs: (batch_size, src_len, hidden_size * 2)
        
        # 5. Reducir hidden states si es bidireccional
        if self.bidirectional:
            # h_n: (num_layers * 2, batch_size, hidden_size)
            # Tomar última capa: (2, batch_size, hidden_size)
            h_n = h_n[-2:]  # [forward, backward] de la última capa
            c_n = c_n[-2:]
            
            # Concatenar forward y backward
            h_n = torch.cat([h_n[0], h_n[1]], dim=1)  # (batch_size, hidden_size * 2)
            c_n = torch.cat([c_n[0], c_n[1]], dim=1)
            
            # Reducir a hidden_size
            h_n = torch.relu(self.reduce_h(h_n))  # (batch_size, hidden_size)
            c_n = torch.relu(self.reduce_c(c_n))  # (batch_size, hidden_size)
            
            # Añadir dimensión de capas
            h_n = h_n.unsqueeze(0)  # (1, batch_size, hidden_size)
            c_n = c_n.unsqueeze(0)
        
        return encoder_outputs, (h_n, c_n)

In [13]:


class Decoder(nn.Module):
    """
    Decoder LSTM con Pointer-Generator Network y Coverage Mechanism.
    """
    
    def __init__(self,
                 vocab_size,
                 embedding_size,
                 hidden_size,
                 num_dec_layers,
                 dropout_ratio,
                 is_attention=True,
                 is_pgen=True,
                 is_coverage=True):
        """
        Args:
            vocab_size: Tamaño del vocabulario base (sin OOVs)
            embedding_size: Dimensión de los embeddings
            hidden_size: Dimensión del hidden state del LSTM
            num_dec_layers: Número de capas del LSTM (debe ser 1 para PGN)
            dropout_ratio: Probabilidad de dropout
            is_attention: Si se usa atención
            is_pgen: Si se usa pointer-generator
            is_coverage: Si se usa coverage mechanism
        """
        super(Decoder, self).__init__()
        
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.is_attention = is_attention
        self.is_pgen = is_pgen
        self.is_coverage = is_coverage
        self.dropout = nn.Dropout(dropout_ratio)
        # Embedding layer (solo para vocabulario base)
        self.embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=0)
        
        # LSTM decoder
        # Input: embedding + context vector (si hay atención)
        lstm_input_size = embedding_size + (hidden_size * 2 if is_attention else 0)
        
        self.lstm = nn.LSTM(
            input_size=lstm_input_size,
            hidden_size=hidden_size,
            num_layers=num_dec_layers,
            batch_first=True,
            dropout=dropout_ratio if num_dec_layers > 1 else 0
        )
        
        # Attention mechanism
        if is_attention:
            self.attention = Attention(hidden_size, is_coverage=is_coverage)
        
        # Proyección para generar distribución de vocabulario
        # Input: decoder state + context vector
        vocab_input_size = hidden_size + (hidden_size * 2 if is_attention else 0)
        self.vocab_proj = nn.Linear(vocab_input_size, vocab_size)
        
        # Pointer-Generator: calcular p_gen
        if is_pgen:
            # p_gen depende de: context vector, decoder state, decoder input
            pgen_input_size = (hidden_size * 2) + hidden_size + embedding_size
            self.p_gen_linear = nn.Linear(pgen_input_size, 1)
    
    def forward(self, decoder_input, decoder_state, encoder_outputs, encoder_mask, 
                extended_encoder_input, context_vector=None, coverage=None):
        """
        Un paso de decodificación.
        
        Args:
            decoder_input: (batch_size,) - Token actual (ID del vocabulario base)
            decoder_state: Tuple (h, c) donde cada uno es (1, batch_size, hidden_size)
            encoder_outputs: (batch_size, src_len, hidden_size * 2)
            encoder_mask: (batch_size, src_len)
            extended_encoder_input: (batch_size, src_len) - IDs extendidos del source
            context_vector: (batch_size, hidden_size * 2) - Context vector del paso anterior
            coverage: (batch_size, src_len) - Coverage acumulado
            
        Returns:
            final_dist: (batch_size, extended_vocab_size) - Distribución final sobre vocab extendido
            decoder_state: Tuple (h, c) actualizado
            context_vector: (batch_size, hidden_size * 2) - Nuevo context vector
            attention_dist: (batch_size, src_len) - Distribución de atención
            p_gen: (batch_size, 1) - Probabilidad de generar (si is_pgen=True)
            coverage: (batch_size, src_len) - Coverage actualizado
        """
        batch_size = decoder_input.size(0)
        
        # 1. Embeddings del input
        embedded = self.embedding(decoder_input)  # (batch_size, embedding_size)
        embedded = self.dropout(embedded)
        embedded = embedded.unsqueeze(1)  # (batch_size, 1, embedding_size)
        
        # 2. Concatenar con context vector si hay atención
        if self.is_attention and context_vector is not None:
            lstm_input = torch.cat([embedded, context_vector.unsqueeze(1)], dim=2)
        else:
            lstm_input = embedded
        
        # 3. LSTM step
        lstm_output, decoder_state = self.lstm(lstm_input, decoder_state)
        # lstm_output: (batch_size, 1, hidden_size)
        lstm_output = lstm_output.squeeze(1)  # (batch_size, hidden_size)
        
        # 4. Calcular atención
        if self.is_attention:
            context_vector, attention_dist, coverage = self.attention(
                lstm_output, encoder_outputs, encoder_mask, coverage
            )
            # context_vector: (batch_size, hidden_size * 2)
            # attention_dist: (batch_size, src_len)
        else:
            attention_dist = None
        
        # 5. Generar distribución de vocabulario
        if self.is_attention:
            vocab_input = torch.cat([lstm_output, context_vector], dim=1)
        else:
            vocab_input = lstm_output
        
        vocab_logits = self.vocab_proj(vocab_input)  # (batch_size, vocab_size)
        vocab_dist = F.softmax(vocab_logits, dim=1)  # (batch_size, vocab_size)
        
        # 6. Pointer-Generator mechanism
        p_gen = None
        if self.is_pgen and self.is_attention:
            # Calcular p_gen
            pgen_input = torch.cat([
                context_vector,           # (batch_size, hidden_size * 2)
                lstm_output,              # (batch_size, hidden_size)
                embedded.squeeze(1)       # (batch_size, embedding_size)
            ], dim=1)
            
            p_gen = torch.sigmoid(self.p_gen_linear(pgen_input))  # (batch_size, 1)
            
            # Combinar vocab_dist y attention_dist
            # final_dist = p_gen * vocab_dist + (1 - p_gen) * copy_dist
            
            # Crear distribución extendida
            src_len = extended_encoder_input.size(1)
            extended_vocab_size = self.vocab_size + src_len  # Aproximación conservadora
            
            # Inicializar distribución extendida
            final_dist = torch.zeros(batch_size, extended_vocab_size, device=vocab_dist.device)
            
            # Añadir vocab_dist ponderado por p_gen
            final_dist[:, :self.vocab_size] = p_gen * vocab_dist
            
            # Añadir attention_dist ponderado por (1 - p_gen) usando scatter_add
            # Copiar de las posiciones del source
            attention_weighted = (1 - p_gen) * attention_dist  # (batch_size, src_len)
            
            # scatter_add para acumular probabilidades en posiciones extendidas
            final_dist.scatter_add_(
                dim=1,
                index=extended_encoder_input,
                src=attention_weighted
            )
        else:
            final_dist = vocab_dist
        
        return final_dist, decoder_state, context_vector, attention_dist, p_gen, coverage

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class PointerGeneratorNetwork(nn.Module):
    """
    Pointer-Generator Network con Coverage Mechanism para text summarization.
    
    Referencia: "Get To The Point: Summarization with Pointer-Generator Networks"
    (See et al., 2017) - https://arxiv.org/abs/1704.04368
    """
    
    def __init__(self, config, vocab, pretrained_weights=None):
        """
        Args:
            config: Objeto Config con hiperparámetros
            vocab: Objeto Vocabulary
            pretrained_weights: Tensor con pesos pre-entrenados (opcional)
        """
        super(PointerGeneratorNetwork, self).__init__()
        
        self.config = config
        self.vocab = vocab
        self.vocab_size = config['max_vocab_size']
        
        # Encoder
        self.encoder = Encoder(
            vocab_size=self.vocab_size,
            embedding_size=config['embedding_size'],
            hidden_size=config['hidden_size'],
            num_enc_layers=config['num_enc_layers'],  
            dropout_ratio=config['dropout_ratio'],
            bidirectional=config['bidirectional'],
            pretrained_weights=pretrained_weights
        )
        
        # Decoder
        self.decoder = Decoder(
            vocab_size=self.vocab_size,
            embedding_size=config['embedding_size'],
            hidden_size=config['hidden_size'],
            num_dec_layers=config['num_dec_layers'],  
            dropout_ratio=config['dropout_ratio'],
            is_attention=True,
            is_pgen=config['is_pgen'],
            is_coverage=config['is_coverage']
        )
        
        self.is_coverage = config['is_coverage']
        self.coverage_lambda = config['coverage_lambda'] if config['coverage_lambda'] is not None else 1.0
        self.device = config['device']
        
        # Compartir embeddings entre encoder y decoder
        self.decoder.embedding.weight = self.encoder.embedding.weight
    
    def forward(self, batch, is_training=True):
        """
        Forward pass completo para training.
        
        Args:
            batch: Dict con:
                - encoder_input: (batch_size, src_len) - IDs base
                - extended_encoder_input: (batch_size, src_len) - IDs extendidos
                - encoder_length: (batch_size,)
                - encoder_mask: (batch_size, src_len)
                - decoder_input: (batch_size, tgt_len) - IDs base
                - decoder_target: (batch_size, tgt_len) - IDs extendidos
            is_training: Si es modo training (teacher forcing)
            
        Returns:
            Dict con:
                - loss: Scalar tensor
                - vocab_loss: Scalar tensor
                - coverage_loss: Scalar tensor (si is_coverage=True)
                - final_dists: (batch_size, tgt_len, extended_vocab_size)
        """
        # Unpack batch
        encoder_input = batch['encoder_input'].to(self.device)
        extended_encoder_input = batch['extended_encoder_input'].to(self.device)
        encoder_length = batch['encoder_length'].to(self.device)
        encoder_mask = batch['encoder_mask'].to(self.device)
        decoder_input = batch['decoder_input'].to(self.device)
        decoder_target = batch['decoder_target'].to(self.device)
        
        batch_size, tgt_len = decoder_input.size()
        src_len = encoder_input.size(1)
        
        # 1. Encoder
        encoder_outputs, decoder_state = self.encoder(encoder_input, encoder_length)
        # encoder_outputs: (batch_size, src_len, hidden_size * 2)
        # decoder_state: Tuple (h, c) - (1, batch_size, hidden_size)
        
        # 2. Inicializar
        context_vector = torch.zeros(batch_size, self.config['hidden_size'] * 2, device=self.device)
        coverage = torch.zeros(batch_size, src_len, device=self.device) if self.is_coverage else None
        
        # 3. Decoder loop (teacher forcing)
        final_dists = []
        attention_dists = []
        coverages = []
        
        for t in range(tgt_len):
            decoder_input_t = decoder_input[:, t]  # (batch_size,)
            
            final_dist, decoder_state, context_vector, attention_dist, p_gen, coverage = self.decoder(
                decoder_input=decoder_input_t,
                decoder_state=decoder_state,
                encoder_outputs=encoder_outputs,
                encoder_mask=encoder_mask,
                extended_encoder_input=extended_encoder_input,
                context_vector=context_vector,
                coverage=coverage
            )
            
            final_dists.append(final_dist)
            attention_dists.append(attention_dist)
            if self.is_coverage:
                coverages.append(coverage)
        
        # Stack outputs
        final_dists = torch.stack(final_dists, dim=1)  # (batch_size, tgt_len, extended_vocab_size)
        attention_dists = torch.stack(attention_dists, dim=1)  # (batch_size, tgt_len, src_len)
        
        # 4. Calcular loss
        vocab_loss = self._calculate_vocab_loss(final_dists, decoder_target)
        
        coverage_loss = torch.tensor(0.0, device=self.device)
        if self.is_coverage and len(coverages) > 0:
            coverages = torch.stack(coverages, dim=1)  # (batch_size, tgt_len, src_len)
            coverage_loss = self._calculate_coverage_loss(attention_dists, coverages, encoder_mask)
        
        # Loss total (aplicar lambda al coverage loss)
        total_loss = vocab_loss + self.coverage_lambda * coverage_loss
        
        return {
            'loss': total_loss,
            'vocab_loss': vocab_loss,
            'coverage_loss': coverage_loss,
            'final_dists': final_dists
        }
    
    def _calculate_vocab_loss(self, final_dists, targets):
        """
        Calcula negative log likelihood loss.
        
        Args:
            final_dists: (batch_size, tgt_len, extended_vocab_size)
            targets: (batch_size, tgt_len) - IDs extendidos
            
        Returns:
            loss: Scalar tensor
        """
        batch_size, tgt_len, _ = final_dists.size()
        
        # Evitar log(0)
        final_dists = final_dists + 1e-12
        
        # Gather las probabilidades de los targets
        targets_expanded = targets.unsqueeze(2)  # (batch_size, tgt_len, 1)
        
        # Clamp targets para evitar índices fuera de rango
        max_idx = final_dists.size(2) - 1
        targets_clamped = torch.clamp(targets_expanded, 0, max_idx)
        
        probs = torch.gather(final_dists, dim=2, index=targets_clamped)  # (batch_size, tgt_len, 1)
        probs = probs.squeeze(2)  # (batch_size, tgt_len)
        
        # Negative log likelihood
        losses = -torch.log(probs)
        
        # Máscara de padding (PAD_ID = 0)
        mask = (targets != 0).float()
        
        # Loss promedio sobre tokens no-padding
        loss = (losses * mask).sum() / mask.sum()
        
        return loss
    
    def _calculate_coverage_loss(self, attention_dists, coverages, encoder_mask):
        """
        Calcula coverage loss para penalizar atención repetida.
        
        Args:
            attention_dists: (batch_size, tgt_len, src_len)
            coverages: (batch_size, tgt_len, src_len) - Coverage en cada paso
            encoder_mask: (batch_size, src_len)
            
        Returns:
            coverage_loss: Scalar tensor
        """
        # Coverage loss = sum_t min(a_t, c_t)
        # Penaliza cuando atendemos posiciones ya atendidas
        
        # Shift coverage: usamos coverage del paso anterior
        coverage_prev = torch.cat([
            torch.zeros_like(coverages[:, :1, :]),  # t=0 no tiene coverage previo
            coverages[:, :-1, :]  # t>0 usa coverage de t-1
        ], dim=1)
        
        # min(attention, coverage_prev)
        min_vals = torch.min(attention_dists, coverage_prev)
        
        # Sumar sobre src_len y tgt_len, aplicar máscara
        encoder_mask_expanded = encoder_mask.unsqueeze(1)  # (batch_size, 1, src_len)
        
        coverage_loss = (min_vals * encoder_mask_expanded.float()).sum()
        
        # Normalizar por número de tokens
        num_tokens = encoder_mask.sum()
        coverage_loss = coverage_loss / num_tokens
        
        return coverage_loss
    
    def decode_greedy(self, batch, max_len=None):
        """
        Decodificación greedy (sin beam search).
        
        Args:
            batch: Dict con encoder inputs
            max_len: Longitud máxima de generación
            
        Returns:
            generated_ids: (batch_size, max_len) - Secuencia generada
        """
        if max_len is None:
            max_len = self.config['tgt_len']
        
        # Unpack
        encoder_input = batch['encoder_input'].to(self.device)
        extended_encoder_input = batch['extended_encoder_input'].to(self.device)
        encoder_length = batch['encoder_length'].to(self.device)
        encoder_mask = batch['encoder_mask'].to(self.device)
        
        batch_size = encoder_input.size(0)
        src_len = encoder_input.size(1)
        
        # Encoder
        encoder_outputs, decoder_state = self.encoder(encoder_input, encoder_length)
        
        # Inicializar
        context_vector = torch.zeros(batch_size, self.config['hidden_size'] * 2, device=self.device)
        coverage = torch.zeros(batch_size, src_len, device=self.device) if self.is_coverage else None
        
        # Start token
        decoder_input = torch.full(
            (batch_size,), 
            self.vocab.word2id(self.vocab.start_decoding),
            dtype=torch.long,
            device=self.device
        )
        
        generated_ids = []
        p_gens = []
                
        for t in range(max_len):
            final_dist, decoder_state, context_vector, attention_dist, p_gen, coverage = self.decoder(
                decoder_input=decoder_input,
                decoder_state=decoder_state,
                encoder_outputs=encoder_outputs,
                encoder_mask=encoder_mask,
                extended_encoder_input=extended_encoder_input,
                context_vector=context_vector,
                coverage=coverage
            )
            
            # Greedy: seleccionar el token con mayor probabilidad
            predicted_ids = torch.argmax(final_dist, dim=1)  # (batch_size,)
            generated_ids.append(predicted_ids)
            
            # Guardamos p_gen (si es None por no-pgen, guardamos 1.0 = generación pura)
            if p_gen is not None:
                p_gens.append(p_gen)
            else:
                p_gens.append(torch.ones(batch_size, 1, device=self.device))
            
            # Próximo input: convertir OOVs a UNK
            decoder_input = torch.where(
                predicted_ids < self.vocab_size,
                predicted_ids,
                torch.full_like(predicted_ids, self.vocab.word2id(self.vocab.unk_token))
            )
        
        generated_ids = torch.stack(generated_ids, dim=1)  # (batch_size, max_len)
        p_gens = torch.stack(p_gens, dim=1) # (batch_size, max_len, 1)
        return generated_ids, p_gens


In [15]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import os
import time
from tqdm import tqdm
import json

class Trainer:
    """
    Clase para entrenar el modelo Pointer-Generator Network.
    """
    
    def __init__(self, config, vocab, pretrained_weights=None):
        """
        Args:
            config: Config object
            vocab: Vocabulary object
            pretrained_weights: Tensor con pesos pre-entrenados (opcional)
        """
        self.config = config
        self.vocab = vocab
        self.device = config['device']
        
        # Modelo
        self.model = PointerGeneratorNetwork(config, vocab, pretrained_weights).to(self.device)
        
        # Optimizer
        self.optimizer = build_optimizer(self.model, config)
        
        # Beam search para validación
        self.beam_search = BeamSearch(
            self.model,
            vocab,
            beam_size=config['beam_size'],
            max_len=config['tgt_len']
        )
        
        # Tracking
        self.start_epoch = 0
        self.global_step = 0
        self.best_val_loss = float('inf')
        
        # History
        self.train_history = {
            'epoch': [],
            'train_loss': [],
            'vocab_loss': [],
            'coverage_loss': [],
            'val_loss': []
        }
        
        # Paths
        self.checkpoint_dir = config['checkpoint_dir']
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        
        # AMP Scaler
        self.scaler = torch.amp.GradScaler('cuda', enabled=config['use_gpu'])
    
    def train(self, train_loader, val_loader, num_epochs):
        """
        Entrena el modelo.
        
        Args:
            train_loader: DataLoader de entrenamiento
            val_loader: DataLoader de validación
            num_epochs: Número de épocas
        """
        print(f"\n{'='*60}")
        print(f"Iniciando entrenamiento")
        print(f"{'='*60}")
        print(f"Device: {self.device}")
        print(f"Épocas: {num_epochs}")
        print(f"Batch size: {self.config['train_batch_size']}")
        print(f"Learning rate: {self.config['learning_rate']}")
        print(f"Pointer-Generator: {self.config['is_pgen']}")
        print(f"Coverage: {self.config['is_coverage']}")
        print(f"{'='*60}\n")
        
        for epoch in range(self.start_epoch, num_epochs):
            epoch_start_time = time.time()
            
            # Actualizar learning rate
            self.optimizer.update_learning_rate(epoch)
            current_lr = self.optimizer.get_learning_rate()
            
            print(f"\n--- Epoch {epoch+1}/{num_epochs} (LR: {current_lr:.6f}) ---")
            
            # Training
            train_metrics = self._train_epoch(train_loader, epoch)
            
            # Validation
            val_metrics = self._validate_epoch(val_loader, epoch)
            
            # Guardar history
            self.train_history['epoch'].append(epoch + 1)
            self.train_history['train_loss'].append(train_metrics['loss'])
            self.train_history['vocab_loss'].append(train_metrics['vocab_loss'])
            self.train_history['coverage_loss'].append(train_metrics['coverage_loss'])
            self.train_history['val_loss'].append(val_metrics['loss'])
            
            # Guardar checkpoint
            is_best = val_metrics['loss'] < self.best_val_loss
            if is_best:
                self.best_val_loss = val_metrics['loss']
            
            save_model = self.config['save_model_epoch'] if self.config['save_model_epoch'] is not None else True
            if save_model:
                self._save_checkpoint(epoch, is_best)
            
            # Tiempo
            epoch_time = time.time() - epoch_start_time
            print(f"Tiempo de época: {epoch_time/60:.2f} min")
            
        # Guardar history final
        save_hist = self.config['save_history'] if self.config['save_history'] is not None else True
        if save_hist:
            self._save_history()
        
        print(f"\n{'='*60}")
        print(f"Entrenamiento completado!")
        print(f"Mejor Val Loss: {self.best_val_loss:.4f}")
        print(f"{'='*60}\n")
    
    def _train_epoch(self, train_loader, epoch):
        """Entrena una época."""
        self.model.train()
        
        total_loss = 0.0
        total_vocab_loss = 0.0
        total_coverage_loss = 0.0
        num_batches = 0
        
        # Progress bar
        pbar = tqdm(train_loader, desc=f"Training", total=len(train_loader))
        
        for batch_idx, batch in enumerate(pbar):
            if batch is None:
                continue
            self.optimizer.zero_grad()
            
            # Forward pass (con AMP)
            with torch.amp.autocast('cuda', enabled=self.config['use_gpu']):
                outputs = self.model(batch, is_training=True)
                loss = outputs['loss']
            
            # Backward pass (con Scaler)
            self.scaler.scale(loss).backward()
            
            # Optimizer step (con gradient clipping y Scaler)
            # Desescalar gradientes par clipping
            self.scaler.unscale_(self.optimizer.optimizer)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['grad_clip'])
            
            self.scaler.step(self.optimizer.optimizer)
            self.scaler.update()
            
            # Actualizar contador de pasos del wrapper (ya que no llamamos a self.optimizer.step())
            self.optimizer.current_step += 1
            
            # Tracking
            total_loss += loss.item()
            total_vocab_loss += outputs['vocab_loss'].item()
            total_coverage_loss += outputs['coverage_loss'].item()
            num_batches += 1
            self.global_step += 1
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'vocab': f"{outputs['vocab_loss'].item():.4f}",
                'cov': f"{outputs['coverage_loss'].item():.4f}"
            })
            
            # Limitar iteraciones por época si está configurado
            iters_per_epoch = self.config['iters_per_epoch']
            if iters_per_epoch and batch_idx >= iters_per_epoch:
                break
        
        pbar.close()
        
        # Promedios
        avg_loss = total_loss / num_batches
        avg_vocab_loss = total_vocab_loss / num_batches
        avg_coverage_loss = total_coverage_loss / num_batches
        
        print(f"Train Loss: {avg_loss:.4f} | "
              f"Vocab: {avg_vocab_loss:.4f} | "
              f"Coverage: {avg_coverage_loss:.4f}")
        
        return {
            'loss': avg_loss,
            'vocab_loss': avg_vocab_loss,
            'coverage_loss': avg_coverage_loss
        }
    
    def _validate_epoch(self, val_loader, epoch):
        """Valida el modelo."""
        self.model.eval()
        
        total_loss = 0.0
        total_vocab_loss = 0.0
        total_coverage_loss = 0.0
        num_batches = 0
        
        with torch.no_grad():
            pbar = tqdm(val_loader, desc=f"Validation", total=len(val_loader))
            
            for batch in pbar:
                if batch is None:
                    continue
                outputs = self.model(batch, is_training=True)
                
                total_loss += outputs['loss'].item()
                total_vocab_loss += outputs['vocab_loss'].item()
                total_coverage_loss += outputs['coverage_loss'].item()
                num_batches += 1
                
                pbar.set_postfix({
                    'val_loss': f"{outputs['loss'].item():.4f}"
                })
            
            pbar.close()
        
        # Promedios
        avg_loss = total_loss / num_batches
        avg_vocab_loss = total_vocab_loss / num_batches
        avg_coverage_loss = total_coverage_loss / num_batches
        
        print(f"Val Loss: {avg_loss:.4f} | "
              f"Vocab: {avg_vocab_loss:.4f} | "
              f"Coverage: {avg_coverage_loss:.4f}")
        
        return {
            'loss': avg_loss,
            'vocab_loss': avg_vocab_loss,
            'coverage_loss': avg_coverage_loss
        }
    
    def _save_checkpoint(self, epoch, is_best=False):
        """Guarda un checkpoint del modelo."""
        checkpoint = {
            'epoch': epoch + 1,
            'global_step': self.global_step,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'best_val_loss': self.best_val_loss,
            'config': self.config.get_config_dict(),
            'train_history': self.train_history
        }
        
        
        # Guardar último checkpoint (sobreescribe)
        checkpoint_path = os.path.join(self.checkpoint_dir, 'checkpoint_last2.pt')
        torch.save(checkpoint, checkpoint_path)
        
        # Guardar mejor checkpoint
        if is_best:
            best_path = os.path.join(self.checkpoint_dir, 'checkpoint_best2.pt')
            torch.save(checkpoint, best_path)
            print(f"✓ Guardado mejor modelo (Val Loss: {self.best_val_loss:.4f})")
    
    def _save_history(self):
        """Guarda el historial de entrenamiento."""
        history_path = os.path.join(self.checkpoint_dir, 'train_history.json')
        with open(history_path, 'w') as f:
            json.dump(self.train_history, f, indent=2)
        print(f"✓ Historial guardado en {history_path}")
    
    def load_checkpoint(self, checkpoint_path):
        """
        Carga un checkpoint.
        
        Args:
            checkpoint_path: Ruta al checkpoint
        """
        if not os.path.exists(checkpoint_path):
            print(f"⚠ Checkpoint no encontrado: {checkpoint_path}")
            return
        
        print(f"Cargando checkpoint desde {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.start_epoch = checkpoint['epoch']
        self.global_step = checkpoint['global_step']
        self.best_val_loss = checkpoint['best_val_loss']
        self.train_history = checkpoint.get('train_history', self.train_history)
        
        print(f"✓ Checkpoint cargado (Epoch {self.start_epoch}, Step {self.global_step})")
    
    def find_latest_checkpoint(self):
        """Busca el checkpoint más reciente."""
        if not os.path.exists(self.checkpoint_dir):
            return None
        
        # Buscar checkpoint_last.pt primero
        last_checkpoint = os.path.join(self.checkpoint_dir, 'checkpoint_last2.pt')
        if os.path.exists(last_checkpoint):
            return last_checkpoint
        
        # Si no existe, buscar el checkpoint de época más reciente
        epoch_checkpoints = []
        for filename in os.listdir(self.checkpoint_dir):
            if filename.startswith('checkpoint_epoch_x2') and filename.endswith('.pt'):
                try:
                    epoch_num = int(filename.replace('checkpoint_epoch_x2', '').replace('.pt', ''))
                    epoch_checkpoints.append((epoch_num, os.path.join(self.checkpoint_dir, filename)))
                except ValueError:
                    continue
        
        if epoch_checkpoints:
            # Retornar el de mayor número de época
            epoch_checkpoints.sort(reverse=True)
            return epoch_checkpoints[0][1]
        
        return None


def main():
    """
    Función principal para entrenar el modelo.
    """
    # 1. Configurar reproducibilidad
    if REPRODUCIBILITY:
        torch.manual_seed(SEED)
        torch.cuda.manual_seed_all(SEED)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    
    # 2. Crear vocabulario
    print("Construyendo vocabulario...")
    vocab = Vocabulary(
        CREATE_VOCABULARY=CREATE_VOCABULARY,
        PAD_TOKEN=PAD_TOKEN,
        UNK_TOKEN=UNK_TOKEN,
        START_DECODING=START_DECODING,
        END_DECODING=END_DECODING,
        MAX_VOCAB_SIZE=MAX_VOCAB_SIZE,
        CHECKPOINT_VOCABULARY_DIR=CHECKPOINT_VOCABULARY_DIR,
        DATA_DIR=DATA_DIR,
        VOCAB_NAME=VOCAB_NAME
    )
    vocab.build_vocabulary()
    print(f"✓ Vocabulario construido: {vocab.total_size()} palabras")
    
    # 3. Configurar modelo
    config = Config(
        max_vocab_size=vocab.total_size(),
        src_len=MAX_LEN_SRC,
        tgt_len=MAX_LEN_TGT,
        embedding_size=EMBEDDING_SIZE,
        hidden_size=HIDDEN_SIZE,
        num_enc_layers=NUM_ENC_LAYERS,
        num_dec_layers=NUM_DEC_LAYERS,
        use_gpu=USE_GPU,
        is_pgen=IS_PGEN,
        is_coverage=IS_COVERAGE,
        coverage_lambda=COV_LOSS_LAMBDA,
        grad_clip=GRAD_CLIP,
        epochs=EPOCHS,
        data_path=DATA_DIR,
        generated_text_dir=GENERATED_TEXT_DIR,
        checkpoint_dir=CHECKPOINT_DIR,
        reproducibility=REPRODUCIBILITY,
        plot=PLOT,
        dropout_ratio=DROPOUT_RATIO,
        bidirectional=BIDIRECTIONAL,
        save_history=SAVE_HISTORY,
        save_model_epoch=SAVE_MODEL_EPOCH,
        seed=SEED,
        device=DEVICE,
        decoding_strategy=DECODING_STRATEGY,
        beam_size=BEAM_SIZE,
        train_batch_size=TRAIN_BATCH_SIZE,
        eval_batch_size=EVAL_BATCH_SIZE,
        learner=LEARNER,
        learning_rate=LEARNING_RATE,
        iters_per_epoch=ITERS_PER_EPOCH,
        gpu_id=GPU_ID,
        warmup_epochs=WARMUP_EPOCHS,
        embedding_path=EMBEDDING_PATH
    )
    
    print(config)

    # 4. Cargar embeddings pre-entrenados (OPCIONAL)
    pretrained_weights = None
    if config['embedding_path'] is not None:
        pretrained_weights = vocab.load_pretrained_embeddings(
            config['embedding_path'], 
            config['embedding_size']
        )
    
    # 5. Crear datasets
    print("\nCargando datasets...")
    train_dataset = PGNDataset(
        vocab=vocab,
        MAX_LEN_SRC=config['src_len'],
        MAX_LEN_TGT=config['tgt_len'],
        data_dir=config['data_path'],
        split='train'
    )
    
    val_dataset = PGNDataset(
        vocab=vocab,
        MAX_LEN_SRC=config['src_len'],
        MAX_LEN_TGT=config['tgt_len'],
        data_dir=config['data_path'],
        split='val'
    )
    
    print(f"✓ Train: {len(train_dataset)} ejemplos")
    print(f"✓ Val: {len(val_dataset)} ejemplos")
    
    # 5. Crear DataLoaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['train_batch_size'],
        shuffle=True,
        collate_fn=pgn_collate_fn,
        num_workers=2,  # Optimizado para Kaggle (2-4 suele ser ideal)
        pin_memory=True if USE_GPU else False
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config['eval_batch_size'],
        shuffle=False,
        collate_fn=pgn_collate_fn,
        num_workers=2,
        pin_memory=True if USE_GPU else False
    )
    
    # 6. Crear trainer
    trainer = Trainer(config, vocab, pretrained_weights)
    
    # 7. Buscar y cargar checkpoint automáticamente
    latest_checkpoint = trainer.find_latest_checkpoint()
    if latest_checkpoint:
        print(f"\n🔄 Checkpoint encontrado: {latest_checkpoint}")
        trainer.load_checkpoint(latest_checkpoint)
        print(f"Continuando desde época {trainer.start_epoch}\n")
    else:
        print("\n🆕 No se encontró checkpoint previo. Iniciando entrenamiento desde cero.\n")
    
    # 8. Entrenar
    trainer.train(
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=config['epochs']
    )


if __name__ == "__main__":
    main()


Construyendo vocabulario...
Descargando modelo de spaCy español (Large)...
Collecting es-core-news-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/es_core_news_sm-3.8.0/es_core_news_sm-3.8.0-py3-none-any.whl (12.9 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 12.9/12.9 MB 31.8 MB/s eta 0:00:00
Installing collected packages: es-core-news-sm
Successfully installed es-core-news-sm-3.8.0
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('es_core_news_sm')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.
saved/working/Vocabulary.json
 Vocabulario cargado desde: saved/working/Vocabulary.json
 Tamaño total: 50000
 Tokens especiales: 4
 Tokens regulares: 49996
✓ Vocabulario construido: 50000 palabras

Hyper Parame

Alineando embeddings: 2000000it [00:31, 64001.48it/s]


✓ Cobertura Exacta: 49958 / 50000 (99.9%)
  - Las 42 palabras restantes serán aprendidas desde cero.

Cargando datasets...
✓ Usando archivos TOKENIZADOS para train (Carga optimizada en Kaggle)
✓ Usando archivos TOKENIZADOS para val (Carga optimizada en Kaggle)
✓ Train: 266367 ejemplos
✓ Val: 10358 ejemplos
✓ Encoder: Pesos de embedding inicializados (no Entrenables).

🔄 Checkpoint encontrado: kaggle/working/saved/checkpoint_last2.pt
Cargando checkpoint desde kaggle/working/saved/checkpoint_last2.pt
✓ Checkpoint cargado (Epoch 6, Step 24972)
Continuando desde época 6


Iniciando entrenamiento
Device: cuda:0
Épocas: 15
Batch size: 64
Learning rate: 0.001
Pointer-Generator: True
Coverage: True


--- Epoch 7/15 (LR: 0.001000) ---


Training: 100%|██████████| 4162/4162 [39:14<00:00,  1.77it/s, loss=3.0947, vocab=3.0790, cov=0.0158]


Train Loss: 2.9726 | Vocab: 2.9540 | Coverage: 0.0186


Validation: 100%|██████████| 162/162 [00:32<00:00,  4.98it/s, val_loss=3.4514]


Val Loss: 3.2078 | Vocab: 3.1891 | Coverage: 0.0187
Tiempo de época: 39.81 min

--- Epoch 8/15 (LR: 0.001000) ---


Training: 100%|██████████| 4162/4162 [39:14<00:00,  1.77it/s, loss=2.6625, vocab=2.6477, cov=0.0149]


Train Loss: 2.9168 | Vocab: 2.8985 | Coverage: 0.0183


Validation: 100%|██████████| 162/162 [00:32<00:00,  5.00it/s, val_loss=3.4433]


Val Loss: 3.2070 | Vocab: 3.1886 | Coverage: 0.0183
Tiempo de época: 39.81 min

--- Epoch 9/15 (LR: 0.001000) ---


Training: 100%|██████████| 4162/4162 [39:10<00:00,  1.77it/s, loss=2.7482, vocab=2.7289, cov=0.0193]


Train Loss: 2.8641 | Vocab: 2.8462 | Coverage: 0.0179


Validation: 100%|██████████| 162/162 [00:32<00:00,  4.98it/s, val_loss=3.4268]


Val Loss: 3.2248 | Vocab: 3.2046 | Coverage: 0.0201
Tiempo de época: 39.74 min

--- Epoch 10/15 (LR: 0.001000) ---


Training: 100%|██████████| 4162/4162 [39:15<00:00,  1.77it/s, loss=2.6322, vocab=2.6147, cov=0.0175]


Train Loss: 2.8234 | Vocab: 2.8055 | Coverage: 0.0179


Validation: 100%|██████████| 162/162 [00:32<00:00,  5.01it/s, val_loss=3.4530]


Val Loss: 3.2376 | Vocab: 3.2173 | Coverage: 0.0203
Tiempo de época: 39.83 min

--- Epoch 11/15 (LR: 0.001000) ---


Training: 100%|██████████| 4162/4162 [39:14<00:00,  1.77it/s, loss=2.8434, vocab=2.8203, cov=0.0231]


Train Loss: 2.8223 | Vocab: 2.8010 | Coverage: 0.0212


Validation: 100%|██████████| 162/162 [00:32<00:00,  5.03it/s, val_loss=3.4762]


Val Loss: 3.2557 | Vocab: 3.2359 | Coverage: 0.0197
Tiempo de época: 39.80 min

--- Epoch 12/15 (LR: 0.001000) ---


Training: 100%|██████████| 4162/4162 [39:11<00:00,  1.77it/s, loss=2.8557, vocab=2.8360, cov=0.0197]


Train Loss: 2.8017 | Vocab: 2.7812 | Coverage: 0.0205


Validation: 100%|██████████| 162/162 [00:32<00:00,  5.02it/s, val_loss=3.4779]


Val Loss: 3.2506 | Vocab: 3.2301 | Coverage: 0.0204
Tiempo de época: 39.75 min

--- Epoch 13/15 (LR: 0.001000) ---


Training: 100%|██████████| 4162/4162 [39:12<00:00,  1.77it/s, loss=2.8316, vocab=2.8164, cov=0.0152]


Train Loss: 2.7505 | Vocab: 2.7305 | Coverage: 0.0200


Validation: 100%|██████████| 162/162 [00:32<00:00,  4.98it/s, val_loss=3.4756]


Val Loss: 3.2620 | Vocab: 3.2413 | Coverage: 0.0207
Tiempo de época: 39.77 min

--- Epoch 14/15 (LR: 0.001000) ---


Training: 100%|██████████| 4162/4162 [39:13<00:00,  1.77it/s, loss=2.6551, vocab=2.6352, cov=0.0199]


Train Loss: 2.7245 | Vocab: 2.7052 | Coverage: 0.0193


Validation: 100%|██████████| 162/162 [00:32<00:00,  4.97it/s, val_loss=3.5223]


Val Loss: 3.2808 | Vocab: 3.2597 | Coverage: 0.0211
Tiempo de época: 39.78 min

--- Epoch 15/15 (LR: 0.001000) ---


Training: 100%|██████████| 4162/4162 [39:17<00:00,  1.77it/s, loss=2.9747, vocab=2.9342, cov=0.0404]


Train Loss: 2.7592 | Vocab: 2.7322 | Coverage: 0.0270


Validation: 100%|██████████| 162/162 [00:32<00:00,  5.00it/s, val_loss=3.5651]


Val Loss: 3.3569 | Vocab: 3.3192 | Coverage: 0.0377
Tiempo de época: 39.86 min
✓ Historial guardado en kaggle/working/saved/train_history.json

Entrenamiento completado!
Mejor Val Loss: 3.2034



# generar

In [16]:
import os
import json
import sys

def analyze_missing_words():
    vocab_path = os.path.join(CHECKPOINT_VOCABULARY_DIR, VOCAB_NAME)
    
    if not os.path.exists(vocab_path):
        print(f"✗ No se encontró el vocabulario en {vocab_path}. Asegúrate de haberlo construido antes.")
        return

    print(f"Cargando vocabulario desde {vocab_path}...")
    with open(vocab_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    word_to_id = data['word_to_id']
    id_to_word = data['id_to_word']
    word_count = data['word_count']
    
    if not os.path.exists(EMBEDDING_PATH):
        print(f"✗ No se encontró el archivo de embeddings en {EMBEDDING_PATH}")
        return

    print(f"Leyendo archivo de embeddings para identificar palabras presentes...")
    words_in_embeddings = set()
    a=[]
    s=1
    try:
        with open(EMBEDDING_PATH, 'r', encoding='utf-8', errors='ignore') as f:
            # Saltar header si existe
            header = f.readline().split()
            if len(header) != 2:
                f.seek(0)
            
            for line in f:
                parts = line.rstrip().split(' ')
                words_in_embeddings.add(parts[0])
                if s==1:
                    s+=1
                    a.append((parts[0],parts[1:301]))
    except Exception as e:
        print(f"✗ Error leyendo embeddings: {e}")
        return

    missing_words = []
    found_count = 0
    
    for word in id_to_word:
        # Ignorar tokens especiales
        if word.startswith('[') and word.endswith(']'):
            continue
            
        if word in words_in_embeddings:
            found_count += 1
        else:
            count = word_count.get(word, 0)
            missing_words.append((word, count))

    # Ordenar palabras faltantes por frecuencia
    missing_words.sort(key=lambda x: x[1], reverse=True)

    total_vocab = len(id_to_word) - 4 # Descontar tokens especiales
    print(f"\n{'='*60}")
    print(f"ANÁLISIS DE COBERTURA")
    print(f"{'='*60}")
    print(f"Total palabras (sin especiales): {total_vocab}")
    print(f"Palabras encontradas:            {found_count} ({found_count/total_vocab*100:.2f}%)")
    print(f"Palabras FALTANTES:              {len(missing_words)} ({len(missing_words)/total_vocab*100:.2f}%)")
    print(f"{'='*60}")
    
    print("\nTOP 200 PALABRAS MÁS FRECUENTES FALTANTES EN EMBEDDINGS:")
    print(f"{'Palabra':<30} | {'Frecuencia':<10}")
    print("-" * 45)
    for word, count in missing_words[:200]:
        print(f"{word:<30} | {count:<10}")
    

    # Análisis de causas comunes
    casing_issues = sum(1 for w, c in missing_words if w.lower() in words_in_embeddings and w != w.lower())
    print(f"\nPosibles mejoras:")
    print(f"- {casing_issues} palabras podrían encontrarse si pasamos todo a minúsculas.")
    print(a[0])
if __name__ == "__main__":
    analyze_missing_words()


Cargando vocabulario desde saved/working/Vocabulary.json...
Leyendo archivo de embeddings para identificar palabras presentes...

ANÁLISIS DE COBERTURA
Total palabras (sin especiales): 49996
Palabras encontradas:            49958 (99.92%)
Palabras FALTANTES:              38 (0.08%)

TOP 200 PALABRAS MÁS FRECUENTES FALTANTES EN EMBEDDINGS:
Palabra                        | Frecuencia
---------------------------------------------
I+D                            | 1849      
d'Esquadra                     | 1591      
-M.                            | 907       
I+D+i                          | 820       
elpaissemanal                  | 756       
REUTERS-QUALITY                | 700       
n't                            | 678       
km/h                           | 660       
-year-old                      | 639       
-S.                            | 596       
-N.                            | 560       
L'Hospitalet                   | 553       
Canal+                         | 534     

In [4]:
import torch
from torch.utils.data import DataLoader
import os
from tqdm import tqdm
import json

def decode_sequence_to_text(id_sequence, vocab, oov_id_to_word):
    """
    Decodifica una secuencia de IDs a texto.
    
    Args:
        id_sequence: List o tensor de IDs
        vocab: Vocabulary object
        oov_id_to_word: Dict[int, str] - Mapeo de IDs extendidos a palabras OOV
    
    Returns:
        List[str] - Palabras decodificadas
    """
    if torch.is_tensor(id_sequence):
        id_sequence = id_sequence.cpu().tolist()
    
    V_base = len(vocab.word_to_id)
    decoded_words = []
    
    for id in id_sequence:
        id = int(id)
        
        # Tokens especiales
        if id == vocab.word2id(vocab.pad_token):
            continue  # Ignorar padding
        elif id == vocab.word2id(vocab.start_decoding):
            continue  # Ignorar START
        elif id == vocab.word2id(vocab.end_decoding):
            break  # Terminar en END
        # IDs base del vocabulario
        elif id < V_base:
            decoded_words.append(vocab.id2word(id))
        # IDs extendidos (OOV copiado)
        elif id in oov_id_to_word:
            decoded_words.append(oov_id_to_word[id])
        else:
            decoded_words.append(vocab.unk_token)
    
    return decoded_words


def create_oov_id_to_word_map(oov_words, V_base):
    """
    Crea el mapeo ID a Palabra OOV.
    
    Args:
        oov_words: List[str] - Palabras OOV del ejemplo
        V_base: int - Tamaño del vocabulario base
    
    Returns:
        Dict[int, str] - Mapeo de ID extendido a palabra OOV
    """
    oov_id_to_word = {}
    oov_id = V_base
    
    for word in oov_words:
        if word == '':  # Ignorar padding
            continue
        oov_id_to_word[oov_id] = word
        oov_id += 1
    
    return oov_id_to_word


class Generator:
    """
    Clase para generar resúmenes usando el modelo entrenado.
    """
    
    def __init__(self, config, vocab, model_path):
        """
        Args:
            config: Config object
            vocab: Vocabulary object
            model_path: Ruta al checkpoint del modelo
        """
        self.config = config
        self.vocab = vocab
        self.device = config['device']
        
        # Cargar modelo
        print(f"Cargando modelo desde {model_path}")
        self.model = PointerGeneratorNetwork(config, vocab).to(self.device)
        
        checkpoint = torch.load(model_path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.eval()
        
        print(f"✓ Modelo cargado (Epoch {checkpoint['epoch']})")
        
        # Beam search
        self.beam_search = BeamSearch(
            self.model,
            vocab,
            beam_size=config['beam_size'],
            max_len=config['tgt_len']
        )
    
    def generate(self, test_loader, output_file=None, num_examples=None):
        """
        Genera resúmenes para el dataset de test.
        
        Args:
            test_loader: DataLoader de test
            output_file: Archivo donde guardar los resúmenes generados
            num_examples: Número de ejemplos a generar (None = todos)
        
        Returns:
            List[Dict] - Lista con resultados (source, target, generated)
        """
        results = []
        V_base = len(self.vocab.word_to_id)
        
        print(f"\n{'='*60}")
        print(f"Generando resúmenes")
        print(f"{'='*60}")
        print(f"Estrategia: {self.config['decoding_strategy']}")
        print(f"Beam size: {self.config['beam_size']}")
        print(f"{'='*60}\n")
        
        with torch.no_grad():
            pbar = tqdm(test_loader, desc="Generando", total=len(test_loader))
            
            for batch_idx, batch in enumerate(pbar):
                batch_size = batch['encoder_input'].size(0)
                
                for b in range(batch_size):
                    # Extraer ejemplo individual
                    single_batch = {
                        'encoder_input': batch['encoder_input'][b:b+1],
                        'extended_encoder_input': batch['extended_encoder_input'][b:b+1],
                        'encoder_length': batch['encoder_length'][b:b+1],
                        'encoder_mask': batch['encoder_mask'][b:b+1],
                        'decoder_target': batch['decoder_target'][b:b+1]
                    }
                    
                    oov_words = batch['oov_words'][b]
                    oov_map = create_oov_id_to_word_map(oov_words, V_base)
                    
                    # Generar resumen
                    if self.config['decoding_strategy'] == 'beam_search':
                        hypothesis = self.beam_search.search(single_batch)
                        generated_ids = hypothesis.tokens[1:]  # Quitar START
                    else:  # greedy
                        generated_ids = self.model.decode_greedy(single_batch, max_len=self.config['tgt_len'])
                        generated_ids = generated_ids[0].cpu().tolist()
                    
                    # Decodificar a texto
                    source_ids = single_batch['extended_encoder_input'][0].cpu().tolist()
                    target_ids = single_batch['decoder_target'][0].cpu().tolist()
                    
                    source_text = decode_sequence_to_text(source_ids, self.vocab, oov_map)
                    target_text = decode_sequence_to_text(target_ids, self.vocab, oov_map)
                    generated_text = decode_sequence_to_text(generated_ids, self.vocab, oov_map)
                    
                    result = {
                        'source': ' '.join(source_text),
                        'target': ' '.join(target_text),
                        'generated': ' '.join(generated_text)
                    }
                    
                    results.append(result)
                    
                    if num_examples and len(results) >= num_examples:
                        break
                
                if num_examples and len(results) >= num_examples:
                    break
            
            pbar.close()
        
        print(f"\n✓ Generados {len(results)} resúmenes")
        
        # Guardar resultados
        if output_file:
            os.makedirs(os.path.dirname(output_file), exist_ok=True)
            
            with open(output_file, 'w', encoding='utf-8') as f:
                json.dump(results, f, ensure_ascii=False, indent=2)
            
            print(f"✓ Resultados guardados en {output_file}")
            
            # También guardar en formato legible
            txt_file = output_file.replace('.json', '.txt')
            with open(txt_file, 'w', encoding='utf-8') as f:
                for i, result in enumerate(results):
                    f.write(f"{'='*60}\n")
                    f.write(f"Ejemplo {i+1}\n")
                    f.write(f"{'='*60}\n")
                    f.write(f"SOURCE:\n{result['source']}\n\n")
                    f.write(f"TARGET:\n{result['target']}\n\n")
                    f.write(f"GENERATED:\n{result['generated']}\n\n")
            
            print(f"✓ Resultados legibles en {txt_file}")
        
        return results


def main():
    """
    Función principal para generar resúmenes.
    """
    # 1. Cargar vocabulario
    print("Cargando vocabulario...")
    vocab = Vocabulary(
        CREATE_VOCABULARY=False,  # No crear, solo cargar
        PAD_TOKEN=PAD_TOKEN,
        UNK_TOKEN=UNK_TOKEN,
        START_DECODING=START_DECODING,
        END_DECODING=END_DECODING,
        MAX_VOCAB_SIZE=MAX_VOCAB_SIZE,
        CHECKPOINT_VOCABULARY_DIR=CHECKPOINT_VOCABULARY_DIR,
        DATA_DIR=DATA_DIR,
        VOCAB_NAME=VOCAB_NAME
    )
    vocab.build_vocabulary()
    print(f"✓ Vocabulario cargado: {vocab.total_size()} palabras")
    
    # 2. Configurar
    config = Config(
        max_vocab_size=vocab.total_size(),
        src_len=MAX_LEN_SRC,
        tgt_len=MAX_LEN_TGT,
        embedding_size=EMBEDDING_SIZE,
        hidden_size=HIDDEN_SIZE,
        num_enc_layers=NUM_ENC_LAYERS,
        num_dec_layers=NUM_DEC_LAYERS,
        use_gpu=USE_GPU,
        is_pgen=IS_PGEN,
        is_coverage=IS_COVERAGE,
        dropout_ratio=DROPOUT_RATIO,
        bidirectional=BIDIRECTIONAL,
        device=DEVICE,
        decoding_strategy=DECODING_STRATEGY,
        beam_size=BEAM_SIZE,
        gpu_id=GPU_ID
    )
    
    # 3. Dataset de test
    print("\nCargando dataset de test...")
    test_dataset = PGNDataset(
        vocab=vocab,
        MAX_LEN_SRC=config['src_len'],
        MAX_LEN_TGT=config['tgt_len'],
        data_dir=DATA_DIR,
        split='test'
    )
    
    print(f"✓ Test: {len(test_dataset)} ejemplos")
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=1,  # Procesar de uno en uno para beam search
        shuffle=True,
        collate_fn=pgn_collate_fn,
        num_workers=0
    )
    
    # 4. Ruta del modelo
    # Usar el mejor modelo por defecto
    model_path = os.path.join(CHECKPOINT_DIR, 'checkpoint_best2.pt')
    
    # Si no existe, buscar el último
    if not os.path.exists(model_path):
        model_path = os.path.join(CHECKPOINT_DIR, 'checkpoint_last2.pt')
    
    if not os.path.exists(model_path):
        print(f"⚠ No se encontró ningún checkpoint en {CHECKPOINT_DIR}")
        return
    
    # 5. Generar
    generator = Generator(config, vocab, model_path)
    
    # Generar todos los ejemplos (o especificar un número)
    output_file = os.path.join(GENERATED_TEXT_DIR, 'test_results.json')
    
    results = generator.generate(
        test_loader,
        output_file=output_file,
        num_examples=10  # None = todos, o especificar un número
    )
    
    # Mostrar algunos ejemplos
    print(f"\n{'='*60}")
    print("Ejemplos generados:")
    print(f"{'='*60}\n")
    
    for i, result in enumerate(results[:20]):
        print(f"Ejemplo {i+1}:")
        print(f"SRC: {result['source']}")
        print(f"TARGET: {result['target']}")
        print(f"GENERATED: {result['generated']}")
        print()


if __name__ == "__main__":
    main()


Cargando vocabulario...
Descargando modelo de spaCy español (Large)...
Collecting es-core-news-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/es_core_news_sm-3.8.0/es_core_news_sm-3.8.0-py3-none-any.whl (12.9 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 12.9/12.9 MB 46.1 MB/s eta 0:00:00
Installing collected packages: es-core-news-sm
Successfully installed es-core-news-sm-3.8.0
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('es_core_news_sm')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.
saved/working/Vocabulary.json
 Vocabulario cargado desde: saved/working/Vocabulary.json
 Tamaño total: 50000
 Tokens especiales: 4
 Tokens regulares: 49996
✓ Vocabulario cargado: 50000 palabras

Cargando dataset de

Generando:   0%|          | 9/13920 [00:15<6:35:01,  1.70s/it]


✓ Generados 10 resúmenes
✓ Resultados guardados en kaggle/working/generated/test_results.json
✓ Resultados legibles en kaggle/working/generated/test_results.txt

Ejemplos generados:

Ejemplo 1:
SRC: Puntuación 6,5 Arquitectura 7 Decoración 5 Estado de conservación 7 Confortabilidad habitaciones 6 Aseos 5 Ambiente 6 Desayuno 8 Atención 9 Tranquilidad 7 Instalaciones 6 Merche y José Luis Castillo se anticiparon en 2002 al boom del enoturismo en La Rioja con la apertura de este hotelito-bodega situado en Ábalos , a unos 15 kilómetros de Haro . Sus conocimientos en viticultura y enología , por no mencionar los platos que acompañan estas catas , los guiaron hasta una casa solariega del siglo XVII reconocida por su impecable factura barroca que enseguida reformaron y decoraron con acento rústico , ecléctico y colorista . La piedra adusta , acaramelada con maderas y detalles de forja , era el reclamo necesario para la seriedad que impone un viaje por la arquitectura histórica de la comarca ,




# Evaluar

In [8]:
import torch
from torch.utils.data import DataLoader
import os
from tqdm import tqdm
import json
import numpy as np
import matplotlib.pyplot as plt
from rouge import Rouge
from nltk.translate.meteor_score import meteor_score
import nltk


try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt', quiet=True)
try:
    nltk.data.find('corpora/wordnet')
except LookupError:
    nltk.download('wordnet', quiet=True)


def decode_sequence_to_text(id_sequence, vocab, oov_id_to_word):
    """Decodifica una secuencia de IDs a texto."""
    if torch.is_tensor(id_sequence):
        id_sequence = id_sequence.cpu().tolist()
    
    V_base = len(vocab.word_to_id)
    decoded_words = []
    
    for id in id_sequence:
        id = int(id)
        
        if id == vocab.word2id(vocab.pad_token):
            continue
        elif id == vocab.word2id(vocab.start_decoding):
            continue
        elif id == vocab.word2id(vocab.end_decoding):
            break
        elif id < V_base:
            decoded_words.append(vocab.id2word(id))
        elif id in oov_id_to_word:
            decoded_words.append(oov_id_to_word[id])
        else:
            decoded_words.append(vocab.unk_token)
    
    return decoded_words


def create_oov_id_to_word_map(oov_words, V_base):
    """Crea el mapeo ID a Palabra OOV."""
    oov_id_to_word = {}
    oov_id = V_base
    
    for word in oov_words:
        if word == '':
            continue
        oov_id_to_word[oov_id] = word
        oov_id += 1
    
    return oov_id_to_word


def calculate_rouge_scores(reference, candidate):
    """
    Calcula ROUGE scores usando la librería rouge.
    
    Args:
        reference: str - Resumen de referencia
        candidate: str - Resumen generado
        
    Returns:
        Dict con ROUGE-1, ROUGE-2 y ROUGE-L scores
    """
    if not reference.strip() or not candidate.strip():
        return {
            'rouge-1': {'f': 0.0, 'p': 0.0, 'r': 0.0},
            'rouge-2': {'f': 0.0, 'p': 0.0, 'r': 0.0},
            'rouge-l': {'f': 0.0, 'p': 0.0, 'r': 0.0}
        }
    
    rouge = Rouge()
    try:
        scores = rouge.get_scores(candidate, reference)[0]
        return scores
    except Exception as e:
        print(f"⚠ Error calculando ROUGE: {e}")
        return {
            'rouge-1': {'f': 0.0, 'p': 0.0, 'r': 0.0},
            'rouge-2': {'f': 0.0, 'p': 0.0, 'r': 0.0},
            'rouge-l': {'f': 0.0, 'p': 0.0, 'r': 0.0}
        }


def calculate_meteor(reference, candidate):
    """
    Calcula METEOR score usando NLTK.
    
    Args:
        reference: str - Resumen de referencia
        candidate: str - Resumen generado
        
    Returns:
        float - METEOR score
    """
    if not reference.strip() or not candidate.strip():
        return 0.0
    
    try:
        # METEOR requiere lista de tokens
        reference_tokens = reference.split()
        candidate_tokens = candidate.split()
        
        # METEOR espera una lista de referencias
        score = meteor_score([reference_tokens], candidate_tokens)
        return score
    except Exception as e:
        print(f"⚠ Error calculando METEOR: {e}")
        return 0.0


class Evaluator:
        
    """
    Clase para evaluar el modelo en el dataset de test.
    """
    
    def __init__(self, config, vocab, model_path):
        """
        Args:
            config: Config object
            vocab: Vocabulary object
            model_path: Ruta al checkpoint del modelo
        """
        self.config = config
        self.vocab = vocab
        self.device = config['device']
        
        # Cargar modelo
        print(f"Cargando modelo desde {model_path}")
        self.model = PointerGeneratorNetwork(config, vocab).to(self.device)
        
        checkpoint = torch.load(model_path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.eval()
        
        print(f"✓ Modelo cargado (Epoch {checkpoint['epoch']})")
        print(f"✓ Best Val Loss: {checkpoint.get('best_val_loss', 'N/A')}")
        
        # Beam search
        self.beam_search = BeamSearch(
            self.model,
            vocab,
            beam_size=config['beam_size'],
            max_len=config['tgt_len']
        )
    def _copy_rate(self, candidate_tokens, source_tokens):
            """Porcentaje de palabras del resumen generado que aparecen en el source."""
            if not candidate_tokens or not source_tokens:
                return 0.0
            source_set = set(source_tokens)
            copy_count = sum(1 for w in candidate_tokens if w in source_set)
            return copy_count / len(candidate_tokens)

    def _ngram_overlap(self, candidate_tokens, source_tokens, n=2):
            """Porcentaje de n-gramas del resumen generado que aparecen en el source."""
            if len(candidate_tokens) < n or len(source_tokens) < n:
                return 0.0
            def ngrams(tokens, n):
                return set(tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1))
            cand_ngrams = ngrams(candidate_tokens, n)
            src_ngrams = ngrams(source_tokens, n)
            if not cand_ngrams:
                return 0.0
            overlap = len(cand_ngrams & src_ngrams)
            return overlap / len(cand_ngrams)
    def evaluate(self, test_loader, num_examples=None):
        """
        Evalúa el modelo en el dataset de test.
        
        Args:
            test_loader: DataLoader de test
            num_examples: Número de ejemplos a evaluar (None = todos)
            
        Returns:
            Dict con métricas y resultados
        """
        V_base = len(self.vocab.word_to_id)
        
        # Métricas acumuladas
        rouge1_scores = []
        rouge2_scores = []
        rougeL_scores = []
        meteor_scores = []
        copy_rates = []
        bigram_overlaps = []
        p_gen_avgs = [] # Promedio de p_gen por ejemplo
        
        test_loss = 0.0
        num_batches = 0
        
        results = []
        
        print(f"\n{'='*60}")
        print(f"Evaluando modelo en Test Set")
        print(f"{'='*60}")
        print(f"Estrategia: {self.config['decoding_strategy']}")
        print(f"Beam size: {self.config['beam_size']}")
        print(f"{'='*60}\n")
        
        with torch.no_grad():
            pbar = tqdm(test_loader, desc="Evaluando", total=len(test_loader))
            
            for batch_idx, batch in enumerate(pbar):
                if batch is None:
                    continue
                
                # Calcular loss
                outputs = self.model(batch, is_training=True)
                test_loss += outputs['loss'].item()
                num_batches += 1
                
                batch_size = batch['encoder_input'].size(0)
                
                for b in range(batch_size):
                    # Extraer ejemplo individual
                    single_batch = {
                        'encoder_input': batch['encoder_input'][b:b+1],
                        'extended_encoder_input': batch['extended_encoder_input'][b:b+1],
                        'encoder_length': batch['encoder_length'][b:b+1],
                        'encoder_mask': batch['encoder_mask'][b:b+1],
                        'decoder_target': batch['decoder_target'][b:b+1]
                    }
                    
                    oov_words = batch['oov_words'][b]
                    oov_map = create_oov_id_to_word_map(oov_words, V_base)
                    
                    # Generar resumen
                    if self.config['decoding_strategy'] == 'beam_search':
                        hypothesis = self.beam_search.search(single_batch)
                        generated_ids = hypothesis.tokens[1:]  # Quitar START
                        # Extraer p_gens de la hipótesis (lista de tensores (1,1))
                        p_gens_list = hypothesis.p_gens
                        if p_gens_list:
                            p_gens_tensor = torch.cat(p_gens_list).squeeze() # (tgt_len,)
                        else:
                            p_gens_tensor = torch.tensor([])
                            
                    else:  # greedy
                        generated_ids, p_gens_tensor = self.model.decode_greedy(single_batch, max_len=self.config['tgt_len'])
                        generated_ids = generated_ids[0].cpu().tolist()
                        p_gens_tensor = p_gens_tensor[0].squeeze(-1).cpu() # (max_len,)
                    
                    # Calcular promedio de p_gen para este ejemplo
                    if p_gens_tensor.numel() > 0:
                        avg_p_gen_example = p_gens_tensor.mean().item()
                    else:
                        avg_p_gen_example = 0.0
                        
                    p_gen_avgs.append(avg_p_gen_example)
                    
                    # Decodificar a texto
                    target_ids = single_batch['decoder_target'][0].cpu().tolist()
                    
                    reference = decode_sequence_to_text(target_ids, self.vocab, oov_map)
                    candidate = decode_sequence_to_text(generated_ids, self.vocab, oov_map)
                    
                    # Convertir a string
                    reference_text = ' '.join(reference)
                    candidate_text = ' '.join(candidate)
                    
                    # Calcular ROUGE usando la librería
                    rouge_scores = calculate_rouge_scores(reference_text, candidate_text)
                    # Calcular METEOR
                    meteor = calculate_meteor(reference_text, candidate_text)
                    rouge1_scores.append(rouge_scores['rouge-1']['f'])
                    rouge2_scores.append(rouge_scores['rouge-2']['f'])
                    rougeL_scores.append(rouge_scores['rouge-l']['f'])
                    meteor_scores.append(meteor)
                    # Calcular tasa de copia y bigram overlap
                    src_ids = single_batch['encoder_input'][0].cpu().tolist()
                    source_tokens = decode_sequence_to_text(src_ids, self.vocab, oov_map)
                    copy_rate = self._copy_rate(candidate, source_tokens)
                    bigram_overlap = self._ngram_overlap(candidate, source_tokens, n=2)

                    copy_rates.append(copy_rate)
                    bigram_overlaps.append(bigram_overlap)

                    # Guardar resultado
                    result = {
                        'reference': reference_text,
                        'candidate': candidate_text,
                        'rouge1_f1': rouge_scores['rouge-1']['f'],
                        'rouge2_f1': rouge_scores['rouge-2']['f'],
                        'rougeL_f1': rouge_scores['rouge-l']['f'],
                        'meteor': meteor,
                        'copy_rate': copy_rate,
                        'bigram_overlap': bigram_overlap,
                        'avg_p_gen': avg_p_gen_example,
                        'avg_p_copy': 1.0 - avg_p_gen_example
                    }
                    results.append(result)
                    
                    # Update progress bar
                    pbar.set_postfix({
                        'loss': f"{outputs['loss'].item():.4f}",
                        'R1': f"{np.mean(rouge1_scores):.4f}",
                        'R2': f"{np.mean(rouge2_scores):.4f}",
                        'RL': f"{np.mean(rougeL_scores):.4f}",
                        'M': f"{np.mean(meteor_scores):.4f}",
                        'Copy': f"{np.mean(copy_rates):.2f}",
                        'BiOv': f"{np.mean(bigram_overlaps):.2f}",
                        'P_Gen': f"{np.mean(p_gen_avgs):.2f}"
                    })
                    
                    if num_examples and len(results) >= num_examples:
                        break
                
                if num_examples and len(results) >= num_examples:
                    break
            
            pbar.close()
        
        # Calcular promedios
        avg_test_loss = test_loss / num_batches if num_batches > 0 else 0.0
        avg_rouge1 = np.mean(rouge1_scores) if rouge1_scores else 0.0
        avg_rouge2 = np.mean(rouge2_scores) if rouge2_scores else 0.0
        avg_rougeL = np.mean(rougeL_scores) if rougeL_scores else 0.0
        avg_meteor = np.mean(meteor_scores) if meteor_scores else 0.0
        
        avg_copy_rate = np.mean(copy_rates) if copy_rates else 0.0
        avg_bigram_overlap = np.mean(bigram_overlaps) if bigram_overlaps else 0.0
        avg_p_gen = np.mean(p_gen_avgs) if p_gen_avgs else 0.0
        
        metrics = {
            'test_loss': avg_test_loss,
            'rouge1_f1': avg_rouge1,
            'rouge2_f1': avg_rouge2,
            'rougeL_f1': avg_rougeL,
            'meteor': avg_meteor,
            'copy_rate': avg_copy_rate,
            'bigram_overlap': avg_bigram_overlap,
            'avg_p_gen': avg_p_gen,
            'avg_p_copy': 1.0 - avg_p_gen,
            'num_examples': len(results)
        }
        
        return metrics, results
    
    def save_results(self, metrics, results, output_dir):
        """
        Guarda los resultados de la evaluación.
        
        Args:
            metrics: Dict con métricas
            results: Lista de resultados por ejemplo
            output_dir: Directorio donde guardar
        """
        os.makedirs(output_dir, exist_ok=True)
        
        # Guardar métricas
        metrics_path = os.path.join(output_dir, 'test_metrics.json')
        with open(metrics_path, 'w', encoding='utf-8') as f:
            json.dump(metrics, f, indent=2)
        print(f"\n✓ Métricas guardadas en {metrics_path}")
        
        # Guardar resultados completos
        results_path = os.path.join(output_dir, 'test_results.json')
        with open(results_path, 'w', encoding='utf-8') as f:
            json.dump(results, f, ensure_ascii=False, indent=2)
        print(f"✓ Resultados guardados en {results_path}")
        
        # Guardar formato legible
        txt_path = os.path.join(output_dir, 'test_results.txt')
        with open(txt_path, 'w', encoding='utf-8') as f:
            f.write(f"{'='*60}\n")
            f.write(f"MÉTRICAS DE EVALUACIÓN\n")
            f.write(f"{'='*60}\n")
            f.write(f"Test Loss: {metrics['test_loss']:.4f}\n")
            f.write(f"ROUGE-1 F1: {metrics['rouge1_f1']:.4f}\n")
            f.write(f"ROUGE-2 F1: {metrics['rouge2_f1']:.4f}\n")
            f.write(f"ROUGE-L F1: {metrics['rougeL_f1']:.4f}\n")
            f.write(f"METEOR: {metrics['meteor']:.4f}\n")
            f.write(f"Copy Rate: {metrics['copy_rate']:.4f}\n")
            f.write(f"Bigram Overlap: {metrics['bigram_overlap']:.4f}\n")
            f.write(f"Avg P_Gen: {metrics['avg_p_gen']:.4f}\n")
            f.write(f"Avg P_Copy: {metrics['avg_p_copy']:.4f}\n")
            f.write(f"Ejemplos evaluados: {metrics['num_examples']}\n")
            f.write(f"{'='*60}\n\n")
            
            for i, result in enumerate(results[:10]):  # Primeros 10 ejemplos
                f.write(f"{'='*60}\n")
                f.write(f"Ejemplo {i+1}\n")
                f.write(f"{'='*60}\n")
                f.write(f"REFERENCE:\n{result['reference']}\n\n")
                f.write(f"GENERATED:\n{result['candidate']}\n\n")
                f.write(f"ROUGE-1: {result['rouge1_f1']:.4f} | ")
                f.write(f"ROUGE-2: {result['rouge2_f1']:.4f} | ")
                f.write(f"ROUGE-L: {result['rougeL_f1']:.4f} | ")
                f.write(f"METEOR: {result['meteor']:.4f} | ")
                f.write(f"Copy Rate: {result['copy_rate']:.4f} | ")
                f.write(f"Bigram Overlap: {result['bigram_overlap']:.4f} | ")
                f.write(f"P_Gen: {result['avg_p_gen']:.4f} | ")
                f.write(f"P_Copy: {result['avg_p_copy']:.4f}\n\n")
        
        print(f"✓ Resultados legibles en {txt_path}")


def plot_training_history(output_dir, checkpoint_path):
    """
    Grafica la historia de entrenamiento desde el checkpoint del modelo.
    
    Args:
        output_dir: Directorio donde guardar las gráficas
        checkpoint_path: Ruta al checkpoint del modelo
    """
    if not os.path.exists(checkpoint_path):
        print(f"⚠ No se encontró el checkpoint {checkpoint_path}")
        return
    
    print(f"\nCargando historial de entrenamiento desde {checkpoint_path}...")
    
    # Cargar checkpoint
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    
    # Extraer historial del checkpoint
    history = checkpoint.get('train_history', None)
    
    if not history or not history.get('epoch'):
        print("⚠ El historial está vacío o no existe en el checkpoint")
        return
    
    # Extraer datos
    epochs = history['epoch']
    train_losses = history['train_loss']
    val_losses = history['val_loss']
    
    # Crear figura con 2 subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot 1: Train y Val Loss
    ax1.plot(epochs, train_losses, 'b-', label='Train Loss', linewidth=2)
    ax1.plot(epochs, val_losses, 'r-', label='Val Loss', linewidth=2)
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
    ax1.legend(fontsize=10)
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Vocab Loss y Coverage Loss
    vocab_losses = history.get('vocab_loss', [])
    coverage_losses = history.get('coverage_loss', [])
    
    if vocab_losses and coverage_losses:
        ax2.plot(epochs, vocab_losses, 'g-', label='Vocab Loss', linewidth=2)
        ax2.plot(epochs, coverage_losses, 'r-', label='Coverage Loss', linewidth=2)
        ax2.set_xlabel('Epoch', fontsize=12)
        ax2.set_ylabel('Loss', fontsize=12)
        ax2.set_title('Vocab Loss vs Coverage Loss', fontsize=14, fontweight='bold')
        ax2.legend(fontsize=10)
        ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Guardar figura
    os.makedirs(output_dir, exist_ok=True)
    plot_path = os.path.join(output_dir, 'training_curves.png')
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    print(f"✓ Gráfica guardada en {plot_path}")
    
    # Mostrar estadísticas
    print(f"\n{'='*60}")
    print(f"ESTADÍSTICAS DE ENTRENAMIENTO")
    print(f"{'='*60}")
    print(f"Épocas completadas: {len(epochs)}")
    print(f"Mejor Train Loss: {min(train_losses):.4f} (Epoch {epochs[train_losses.index(min(train_losses))]})")
    print(f"Mejor Val Loss: {min(val_losses):.4f} (Epoch {epochs[val_losses.index(min(val_losses))]})")
    print(f"Última Train Loss: {train_losses[-1]:.4f}")
    print(f"Última Val Loss: {val_losses[-1]:.4f}")
    
    if vocab_losses and coverage_losses:
        print(f"Última Vocab Loss: {vocab_losses[-1]:.4f}")
        print(f"Última Coverage Loss: {coverage_losses[-1]:.4f}")
    
    print(f"{'='*60}\n")
    
    plt.close()


def main():
    """
    Función principal para evaluar el modelo.
    """
    # 1. Cargar vocabulario
    print("Cargando vocabulario...")
    vocab = Vocabulary(
        CREATE_VOCABULARY=False,
        PAD_TOKEN=PAD_TOKEN,
        UNK_TOKEN=UNK_TOKEN,
        START_DECODING=START_DECODING,
        END_DECODING=END_DECODING,
        MAX_VOCAB_SIZE=MAX_VOCAB_SIZE,
        CHECKPOINT_VOCABULARY_DIR=CHECKPOINT_VOCABULARY_DIR,
        DATA_DIR=DATA_DIR,
        VOCAB_NAME=VOCAB_NAME
    )
    vocab.build_vocabulary()
    print(f"✓ Vocabulario cargado: {vocab.total_size()} palabras")
    
    # 2. Configurar
    config = Config(
        max_vocab_size=vocab.total_size(),
        src_len=MAX_LEN_SRC,
        tgt_len=MAX_LEN_TGT,
        embedding_size=EMBEDDING_SIZE,
        hidden_size=HIDDEN_SIZE,
        num_enc_layers=NUM_ENC_LAYERS,
        num_dec_layers=NUM_DEC_LAYERS,
        use_gpu=USE_GPU,
        is_pgen=IS_PGEN,
        is_coverage=IS_COVERAGE,
        coverage_lambda=COV_LOSS_LAMBDA,
        dropout_ratio=DROPOUT_RATIO,
        bidirectional=BIDIRECTIONAL,
        device=DEVICE,
        decoding_strategy=DECODING_STRATEGY,
        beam_size=BEAM_SIZE,
        gpu_id=GPU_ID
    )
    
    # 3. Dataset de test
    print("\nCargando dataset de test...")
    test_dataset = PGNDataset(
        vocab=vocab,
        MAX_LEN_SRC=config['src_len'],
        MAX_LEN_TGT=config['tgt_len'],
        data_dir=DATA_DIR,
        split='test'
    )
    
    print(f"✓ Test: {len(test_dataset)} ejemplos")
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=1,
        shuffle=True,
        collate_fn=pgn_collate_fn,
        num_workers=0
    )
    
    # 4. Ruta del modelo
    model_path = os.path.join(CHECKPOINT_DIR, 'checkpoint_best2.pt')
    
    if not os.path.exists(model_path):
        model_path = os.path.join(CHECKPOINT_DIR, 'checkpoint_last1.pt')
    
    if not os.path.exists(model_path):
        print(f"⚠ No se encontró ningún checkpoint en {CHECKPOINT_DIR}")
        return
    
    output_dir = GENERATED_TEXT_DIR
    
    # 5. Graficar pérdida de entrenamiento (desde el checkpoint)
    plot_training_history(output_dir, model_path)
    
    # 6. Evaluar
    evaluator = Evaluator(config, vocab, model_path)
    
    metrics, results = evaluator.evaluate(
        test_loader,
        num_examples=30  # None = todos
    )
    
    # 7. Mostrar resultados
    print(f"\n{'='*60}")
    print(f"RESULTADOS DE EVALUACIÓN")
    print(f"{'='*60}")
    print(f"Test Loss:   {metrics['test_loss']:.4f}")
    print(f"ROUGE-1 F1:  {metrics['rouge1_f1']:.4f}")
    print(f"ROUGE-2 F1:  {metrics['rouge2_f1']:.4f}")
    print(f"ROUGE-L F1:  {metrics['rougeL_f1']:.4f}")
    print(f"METEOR:      {metrics['meteor']:.4f}")
    print(f"Ejemplos:    {metrics['num_examples']}")
    print(f"{'='*60}\n")
    
    # 8. Guardar resultados
    evaluator.save_results(metrics, results, output_dir)
    
    # 9. Mostrar algunos ejemplos
    print(f"\n{'='*60}")
    print("EJEMPLOS DE RESÚMENES GENERADOS")
    print(f"{'='*60}\n")
    
    for i, result in enumerate(results[:3]):
        print(f"Ejemplo {i+1}:")
        print(f"REFERENCE: {result['reference']}...")
        print(f"GENERATED: {result['candidate']}...")
        print(f"ROUGE-1: {result['rouge1_f1']:.4f} | "
              f"ROUGE-2: {result['rouge2_f1']:.4f} | "
              f"ROUGE-L: {result['rougeL_f1']:.4f} | "
              f"METEOR: {result['meteor']:.4f} | "
              f"P_Gen: {result['avg_p_gen']:.4f}")
        print()


if __name__ == "__main__":
    main()



Cargando vocabulario...
saved/working/Vocabulary.json
 Vocabulario cargado desde: saved/working/Vocabulary.json
 Tamaño total: 50000
 Tokens especiales: 4
 Tokens regulares: 49996
✓ Vocabulario cargado: 50000 palabras

Cargando dataset de test...
✓ Usando archivos TOKENIZADOS para test (Carga optimizada en Kaggle)
✓ Test: 13920 ejemplos

Cargando historial de entrenamiento desde kaggle/working/saved/checkpoint_best2.pt...
✓ Gráfica guardada en kaggle/working/generated/training_curves.png

ESTADÍSTICAS DE ENTRENAMIENTO
Épocas completadas: 6
Mejor Train Loss: 3.0379 (Epoch 6)
Mejor Val Loss: 3.2034 (Epoch 6)
Última Train Loss: 3.0379
Última Val Loss: 3.2034
Última Vocab Loss: 3.0191
Última Coverage Loss: 0.0188

Cargando modelo desde kaggle/working/saved/checkpoint_best2.pt
✓ Modelo cargado (Epoch 6)
✓ Best Val Loss: 3.2033928868211348

Evaluando modelo en Test Set
Estrategia: beam_search
Beam size: 5



Evaluando:   0%|          | 29/13920 [00:56<7:29:07,  1.94s/it, loss=3.3668, R1=0.2413, R2=0.0930, RL=0.2020, M=0.1981, Copy=0.95, BiOv=0.90, P_Gen=0.25]


RESULTADOS DE EVALUACIÓN
Test Loss:   2.7415
ROUGE-1 F1:  0.2413
ROUGE-2 F1:  0.0930
ROUGE-L F1:  0.2020
METEOR:      0.1981
Ejemplos:    30


✓ Métricas guardadas en kaggle/working/generated/test_metrics.json
✓ Resultados guardados en kaggle/working/generated/test_results.json
✓ Resultados legibles en kaggle/working/generated/test_results.txt

EJEMPLOS DE RESÚMENES GENERADOS

Ejemplo 1:
REFERENCE: Si existían sospechas sobre la erosión del orden liberal en el mundo , lo sucedido esta semana en Bruselas permite abandonar cualquier atisbo de duda...
GENERATED: El Consejo es otra cosa los ultras están más coordinados que nunca...
ROUGE-1: 0.0000 | ROUGE-2: 0.0000 | ROUGE-L: 0.0000 | METEOR: 0.0422 | P_Gen: 0.1981

Ejemplo 2:
REFERENCE: Urge el debate de un programa nacional para incorporar la bici a un sistema [UNK] de transporte...
GENERATED: La bici se ha convertido en un signo de distinción , compromiso y modernidad...
ROUGE-1: 0.2069 | ROUGE-2: 0.0000 | ROUGE-L: 0.2069 | METEOR: 0.1




In [6]:
!pip install rouge

Collecting rouge
  Downloading rouge-1.0.1-py3-none-any.whl.metadata (4.1 kB)
Downloading rouge-1.0.1-py3-none-any.whl (13 kB)
Installing collected packages: rouge
Successfully installed rouge-1.0.1


In [None]:
# En tu notebook de Kaggle
import os
from IPython.display import FileLink

# Comprimir archivos específicos
import zipfile

# Listar archivos en /kaggle/working/
archivos = os.listdir('/kaggle/working/kaggle/working/saved/')
print("Archivos disponibles:", archivos)

# Crear zip con outputs importantes
with zipfile.ZipFile('/kaggle/working/outputs.zip', 'w') as zipf:
    for file in archivos:
        if file.endswith('checkpoint_last1.pt'):
            zipf.write(f'/kaggle/working//kaggle/working/saved/{file}', file)

# Crear enlace para descarga
display(FileLink('outputs.zip'))