In [1]:
# %pip install torch transformers sentencepiece

In [2]:
import math 
import random 
from typing import List, Tuple
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
print(torch.cuda.is_available())

True


In [4]:
class SyntheticDataset(Dataset):
    """Dataset sintetis: setiap item adalah list of sentences (shuffled) dan target order.
    Digunakan untuk contoh/training cepat.
    """
    def __init__(self, n_examples=1000, min_sents=3, max_sents=8, vocab=None):
        super().__init__()
        self.examples = []
        for _ in range(n_examples):
            n = random.randint(min_sents, max_sents)
            sents_dict = {i: f"Ini kalimat ke {i} yang berisi contoh informasi." for i in range(n)}
            
            shuffled_items = list(sents_dict.items())
            random.shuffle(shuffled_items)

            shuffled_sents = [sent for _, sent in shuffled_items]   # texts
            correct_order = [key for key, _ in shuffled_items]
            self.examples.append((shuffled_sents, correct_order))

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        sents, order = self.examples[idx]
        return sents, order

In [5]:
data = SyntheticDataset(10)
data.__getitem__(0)

(['Ini kalimat ke 0 yang berisi contoh informasi.',
  'Ini kalimat ke 2 yang berisi contoh informasi.',
  'Ini kalimat ke 1 yang berisi contoh informasi.',
  'Ini kalimat ke 3 yang berisi contoh informasi.'],
 [0, 2, 1, 3])

In [6]:
def collate_fn(batch, tokenizer, device, max_length=64):
    """Batch sentences and compute bert embeddings later.
    Returns:
      - batch_sentences: list[list[str]] (batch)
      - lengths: list[int]
      - target_orders: Tensor(batch, max_len) padded with -1 for unused positions
    """
    batch_sents = [item[0] for item in batch]
    batch_orders = [item[1] for item in batch]
    lengths = [len(x) for x in batch_sents]
    max_len = max(lengths)
    # pad target orders with -1
    padded_orders = torch.full((len(batch), max_len), -1, dtype=torch.long)
    for i, ords in enumerate(batch_orders):
        padded_orders[i, :len(ords)] = torch.tensor(ords, dtype=torch.long)
    return batch_sents, lengths, padded_orders

In [7]:
class SentenceEmbedder:
    def __init__(self, model_name='indobenchmark/indobert-base-p1', device='cpu', freeze=True):
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(device)
        if freeze:
            for p in self.model.parameters():
                p.requires_grad = False

    def embed_sentences(self, batch_sentences: List[List[str]], batch_size=16, max_length=64) -> torch.Tensor:
        """
        Input: list of N examples, each example is list of M_i sentences.
        Return: embeddings of shape (N, M, D) where M = max M_i (padded with zeros)
        """
        # flatten all sentences in batch -> do batched encoding
        flat_sents = []
        lengths = [len(x) for x in batch_sentences]
        for sents in batch_sentences:
            flat_sents.extend(sents)

        # tokenize in batches
        encodings = self.tokenizer(flat_sents, padding=True, truncation=True,
                                   max_length=max_length, return_tensors='pt')
        encodings = {k: v.to(self.device) for k, v in encodings.items()}

        with torch.no_grad():
            outputs = self.model(**encodings)
            # use CLS token representation (first token)
            # outputs.last_hidden_state: (total_sents, seq_len, hidden)
            cls_embeds = outputs.last_hidden_state[:, 0, :]

        D = cls_embeds.size(-1)
        # reshape back to (batch, M, D)
        max_M = max(lengths)
        batch_size = len(batch_sentences)
        embeds = torch.zeros((batch_size, max_M, D), device=self.device)
        idx = 0
        for i, L in enumerate(lengths):
            if L > 0:
                embeds[i, :L, :] = cls_embeds[idx: idx + L]
            idx += L
        return embeds

In [8]:
class PointerNetDecoder(nn.Module):
    """Pointer network decoder. At each step points to one of encoder steps.
    Based on: Vinyals et al. (2015)
    """
    def __init__(self, enc_hidden_dim, dec_hidden_dim):
        super().__init__()
        self.dec_rnn = nn.LSTMCell(enc_hidden_dim, dec_hidden_dim)
        self.W_ref = nn.Linear(enc_hidden_dim, dec_hidden_dim, bias=False)
        self.W_q = nn.Linear(dec_hidden_dim, dec_hidden_dim, bias=False)
        self.v = nn.Linear(dec_hidden_dim, 1, bias=False)

    def forward(self, encoder_outputs, mask=None, lengths=None, max_decode_len=None, teacher_forcing=None):
        # encoder_outputs: (B, M, H_e)
        B, M, H_e = encoder_outputs.size()
        device = encoder_outputs.device

        if mask is None:
            mask = torch.zeros(B, M, dtype=torch.bool, device=device)
        if lengths is None:
            # try infer from mask (count non-masked positions)
            lengths = (~mask).sum(dim=1)
        # ensure lengths is tensor on correct device
        lengths = torch.as_tensor(lengths, dtype=torch.long, device=device)

        if max_decode_len is None:
            max_decode_len = M

        ref = encoder_outputs
        ref_proj = self.W_ref(ref)  # (B, M, D)

        # initial decoder state: mean of encoder outputs (projected if needed)
        h = encoder_outputs.mean(dim=1)  # (B, H_e)
        # ensure correct hidden size for LSTMCell (dec_hidden_dim)
        # if dec_hidden_dim != H_e, you should map here. (in your setup they match)
        c = torch.zeros(B, h.size(1), device=device)

        logits_seq = []
        pointers = []
        avail_mask = ~mask.clone()  # True means available

        # we'll use a simple input vector for the LSTMCell each step:
        # zeros of size enc_hidden_dim (input_size)
        input_t = torch.zeros(B, H_e, device=device)

        for t in range(max_decode_len):
            # feed previous step's input (or zeros for simplicity)
            h, c = self.dec_rnn(input_t, (h, c))  # (B, dec_hidden_dim)

            q = self.W_q(h).unsqueeze(1)  # (B,1,D)
            scores = self.v(torch.tanh(ref_proj + q)).squeeze(-1)  # (B, M)

            # mask out already chosen or padding positions
            scores = scores.masked_fill(~avail_mask, float('-inf'))

            logits_seq.append(scores)
            probs = F.log_softmax(scores, dim=-1)

            if teacher_forcing is not None:
                # teacher_forcing: (B, T) with -1 for padded decode steps
                raw_idx = teacher_forcing[:, t]  # may contain -1
                # For safety, clamp to valid range (but don't mark avail_mask for padded sequences)
                # valid_step mask: whether this decode step exists for each example
                valid_step = (t < lengths).to(device)  # (B,)
                # make a safe idx: clamp negative to 0 (or any valid index)
                safe_idx = torch.clamp(raw_idx, 0, M-1).to(device)
                # But we must only use safe_idx where valid_step True; otherwise ignore updates
                idx = safe_idx
            else:
                idx = probs.exp().multinomial(1).squeeze(-1)  # (B,)

            # update avail_mask only for examples where this decode step is valid
            valid_step = (t < lengths).to(device)
            # build one_hot only for valid examples to avoid out-of-range errors
            # create one_hot for all, but mask later when updating avail_mask
            # print("t", t, "idx min/max", idx.min().item(), idx.max().item(), "lengths", lengths)
            one_hot = F.one_hot(idx.clamp(0, M-1), num_classes=M).bool()  # (B, M)
            # zero-out rows corresponding to invalid steps so avail_mask not changed
            one_hot = one_hot & valid_step.unsqueeze(1)

            avail_mask = avail_mask & (~one_hot)
            pointers.append(idx)

        logits = torch.stack(logits_seq, dim=1)  # (B, T, M)
        pointers = torch.stack(pointers, dim=1)  # (B, T)
        return logits, pointers

In [9]:
class SONModel(nn.Module):
    def __init__(self, enc_input_dim, enc_hidden_dim=512, dec_hidden_dim=512, bidirectional=True, dropout=0.1):
        super().__init__()
        self.bi = 2 if bidirectional else 1
        self.enc_lstm = nn.LSTM(enc_input_dim, enc_hidden_dim, batch_first=True, bidirectional=bidirectional)
        self.reduce_dim = nn.Linear(enc_hidden_dim * self.bi, enc_hidden_dim)
        self.decoder = PointerNetDecoder(enc_hidden_dim, dec_hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, sent_embs, lengths, teacher_forcing=None):
        # sent_embs: (B, M, D)
        lengths = torch.as_tensor(lengths, dtype=torch.long, device=sent_embs.device)
        packed = nn.utils.rnn.pack_padded_sequence(sent_embs, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_out, _ = self.enc_lstm(packed)
        enc_out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
        # enc_out: (B, M, enc_hidden_dim * num_directions)
        enc_out = self.reduce_dim(enc_out)  # map to enc_hidden_dim
        enc_out = torch.tanh(enc_out)
        enc_out = self.dropout(enc_out)

        max_len = enc_out.size(1)
        seq_range = torch.arange(max_len, device=enc_out.device).unsqueeze(0)
        mask = seq_range >= lengths.unsqueeze(1)  # True for padded positions

        # pass lengths into decoder so it can know when each example should stop decoding
        logits, pointers = self.decoder(enc_out, mask=mask, lengths=lengths, max_decode_len=max_len, teacher_forcing=teacher_forcing)
        return logits, pointers

In [10]:
def pointer_loss(logits, targets, lengths):
    """
    logits: (B, T, M) logit scores
    targets: (B, T) indices (if padded target is -1, ignore)
    lengths: list
    """
    B, T, M = logits.size()
    loss = 0.0
    total = 0
    for i in range(B):
        L = lengths[i]
        if L == 0:
            continue
        valid_T = L
        logp = F.log_softmax(logits[i, :valid_T], dim=-1)
        # gather log-probs of correct indices
        tgt = targets[i, :valid_T]
        loss_i = -logp[range(valid_T), tgt].sum()
        loss += loss_i
        total += valid_T
    return loss / max(1, total)

In [15]:
def train_example(device='cpu'):
    # config
    MODEL_NAME = 'indobenchmark/indobert-base-p1'  # ganti kalau mau
    device = torch.device(device)

    # data
    dataset = SyntheticDataset(n_examples=500)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True,
                            collate_fn=lambda b: collate_fn(b, tokenizer, device))

    # embedder
    embedder = SentenceEmbedder(model_name=MODEL_NAME, device=device, freeze=True)
    # get embedding dim by running a small batch
    sample_batch = [dataset[0][0]]
    emb = embedder.embed_sentences(sample_batch)
    D = emb.size(-1)

    # model
    model = SONModel(enc_input_dim=D).to(device)
    optim = torch.optim.AdamW(model.parameters(), lr=1e-4)

    model.train()
    for epoch in range(30):
        total_loss = 0.0
        for batch_sents, lengths, padded_orders in dataloader:
            # embed sentences
            sent_embs = embedder.embed_sentences(batch_sents)
            sent_embs = sent_embs.to(device)
            padded_orders = padded_orders.to(device)
            lengths = lengths

            # teacher forcing: we feed the *correct* order indices as decoder input
            # For pointer net we need for each step the index of the sentence to point to.
            # Here teacher_forcing is just the target sequence (original order indices)
            teacher = padded_orders.clone()

            logits, preds = model.forward(sent_embs, lengths, teacher_forcing=teacher)
            loss = pointer_loss(logits, teacher, lengths)

            optim.zero_grad()
            loss.backward()
            optim.step()

            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")

    # testing on one example
    model.eval()
    ex_sents, _ = dataset[1]
    print("\nExample shuffled input:\n", ex_sents)
    with torch.no_grad():
        emb = embedder.embed_sentences([ex_sents]).to(device)
        logits, preds = model(emb, [len(ex_sents)], teacher_forcing=None)
        print("Predicted pointer order:", preds[0].tolist())

In [16]:
!nvidia-smi

Sat Nov  8 14:02:25 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 566.07                 Driver Version: 566.07         CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3050 ...  WDDM  |   00000000:01:00.0 Off |                  N/A |
| N/A   55C    P8              9W /   40W |     841MiB /   4096MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [17]:
if __name__ == '__main__':
    # ubah device ke 'cuda' jika tersedia
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    train_example(device=device)

Epoch 1, Loss: 1.0469
Epoch 2, Loss: 1.0245
Epoch 3, Loss: 1.0089
Epoch 4, Loss: 1.0002
Epoch 5, Loss: 1.0009
Epoch 6, Loss: 0.9704
Epoch 7, Loss: 0.9257
Epoch 8, Loss: 0.8894
Epoch 9, Loss: 0.8383
Epoch 10, Loss: 0.8095
Epoch 11, Loss: 0.8041
Epoch 12, Loss: 0.7668
Epoch 13, Loss: 0.7209
Epoch 14, Loss: 0.6950
Epoch 15, Loss: 0.6612
Epoch 16, Loss: 0.6228
Epoch 17, Loss: 0.5832
Epoch 18, Loss: 0.5647
Epoch 19, Loss: 0.5468
Epoch 20, Loss: 0.5105
Epoch 21, Loss: 0.4937
Epoch 22, Loss: 0.4793
Epoch 23, Loss: 0.4478
Epoch 24, Loss: 0.4387
Epoch 25, Loss: 0.4196
Epoch 26, Loss: 0.4090
Epoch 27, Loss: 0.4168
Epoch 28, Loss: 0.3852
Epoch 29, Loss: 0.3759
Epoch 30, Loss: 0.3638

Example shuffled input:
 ['Ini kalimat ke 3 yang berisi contoh informasi.', 'Ini kalimat ke 4 yang berisi contoh informasi.', 'Ini kalimat ke 2 yang berisi contoh informasi.', 'Ini kalimat ke 1 yang berisi contoh informasi.', 'Ini kalimat ke 5 yang berisi contoh informasi.', 'Ini kalimat ke 0 yang berisi contoh infor

In [14]:
# # after training loop
# torch.save(model.state_dict(), "son_model.pt")
# print("âœ… Model saved as son_model.pt")
