# Ejemplo Toy de Pointer Network para Sumar Subconjuntos

Este cuaderno explica paso a paso el código del proyecto toy_ptrnet, que implementa un Pointer Network para resolver la tarea de encontrar un subconjunto de índices cuya suma coincide con un valor objetivo. 

**Contenido**:
  - Definición del dataset y DataLoader
  - Arquitectura del modelo y explicación detallada
  - Bucle de entrenamiento con funciones de pérdida combinada
  - Evaluación con métricas y tiempos de inferencia
  - Comentarios sobre mejoras 

## 1. Instalación de dependencias

Se debe de instalar esta librería

```bash
pip install torch
```


In [1]:
%pip install torch

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.3.1 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


## 2. Definición del Dataset 

Generar secuencias de enteros $x$ y subconjuntos de índices cuya suma es $S$.

In [None]:
from torch.utils.data import Dataset, DataLoader
import random, torch

class SubsetSumDataset(Dataset):
    def __init__(self, num_examples, seq_len, max_val, max_subset):
        self.data = []
        for _ in range(num_examples):
            x = [random.randint(1, max_val) for _ in range(seq_len)]
            k = random.randint(1, max_subset)
            idxs = random.sample(range(seq_len), k)
            S = sum(x[i] for i in idxs)
            self.data.append((x, idxs, S))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        x, idxs, S = self.data[i]
        return x, idxs, S

# Función de collate para padding
def collate_fn(batch):
    xs, idxs_list, Ss = zip(*batch)
    B = len(xs)
    seq_len = len(xs[0])
    max_k = max(len(idxs) for idxs in idxs_list)
    x_tensor   = torch.tensor(xs, dtype=torch.long)
    idxs_tensor = torch.zeros(B, max_k, dtype=torch.long)
    for i, idxs in enumerate(idxs_list):
        padded = idxs + [idxs[-1]]*(max_k-len(idxs))
        idxs_tensor[i] = torch.tensor(padded)
    return x_tensor, idxs_tensor, torch.tensor(Ss, dtype=torch.long)

# DataLoader
def get_dataloader(batch_size, num_examples, seq_len, max_val, max_subset, shuffle=True):
    ds = SubsetSumDataset(num_examples, seq_len, max_val, max_subset)
    return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)

- **SubsetSumDataset**: construye las muestras.
- **collate_fn**: agrupa un batch, hace padding de los índices y devuelve tensores.

## 3. Definición del Modelo

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class PointerNetwork(nn.Module):
    """
    Pointer Network toy para sumar subconjuntos.
    Conditioned on S (suma objetivo).
    """

    def __init__(
        self,
        seq_len: int,
        vocab_size: int,
        embed_dim: int,
        hidden_dim: int,
        dropout: float = 0.1,
    ):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.dropout = nn.Dropout(dropout)
        # Encoder bidireccional
        self.encoder = nn.LSTM(
            embed_dim, hidden_dim, batch_first=True, bidirectional=True
        )
        # Inicializar decoder
        self.init_linear = nn.Linear(hidden_dim * 2, hidden_dim)

        self.decoder_cell = nn.LSTMCell(embed_dim, hidden_dim)

        # Atención (pointer)
        self.W1 = nn.Linear(hidden_dim * 2, hidden_dim, bias=False)
        self.W2 = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.v = nn.Linear(hidden_dim, 1, bias=False)

        # Proyección de la suma
        self.sum_proj = nn.Linear(1, embed_dim)

    def forward(self, x: torch.LongTensor, S: torch.LongTensor, max_output_len: int):

        B, seq_len = x.size()
        # Encode
        emb = self.embedding(x)
        emb = self.dropout(emb)
        enc_out, (h_n, c_n) = self.encoder(emb)
        enc_out = self.dropout(enc_out)
        # Init decoder
        h0 = torch.cat([h_n[-2], h_n[-1]], dim=1)
        c0 = torch.cat([c_n[-2], c_n[-1]], dim=1)
        h = self.init_linear(h0)
        c = self.init_linear(c0)
        # sum embedding
        s_emb = self.sum_proj(S.unsqueeze(-1).float())
        inp = self.dropout(s_emb)
        # Decode steps
        ptr_dists = []
        for _ in range(max_output_len):
            h, c = self.decoder_cell(inp, (h, c))
            # Atención como puntero
            w1 = self.W1(enc_out)
            w2 = self.W2(h).unsqueeze(1)
            u = self.v(torch.tanh(w1 + w2)).squeeze(-1)
            a = F.softmax(u, dim=1)
            ptr_dists.append(a)
            # embed pointed element
            idx = a.argmax(dim=1)
            inp = emb.gather(1, idx.view(B, 1, 1).expand(-1, -1, emb.size(2)))
            inp = inp.squeeze(1)
            inp = self.dropout(inp)
        return torch.stack(ptr_dists, dim=1)


- La atención calcula puntuaciones $u_{t,j}$ y $softmax$ sobre posiciones de entrada.
- El decoder apunta a índices en $x$ en cada paso.



## 4. Entrenamiento

In [2]:
import time
import torch
import torch.nn.functional as F
from torch.optim import Adam
from data import get_dataloader
from model import PointerNetwork

# Hiperparámetros
epochs       = 20
batch_size   = 64
seq_len      = 20
max_val      = 20
max_subset   = 3
vocab_size   = max_val+1
embed_dim    = 32
hidden_dim   = 64
lr           = 1e-3
alpha        = 0.1
train_size   = 10000
val_size     = 2000

# DataLoaders
train_loader = get_dataloader(batch_size, train_size, seq_len, max_val, max_subset)
val_loader   = get_dataloader(batch_size, val_size,   seq_len, max_val, max_subset, shuffle=False)

# Modelo + optimizador
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model  = PointerNetwork(seq_len, vocab_size, embed_dim, hidden_dim).to(device)
opt    = Adam(model.parameters(), lr=lr)

for ep in range(1, epochs+1):
    t0 = time.time()
    model.train(); train_loss=0
    for xs, idxs, Ss in train_loader:
        xs, idxs, Ss = xs.to(device), idxs.to(device), Ss.to(device)
        ptr = model(xs, Ss, max_output_len=idxs.size(1))
        # Cross-entropy
        ce = sum(
            F.nll_loss(torch.log(ptr[:,t,:]+1e-8), idxs[:,t])
            for t in range(ptr.size(1))
        ) / ptr.size(1)
        # Error en suma
        expected = (ptr * xs.unsqueeze(1).float()).sum((1,2))
        ls       = F.l1_loss(expected, Ss.float())
        loss     = ce + alpha * ls
        opt.zero_grad(); loss.backward(); opt.step()
        train_loss += loss.item()
    # Validación exact-match
    model.eval(); correct=total=0
    with torch.no_grad():
        for xs, idxs, Ss in val_loader:
            xs, idxs = xs.to(device), idxs.to(device)
            ptr = model(xs, Ss, idxs.size(1))
            pred = ptr.argmax(-1)
            for b in range(xs.size(0)):
                if set(pred[b].tolist()) == set(idxs[b].tolist()):
                    correct+=1
                total+=1
    t1 = time.time()
    print(f"Ep {ep:02d} Loss {train_loss/len(train_loader):.3f} "
          f"ValAcc {correct/total:.2%} Time {t1-t0:.1f}s")

# Guardar modelo
# torch.save(model.state_dict(), 'best_ptrnet.pth')

KeyboardInterrupt: 

5. Evaluación 

In [None]:
import time
import torch
from data import get_dataloader
from model import PointerNetwork

# Parámetros idénticos a train.py
seq_len     = 20
max_val     = 20
max_subset  = 3
vocab_size  = max_val+1
embed_dim   = 32
hidden_dim  = 64
batch_size  = 64
model_path  = 'best_ptrnet.pth'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PointerNetwork(seq_len, vocab_size, embed_dim, hidden_dim).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

test_loader = get_dataloader(batch_size, 2000, seq_len, max_val, max_subset, shuffle=False)
correct=total=0; times=[]
with torch.no_grad():
    for xs, idxs, Ss in test_loader:
        xs, idxs, Ss = xs.to(device), idxs.to(device), Ss.to(device)
        t0 = time.time()
        ptr = model(xs, Ss, idxs.size(1))
        times.append((time.time()-t0)/xs.size(0))
        pred = ptr.argmax(-1)
        for b in range(xs.size(0)):
            if set(pred[b].tolist())==set(idxs[b].tolist()): correct+=1
            total+=1
print(f"Test Acc: {correct/total:.2%}")
print(f"Avg inf time: {sum(times)/len(times)*1000:.2f} ms")