## Importing Libraries and models

In [5]:
!nvidia-smi

Sat Oct 18 09:22:10 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.54.03              Driver Version: 535.54.03    CUDA Version: 12.5     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          Off | 00000000:17:00.0 Off |                    0 |
| N/A   53C    P0              78W / 300W |  40494MiB / 81920MiB |     21%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80GB PCIe          Off | 00000000:31:00.0 Off |  

In [6]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [7]:
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import wandb
import random
import pandas as pd
import torch
import time
import numpy as np
import torch.nn as nn
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()
print(device)

cuda


## Load Dataset

In [9]:
class Language:
    def __init__(self, name):
        self.name = name
        self.char2index = {'#': 0, '$': 1, '^': 2}   # '^': start of sequence, '$' : unknown char, '#' : padding
        self.index2char = {0: '#', 1: '$', 2: '^'}
        self.vocab_size = 3  # Count

    def addWord(self, word):
        for char in word:
            self.addChar(char)

    def addChar(self, char):
        if char not in self.char2index:
            self.char2index[char] = self.vocab_size
            self.index2char[self.vocab_size] = char
            self.vocab_size += 1

    def encode(self, s):
        return [self.char2index[ch] for ch in s]

    def decode(self, l):
        return ''.join([self.index2char[i] for i in l])

    def vocab(self):
        return self.char2index.keys()


In [10]:
# returns maximum length of input and output words
def maxLength(data):
    ip_mlen, op_mlen = 0, 0

    for i in range(len(data)):
        input = data[0][i]
        output = data[1][i]
        if(len(input)>ip_mlen):
            ip_mlen=len(input)

        if(len(output)>op_mlen):
            op_mlen=len(output)

    return ip_mlen, op_mlen

In [11]:
import pandas as pd

def getMaxLengthValues(lang):
    base_path = f"../../aks_dataset/{lang}"
    
    # Load datasets
    train_df = pd.read_csv(f"{base_path}/train.csv", header=None)
    val_df = pd.read_csv(f"{base_path}/valid.csv", header=None)
    test_df = pd.read_csv(f"{base_path}/test.csv", header=None)

    # Initialize language vocabularies
    input_lang = Language('eng')
    output_lang = Language(lang)
    
    # Build vocabulary only from train data
    for _, row in train_df.iterrows():
        input_lang.addWord(str(row[0]))
        output_lang.addWord(str(row[1]))
    
    # Compute max input/output lengths for each split
    m1, m01 = maxLength(train_df)
    m2, m02 = maxLength(test_df)
    m3, m03 = maxLength(val_df)

    # Return the largest values across all splits
    return max(m1, m2, m3), max(m01, m02, m03)

# Example usage
input_max_len, output_max_len = getMaxLengthValues('hin')
print(input_max_len, output_max_len)

29 26


In [12]:
input_shape = 0
def preprocess(data, input_lang, output_lang, input_max_len, output_max_len, s=''):

    unknown = input_lang.char2index['$']

    n = len(data)
    input = torch.zeros((n, input_max_len + 1), device = device)
    output = torch.zeros((n, output_max_len + 2), device = device)

    for i in range(n):

        inp = data[0][i].ljust(input_max_len + 1, '#')
        op = '^' + data[1][i]       # add start symbol to output
        op = op.ljust(output_max_len + 2, '#')

        for index, char in enumerate(inp):
            if char in input_lang.char2index:
                input[i][index] = input_lang.char2index[char]
            else:
                input[i][index] = unknown

        for index, char in enumerate(op):
            if char in output_lang.char2index:
                output[i][index] = output_lang.char2index[char]
            else:
                output[i][index] = unknown

    print(s, ' dataset')
    print(input.shape)
    print(output.shape)

    return TensorDataset(input.to(torch.int32), output.to(torch.int32))

In [13]:
def load_prepare_data(lang):
    train_df = pd.read_csv(f"../../aks_dataset/{lang}/train.csv", header = None)
    val_df = pd.read_csv(f"../../aks_dataset/{lang}/valid.csv", header = None)
    test_df = pd.read_csv(f"../../aks_dataset/{lang}/test.csv", header = None)

    input_lang = Language('eng')
    output_lang = Language(lang)

    # create vocablury
    for i in range(len(train_df)):
        input_lang.addWord(train_df[0][i]) # 'eng'
        output_lang.addWord(train_df[1][i]) # 'hin'

    # encode the datasets
    test_data = preprocess(test_df, input_lang, output_lang,input_max_len, output_max_len, 'test')

    return test_data, input_lang, output_lang


test_data, input_lang, output_lang = load_prepare_data('hin')

test  dataset
torch.Size([10112, 30])
torch.Size([10112, 28])


In [14]:
# test_data[23][1]

In [15]:
print(input_lang.decode(test_data[23][0].tolist()))
output_lang.decode(test_data[23][1].tolist())

jadule########################


'^जडूले######################'

In [16]:
test_data[23][1]

tensor([ 2,  8, 32, 29, 13, 30,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0], device='cuda:0',
       dtype=torch.int32)

In [17]:
# encoder specific detail
input_vocab_size = input_lang.vocab_size
encoder_block_size = len(test_data[0][0])

# decoder specific detail
output_vocab_size = output_lang.vocab_size
decoder_block_size = len(test_data[0][1])

In [18]:
print(encoder_block_size)
print(decoder_block_size)

30
28


### Encoder model

In [19]:
class Head(nn.Module):
    """ one self-attention head """

    def __init__(self, n_embd, d_k, dropout, mask=0): # d_k is dimention of key , nomaly d_k = n_embd / 4
        super().__init__()
        self.mask = mask
        self.key = nn.Linear(n_embd, d_k, bias=False, device=device)
        self.query = nn.Linear(n_embd, d_k, bias=False, device=device)
        self.value = nn.Linear(n_embd, d_k, bias=False, device=device)
        if mask:
            self.register_buffer('tril', torch.tril(torch.ones(encoder_block_size, encoder_block_size, device=device)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_output = None):
        B,T,C = x.shape

        if encoder_output is not None:
            k = self.key(encoder_output)
            Be, Te, Ce = encoder_output.shape
        else:
            k = self.key(x) # (B,T,d_k)

        q = self.query(x) # (B,T,d_k)
        # compute attention scores
        wei = q @ k.transpose(-2, -1) * C**-0.5 # (B,T,T)

        if self.mask:
            if encoder_output is not None:
                wei = wei.masked_fill(self.tril[:T, :Te] == 0, float('-inf')) # (B,T,T)
            else:
                wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B,T,T)

        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        # perform weighted aggregation of values
        if encoder_output is not None:
            v = self.value(encoder_output)
        else:
            v = self.value(x)
        out = wei @ v # (B,T,C)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple self attention heads in parallel """

    def __init__(self, n_embd, num_head, d_k, dropout, mask=0):
        super().__init__()
        self.heads = nn.ModuleList([Head(n_embd, d_k, dropout, mask) for _ in range(num_head)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_output=None):
        out = torch.cat([h(x, encoder_output) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedForward(nn.Module):
    """ multiple self attention heads in parallel """

    def __init__(self, n_embd, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class encoderBlock(nn.Module):
    """ Tranformer encoder block : communication followed by computation """

    def __init__(self, n_embd, n_head, dropout):
        super().__init__()
        d_k = n_embd // n_head
        self.sa = MultiHeadAttention(n_embd, n_head, d_k, dropout)
        self.ffwd = FeedForward(n_embd, dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x, encoder_output=None):
        x = x + self.sa(self.ln1(x), encoder_output)
        x = x + self.ffwd(self.ln2(x))
        return x

class Encoder(nn.Module):

    def __init__(self, n_embd, n_head, n_layers, dropout):
        super().__init__()

        self.token_embedding_table = nn.Embedding(input_vocab_size, n_embd) # n_embd: input embedding dimension
        self.position_embedding_table = nn.Embedding(encoder_block_size, n_embd)
        self.blocks = nn.Sequential(*[encoderBlock(n_embd, n_head, dropout) for _ in range(n_layers)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm

    def forward(self, idx):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx) # (B,T,n_embd)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,n_embd)
        x = tok_emb + pos_emb # (B,T,n_embd)
        x = self.blocks(x) # apply one attention layer (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        return x


### Decoder model

In [20]:
class decoderBlock(nn.Module):
    """ Tranformer decoder block : self communication then cross communication followed by computation """

    def __init__(self, n_embd, n_head, dropout):
        super().__init__()
        d_k = n_embd // n_head
        self.sa = MultiHeadAttention(n_embd, n_head, d_k, dropout, mask = 1)
        self.ca = MultiHeadAttention(n_embd, n_head, d_k, dropout, mask = 1)
        self.ffwd = FeedForward(n_embd, dropout)
        self.ln1 = nn.LayerNorm(n_embd, device=device)
        self.ln2 = nn.LayerNorm(n_embd, device=device)
        self.ln3 = nn.LayerNorm(n_embd, device=device)

    def forward(self, x_encoder_output):
        x = x_encoder_output[0]
        encoder_output = x_encoder_output[1]
        x = x + self.sa(self.ln1(x))
        x = x + self.ca(self.ln2(x), encoder_output)
        x = x + self.ffwd(self.ln3(x))
        return (x,encoder_output)

class Decoder(nn.Module):

    def __init__(self, n_embd, n_head, n_layers, dropout):
        super().__init__()

        self.token_embedding_table = nn.Embedding(output_vocab_size, n_embd) # n_embd: input embedding dimension
        self.position_embedding_table = nn.Embedding(decoder_block_size, n_embd)
        self.blocks = nn.Sequential(*[decoderBlock(n_embd, n_head=n_head, dropout=dropout) for _ in range(n_layers)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, output_vocab_size)

    def forward(self, idx, encoder_output, targets=None):
        B, T = idx.shape

        tok_emb = self.token_embedding_table(idx) # (B,T,n_embd)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,n_embd)
        x = tok_emb + pos_emb # (B,T,n_embd)

        x =self.blocks((x, encoder_output))
        x = self.ln_f(x[0]) # (B,T,C)
        logits = self.lm_head(x) # (B,T,output_vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            temp_logits = logits.view(B*T, C)
            targets = targets.reshape(B*T)

            loss = F.cross_entropy(temp_logits, targets.long())

        # print(logits)
        # out = torch.argmax(logits)

        return logits, loss



## generate output sequence

In [17]:
def generate(input):
    B, T = input.shape
    encoder_output = encoder(input)
    idx = torch.full((B, 1), 2, dtype=torch.long, device=device) # (B,1)

    # idx is (B, T) array of indices in the current context
    for _ in range(decoder_block_size-1):
        # get the predictions
        logits, loss = decoder(idx, encoder_output) # logits (B, T, vocab_size)
        # focus only on the last time step
        logits = logits[:, -1, :] # becomes (B, C)
        # apply softmax to get probabilities
        idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (B, 1)
        # append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    return idx

## Check Test Accuracy

In [18]:
import torch
from torch.utils.data import DataLoader

def check(encoder, decoder, test_data, device='cpu', batch_size=64, pad_idx=0):
    """
    Compute:
      - word-level accuracy: fraction of sequences where all non-pad tokens match exactly
      - char-level accuracy: fraction of non-pad tokens that are predicted correctly
      - avg validation loss (per-batch loss averaged over batches)
    Parameters:
      encoder, decoder : your models
      test_data         : dataset (not a dataloader) or anything accepted by DataLoader
      device            : 'cpu' or 'cuda'
      batch_size        : dataloader batch size
      pad_idx           : integer index used for padding tokens; set to None to not mask pads
    """
    encoder.eval()
    decoder.eval()

    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

    total_sequences = 0           # total number of sequences (examples)
    total_nonpad_tokens = 0       # total number of non-pad tokens counted across all sequences
    char_correct = 0              # count of correctly predicted tokens (char-level)
    word_exact_correct = 0        # count of sequences with exact match on all non-pad tokens
    running_loss_val = 0.0
    n_batches = 0

    with torch.no_grad():
        for val_x, val_y in test_loader:
            val_x = val_x.to(device)
            val_y = val_y.to(device)

            # generate should produce token predictions of same shape as val_y
            output = generate(val_x)  # expected shape: (batch, seq_len)
            # get logits and loss from decoder as before; decoder returns (logits, loss)
            encoder_output = encoder(val_x)
            logits, loss = decoder(val_y[:, :-1], encoder_output, val_y[:, 1:])
            # accumulate scalar loss
            running_loss_val += loss.item()
            n_batches += 1

            # Align lengths: compare output[:,1:] with val_y[:,1:]
            pred = output[:, 1:]
            target = val_y[:, 1:]

            # ensure same shape (if generate produced different length, adjust or slice)
            min_len = min(pred.size(1), target.size(1))
            pred = pred[:, :min_len]
            target = target[:, :min_len]

            if pad_idx is not None:
                mask = (target != pad_idx)          # True where token is not padding
                nonpad_count = mask.sum().item()
                # char-level correct: count positions where pred == target and target != pad
                char_correct += ( (pred == target) & mask ).sum().item()
                total_nonpad_tokens += nonpad_count

                # word-level exact: for each sequence, check equality only on non-pad positions
                # If a sequence has zero nonpad tokens (weird), treat as not correct.
                seq_equal = torch.all( ((pred == target) | (~mask)), dim=1 )
                # but we should exclude sequences with zero nonpad tokens from denominator:
                nonpad_per_seq = mask.sum(dim=1)
                valid_seq_mask = (nonpad_per_seq > 0)
                if valid_seq_mask.any():
                    word_exact_correct += seq_equal[valid_seq_mask].sum().item()
                    total_sequences += valid_seq_mask.sum().item()
                # sequences with zero nonpad tokens are ignored for word-level stats
            else:
                # no padding: count all positions
                total_nonpad_tokens += target.numel()
                char_correct += (pred == target).sum().item()

                seq_equal = torch.all(pred == target, dim=1)
                word_exact_correct += seq_equal.sum().item()
                total_sequences += pred.size(0)

    # final metrics
    avg_loss = running_loss_val / n_batches if n_batches > 0 else float('nan')
    char_acc = (char_correct / total_nonpad_tokens * 100.0) if total_nonpad_tokens > 0 else float('nan')
    word_acc = (word_exact_correct / total_sequences * 100.0) if total_sequences > 0 else float('nan')

    print(f"Validation avg loss (per-batch): {avg_loss:.6f}")
    print(f"Char-level accuracy (non-pad tokens): {char_acc:.4f}%  ({char_correct}/{total_nonpad_tokens})")
    print(f"Word-level (sequence-exact) accuracy: {word_acc:.4f}%  ({word_exact_correct}/{total_sequences})")

    return {
        "avg_loss": avg_loss,
        "char_acc": char_acc,
        "word_acc": word_acc,
        "char_correct": char_correct,
        "total_nonpad_tokens": total_nonpad_tokens,
        "word_exact_correct": word_exact_correct,
        "total_sequences": total_sequences
    }


In [21]:
import torch

# Make sure Encoder and Decoder class definitions are already imported or defined above this line!

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load saved encoder and decoder
encoder = torch.load('models/transformer-encoder.pth', map_location=device)
decoder = torch.load('models/transformer-decoder.pth', map_location=device)

# Put them in eval mode before testing
encoder.eval()
decoder.eval()

print("Models loaded successfully on", device)

Models loaded successfully on cuda


In [22]:
check(encoder, decoder, test_data, device)

NameError: name 'check' is not defined

In [25]:
# input_lang.char2index
# output_lang.index2char

In [3]:
import torch
from torch.utils.data import TensorDataset

def encode_str(input_str, input_lang, input_max_len, s=''):
    input_words = input_str.split(' ')
    unknown = input_lang.char2index.get('$', 0)  # fallback in case '$' is missing

    n = len(input_words)
    input_tensor = torch.zeros((n, input_max_len + 1), dtype=torch.long, device=device)

    for i, word in enumerate(input_words):
        inp = word.ljust(input_max_len + 1, '#')  # pad to fixed length
        for index, char in enumerate(inp):
            input_tensor[i][index] = input_lang.char2index.get(char, unknown)

#     print(s, ' dataset')
#     print(input_tensor.shape)

    return input_tensor

In [4]:
encoded_input = encode_str("input str how", input_lang, input_max_len)
# print(encoded_input)

NameError: name 'input_lang' is not defined

In [42]:
# for enc in enc_list:
#     print(input_lang.decode(enc.tolist()))

input#########################
str###########################
how###########################


In [46]:
def test_generate(input, max_len=None, start_token=2):
    """
    Auto-regressive greedy decoding from encoder output.
    
    Parameters:
        input      : torch.Tensor of shape (B, T_in) - input sequences
        max_len    : maximum length to decode (optional)
        start_token: integer index for start-of-sequence token
        
    Returns:
        torch.Tensor of shape (B, T_out) with predicted token IDs
    """
    encoder_output = encoder(input)  # (B, ...), whatever your encoder outputs
    B = input.size(0)
    
    if max_len is None:
        max_len = decoder_block_size  # default decoding length

    # start with start token
    idx = torch.full((B, 1), start_token, dtype=torch.long, device=input.device)

    for _ in range(max_len - 1):
        logits, _ = decoder(idx, encoder_output)  # logits: (B, T_so_far, vocab_size)
        # take last timestep
        logits_last = logits[:, -1, :]            # (B, vocab_size)
        # greedy: pick highest probability token
        next_token = torch.argmax(logits_last, dim=-1, keepdim=True)  # (B, 1)
        # append to sequence
        idx = torch.cat((idx, next_token), dim=1)

    return idx

In [24]:
import torch
import torch.nn.functional as F

def beam_search_generate(input_ids,
                         max_len=None,
                         start_token=2,
                         beam_width=4,
                         EOS_IDX=None,
                         length_penalty=0.0):
    """
    Batched beam search decoding.

    Args:
        input_ids (torch.LongTensor): (B, T_in) encoder inputs.
        max_len (int or None): max decode length (including start token). If None, uses decoder_block_size.
        start_token (int): start-of-sequence token id.
        beam_width (int): beam width.
        EOS_IDX (int or None): optional end-of-sequence token id.
        length_penalty (float): exponent alpha. final_score = score / (length ** alpha).
                               Use 0.0 to disable.

    Returns:
        best_sequences: (B, L) LongTensor -- best beam per batch (includes start_token)
        all_beams: dict with keys:
            'sequences' -> (B, beam_width, L) LongTensor (all beams)
            'scores'    -> (B, beam_width) FloatTensor (raw log-prob sums, before length penalty)
    Notes:
        - Expects `encoder(input_ids)` and `decoder(decoder_input_ids, encoder_output)` to be defined in scope.
        - `decoder` is expected to return (logits, ...). logits shape: (B*beam_width, cur_len, V)
    """
    device = input_ids.device
    B = input_ids.size(0)

    # run encoder (adjust if your encoder returns tuple)
    encoder_output = encoder(input_ids)   # shape dependent on your model

    if max_len is None:
        max_len = decoder_block_size  # assumed to be defined in scope

    # initial length (we store sequences including the start token)
    cur_len = 1

    # sequences: (B, beam_width, cur_len)
    sequences = torch.full((B, beam_width, cur_len),
                           fill_value=start_token,
                           dtype=torch.long,
                           device=device)

    # scores: log-prob sums for each beam. initialize -inf for beams > 0
    neg_inf = -1e9
    scores = torch.full((B, beam_width), neg_inf, device=device)
    scores[:, 0] = 0.0

    # finished flags per beam
    finished = torch.zeros((B, beam_width), dtype=torch.bool, device=device)

    # We'll keep expanding up to max_len-1 additional tokens (since we already have start token)
    for step in range(1, max_len):
        # Flatten sequences to feed to decoder: (B*beam_width, cur_len)
        flat_seq = sequences.view(B * beam_width, cur_len)

        # Repeat encoder output to match beam dimension.
        # If encoder_output is a tensor of shape (B, ...), repeat_interleave works.
        # If it's a tuple (e.g., (enc_out, enc_mask)), adapt accordingly.
        try:
            enc_flat = encoder_output.repeat_interleave(beam_width, dim=0)
        except Exception:
            # Fallback: if encoder_output is a tuple/list, repeat each tensor inside
            if isinstance(encoder_output, (tuple, list)):
                enc_flat = tuple(x.repeat_interleave(beam_width, dim=0) for x in encoder_output)
            else:
                raise

        # Call decoder once for all beams
        logits, _ = decoder(flat_seq, enc_flat)            # (B*beam_width, cur_len, V)
        last_logits = logits[:, -1, :]                     # (B*beam_width, V)
        log_probs = F.log_softmax(last_logits, dim=-1)     # (B*beam_width, V)

        V = log_probs.size(-1)
        log_probs = log_probs.view(B, beam_width, V)       # (B, beam_width, V)

        # For beams that are finished, prevent expansion except allowing EOS to preserve score.
        if EOS_IDX is not None:
            # Build mask: (B, beam_width, V) True -> allow, False -> block
            allow_mask = torch.ones_like(log_probs, dtype=torch.bool, device=device)

            # For finished beams, allow only EOS token to have zero mask; others blocked
            finished_exp = finished.unsqueeze(-1).expand(B, beam_width, V)
            if finished_exp.any():
                allow_mask = ~finished_exp  # disallow all tokens for finished beams
                # exception: allow EOS token for finished beams (so it can be "re-selected")
                allow_mask[..., EOS_IDX] = allow_mask[..., EOS_IDX] | finished_exp[..., 0]

            # apply mask: set disallowed tokens to -inf
            log_probs = torch.where(allow_mask, log_probs, torch.tensor(neg_inf, device=device, dtype=log_probs.dtype))

        # candidate scores: scores[:, b] + log_probs[:, b, v]
        scores_expanded = scores.unsqueeze(-1)               # (B, beam_width, 1)
        candidate_scores = scores_expanded + log_probs       # (B, beam_width, V)

        # flatten beam and vocab dims -> (B, beam_width * V)
        candidate_scores_flat = candidate_scores.view(B, beam_width * V)

        # pick top-k overall candidates per batch (k = beam_width)
        k = min(beam_width, candidate_scores_flat.size(1))
        topk_scores, topk_indices = torch.topk(candidate_scores_flat, k=k, dim=-1)  # (B, k)

        # decode topk indices -> previous beam & token id
        prev_beam_idx = topk_indices // V                    # (B, k)
        token_idx = topk_indices % V                         # (B, k)

        # prepare new tensors for next step
        new_sequences = torch.zeros((B, beam_width, cur_len + 1), dtype=torch.long, device=device)
        new_scores = torch.full((B, beam_width), neg_inf, device=device)
        new_finished = torch.zeros((B, beam_width), dtype=torch.bool, device=device)

        # populate new beams
        for i in range(B):
            for k_i in range(k):
                pb = int(prev_beam_idx[i, k_i].item())
                tk = int(token_idx[i, k_i].item())
                new_sequences[i, k_i, :cur_len] = sequences[i, pb]
                new_sequences[i, k_i, cur_len] = tk
                new_scores[i, k_i] = topk_scores[i, k_i]
                # finished if previously finished or we just produced EOS
                new_finished[i, k_i] = finished[i, pb].item() or (EOS_IDX is not None and tk == EOS_IDX)

        # replace beam state
        sequences = new_sequences
        scores = new_scores
        finished = new_finished
        cur_len += 1

        # if all beams finished for all batches -> break
        if finished.all():
            break

    # At this point, `sequences` is (B, beam_width, cur_len) and `scores` are raw log-prob sums.
    # Apply length penalty when ranking final beams (if requested)
    if length_penalty != 0.0:
        # length to use for penalty: number of generated tokens excluding the start token
        length = cur_len - 1
        # avoid division by zero; use max(1, length)
        length_for_penalty = max(1, length)
        penalty = (length_for_penalty ** length_penalty)
        final_scores = scores / penalty
    else:
        final_scores = scores

    # pick best beam per batch
    best_beam_idx = torch.argmax(final_scores, dim=1)   # (B,)

    # extract best sequences
    best_sequences = torch.zeros((B, cur_len), dtype=torch.long, device=device)
    for i in range(B):
        best_sequences[i] = sequences[i, best_beam_idx[i]]

    # Return both best sequences and all beam states if caller wants to inspect n-best.
    all_beams = {'sequences': sequences, 'scores': scores}
    return best_sequences, all_beams


In [25]:
input_str = "pushpak shalaka kaustubh sandhyaa"

print('I am a Transformer Model')

# encode input
encoded_input = encode_str(input_str, input_lang, input_max_len)
# generate prediction
predicted_tokens = test_generate(encoded_input.to(device))

# decode to string
for seq in predicted_tokens:
    print(output_lang.decode(seq.tolist()).split('^')[1].split('#')[0])

I am a Transformer Model


NameError: name 'test_generate' is not defined

In [28]:
input_str = "pushpak shalaka kaustubh sandhyaa"
print('I am a Transformer Model')

# --- encode input ---
encoded_input = encode_str(input_str, input_lang, input_max_len)

# ensure batch dimension (B, T)
if encoded_input.dim() == 1:
    encoded_input = encoded_input.unsqueeze(0)   # shape -> (1, T)

# --- try to discover EOS / SOS indices from your output_lang if available ---
SOS_IDX = 2   # your earlier default; change if your project uses a different SOS
EOS_IDX = None
try:
    if hasattr(output_lang, "token2index"):
        EOS_IDX = output_lang.token2index.get("<EOS>") or output_lang.token2index.get("</s>") or EOS_IDX
    if hasattr(output_lang, "eos_idx"):
        EOS_IDX = output_lang.eos_idx
except Exception:
    EOS_IDX = EOS_IDX

# --- run beam search (beam_width = 2) ---
predicted_tokens, all_beams = beam_search_generate(
    encoded_input.to(device),
    max_len=decoder_block_size,
    start_token=SOS_IDX,
    beam_width=2,
    EOS_IDX=EOS_IDX,
    length_penalty=0.0
)  # returns tensor shape (B, T_out) including start token

# --- decode to strings exactly like you did earlier ---
for seq in predicted_tokens:
    # if you want to remove the start token first: seq = seq.tolist()[1:]
    decoded = output_lang.decode(seq.tolist())
    print(decoded.split('^')[1].split('#')[0])

print(all_beams)

I am a Transformer Model
पुष्पक
शालका
कौस्तुभ
संध्या
{'sequences': tensor([[[ 2,  3, 19, 48, 18,  3,  5,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 2,  3, 19, 41, 18,  3,  5,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]],

        [[ 2, 41, 12, 13,  5, 12,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 2, 41, 13,  5, 12,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]],

        [[ 2,  5, 40, 17, 18, 24, 19, 35,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 2,  5, 40, 17, 18,  7, 19, 35,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]],

        [[ 2, 17, 11, 51, 18,  9, 12,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  