# Predicción de Puntuación y Capitalización en Texto Normalizado

In [None]:
!pip install transformers
from transformers import BertTokenizer, BertModel
import torch
import re
import pandas as pd
import random
from datasets import load_dataset
from torch import nn

RANDOM_SEED = 0

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
bert_model_name = "bert-base-multilingual-cased"
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
bert_model = BertModel.from_pretrained(bert_model_name)

def get_multilingual_token_embedding(token: str):
  """
  Devuelve el embedding (estático) para el token.
  """
  token_id = tokenizer.convert_tokens_to_ids(token)
  if token_id is None or token_id == tokenizer.unk_token_id:
    print(f"❌ El token '{token}' no pertenece al vocabulario de multilingual BERT.")
    return None

  embedding_vector = bert_model.embeddings.word_embeddings.weight[token_id]

  print(f"✅ Token: '{token}' | ID: {token_id}")
  print(f"Embedding shape: {embedding_vector.shape}")
  return embedding_vector

In [13]:
from transformers import BertTokenizerFast
tokenizer_fast = BertTokenizerFast.from_pretrained("bert-base-multilingual-cased")

In [None]:
def capitalizacion_de_palabra(palabra: str) -> int:
    if palabra.islower(): return 0
    elif palabra.istitle(): return 1
    elif palabra.isupper(): return 3
    else: return 2

def tokenizar_con_etiquetas_list(
    oracion: str,
    instancia_id: int,
    tokenizer: BertTokenizerFast
) -> list[dict]:
    """
    Like your old tokenizar_con_etiquetas, but returns a LIST of dicts,
    not a tiny DataFrame.  Much cheaper to extend a Python list.
    """
    tokens_reales = re.findall(r"\w+['’]?\w*|¿|¡|\?|,|\.|!", oracion)
    rows = []
    i = 0
    while i < len(tokens_reales):
        tok = tokens_reales[i]
        # initial‐punct
        if tok in INIT_PUNCT_MAP and tok not in ['','.',',','?','!']:
            i += 1
            if i >= len(tokens_reales):
                break
            init_lbl = INIT_PUNCT_MAP[tokens_reales[i-1]]
            tok = tokens_reales[i]
        else:
            init_lbl = INIT_PUNCT_MAP['']

        # final‐punct
        if i+1 < len(tokens_reales) and tokens_reales[i+1] in FINAL_PUNCT_MAP:
            final_lbl = FINAL_PUNCT_MAP[tokens_reales[i+1]]
            skip_final = True
        else:
            final_lbl = FINAL_PUNCT_MAP['']
            skip_final = False

        # BERT sub‑tokens in batch
        sub_tokens = tokenizer.tokenize(tok.lower())
        cap_lbl = capitalizacion_de_palabra(tok)

        for j, sub in enumerate(sub_tokens):
            rows.append({
                'instancia_id':   instancia_id,
                'token_id':       tokenizer.convert_tokens_to_ids(sub),
                'token':          sub,
                'punt_inicial':   init_lbl   if j==0 else INIT_PUNCT_MAP[''],
                'punt_final':     final_lbl  if j==0 else FINAL_PUNCT_MAP[''],
                'capitalizacion': cap_lbl
            })

        i += 1
        if skip_final:
            i += 1

    return rows

In [None]:
def tokenizar_con_etiquetas(
    oracion: str,
    instancia_id: int,
    tokenizer: BertTokenizerFast
) -> list[dict]:
    """
    Like your old tokenizar_con_etiquetas, but returns a LIST of dicts.
    Much cheaper to extend one big list than to build tiny DataFrames.
    """
    # same INIT_PUNCT_MAP, FINAL_PUNCT_MAP, CAP_MAP, capitalizacion_de_palabra...
    tokens_reales = re.findall(r"\w+['’]?\w*|¿|¡|\?|,|\.|!", oracion)
    rows = []
    i = 0
    while i < len(tokens_reales):
        tok = tokens_reales[i]
        if tok in INIT_PUNCT_MAP and tok not in ['','.',',','?','!']:
            i += 1
            if i >= len(tokens_reales): break
            init_lbl = INIT_PUNCT_MAP[tokens_reales[i-1]]
            tok = tokens_reales[i]
        else:
            init_lbl = INIT_PUNCT_MAP['']

        if i+1 < len(tokens_reales) and tokens_reales[i+1] in FINAL_PUNCT_MAP:
            final_lbl = FINAL_PUNCT_MAP[tokens_reales[i+1]]
            skip_final = True
        else:
            final_lbl = FINAL_PUNCT_MAP['']
            skip_final = False

        sub_tokens = tokenizer.tokenize(tok.lower())
        cap_lbl = capitalizacion_de_palabra(tok)

        for j, sub in enumerate(sub_tokens):
            rows.append({
                'instancia_id':   instancia_id,
                'token_id':       tokenizer.convert_tokens_to_ids(sub),
                'token':          sub,
                'punt_inicial':   init_lbl   if j==0 else INIT_PUNCT_MAP[''],
                'punt_final':     final_lbl  if j==0 else FINAL_PUNCT_MAP[''],
                'capitalizacion': cap_lbl
            })

        i += 1
        if skip_final:
            i += 1

    return rows


In [None]:
INIT_PUNCT_MAP = {'': 0, '¿': 1}
FINAL_PUNCT_MAP = {'': 0, ',': 1, '.': 2, '?': 3}
CAP_MAP = {'lower': 0, 'title': 1, 'mixed': 2, 'upper': 3}
df = tokenizar_con_etiquetas("¿Esperemos a todos los jugadorazos?", 1, tokenizer_fast)
df

In [None]:
def cargar_dataset(
    path: str,
    tokenizer: BertTokenizerFast,
    n_max_oraciones: int | None = None,
    shuffle: bool = False,
    random_seed: int = 0,
    report_every: int = 500_000
) -> pd.DataFrame:
    """
    Reads up to `n_max_oraciones` lines; optionally shuffles them;
    tokenizes/labels in one pass (collecting into a list); logs
    progress every `report_every` lines; then builds a single DataFrame.
    """
    # 1) load raw lines
    with open(path, "r", encoding="utf-8") as f:
        if n_max_oraciones:
            lines = [f.readline() for _ in range(n_max_oraciones)]
        else:
            lines = f.readlines()

    if shuffle:
        import random
        random.seed(random_seed)
        random.shuffle(lines)

    # 2) process all sentences into one big Python list
    all_rows = []
    for idx, sent in enumerate(lines, start=1):
        all_rows.extend(tokenizar_con_etiquetas(sent, idx, tokenizer))
        if idx % report_every == 0:
            print(f"… processed {idx} sentences")

    # 3) one-shot DataFrame construction
    df = pd.DataFrame(all_rows)
    print(f"Done: total sentences = {idx}, total tokens = {len(df)}")
    return df

In [None]:
dataset['punt_final'].value_counts()

In [None]:
from torch.utils.data import Dataset
import torch
import pandas as pd

class CapitalizacionDataset(Dataset):
    def __init__(self, dataset: pd.DataFrame):
        self.dataset = dataset
        self.instance_ids = sorted(dataset["instancia_id"].unique())

    def __len__(self):
        return len(self.instance_ids)

    def __getitem__(self, idx: int):
        inst_id = self.instance_ids[idx]
        row = self.dataset[self.dataset["instancia_id"] == inst_id]

        input_ids = torch.tensor(row["token_id"].tolist(), dtype=torch.long)
        labels_init = torch.tensor(row["punt_inicial"].tolist(), dtype=torch.long)
        labels_final = torch.tensor(row["punt_final"].tolist(), dtype=torch.long)
        labels_cap = torch.tensor(row["capitalizacion"].tolist(), dtype=torch.long)

        return input_ids, labels_init, labels_final, labels_cap


In [None]:
from torch.nn.utils.rnn import pad_sequence

PAD_token = 0
PAD_token_target = 4
MAX_LEN = dataset["instancia_id"].value_counts().max()

def collate_fn(batch):
    """
    batch: lista de tuplas (tokens_tensor, target_tensor)
    MAX_LEN: longitud fija a la que se debe paddear o truncar
    PAD_token: ID usado para el padding
    """
    tokens, targets = [], []
    for token_tensor, target_tensor in batch:
        tokens.append(token_tensor)
        targets.append(target_tensor)

    # Primero paddeamos hasta el más largo del batch
    tokens_padded = pad_sequence(tokens, batch_first=True, padding_value=PAD_token)
    targets_padded = pad_sequence(targets, batch_first=True, padding_value=PAD_token_target)

    # Luego truncamos o paddeamos a longitud fija MAX_LEN
    def pad_or_truncate(tensor, max_len, pad_token):
        if tensor.size(1) > max_len:
            return tensor[:, :max_len]
        elif tensor.size(1) < max_len:
            pad_size = max_len - tensor.size(1)
            padding = torch.full((tensor.size(0), pad_size), pad_token, dtype=torch.long)
            return torch.cat([tensor, padding], dim=1)
        else:
            return tensor

    tokens_fixed = pad_or_truncate(tokens_padded, MAX_LEN, PAD_token)
    targets_fixed = pad_or_truncate(targets_padded, MAX_LEN, PAD_token_target)

    return tokens_fixed, targets_fixed

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from transformers import BertModel, BertTokenizer

# --- Dataset and Collate Function ---
class TokenClassificationDataset(Dataset):
    def __init__(self, data_df):
        """data_df debe contener columnas: instancia_id, token_id, punt_inicial, punt_final, capitalizacion"""
        self.df = data_df
        self.instance_ids = sorted(data_df['instancia_id'].unique())

    def __len__(self):
        return len(self.instance_ids)

    def __getitem__(self, idx):
        inst_id = self.instance_ids[idx]
        subset = self.df[self.df['instancia_id'] == inst_id]
        input_ids = torch.tensor(subset['token_id'].tolist(), dtype=torch.long)
        labels_init = torch.tensor(subset['punt_inicial'].tolist(), dtype=torch.long)
        labels_final = torch.tensor(subset['punt_final'].tolist(), dtype=torch.long)
        labels_cap = torch.tensor(subset['capitalizacion'].tolist(), dtype=torch.long)
        return input_ids, labels_init, labels_final, labels_cap

PAD_IDX = 0
PAD_LABEL_INIT = 0  # ajustar según mapeo de etiquetas
PAD_LABEL_FINAL = 0
PAD_LABEL_CAP = 0
MAX_LEN = 128  # o el que convenga

def collate_batch(batch):
    inputs, l_init, l_final, l_cap = zip(*batch)
    inputs_padded = pad_sequence(inputs, batch_first=True, padding_value=PAD_IDX)
    init_padded = pad_sequence(l_init, batch_first=True, padding_value=PAD_LABEL_INIT)
    final_padded = pad_sequence(l_final, batch_first=True, padding_value=PAD_LABEL_FINAL)
    cap_padded = pad_sequence(l_cap, batch_first=True, padding_value=PAD_LABEL_CAP)
    # truncar o pad a MAX_LEN
    if inputs_padded.size(1) > MAX_LEN:
        inputs_padded = inputs_padded[:, :MAX_LEN]
        init_padded = init_padded[:, :MAX_LEN]
        final_padded = final_padded[:, :MAX_LEN]
        cap_padded = cap_padded[:, :MAX_LEN]
    else:
        pad_size = MAX_LEN - inputs_padded.size(1)
        inputs_padded = nn.functional.pad(inputs_padded, (0, pad_size), value=PAD_IDX)
        init_padded = nn.functional.pad(init_padded, (0, pad_size), value=PAD_LABEL_INIT)
        final_padded = nn.functional.pad(final_padded, (0, pad_size), value=PAD_LABEL_FINAL)
        cap_padded = nn.functional.pad(cap_padded, (0, pad_size), value=PAD_LABEL_CAP)
    return inputs_padded, init_padded, final_padded, cap_padded

# --- Model ---
class BiLSTMMultiHead(nn.Module):
    def __init__(self,
                 bert_model_name: str,
                 hid_dim: int,
                 n_init: int,
                 n_final: int,
                 n_cap: int,
                 n_layers: int = 1,
                 dropout: float = 0.1):
        super().__init__()
        # BERT embeddings as input
        bert = BertModel.from_pretrained(bert_model_name)
        self.embedding = bert.embeddings.word_embeddings
        emb_dim = bert.config.hidden_size
        self.lstm = nn.LSTM(
            input_size=emb_dim,
            hidden_size=hid_dim,
            num_layers=n_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if n_layers > 1 else 0.0,
        )
        self.dropout = nn.Dropout(dropout)
        self.fc_init = nn.Linear(hid_dim * 2, n_init)
        self.fc_final = nn.Linear(hid_dim * 2, n_final)
        self.fc_cap = nn.Linear(hid_dim * 2, n_cap)

    def forward(self, input_ids):
        # input_ids: [batch, seq]
        x = self.embedding(input_ids)  # [batch, seq, emb_dim]
        x, _ = self.lstm(x)            # [batch, seq, hid_dim*2]
        x = self.dropout(x)
        out_init = self.fc_init(x)     # [batch, seq, n_init]
        out_final = self.fc_final(x)   # [batch, seq, n_final]
        out_cap = self.fc_cap(x)       # [batch, seq, n_cap]
        return out_init, out_final, out_cap

# --- Training Loop ---
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for inputs, init_lbl, final_lbl, cap_lbl in dataloader:
        inputs = inputs.to(device)
        init_lbl = init_lbl.to(device)
        final_lbl = final_lbl.to(device)
        cap_lbl = cap_lbl.to(device)

        optimizer.zero_grad()
        logits_init, logits_final, logits_cap = model(inputs)

        # reshape for Loss: (batch*seq, classes)
        bs, seq_len, _ = logits_init.size()
        loss_init = criterion(logits_init.view(-1, logits_init.size(-1)), init_lbl.view(-1))
        loss_final = criterion(logits_final.view(-1, logits_final.size(-1)), final_lbl.view(-1))
        loss_cap = criterion(logits_cap.view(-1, logits_cap.size(-1)), cap_lbl.view(-1))
        loss = loss_init + loss_final + loss_cap
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

# --- Evaluation ---
def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for inputs, init_lbl, final_lbl, cap_lbl in dataloader:
            inputs = inputs.to(device)
            init_lbl = init_lbl.to(device)
            final_lbl = final_lbl.to(device)
            cap_lbl = cap_lbl.to(device)

            logits_init, logits_final, logits_cap = model(inputs)
            loss_init = criterion(logits_init.view(-1, logits_init.size(-1)), init_lbl.view(-1))
            loss_final = criterion(logits_final.view(-1, logits_final.size(-1)), final_lbl.view(-1))
            loss_cap = criterion(logits_cap.view(-1, logits_cap.size(-1)), cap_lbl.view(-1))
            total_loss += (loss_init + loss_final + loss_cap).item()
    return total_loss / len(dataloader)

# --- Inference ---

def predict_sentence(model, tokenizer: BertTokenizer, sentence: str, device):
    model.eval()
    tokens = tokenizer.tokenize(sentence)
    ids = tokenizer.convert_tokens_to_ids(tokens)
    input_ids = torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(device)
    with torch.no_grad():
        logits_init, logits_final, logits_cap = model(input_ids)
    preds_init = torch.argmax(logits_init, dim=-1).squeeze(0).cpu().tolist()
    preds_final = torch.argmax(logits_final, dim=-1).squeeze(0).cpu().tolist()
    preds_cap = torch.argmax(logits_cap, dim=-1).squeeze(0).cpu().tolist()
    return list(zip(tokens, preds_init, preds_final, preds_cap))

# --- Ejemplo de uso ---
# tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
# model = BiLSTMMultiHead(
#     bert_model_name='bert-base-multilingual-cased',
#     hid_dim=256,
#     n_init=5,
#     n_final=5,
#     n_cap=4,
#     n_layers=2,
#     dropout=0.2
# ).to(device)
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# criterion = nn.CrossEntropyLoss(ignore_index=PAD_LABEL_INIT)
# train_loader = DataLoader(train_dataset, batch_size=32, collate_fn=collate_batch)
# val_loader = DataLoader(val_dataset, batch_size=32, collate_fn=collate_batch)
# for epoch in range(10):
#     train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
#     val_loss = evaluate(model, val_loader, criterion, device)
#     print(f"Epoch {epoch} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")


In [None]:
def restore_text(
    model,
    tokenizer: BertTokenizer,
    sentence: str,
    device: torch.device,
    init_map: dict,
    final_map: dict,
    cap_map: dict
) -> str:
    """
    Reconstructs a sentence by applying token-level predictions:
      - init_map:    {label_idx: init_punct_str, …}
      - final_map:   {label_idx: final_punct_str, …}
      - cap_map:     {label_idx: 'lower'|'title'|'mixed'|'upper', …}

    Returns the fully punctuated & capitalized sentence.
    """
    model.eval()
    # 1) tokenize & convert to IDs
    tokens = tokenizer.tokenize(sentence)
    ids = tokenizer.convert_tokens_to_ids(tokens)
    input_ids = torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(device)

    # 2) model forward → get label logits
    with torch.no_grad():
        logits_init, logits_final, logits_cap = model(input_ids)

    # 3) take argmax to get class indices
    preds_init  = torch.argmax(logits_init,  dim=-1).squeeze(0).cpu().tolist()
    preds_final = torch.argmax(logits_final, dim=-1).squeeze(0).cpu().tolist()
    preds_cap   = torch.argmax(logits_cap,   dim=-1).squeeze(0).cpu().tolist()
    print(tokens)
    print(preds_init)
    print(preds_final)
    print(preds_cap)

    # 4) rebuild words, merging wordpieces
    words: list[str] = []
    for token, i_init, i_fin, i_cap in zip(tokens, preds_init, preds_final, preds_cap):
        piece = token
        if piece.startswith("##"):
            # merge onto previous word
            words[-1] += piece[2:]
            continue

        # apply capitalization
        cap_label = cap_map[i_cap]
        if cap_label == "upper":
            piece = piece.upper()
        elif cap_label == "title":
            piece = piece.capitalize()
        # lower or mixed: keep as-is (mixed will contain original casing in training)

        # prepend initial punct, append final punct
        init_sign = init_map.get(i_init, "")
        fin_sign  = final_map.get(i_fin, "")

        words.append(f"{init_sign}{piece}{fin_sign}")

    # 5) join & clean up spaces before punctuation
    sentence_restored = " ".join(words)
    for p in [",", ".", "?", "!", ":", ";"]:
        sentence_restored = sentence_restored.replace(f" {p}", p)

    return sentence_restored


In [None]:
data_path = "es_419_validas.txt"
dataset = cargar_dataset(data_path, random_seed=RANDOM_SEED,tokenizer = tokenizer_fast,n_max_oraciones=100000)
dataset

In [None]:
import torch
from torch.utils.data import DataLoader, random_split
from transformers import BertTokenizer
import torch.nn as nn
import torch.optim as optim

# 1) Prepare dataset + split
full_dataset = CapitalizacionDataset(dataset)  # your DataFrame → Dataset
train_size = int(len(full_dataset) * 0.9)
val_size   = len(full_dataset) - train_size
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(
    train_ds,
    batch_size=32,
    shuffle=True,
    collate_fn=collate_batch,
)
val_loader = DataLoader(
    val_ds,
    batch_size=64,
    shuffle=False,
    collate_fn=collate_batch,
)

# 2) Instantiate model, tokenizer, device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = tokenizer_fast

model = BiLSTMMultiHead(
    bert_model_name="bert-base-multilingual-cased",
    hid_dim=256,
    n_init=len(INIT_PUNCT_MAP),
    n_final=len(FINAL_PUNCT_MAP),
    n_cap=len(CAP_MAP),
    n_layers=3,
    dropout=0.1,
).to(device)

# 3) Loss + optimizer
# We’ll ignore pad‐labels for all three heads
ignore_idx = PAD_LABEL_INIT  # make sure PAD_LABEL_INIT == PAD_LABEL_FINAL == PAD_LABEL_CAP
criterion = nn.CrossEntropyLoss(ignore_index=ignore_idx)
optimizer = optim.Adam(model.parameters(), lr=3e-4)

# 4) Training loop
n_epochs = 3
for epoch in range(1, n_epochs + 1):
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
    val_loss   = evaluate(  model, val_loader,   criterion, device)

    print(f"[Epoch {epoch:02d}] "
          f"Train Loss = {train_loss:.4f}  |  Val Loss = {val_loss:.4f}")

# 5) Save your model
torch.save(model.state_dict(), "bilstm_multihead.pth")


In [None]:
# invert your maps:
init_map  = {v:k for k,v in INIT_PUNCT_MAP.items()}
final_map = {v:k for k,v in FINAL_PUNCT_MAP.items()}
cap_map   = {v:k for k,v in CAP_MAP.items()}

output = restore_text(
    model, tokenizer,
     "estoy podrido de esto por que esta pasando ayuda por favor me estoy muriendo",
    device,
    init_map=init_map,
    final_map=final_map,
    cap_map=cap_map
)
print(output)  # → "arquitectura"


In [None]:
print(init_map)
print(final_map)
print(cap_map)

In [None]:
init_map

# Transformer

In [17]:
from transformers import BertTokenizer, BertModel
import torch
import re
import pandas as pd
import random
from datasets import load_dataset
from torch import nn

RANDOM_SEED = 0
def clean_text(text: str) -> str:
    """
    Lowercases and removes punctuation from the text, returning a "clean" version.
    """
    # remove leading/trailing whitespace
    text = text.strip()
    # lowercase
    text = text.lower()
    # remove punctuation (keep letters, numbers and spaces)
    text = re.sub(r"[^\w\sáéíóúüñÁÉÍÓÚÜÑ]", "", text)
    return text


def cargar_dataset(
    path: str,
    tokenizer: BertTokenizerFast,
    n_max_oraciones: int | None = None,
    shuffle: bool = False,
    random_seed: int = 0,
    report_every: int = 500_000
) -> pd.DataFrame:
    """
    Reads up to `n_max_oraciones` lines; optionally shuffles them;
    tokenizes/labels in one pass (collecting into a list); logs
    progress every `report_every` lines; then returns a DataFrame
    with two columns:
      - raw_sentence: original sentence (with punctuation & casing)
      - clean_sentence: lowercase, punctuation-free version
    """
    # 1) load raw lines
    with open(path, "r", encoding="utf-8") as f:
        if n_max_oraciones:
            raw_lines = [f.readline().rstrip("\n") for _ in range(n_max_oraciones)]
        else:
            raw_lines = [line.rstrip("\n") for line in f]

    # 2) optional shuffle
    if shuffle:
        random.seed(random_seed)
        random.shuffle(raw_lines)

    # 3) build DataFrame rows
    rows = []
    for idx, raw in enumerate(raw_lines, start=1):
        clean = clean_text(raw)
        rows.append({
            'raw_sentence': raw,
            'clean_sentence': clean
        })
        if idx % report_every == 0:
            print(f"… processed {idx} sentences")

    # 4) assemble DataFrame
    df = pd.DataFrame(rows)
    print(f"Done: total sentences = {len(df)}")
    return df

df = cargar_dataset("es_419_validas.txt", tokenizer_fast, n_max_oraciones=150000)
df.head()

Done: total sentences = 150000


Unnamed: 0,raw_sentence,clean_sentence
0,Te mostraré los resultados.,te mostraré los resultados
1,Me permite hablar por los muertos.,me permite hablar por los muertos
2,Serás Margaret Penobscott.,serás margaret penobscott
3,Somos compañeras de clase de Mariko.,somos compañeras de clase de mariko
4,Pasado mañana.,pasado mañana


In [10]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
# import argparse # Remove argparse as it's not needed in the notebook context

def set_seed(seed: int = 42):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

class TransformerAutoencoder(nn.Module):
    def __init__(self,
                 vocab_size: int,
                 d_model: int = 256,
                 nhead: int = 4,
                 num_encoder_layers: int = 3,
                 num_decoder_layers: int = 3,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1,
                 pad_token_id: int = 0,
                 max_seq_len: int = 128):
        super().__init__()
        self.pad_token_id = pad_token_id
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
        self.positional_encoding = nn.Parameter(
            torch.zeros(max_seq_len, d_model), requires_grad=True
        )
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.output_layer = nn.Linear(d_model, vocab_size)

    def forward(self,
                src_ids: torch.LongTensor,
                tgt_ids: torch.LongTensor,
                src_key_padding_mask: torch.BoolTensor = None,
                tgt_key_padding_mask: torch.BoolTensor = None,
                tgt_mask: torch.Tensor = None):
        # Embedding + positional
        seq_len_src = src_ids.size(1)
        seq_len_tgt = tgt_ids.size(1)
        src_emb = self.embedding(src_ids) + self.positional_encoding[:seq_len_src]
        tgt_emb = self.embedding(tgt_ids) + self.positional_encoding[:seq_len_tgt]

        # Transformer
        out = self.transformer(
            src=src_emb,
            tgt=tgt_emb,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask
        )
        # Final projection
        return self.output_layer(out)

class RestorationDataset(Dataset):
    def __init__(self, examples, tokenizer, max_len=64):
        self.examples = examples
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        src = self.examples[idx][0]
        tgt = self.examples[idx][1]
        # Encode source
        src_enc = self.tokenizer.encode_plus(
            src,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        # Encode target
        tgt_enc = self.tokenizer.encode_plus(
            tgt,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        src_ids = src_enc.input_ids.squeeze(0)
        src_mask = src_enc.attention_mask.squeeze(0).bool()
        full_tgt_ids = tgt_enc.input_ids.squeeze(0)
        # decoder inputs and labels
        decoder_input_ids = full_tgt_ids[:-1]
        decoder_attention_mask = tgt_enc.attention_mask.squeeze(0)[:-1].bool()
        labels = full_tgt_ids[1:].clone()
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            'src_ids': src_ids,
            'src_mask': ~src_mask,
            'tgt_ids': decoder_input_ids,
            'tgt_mask': ~decoder_attention_mask,
            'labels': labels
        }

def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
    mask = torch.triu(torch.ones(sz, sz), diagonal=1).bool()
    return mask


def train(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for batch in dataloader:
        optimizer.zero_grad()
        src_ids = batch['src_ids'].to(device)
        tgt_ids = batch['tgt_ids'].to(device)
        src_key_mask = batch['src_mask'].to(device)
        tgt_key_mask = batch['tgt_mask'].to(device)
        labels = batch['labels'].to(device)

        tgt_mask = generate_square_subsequent_mask(tgt_ids.size(1)).to(device)
        outputs = model(
            src_ids=src_ids,
            tgt_ids=tgt_ids,
            src_key_padding_mask=src_key_mask,
            tgt_key_padding_mask=tgt_key_mask,
            tgt_mask=tgt_mask
        )
        # outputs shape: (B, T, V)
        loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)


def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in dataloader:
            src_ids = batch['src_ids'].to(device)
            tgt_ids = batch['tgt_ids'].to(device)
            src_key_mask = batch['src_mask'].to(device)
            tgt_key_mask = batch['tgt_mask'].to(device)
            labels = batch['labels'].to(device)

            tgt_mask = generate_square_subsequent_mask(tgt_ids.size(1)).to(device)
            outputs = model(
                src_ids=src_ids,
                tgt_ids=tgt_ids,
                src_key_padding_mask=src_key_mask,
                tgt_key_padding_mask=tgt_key_mask,
                tgt_mask=tgt_mask
            )
            loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
            total_loss += loss.item()
    return total_loss / len(dataloader)


def greedy_decode(model, tokenizer, src_sentence: str, device, max_len: int = 64):
    model.eval()
    with torch.no_grad():
        src_enc = tokenizer.encode_plus(
            src_sentence,
            add_special_tokens=True,
            max_length=max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        src_ids = src_enc.input_ids.to(device)
        # Ensure src_mask is 2D by unsqueezing the batch dimension if it's 1D
        src_mask = ~src_enc.attention_mask.bool().to(device)
        if src_mask.dim() == 1:
            src_mask = src_mask.unsqueeze(0)

        # Encode memory
        seq_len_src = src_ids.size(1)
        memory = model.embedding(src_ids) + model.positional_encoding[:seq_len_src]
        memory = model.transformer.encoder(memory, src_key_padding_mask=src_mask)

        ys = torch.full((1, 1), tokenizer.cls_token_id, dtype=torch.long).to(device)
        for i in range(max_len - 1):
            tgt_mask = generate_square_subsequent_mask(ys.size(1)).to(device)
            tgt_emb = model.embedding(ys) + model.positional_encoding[:ys.size(1)]
            out = model.transformer.decoder(
                tgt=tgt_emb,
                memory=memory,
                tgt_mask=tgt_mask,
                memory_key_padding_mask=src_mask # memory_key_padding_mask should also be 2D
            )
            logits = model.output_layer(out[:, -1, :])
            next_token = logits.argmax(dim=-1).unsqueeze(1)
            ys = torch.cat([ys, next_token], dim=1)
            if next_token.item() == tokenizer.sep_token_id:
                break
        return tokenizer.decode(ys.squeeze(), skip_special_tokens=True)


# Remove the if __name__ == '__main__': block and argparse calls
# Replace it with direct argument setting and function call
class Args: # Simple class to mimic argparse Namespace
    epochs: int = 5
    batch_size: int = 8
    lr: float = 5e-4
    max_len: int = 64
    save_path: str = 'best_model.pt'

args = Args()

# Call the main function with the defined arguments
main(args) # Commenting out the main function call to avoid running it automatically

Epoch 1 | Train Loss: 11.6433 | Val Loss: 10.0844
Epoch 2 | Train Loss: 10.0698 | Val Loss: 9.0064
Epoch 3 | Train Loss: 9.2176 | Val Loss: 8.2967
Epoch 4 | Train Loss: 8.8104 | Val Loss: 7.7722
Epoch 5 | Train Loss: 8.3723 | Val Loss: 7.3332
Input:    hola como estas
Restored: ##lalalalala Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como Como


In [19]:
def main(epochs,batch_size,lr,max_len,save_path):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased')

    df = cargar_dataset("es_419_validas.txt", tokenizer_fast, n_max_oraciones=150000)
    dataset = RestorationDataset(df, tokenizer, max_len=args.max_len)
    train_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
    val_loader   = DataLoader(dataset, batch_size=args.batch_size)

    model = TransformerAutoencoder(
        vocab_size=tokenizer.vocab_size,
        pad_token_id=tokenizer.pad_token_id,
        max_seq_len=args.max_len
    ).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    criterion = nn.CrossEntropyLoss(ignore_index=-100)

    best_val = float('inf')
    for epoch in range(1, args.epochs+1):
        train_loss = train(model, train_loader, optimizer, criterion, device)
        val_loss = evaluate(model, val_loader, criterion, device)
        print(f"Epoch {epoch}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")
        if val_loss < best_val:
            best_val = val_loss
            torch.save(model.state_dict(), args.save_path)

    # Demo inference
    test_sent = "hola como estas"
    print("Input:   ", test_sent)
    print("Restored:", greedy_decode(model, tokenizer, test_sent, device, args.max_len))

if __name__ == '__main__':
    main(3,32,5e-4,64,"transformer_autoencoder.pt")

Done: total sentences = 150000


KeyError: 30565