<a href="https://colab.research.google.com/github/SanjanaRitika/TextToCode_seq2seq/blob/main/improved_version.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [16]:
# Import required libraries
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import math
import time
from datasets import load_dataset
from transformers import AutoTokenizer
from sklearn.model_selection import train_test_split

# Set device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

Using device: cuda


In [17]:
# Load dataset
ds = load_dataset("Nan-Do/code-search-net-python")
print(ds)

DatasetDict({
    train: Dataset({
        features: ['repo', 'path', 'func_name', 'original_string', 'language', 'code', 'code_tokens', 'docstring', 'docstring_tokens', 'sha', 'url', 'partition', 'summary'],
        num_rows: 455243
    })
})


In [18]:
# Extract relevant columns
full_df = ds['train'].to_pandas()
df = full_df[['code', 'code_tokens', 'docstring', 'docstring_tokens']]

# Filter by sequence length constraints and sample
length_filter = (df['docstring_tokens'].map(len) <= 50) & (df['code_tokens'].map(len) <= 80)
filtered_data = df[length_filter]
sampled_data = filtered_data.sample(n=10000, random_state=42).reset_index(drop=True)

print(f"Sampled dataset shape: {sampled_data.shape}")
print("\nFirst few examples:")
print(sampled_data.head(3))

Sampled dataset shape: (10000, 4)

First few examples:
                                                code  \
0  def _lib(self, name, only_if_have=False):\n   ...   
1  def open(self):\n        """Opens an existing ...   
2  def hmac_sha1(self, key_handle, data, flags = ...   

                                         code_tokens  \
0  [def, _lib, (, self, ,, name, ,, only_if_have,...   
1  [def, open, (, self, ), :, try, :, self, ., gr...   
2  [def, hmac_sha1, (, self, ,, key_handle, ,, da...   

                                           docstring  \
0  Specify a linker library.\n\n        Example:\...   
1                           Opens an existing cache.   
2  Have the YubiHSM generate a HMAC SHA1 of 'data...   

                                    docstring_tokens  
0                   [Specify, a, linker, library, .]  
1                    [Opens, an, existing, cache, .]  
2  [Have, the, YubiHSM, generate, a, HMAC, SHA1, ...  


In [19]:
# Analyze token length statistics
stats_df = pd.DataFrame({
    'code_length': sampled_data['code_tokens'].apply(len),
    'docstring_length': sampled_data['docstring_tokens'].apply(len)
}).describe()

print("Dataset Statistics:")
print(stats_df)

Dataset Statistics:
        code_length  docstring_length
count  10000.000000        10000.0000
mean      48.047400           12.5361
std       15.613723            8.8588
min       20.000000            1.0000
25%       35.000000            7.0000
50%       46.000000           10.0000
75%       60.000000           15.0000
max       80.000000           50.0000


In [20]:
# Initialize CodeBERT tokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base", add_prefix_space=True)
VOCAB_SIZE = tokenizer.vocab_size
PAD_IDX = tokenizer.pad_token_id

print(f"Vocabulary size: {VOCAB_SIZE}")
print(f"Padding index: {PAD_IDX}")

Vocabulary size: 50265
Padding index: 1


In [21]:
# Convert tokens to IDs
def tokenize_row(row):
    code_toks = [str(t) for t in row['code_tokens']]
    doc_toks = [str(t) for t in row['docstring_tokens']]

    code_ids = tokenizer(code_toks, is_split_into_words=True,
                         truncation=True, padding='max_length', max_length=128)['input_ids']
    doc_ids = tokenizer(doc_toks, is_split_into_words=True,
                        truncation=True, padding='max_length', max_length=64)['input_ids']

    return pd.Series({'code_ids': code_ids, 'docstring_ids': doc_ids})

tokenized_data = sampled_data.apply(tokenize_row, axis=1)
dataset = pd.concat([sampled_data, tokenized_data], axis=1)
print("\nTokenization complete!")


Tokenization complete!


In [22]:
# Split into train/validation/test sets
train_data, temp_data = train_test_split(dataset, test_size=0.40, random_state=42)
val_data, test_data = train_test_split(temp_data, test_size=0.50, random_state=42)

print(f"Train: {len(train_data)} | Validation: {len(val_data)} | Test: {len(test_data)}")

Train: 6000 | Validation: 2000 | Test: 2000


In [23]:
# Create DataLoaders
def create_dataloader(df, batch_size=64, shuffle=True):
    src_tensor = torch.tensor(df['code_ids'].tolist())
    trg_tensor = torch.tensor(df['docstring_ids'].tolist())
    dataset = TensorDataset(src_tensor, trg_tensor)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

train_loader = create_dataloader(train_data)
val_loader = create_dataloader(val_data)
test_loader = create_dataloader(test_data, shuffle=False)

print("DataLoaders ready!")

DataLoaders ready!


In [24]:
class RNN_Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, pad_idx, n_layers=2, dropout=0.3):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.rnn = nn.RNN(
            embed_dim,
            hidden_dim,
            num_layers=n_layers,
            nonlinearity="tanh",
            batch_first=True,
            dropout=dropout if n_layers > 1 else 0.0
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        # src: [B, S]
        embedded = self.dropout(self.embed(src))  # [B,S,E]
        _, hidden = self.rnn(embedded)            # hidden: [L,B,H]
        return hidden


class RNN_Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, pad_idx, n_layers=2, dropout=0.3, tie_weights=True):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.rnn = nn.RNN(
            embed_dim,
            hidden_dim,
            num_layers=n_layers,
            nonlinearity="tanh",
            batch_first=True,
            dropout=dropout if n_layers > 1 else 0.0
        )
        self.dropout = nn.Dropout(dropout)

        # output projection
        self.out = nn.Linear(hidden_dim, vocab_size)

        # ✅ Weight tying (only if embed_dim == hidden_dim)
        if tie_weights:
            if embed_dim != hidden_dim:
                print("⚠️ tie_weights skipped because embed_dim != hidden_dim")
            else:
                self.out.weight = self.embed.weight  # tie

    def forward(self, token, hidden):
        # token: [B] (one token id per batch)  <-- simpler & standard
        token = token.unsqueeze(1)  # [B,1]
        embedded = self.dropout(self.embed(token))  # [B,1,E]
        output, hidden = self.rnn(embedded, hidden) # output: [B,1,H]
        prediction = self.out(output.squeeze(1))    # [B,V]
        return prediction, hidden


class VanillaSeq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_ratio=0.5):
        # trg: [B,T]
        batch_size, trg_len = trg.shape
        vocab_size = self.decoder.vocab_size

        outputs = torch.zeros(batch_size, trg_len, vocab_size, device=self.device)

        hidden = self.encoder(src)   # [L,B,H]

        token = trg[:, 0]  # [B] (BOS)

        for t in range(1, trg_len):
            pred, hidden = self.decoder(token, hidden)
            outputs[:, t] = pred

            use_teacher = torch.rand(1).item() < teacher_ratio
            token = trg[:, t] if use_teacher else pred.argmax(1)

        return outputs

print("✅ Improved Vanilla RNN model defined")


✅ Improved Vanilla RNN model defined


In [25]:
# ----------------------------
# Better hyperparameters (Model 1: Vanilla RNN)
# ----------------------------
EMBED_DIM  = 256     # make equal to HIDDEN_DIM for weight tying
HIDDEN_DIM = 256
N_LAYERS   = 2
DROPOUT    = 0.3

N_EPOCHS = 10
CLIP = 1.0

# ----------------------------
# Initialize Model 1 (improved encoder/decoder)
# NOTE: this assumes you updated RNN_Encoder/RNN_Decoder to accept:
# (vocab_size, embed_dim, hidden_dim, pad_idx, n_layers, dropout, tie_weights)
# ----------------------------
encoder1 = RNN_Encoder(
    vocab_size=VOCAB_SIZE,
    embed_dim=EMBED_DIM,
    hidden_dim=HIDDEN_DIM,
    pad_idx=PAD_IDX,
    n_layers=N_LAYERS,
    dropout=DROPOUT
)

decoder1 = RNN_Decoder(
    vocab_size=VOCAB_SIZE,
    embed_dim=EMBED_DIM,
    hidden_dim=HIDDEN_DIM,
    pad_idx=PAD_IDX,
    n_layers=N_LAYERS,
    dropout=DROPOUT,
    tie_weights=True   # works because EMBED_DIM == HIDDEN_DIM
)

model1 = VanillaSeq2Seq(encoder1, decoder1, DEVICE).to(DEVICE)

# Multi-GPU (only if available)
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    model1 = nn.DataParallel(model1)

# ----------------------------
# Optimizer + Loss (better defaults)
# ----------------------------
optimizer1 = optim.AdamW(
    model1.parameters(),
    lr=3e-4,                 # smaller LR = more stable for seq2seq
    betas=(0.9, 0.98),
    weight_decay=1e-2
)

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)


In [26]:
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler(enabled=(DEVICE.type == "cuda"))

# Training function (improved)
def train_model(model, loader, optimizer, criterion, clip, teacher_ratio=0.5):
    model.train()
    total_loss = 0
    total_acc = 0

    for src, trg in loader:
        src, trg = src.to(DEVICE), trg.to(DEVICE)
        optimizer.zero_grad(set_to_none=True)

        with autocast(enabled=(DEVICE.type == "cuda")):
            output = model(src, trg, teacher_ratio)

            # output: [B, T, V]
            # ignore BOS position
            output_flat = output[:, 1:].reshape(-1, output.shape[-1])
            trg_flat = trg[:, 1:].reshape(-1)

            loss = criterion(output_flat, trg_flat)

        # backward (AMP-safe)
        scaler.scale(loss).backward()

        # clip gradients
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

        # token accuracy (ignore PAD)
        preds = output[:, 1:].argmax(-1)              # [B, T-1]
        gold  = trg[:, 1:]                           # [B, T-1]
        mask  = gold != PAD_IDX
        correct = ((preds == gold) & mask).sum().item()
        total = mask.sum().item()
        total_acc += correct / (total + 1e-9)

    return total_loss / len(loader), total_acc / len(loader)


# Evaluation function (improved)
def eval_model(model, loader, criterion):
    model.eval()
    total_loss = 0
    total_acc = 0

    with torch.no_grad():
        for src, trg in loader:
            src, trg = src.to(DEVICE), trg.to(DEVICE)

            with autocast(enabled=(DEVICE.type == "cuda")):
                output = model(src, trg, 0)  # No teacher forcing

                output_flat = output[:, 1:].reshape(-1, output.shape[-1])
                trg_flat = trg[:, 1:].reshape(-1)

                loss = criterion(output_flat, trg_flat)

            total_loss += loss.item()

            preds = output[:, 1:].argmax(-1)
            gold  = trg[:, 1:]
            mask  = gold != PAD_IDX
            correct = ((preds == gold) & mask).sum().item()
            total = mask.sum().item()
            total_acc += correct / (total + 1e-9)

    return total_loss / len(loader), total_acc / len(loader)

print("✅ Improved training functions ready")


✅ Improved training functions ready


  scaler = GradScaler(enabled=(DEVICE.type == "cuda"))


In [27]:
best_val_loss_m1 = float("inf")
print("Training Vanilla RNN Seq2Seq...")

train_losses_m1 = []
val_losses_m1 = []
train_accs_m1 = []
val_accs_m1 = []

for epoch in range(N_EPOCHS):

    start_time = time.time()

    # teacher forcing schedule
    teacher_ratio = max(0.3, 0.7 - epoch * 0.05)

    train_loss, train_acc = train_model(
        model1,
        train_loader,
        optimizer1,
        criterion,
        CLIP,
        teacher_ratio=teacher_ratio
    )

    val_loss, val_acc = eval_model(
        model1,
        val_loader,
        criterion
    )

    train_losses_m1.append(train_loss)
    val_losses_m1.append(val_loss)
    train_accs_m1.append(train_acc)
    val_accs_m1.append(val_acc)

    epoch_time = time.time() - start_time

    if val_loss < best_val_loss_m1:
        best_val_loss_m1 = val_loss
        torch.save(model1.state_dict(), "vanilla_rnn.pt")

    print(
        "Epoch {:02d} | Time {:.0f}s | Train Loss {:.3f} | Val Loss {:.3f}".format(
            epoch + 1,
            epoch_time,
            train_loss,
            val_loss
        )
    )

print("Vanilla RNN training finished.")


Training Vanilla RNN Seq2Seq...


  with autocast(enabled=(DEVICE.type == "cuda")):
  with autocast(enabled=(DEVICE.type == "cuda")):


Epoch 01 | Time 72s | Train Loss 19.403 | Val Loss 12.681
Epoch 02 | Time 70s | Train Loss 13.138 | Val Loss 10.003
Epoch 03 | Time 68s | Train Loss 10.413 | Val Loss 8.846
Epoch 04 | Time 68s | Train Loss 9.134 | Val Loss 8.468
Epoch 05 | Time 68s | Train Loss 8.398 | Val Loss 8.029
Epoch 06 | Time 68s | Train Loss 7.852 | Val Loss 7.693
Epoch 07 | Time 68s | Train Loss 7.457 | Val Loss 7.473
Epoch 08 | Time 68s | Train Loss 7.194 | Val Loss 7.235
Epoch 09 | Time 68s | Train Loss 7.001 | Val Loss 7.135
Epoch 10 | Time 68s | Train Loss 6.847 | Val Loss 7.041
Vanilla RNN training finished.


In [30]:
# Load best saved model
model1.load_state_dict(torch.load("vanilla_rnn.pt", map_location=DEVICE))
model1.eval()

# Evaluate on test set
test_loss, test_acc = eval_model(model1, test_loader, criterion)

test_ppl = math.exp(test_loss)

print("Vanilla RNN Test Results")
print("-------------------------")
print("Test Loss :", round(test_loss, 3))
print("Test PPL  :", round(test_ppl, 3))
print("Test Acc  :", round(test_acc, 3))


  with autocast(enabled=(DEVICE.type == "cuda")):


Vanilla RNN Test Results
-------------------------
Test Loss : 7.052
Test PPL  : 1154.857
Test Acc  : 0.083


In [32]:
class LSTM_Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, pad_idx, n_layers=2, dropout=0.3):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.dropout = nn.Dropout(dropout)

        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=n_layers,
            batch_first=True,
            dropout=dropout if n_layers > 1 else 0.0
        )

    def forward(self, src):
        # src: [B, S]
        embedded = self.dropout(self.embed(src))     # [B,S,E]
        _, (hidden, cell) = self.lstm(embedded)      # hidden/cell: [L,B,H]
        return hidden, cell


class LSTM_Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, pad_idx, n_layers=2, dropout=0.3, tie_weights=True):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.dropout = nn.Dropout(dropout)

        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=n_layers,
            batch_first=True,
            dropout=dropout if n_layers > 1 else 0.0
        )

        self.out = nn.Linear(hidden_dim, vocab_size)

        # optional weight tying (only if embed_dim == hidden_dim)
        if tie_weights:
            if embed_dim == hidden_dim:
                self.out.weight = self.embed.weight
            else:
                # keep silent or print a short warning
                pass

    def forward(self, token, hidden, cell):
        # token: [B]  (one token per batch)
        token = token.unsqueeze(1)                   # [B,1]
        embedded = self.dropout(self.embed(token))   # [B,1,E]

        output, (hidden, cell) = self.lstm(embedded, (hidden, cell))  # output: [B,1,H]
        prediction = self.out(output.squeeze(1))     # [B,V]
        return prediction, hidden, cell


class LSTMSeq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_ratio=0.5):
        # src: [B,S], trg: [B,T]
        batch_size, trg_len = trg.shape
        vocab_size = self.decoder.vocab_size

        outputs = torch.zeros(batch_size, trg_len, vocab_size, device=self.device)

        hidden, cell = self.encoder(src)

        token = trg[:, 0]  # BOS

        for t in range(1, trg_len):
            pred, hidden, cell = self.decoder(token, hidden, cell)
            outputs[:, t] = pred

            use_teacher = random.random() < teacher_ratio
            token = trg[:, t] if use_teacher else pred.argmax(1)

        return outputs

print("Improved LSTM model defined")


Improved LSTM model defined


In [51]:
# ----------------------------
# Model 2 hyperparams (recommended)
# ----------------------------
EMBED_DIM  = 256      # set equal to HIDDEN_DIM for weight tying
HIDDEN_DIM = 256
N_LAYERS   = 2
DROPOUT    = 0.3

# ----------------------------
# Initialize Model 2 (LSTM Seq2Seq)
# ----------------------------
encoder2 = LSTM_Encoder(
    vocab_size=VOCAB_SIZE,
    embed_dim=EMBED_DIM,
    hidden_dim=HIDDEN_DIM,
    pad_idx=PAD_IDX,
    n_layers=N_LAYERS,
    dropout=DROPOUT
)

decoder2 = LSTM_Decoder(
    vocab_size=VOCAB_SIZE,
    embed_dim=EMBED_DIM,
    hidden_dim=HIDDEN_DIM,
    pad_idx=PAD_IDX,
    n_layers=N_LAYERS,
    dropout=DROPOUT,
    tie_weights=True
)

model2 = LSTMSeq2Seq(encoder2, decoder2, DEVICE).to(DEVICE)

# Multi-GPU (optional)
if torch.cuda.device_count() > 1:
    model2 = nn.DataParallel(model2)

# ----------------------------
# Optimizer (more stable than default Adam)
# ----------------------------
optimizer2 = optim.AdamW(
    model2.parameters(),
    lr=3e-4,
    betas=(0.9, 0.98),
    weight_decay=1e-2
)


In [54]:
import random
best_val_loss_m2 = float("inf")
print("Training LSTM Seq2Seq...")

train_losses_m2, val_losses_m2 = [], []
train_accs_m2,  val_accs_m2  = [], []

for epoch in range(N_EPOCHS):
    start_time = time.time()

    # Teacher forcing schedule
    teacher_ratio = max(0.3, 0.7 - epoch * 0.05)

    train_loss, train_acc = train_model(
        model2,
        train_loader,
        optimizer2,
        criterion,
        CLIP,
        teacher_ratio=teacher_ratio
    )

    val_loss, val_acc = eval_model(
        model2,
        val_loader,
        criterion
    )

    train_losses_m2.append(train_loss)
    val_losses_m2.append(val_loss)
    train_accs_m2.append(train_acc)
    val_accs_m2.append(val_acc)

    elapsed = time.time() - start_time

    if val_loss < best_val_loss_m2:
        best_val_loss_m2 = val_loss
        torch.save(model2.state_dict(), "lstm_seq2seq.pt")

    print(
        "Epoch {:02d} | Time {:.0f}s | TF {:.2f} | Train Loss {:.3f} | Val Loss {:.3f} | Train Acc {:.3f} | Val Acc {:.3f}".format(
            epoch + 1,
            elapsed,
            teacher_ratio,
            train_loss,
            val_loss,
            train_acc,
            val_acc
        )
    )

print("LSTM training finished.")



Training LSTM Seq2Seq...


  with autocast(enabled=(DEVICE.type == "cuda")):
  with autocast(enabled=(DEVICE.type == "cuda")):


Epoch 01 | Time 73s | TF 0.70 | Train Loss 9.496 | Val Loss 8.956 | Train Acc 0.090 | Val Acc 0.081
Epoch 02 | Time 72s | TF 0.65 | Train Loss 8.442 | Val Loss 8.328 | Train Acc 0.117 | Val Acc 0.085
Epoch 03 | Time 72s | TF 0.60 | Train Loss 7.811 | Val Loss 7.826 | Train Acc 0.121 | Val Acc 0.087
Epoch 04 | Time 72s | TF 0.55 | Train Loss 7.346 | Val Loss 7.487 | Train Acc 0.121 | Val Acc 0.085
Epoch 05 | Time 72s | TF 0.50 | Train Loss 7.033 | Val Loss 7.242 | Train Acc 0.121 | Val Acc 0.084
Epoch 06 | Time 72s | TF 0.45 | Train Loss 6.830 | Val Loss 7.101 | Train Acc 0.118 | Val Acc 0.084
Epoch 07 | Time 73s | TF 0.40 | Train Loss 6.681 | Val Loss 7.005 | Train Acc 0.115 | Val Acc 0.084
Epoch 08 | Time 72s | TF 0.35 | Train Loss 6.590 | Val Loss 6.941 | Train Acc 0.113 | Val Acc 0.087
Epoch 09 | Time 73s | TF 0.30 | Train Loss 6.526 | Val Loss 6.901 | Train Acc 0.111 | Val Acc 0.085
Epoch 10 | Time 72s | TF 0.30 | Train Loss 6.465 | Val Loss 6.898 | Train Acc 0.112 | Val Acc 0.082


In [55]:
# Load best saved model
model2.load_state_dict(torch.load("lstm_seq2seq.pt", map_location=DEVICE))
model2.eval()

# Evaluate on test set
test_loss, test_acc = eval_model(model2, test_loader, criterion)

test_ppl = math.exp(test_loss)

print("LSTM Test Results")
print("------------------")
print("Test Loss :", round(test_loss, 3))
print("Test PPL  :", round(test_ppl, 3))
print("Test Acc  :", round(test_acc, 3))


  with autocast(enabled=(DEVICE.type == "cuda")):


LSTM Test Results
------------------
Test Loss : 6.901
Test PPL  : 993.193
Test Acc  : 0.082


In [56]:
class BahdanauAttention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        # dec_hidden (H) + enc_out (2H) -> 3H
        self.W = nn.Linear(hidden_dim * 3, hidden_dim)
        self.v = nn.Linear(hidden_dim, 1, bias=False)

    def forward(self, dec_hidden, encoder_outputs, src_mask):
        """
        dec_hidden:      [B, H]
        encoder_outputs: [B, S, 2H]
        src_mask:        [B, S]   (1 for real tokens, 0 for PAD)
        """
        B, S, _ = encoder_outputs.shape

        dec_exp = dec_hidden.unsqueeze(1).expand(B, S, dec_hidden.size(-1))  # [B,S,H]
        energy = torch.tanh(self.W(torch.cat((dec_exp, encoder_outputs), dim=2)))  # [B,S,H]
        scores = self.v(energy).squeeze(2)  # [B,S]

        # ✅ mask PAD positions
        scores = scores.masked_fill(src_mask == 0, -1e9)

        attn = torch.softmax(scores, dim=1)  # [B,S]
        return attn


class AttentionEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, pad_idx, n_layers=1, dropout=0.3):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.dropout = nn.Dropout(dropout)

        self.lstm = nn.LSTM(
            embed_dim,
            hidden_dim,
            num_layers=n_layers,
            bidirectional=True,
            batch_first=True,
            dropout=dropout if n_layers > 1 else 0.0
        )

        # bridge bi states (2H) -> H
        self.fc_hidden = nn.Linear(hidden_dim * 2, hidden_dim)
        self.fc_cell = nn.Linear(hidden_dim * 2, hidden_dim)

    def forward(self, src):
        embedded = self.dropout(self.embed(src))               # [B,S,E]
        outputs, (hidden, cell) = self.lstm(embedded)          # outputs: [B,S,2H]

        # take last layer forward/backward
        # hidden: [2*n_layers, B, H] -> last forward: [-2], last backward: [-1]
        h_fwd = hidden[-2, :, :]   # [B,H]
        h_bwd = hidden[-1, :, :]   # [B,H]
        c_fwd = cell[-2, :, :]
        c_bwd = cell[-1, :, :]

        hidden = torch.tanh(self.fc_hidden(torch.cat((h_fwd, h_bwd), dim=1)))  # [B,H]
        cell   = torch.tanh(self.fc_cell(torch.cat((c_fwd, c_bwd), dim=1)))    # [B,H]

        return outputs, hidden, cell


class AttentionDecoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, pad_idx, dropout, attention):
        super().__init__()
        self.vocab_size = vocab_size
        self.attention = attention

        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.dropout = nn.Dropout(dropout)

        # input = [embedded(E) + context(2H)]
        self.lstm = nn.LSTM(embed_dim + (hidden_dim * 2), hidden_dim, batch_first=True)

        # output = [dec_out(H) + context(2H) + embedded(E)] -> vocab
        self.fc = nn.Linear(hidden_dim + (hidden_dim * 2) + embed_dim, vocab_size)

    def forward(self, token, hidden, cell, encoder_outputs, src_mask):
        """
        token:           [B]
        hidden/cell:     [B,H]
        encoder_outputs: [B,S,2H]
        src_mask:        [B,S]
        """
        token = token.unsqueeze(1)                               # [B,1]
        embedded = self.dropout(self.embed(token))               # [B,1,E]

        attn = self.attention(hidden, encoder_outputs, src_mask) # [B,S]
        attn = attn.unsqueeze(1)                                 # [B,1,S]

        context = torch.bmm(attn, encoder_outputs)               # [B,1,2H]

        lstm_input = torch.cat((embedded, context), dim=2)       # [B,1,E+2H]
        output, (hidden1, cell1) = self.lstm(
            lstm_input, (hidden.unsqueeze(0), cell.unsqueeze(0))
        )

        output = output.squeeze(1)       # [B,H]
        context = context.squeeze(1)     # [B,2H]
        embedded = embedded.squeeze(1)   # [B,E]

        pred = self.fc(torch.cat((output, context, embedded), dim=1))  # [B,V]
        return pred, hidden1.squeeze(0), cell1.squeeze(0), attn.squeeze(1)


class AttentionSeq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device, pad_idx):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.pad_idx = pad_idx

    def forward(self, src, trg, teacher_ratio=0.5, return_attn=False):
        """
        return_attn=True will return (outputs, attn_matrix)
        attn_matrix: [B, T-1, S]
        """
        batch_size, trg_len = trg.shape
        vocab_size = self.decoder.vocab_size

        outputs = torch.zeros(batch_size, trg_len, vocab_size, device=self.device)

        encoder_outputs, hidden, cell = self.encoder(src)
        src_mask = (src != self.pad_idx).long()  # [B,S]

        token = trg[:, 0]

        attn_all = []  # store attention weights

        for t in range(1, trg_len):
            pred, hidden, cell, attn = self.decoder(token, hidden, cell, encoder_outputs, src_mask)
            outputs[:, t] = pred
            attn_all.append(attn.unsqueeze(1))  # [B,1,S]

            use_teacher = random.random() < teacher_ratio
            token = trg[:, t] if use_teacher else pred.argmax(1)

        if return_attn:
            attn_matrix = torch.cat(attn_all, dim=1)  # [B,T-1,S]
            return outputs, attn_matrix

        return outputs

print("Improved Attention model defined")


Improved Attention model defined
