<a href="https://colab.research.google.com/github/Priyanshu-Naik/Gen_AI/blob/main/seq2seq.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Step-by-Step Seq2Seq Implementation


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

Step 2: Encoder

We will define:

Each input token is converted to a dense vector (embedding).

The GRU processes the sequence one token at a time, updating its hidden state.

The final hidden state is returned as the context vector, summarizing the input sequence.

In [2]:
class Encoder(nn.Module):
  def __init__(self, input_dim, emb_dim, hidden_dim):
    super().__init__()
    self.embedding = nn.Embedding(input_dim, emb_dim)
    self.rnn = nn.GRU(emb_dim, hidden_dim)

  def forward(self, x):
    embedded = self.embedding(x)
    output, hidden = self.rnn(embedded)
    return hidden

Step 3: Decoder

We will define the decoder:

Takes the current input token and converts it to an embedding.

GRU uses the previous hidden state (or context vector initially) to compute the new hidden state.

The output is passed through a linear layer to get predicted token probabilities.

In [3]:
class Decoder(nn.Module):
  def __init__(self, output_dim, emb_dim, hidden_dim):
    super().__init__()
    self.embedding = nn.Embedding(output_dim, emb_dim)
    self.rnn = nn.GRU(emb_dim + hidden_dim, hidden_dim)
    self.fc = nn.Linear(hidden_dim, output_dim)

  def forward(self, x, hidden):
    x = x.unsqueeze(0)
    embedded = self.embedding(x)
    output, hidden = self.rnn(embedded, hidden)
    prediction = self.fc(output.squeeze(0))
    return prediction, hidden

Step 4: Seq2Seq Model with Teacher Forcing

Batch size & vocab size: extracted from input and decoder.

Encoding: input sequence → encoder → context vector (hidden).

Start token: initialize decoder with token 0.

Loop over max_len:

Decoder predicts next token.

top1 → token with max probability.

Append top1 to outputs.

Teacher forcing: sometimes feed true target token instead of prediction.

Return predictions: concatenated sequence of token IDs.

In [4]:
class seq2seq(nn.Module):
  def __init__(self, encoder, decoder, device):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.device = device

  def forward(self, src, trg = None, max_len = 10, teacher_forcing_ratio = 0.5):
    batch_size = src.shape[1]
    trg_len = self.decoder.fc.out_features
    outputs = []

    hidden = self.encoder(src)

    input = torch.zeros(batch_size, dtype = torch.long).to(self.device)

    for t in range(max_len):
      output, hidden = self.decoder(input, hidden)
      top1 = output.argmax(1)
      outputs.append(top1.unsqueeze(0))

      if trg is not None and t < trg_len and torch.rand(1) < teacher_forcing_ratio:
        input = trg[t]
      else:
        input = top1

    outputs = torch.cat(outputs, dim = 0)
    return outputs

Step 5: Usage Example with Outputs

Test with example,

src: random input token IDs.

trg: random target token IDs (used for teacher forcing).

outputs: predicted token IDs for each sequence.

.T: transpose to show batch sequences as rows.

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

VOCAB_SIZE = 10
EMB_DIM = 8
HID_DIM = 16
SEQ_LEN = 5
BATCH_SIZE = 2

enc = Encoder(VOCAB_SIZE, EMB_DIM, HID_DIM)
dec = Decoder(VOCAB_SIZE, EMB_DIM, HID_DIM)
model = seq2seq(enc, dec, device).to(device)

src = torch.randint(1, VOCAB_SIZE, (SEQ_LEN, BATCH_SIZE)).to(device)
trg = torch.randint(1, VOCAB_SIZE, (SEQ_LEN, BATCH_SIZE)).to(device)

outputs = model(src, trg, max_len=SEQ_LEN, teacher_forcing_ratio=0.7)

print("Source sequence (input tokens):")
print(src.T)
print("\nTarget sequence (true tokens):")
print(trg.T)
print("\nPredicted sequence (model output tokens):")
print(outputs.T)


Source sequence (input tokens):
tensor([[5, 9, 8, 7, 7],
        [9, 8, 6, 4, 8]])

Target sequence (true tokens):
tensor([[4, 1, 2, 2, 8],
        [9, 8, 1, 8, 9]])

Predicted sequence (model output tokens):
tensor([[5, 4, 7, 5, 4],
        [4, 0, 4, 4, 4]])
