# RNN-based Seq2Seq Model for Sentence Disambiguation
## Introduction and Problem Overview
Sentence disambiguation can be framed as a **paraphrase generation** task: we take an ambiguous sentence and generate a rephrased version that resolves ambiguities (lexical, structural, referential) while preserving the original meaning.

We will implement a **recurrent neural network (RNN)** based encoder-decoder (seq2seq) model in PyTorch, largely from scratch (no high-level seq2seq libraries). Our design emphasizes:
- **Minimal external dependencies:** We'll use only PyTorch and Python standard libraries, writing our own tokenizer, data pipeline, and network modules.
- **Flexibility in rephrasing:** The model is not constrained to copy input tokens exactly; it can learn to produce different words or reorder phrases to resolve ambiguity.
- **Expressivity for disambiguation:** We use an architecture (with choices like LSTM units and an attention mechanism) capable of capturing context and meaning needed to handle lexical choice, structural reordering, and pronoun resolution.
- **Scientific rigor:** Each design choice (embedding size, hidden layer type/size, attention, etc.) is justified with reference to established research or best practices.
- **Device adaptability:** The implementation will automatically use GPU if available, falling back to CPU gracefully.
- **Consistency with provided preprocessing:** We will follow the same tokenization and vocabulary construction approach as in the provided `preprocessing.ipynb/vocab_lookup`, ensuring our data pipeline (e.g. handling of special tokens, casing, and underscores) matches the intended setup.


#### for detailed information on the architecture refer to `rnn.ipynb`


In [6]:
import re
import torch
import pandas as pd
import ast
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def tokenize(text):
    tokens = re.findall(r"\w+|[^\w\s]", text, flags=re.UNICODE)
    return [t.lower() for t in tokens]


# --- 1) Load raw vocab.csv (columns: 'primary', 'secondary') and flatten all tokens ---
vocab_df = pd.read_csv('data/vocab.csv')
raw_tokens = set()

for col in ['primary', 'secondary']:
    for cell in vocab_df[col].dropna():
        # each cell is a string like "['and','but',...']"
        try:
            lst = ast.literal_eval(cell)
            if isinstance(lst, list):
                raw_tokens.update(lst)
        except Exception:
            pass

# --- 1a) Define & prepend special tokens ---
specials = ['<PAD>', '<SOS>', '<EOS>', '<UNK>']
# Ensure no overlap:
for s in specials:
    raw_tokens.discard(s)

# Final ordered vocab: specials first, then sorted rest
all_tokens = specials + sorted(raw_tokens)

# --- 1b) Rebuild mappings from scratch ---
word2index = {tok: idx for idx, tok in enumerate(all_tokens)}
index2word = {idx: tok for tok, idx in word2index.items()}
vocab_size = len(all_tokens)

print(f"Prepended specials: {specials}")
print(f"Total unique tokens from CSV: {len(raw_tokens)}")
print(f"New vocab size = {vocab_size}")




# --- 6) Define Dataset class ---
class FinalDataset(Dataset):
    def __init__(self, csv_path, w2i, tokenizer):
        df = pd.read_csv(csv_path)
        self.src_prefixes = df['source_prefix'].astype(str).tolist()
        self.prev_tgts    = df['prev_target'].astype(str).tolist()
        self.next_tgts    = df['target_word'].astype(str).tolist()
        self.w2i = w2i
        self.tokenize = tokenizer

    def __len__(self):
        return len(self.src_prefixes)

    def __getitem__(self, idx):
        src = self.tokenize(self.src_prefixes[idx])
        prev = self.prev_tgts[idx]
        nxt  = self.next_tgts[idx]
        src_ids  = [ self.w2i.get(t, self.w2i['<UNK>']) for t in src ]
        prev_id  = self.w2i.get(prev, self.w2i['<UNK>'])
        next_id  = self.w2i.get(nxt,  self.w2i['<UNK>'])
        return torch.tensor(src_ids, dtype=torch.long), prev_id, next_id
# --- 7) Define collate_fn for padding sequences ---
def collate_fn(batch):
    src_seqs, prev_ids, next_ids = zip(*batch)
    src_lens = [len(s) for s in src_seqs]
    src_pad  = pad_sequence(src_seqs, batch_first=True, padding_value=word2index['<PAD>'])
    return (
        src_pad,
        torch.tensor(src_lens, dtype=torch.long),
        torch.tensor(prev_ids, dtype=torch.long),
        torch.tensor(next_ids, dtype=torch.long)
    )

batch_size = 128
K = 3  

# 8) Load & split
dataset2 = FinalDataset('data/final_dataset_2.csv', word2index, tokenize)
val_size = int(0.1 * len(dataset2))
train_size = len(dataset2) - val_size
train_ds, val_ds = random_split(dataset2, [train_size, val_size],
                                 generator=torch.Generator().manual_seed(SEED))

train_loader2 = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  collate_fn=collate_fn)
val_loader2   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

print(f"Train examples: {len(train_ds)}, Val examples: {len(val_ds)} on device {device}")

# --- 8) Build DataLoader ---
dataset = FinalDataset('data/final_dataset_2.csv', word2index, tokenize)
train_loader = DataLoader(dataset, batch_size=128, shuffle=True, collate_fn=collate_fn)

print(f"Vocab size = {vocab_size}, #examples = {len(dataset)}")

# --- 9) Model definition ---
# --- 1) Device setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# --- 1) Device setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# --- 2) Model definitions (reuse classes you already have) ---
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=word2index['<PAD>'])
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
    def forward(self, x):
        emb = self.embedding(x)                    # (B, L_in, E)
        outputs, hidden = self.lstm(emb)           # outputs=(B,L_in,H), hidden=(h_n,c_n)
        return outputs, hidden

class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=word2index['<PAD>'])
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
        self.out  = nn.Linear(hidden_size*2, vocab_size)
    def forward_step(self, prev_tok, hidden, enc_outputs):
        emb = self.embedding(prev_tok).unsqueeze(1)     # (B,1,E)
        out, hidden = self.lstm(emb, hidden)            # out=(B,1,H)
        # dot-product attention
        scores = torch.bmm(out, enc_outputs.transpose(1,2))  # (B,1,L_in)
        attn  = torch.softmax(scores, dim=2)                # (B,1,L_in)
        ctx   = torch.bmm(attn, enc_outputs).squeeze(1)     # (B,H)
        out_t = out.squeeze(1)                              # (B,H)
        cat   = torch.cat([out_t, ctx], dim=1)              # (B,2H)
        logits= self.out(cat)                               # (B,V)
        return logits, hidden, attn






Prepended specials: ['<PAD>', '<SOS>', '<EOS>', '<UNK>']
Total unique tokens from CSV: 483
New vocab size = 487
Train examples: 1103032, Val examples: 122559 on device cpu
Vocab size = 487, #examples = 1225591
Using device: cpu
Using device: cpu


We train the same encoder–decoder with attention on these prefix examples.  For each example \(i\) with source prefix \(x_{1:r_i}\) and previous target \(y_{t-1}^{(i)}\):

1. **Encode prefix**  
   $$
     (h_{1:r_i}, c_{1:r_i})
     = \mathrm{Encoder}(x_{1:r_i}).
   $$

2. **Single-step decode**  
   $$
     \tilde{h}_t
     = \mathrm{DecoderStep}\bigl(y_{t-1},\,h_{r_i},\,c_{r_i},\,h_{1:r_i}\bigr).
   $$

3. **Dot-product attention**  
   $$
     \alpha_j
     = \frac{\exp(\tilde{h}_t^\top h_j)}
            {\sum_{k=1}^{r_i}\exp(\tilde{h}_t^\top h_k)},\quad
     c_t = \sum_{j=1}^{r_i}\alpha_j\,h_j.
   $$

4. **Prediction & loss**  
   $$
     \hat y_t
     = \arg\max\mathrm{Softmax}\bigl(W[\tilde{h}_t; c_t]+b\bigr),\quad
     \mathcal{L}
     = -\sum_i \log p_\theta\bigl(y_t^{(i)}\mid x_{1:r_i}^{(i)},\,y_{<t}^{(i)}\bigr).
   $$

We backpropagate this cross-entropy loss and update all model parameters.

In [7]:
import torch.nn as nn
from tqdm.notebook import tqdm

# 1) Hyperparameters    
embed_size = 400
hidden_size = 128
num_epochs = 10
lr = 1e-3

# 2) Model instantiation (reuse your Encoder/Decoder classes)
encoder = Encoder(vocab_size, embed_size, hidden_size).to(device)
decoder = Decoder(vocab_size, embed_size, hidden_size).to(device)

# 3) Optimizer & loss
params    = list(encoder.parameters()) + list(decoder.parameters())
optimizer = torch.optim.Adam(params, lr=lr)
criterion = nn.CrossEntropyLoss(ignore_index=word2index['<PAD>'])

# 4) Training & Validation
for epoch in range(1, num_epochs+1):
    encoder.train(); decoder.train()
    total_train = 0.0
    for src_pad, src_lens, prev_ids, tgt_ids in tqdm(train_loader2, desc=f"Epoch {epoch} Training"):
        src_pad, src_lens = src_pad.to(device), src_lens.to(device)
        prev_ids, tgt_ids = prev_ids.to(device), tgt_ids.to(device)
        optimizer.zero_grad()
        # encode prefix
        enc_out, enc_hidden = encoder(src_pad)
        # one-step decode
        logits, dec_hidden, _ = decoder.forward_step(prev_ids, enc_hidden, enc_out)
        loss = criterion(logits, tgt_ids)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(params, max_norm=1.0)
        optimizer.step()
        total_train += loss.item()
    avg_train = total_train / len(train_loader2)

    encoder.eval(); decoder.eval()
    total_val = 0.0
    with torch.no_grad():
        for src_pad, src_lens, prev_ids, tgt_ids in val_loader2:
            src_pad, src_lens = src_pad.to(device), src_lens.to(device)
            prev_ids, tgt_ids = prev_ids.to(device), tgt_ids.to(device)
            enc_out, enc_hidden = encoder(src_pad)
            logits, dec_hidden, _ = decoder.forward_step(prev_ids, enc_hidden, enc_out)
            total_val += criterion(logits, tgt_ids).item()
    avg_val = total_val / len(val_loader2)

    print(f"Epoch {epoch} → Train Loss: {avg_train:.4f}, Val Loss: {avg_val:.4f}")
# --- 10) Save each model seperately---
torch.save(encoder.state_dict(), 'models/encoder_2.pth')
torch.save(decoder.state_dict(), 'models/decoder_2.pth')



Epoch 1 Training:   0%|          | 0/8618 [00:00<?, ?it/s]

Epoch 1 → Train Loss: 1.1899, Val Loss: 0.6422


Epoch 2 Training:   0%|          | 0/8618 [00:00<?, ?it/s]

Epoch 2 → Train Loss: 0.5886, Val Loss: 0.5733


Epoch 3 Training:   0%|          | 0/8618 [00:00<?, ?it/s]

Epoch 3 → Train Loss: 0.5332, Val Loss: 0.5414


Epoch 4 Training:   0%|          | 0/8618 [00:00<?, ?it/s]

Epoch 4 → Train Loss: 0.4970, Val Loss: 0.5251


Epoch 5 Training:   0%|          | 0/8618 [00:00<?, ?it/s]

Epoch 5 → Train Loss: 0.4711, Val Loss: 0.5085


Epoch 6 Training:   0%|          | 0/8618 [00:00<?, ?it/s]

Epoch 6 → Train Loss: 0.4497, Val Loss: 0.4960


Epoch 7 Training:   0%|          | 0/8618 [00:00<?, ?it/s]

Epoch 7 → Train Loss: 0.4326, Val Loss: 0.4871


Epoch 8 Training:   0%|          | 0/8618 [00:00<?, ?it/s]

Epoch 8 → Train Loss: 0.4180, Val Loss: 0.4771


Epoch 9 Training:   0%|          | 0/8618 [00:00<?, ?it/s]

Epoch 9 → Train Loss: 0.4052, Val Loss: 0.4733


Epoch 10 Training:   0%|          | 0/8618 [00:00<?, ?it/s]

Epoch 10 → Train Loss: 0.3941, Val Loss: 0.4670


At test time we generate one token at a time under the same Wait-$K$ schedule:

- Initialize $t=1$, $r(1)=\min(K, L_x)$, and $\hat y_0=\texttt{<sos>}$.  
- Repeat until $\hat y_{t-1}=\texttt{<eos>}$ or $t>\text{max\_len}$:

  1. **Encode prefix**  
     $$
       (h_{1:r(t)}, c_{1:r(t)})
       = \mathrm{Encoder}(x_{1:r(t)}).
     $$

  2. **Decode one step**  
     $$
       \tilde{h}_t
       = \mathrm{DecoderStep}\bigl(\hat y_{t-1},\,h_{r(t)},\,c_{r(t)},\,h_{1:r(t)}\bigr).
     $$

  3. **Attention & predict**  
     $$
       \alpha_j = \frac{\exp(\tilde{h}_t^\top h_j)}
                       {\sum_{k=1}^{r(t)}\exp(\tilde{h}_t^\top h_k)},\quad
       c_t = \sum_{j=1}^{r(t)}\alpha_j\,h_j,
     $$
     $$
       \hat y_t = \arg\max\mathrm{Softmax}\bigl(W[\tilde{h}_t; c_t]+b\bigr).
     $$

  4. **Advance**  
     $t \leftarrow t+1$ and  
     $r(t)=\min(K+(t-1),L_x)$.

The output sequence $\hat y_{1:T}$ is generated under the desired latency constraint.

In [10]:

embed_size = 400
hidden_size = 128
lr = 1e-3
special_tokens = {'<PAD>','<SOS>','<EOS>','<UNK>'}
# load the encoder and decoder models from the models directory
encoder = Encoder(vocab_size, embed_size, hidden_size).to(device)
decoder = Decoder(vocab_size, embed_size, hidden_size).to(device)
encoder.load_state_dict(torch.load('models/encoder_2.pth'))
decoder.load_state_dict(torch.load('models/decoder_2.pth'))

def translate_wait_k(src_sentence, K, max_len=50):
    encoder.eval(); decoder.eval()
    src_toks = tokenize(src_sentence)
    src_ids  = [ word2index.get(t, word2index['<UNK>']) for t in src_toks ]
    outputs = []
    t = 1
    while True:
        r = min(K + (t-1), len(src_ids))
        inp = torch.tensor(src_ids[:r], device=device).unsqueeze(0)
        enc_out, enc_hidden = encoder(inp)
        prev_id = word2index['<SOS>'] if t == 1 else outputs[-1]
        logits, dec_hidden, attn = decoder.forward_step(
            torch.tensor([prev_id], device=device), enc_hidden, enc_out
        )
        # mask specials
        for sp in special_tokens - {'<EOS>'}:
            logits[:, word2index[sp]] = -1e9
        next_id = logits.argmax(dim=1).item()
        if next_id == word2index['<EOS>'] or t >= max_len:
            break
        tok = index2word[next_id]
        # copy from source if UNK
        if tok == '<UNK>':
            a = attn.squeeze(0).squeeze(0)
            src_pos = a.argmax().item()
            tok = src_toks[src_pos]
            next_id = word2index.get(tok, next_id)
        if tok not in special_tokens:
            outputs.append(next_id)
        else:
            outputs.append(next_id)
        t += 1

    return " ".join(index2word[i] for i in outputs)

# Sample usage:
sentence = "Today is our dragon boat festival, in our Chinese culture, to celebrate it with all safe and great in our lives."
print(translate_wait_k(sentence, K=3, max_len=len(sentence.split())+10))


our dragon boat festival , in our chinese culture , to celebrate trapezoidal message with all safe and great in our st._john_chrysostom with all st._john_chrysostom with all st._john_chrysostom with all
