In [1]:
import torch
from torch import nn, optim
import os

from data_preprocess import *
from model import *
from utils import *

In [4]:
user = "ee20d201-indian-institute-of-technology-madras"
project = "DA6401_Assignment_3"
display_name = "attention_heatmaps"

# Configs
data_dir = 'dakshina_dataset_v1.0'
lang = 'hi'  # Hindi
subfolder_dir = 'lexicons'
train_path = os.path.join(data_dir, lang, subfolder_dir, f'{lang}.translit.sampled.train.tsv')
dev_path = os.path.join(data_dir, lang, subfolder_dir, f'{lang}.translit.sampled.dev.tsv')
test_path = os.path.join(data_dir, lang, subfolder_dir, f'{lang}.translit.sampled.test.tsv')

encoder_embedding_dim = 32
decoder_embedding_dim = 128
hidden_dim = 128
num_encoder_layers = 2
num_decoder_layers = 2
rnn_type = 'lstm'  # can be 'RNN' or 'LSTM' or 'GRU'
batch_size = 32
num_epochs = 20
learning_rate = 0.005
dropout_prob = 0.2
use_attention = True
if use_attention:
    output_dir = 'predictions_attention'
else:
    output_dir = 'predictions_vanilla'
teacher_forcing_ratio = 0.75

# device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() and torch.backends.mps.is_built() else 'cpu')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Get dataloaders
train_loader, val_loader, test_loader, src_vocab, tgt_vocab = prepare_dataloaders(train_path, dev_path, test_path, batch_size, repeat_datapoints=True, num_workers=4)

# Initialize models
encoder = Encoder(input_dim=len(src_vocab), emb_dim=encoder_embedding_dim, hidden_dim=hidden_dim,
                  num_layers=num_encoder_layers, rnn_type=rnn_type, dropout=dropout_prob).to(device)
decoder = Decoder(output_dim=len(tgt_vocab), emb_dim=decoder_embedding_dim, hidden_dim=hidden_dim,
                  num_layers=num_decoder_layers, rnn_type=rnn_type, dropout=dropout_prob, use_attention=True).to(device)
model = Seq2Seq(encoder, decoder, device).to(device)

cpu


1

In [5]:
if use_attention:
    checkpoint = torch.load('best_att.pth', map_location=device, weights_only=True)
else:
    checkpoint = torch.load('best.pth', map_location=device, weights_only=True)
criterion = nn.CrossEntropyLoss(ignore_index=tgt_vocab.pad_idx)

# Load model weights
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
# test_model_alternate(model, test_loader, criterion, src_vocab, tgt_vocab, device, beam_validate=True, output_dir=output_dir, wandb_log=wandb_log, n=20)
# plot_attention_heatmaps(model, test_loader, tgt_vocab, src_vocab, device, 10, args.wandb_log)
plot_connectivity(model, src_vocab, tgt_vocab, filepath='predictions_attention/predictions.txt', device=device)


AttributeError: 'torch.device' object has no attribute 'encode'

In [5]:
os.makedirs("predictions_vanilla", exist_ok=True)

total_token_acc = 0
total_word_acc = 0
total_samples = 0

predictions = []
sources = []
references = []

In [4]:
with torch.no_grad():
    for i, (src_batch, tgt_batch) in enumerate(test_loader):
        src_batch = src_batch.to(device)
        tgt_batch = tgt_batch.to(device)
        print(i)
        break

0


In [9]:
with torch.no_grad():
    for j in range(src_batch.size(0)):
        src = src_batch[j].unsqueeze(0)
        tgt = tgt_batch[j].unsqueeze(0)

        pred_tokens = beam_fn(model, src, src_vocab, tgt_vocab, 3, 50, device)
        if j == 61:
            break
        
# print(src_vocab.decode(src.squeeze(0)))        
# print(tgt_vocab.decode(tgt.squeeze(0)))

print(tgt.squeeze(0).size(0))
print(len(pred_tokens))

11
14


In [23]:
with torch.no_grad():
    # for j in range(src_batch.size(0)):
    src = src_batch[61].unsqueeze(0)  # shape: [1, src_len]
    tgt = tgt_batch[61].unsqueeze(0)  # shape: [1, tgt_len]  

    # Beam search decoding
    pred_tokens = beam_fn(model, src, src_vocab, tgt_vocab, 3, 50, device)
    
    # Reference tokens (excluding SOS, including EOS)
    reference = tgt.squeeze(0).tolist() 

    if tgt_vocab.eos_idx in pred_tokens:
        pred_tokens = pred_tokens[1:pred_tokens.index(tgt_vocab.eos_idx)+1]
    else:
        pred_tokens = pred_tokens[1:]
    if tgt_vocab.eos_idx in reference:
        reference = reference[1:reference.index(tgt_vocab.eos_idx)+1]
    else:
        reference = reference[1:]
        

    if len(pred_tokens)<len(reference):
        pred_tokens = pred_tokens + [tgt_vocab.pad_idx] * (len(reference) - len(pred_tokens))
    elif len(pred_tokens)>len(reference):
        reference = reference + [tgt_vocab.pad_idx] * (len(pred_tokens) - len(reference))
    print(pred_tokens)
    print(reference)
    
    # reference_eos_idx = reference.index(tgt_vocab.eos_idx) + 1
    # list_len = max(pred_eos_idx, reference_eos_idx)
    # pred_tokens = pred_tokens[1:list_len]
    # reference = reference[1:list_len]

    # break
    
    # print(reference)
    # break
    targets = torch.tensor(reference).to(device)
    preds = torch.tensor(pred_tokens).to(device)

    correct = (preds == targets).sum().item()
    total = targets.size(0)
    # correct, total = compute_token_accuracy(pred_tokens, tgt, tgt_vocab.pad_idx)
    token_correct += correct
    token_total += total
    
    if all(pred_tokens == reference):
        word_correct += 1
    word_total += 1

    # Save input/prediction/reference for logging
    src_tokens = src.squeeze(0).tolist()
    sources.append(src_vocab.decode(src_tokens))
    predictions.append(tgt_vocab.decode(pred_tokens))
    references.append(tgt_vocab.decode(reference))

# print(pred_tokens)
# print(reference)

[55, 42, 7, 37, 12, 10, 19, 32, 17, 14, 22, 24, 2]
[4, 7, 37, 19, 30, 21, 56, 16, 2, 0, 0, 0, 0]


NameError: name 'token_correct' is not defined

In [28]:
for src_batch, tgt_batch in val_loader:
    src_batch = src_batch.to(device)
    tgt_batch = tgt_batch.to(device)

    outputs = model(src_batch, tgt_batch, teacher_forcing_ratio=0.0)

    output_dim = outputs.size(-1)
    outputs_flat = outputs[:, 1:].reshape(-1, output_dim)
    tgt_flat = tgt_batch[:, 1:].reshape(-1)

    # loss = criterion(outputs_flat, tgt_flat)
    # val_loss += loss.item()

    preds = outputs_flat.argmax(1)

    print(outputs.shape)
    print(tgt_flat.shape)

    break

torch.Size([64, 11, 67])
torch.Size([640])


In [121]:
src_vocab.decode(src.squeeze(0).tolist())

'ank'

In [6]:
def beam_fn(model, src_tensor, src_vocab, tgt_vocab,
                       beam_width:int=3, max_len=50, device='cpu'):
    """
    Beam search decoding for a trained Seq2Seq model.

    Args:
        model: Trained Seq2Seq model
        src_tensor: tensor of token indices (1D tensor) for the source sentence
        src_vocab: source vocabulary (to access <sos> etc.)
        tgt_vocab: target vocabulary (to access <sos>, <eos>)
        beam_width: number of beams to maintain
        max_len: maximum output length
        device: torch.device

    Returns:
        Best output sequence (list of token indices)
    """
    model.eval()

    with torch.no_grad():
        # --- Step 1: Encode source ---
        src_tensor = src_tensor.to(device)  # (1, src_len)
        encoder_outputs, hidden = model.encoder(src_tensor)

        # Beam = list of tuples: (sequence, score, hidden_state)
        beams = [([tgt_vocab.sos_idx], 0.0, hidden)]  # Start with SOS token

        completed_sequences = []

        for _ in range(max_len):
            new_beams = []
            for seq, score, hidden_state in beams:
                input_token = torch.tensor([seq[-1]], device=device)
                if seq[-1] == tgt_vocab.eos_idx:
                    completed_sequences.append((seq, score))
                    continue

                # Decoder step
                output, hidden_next = model.decoder(input_token, hidden_state)
                output = output.squeeze(1)  # (1, vocab_size)

                log_probs = torch.log_softmax(output, dim=-1)  # (1, vocab_size)
                topk_log_probs, topk_indices = log_probs.topk(beam_width)  # (1, beam)

                for i in range(beam_width):
                    token = topk_indices[0, i].item()
                    token_log_prob = topk_log_probs[0, i].item()

                    new_seq = seq + [token]
                    new_score = score + token_log_prob
                    new_beams.append((new_seq, new_score, hidden_next))

            # Keep top k beams
            beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]

            # Stop early if all beams are done
            if all(seq[-1] == tgt_vocab.eos_idx for seq, _, _ in beams):
                completed_sequences.extend(beams)
                break

        # Add unfinished beams
        completed_sequences.extend([b for b in beams if b[0][-1] != tgt_vocab.eos_idx])

        # Sort all complete sequences by score and return best one
        completed_sequences = sorted(completed_sequences, key=lambda x: x[1], reverse=True)

        best_seq = completed_sequences[0][0]
        return best_seq

In [150]:
for i, (src_batch, tgt_batch) in enumerate(tqdm(test_loader, desc="Testing")):
    src_batch = src_batch.to(device)
    tgt_batch = tgt_batch.to(device)

    for j in range (src_batch.size(0)):
        src = src_batch[j].unsqueeze(0)  # shape: [1, src_len]
        tgt = tgt_batch[j].unsqueeze(0)

        print(src)

    if i > 5:
        break

Testing:   5%|█▌                                 | 6/130 [00:00<00:00, 130.25it/s]

tensor([[4, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
tensor([[4, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
tensor([[4, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
tensor([[4, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
tensor([[4, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
tensor([[4, 5, 6, 4, 0, 0, 0, 0, 0, 0, 0, 0]])
tensor([[4, 5, 6, 8, 9, 0, 0, 0, 0, 0, 0, 0]])
tensor([[4, 5, 6, 8, 9, 0, 0, 0, 0, 0, 0, 0]])
tensor([[4, 5, 6, 8, 9, 0, 0, 0, 0, 0, 0, 0]])
tensor([[ 4,  5,  4,  6, 20,  5,  0,  0,  0,  0,  0,  0]])
tensor([[ 4,  5,  6, 16, 20,  5,  0,  0,  0,  0,  0,  0]])
tensor([[ 4,  5,  6, 20,  5,  0,  0,  0,  0,  0,  0,  0]])
tensor([[ 4,  5,  6, 20,  5,  0,  0,  0,  0,  0,  0,  0]])
tensor([[ 4,  5,  7,  6, 20, 14,  0,  0,  0,  0,  0,  0]])
tensor([[ 4,  5,  6, 20, 14,  0,  0,  0,  0,  0,  0,  0]])
tensor([[ 4,  5,  6, 20, 14,  0,  0,  0,  0,  0,  0,  0]])
tensor([[ 4,  5,  7,  4,  4, 14,  4,  6,  0,  0,  0,  0]])
tensor([[ 4,  5,  7,  4, 14,  4,  6,  0,  0,  0,  0,  0]])
tensor([[ 4,  5,  7,  4, 14,  4,  6,  0,  0,  


