In [1]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Embedding, LSTM, GRU, RNN, Dense, Dropout

2025-05-19 16:28:33.012055: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747672113.279346      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747672113.366128      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [16]:
!pip install --upgrade wandb

Collecting wandb
  Downloading wandb-0.19.11-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Downloading wandb-0.19.11-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (21.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.4/21.4 MB[0m [31m68.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: wandb
  Attempting uninstall: wandb
    Found existing installation: wandb 0.19.9
    Uninstalling wandb-0.19.9:
      Successfully uninstalled wandb-0.19.9
Successfully installed wandb-0.19.11


In [18]:
import wandb
from wandb.integration.keras import WandbCallback

wandb.login(key='e030007b097df00d9a751748294abc8440f932b1')



In [8]:
def load_data(file_path):
    df = pd.read_csv(file_path, sep='\t', header=None, names=['latin', 'native'])
    df = df.dropna()
    df['latin'] = df['latin'].astype(str)
    df['native'] = df['native'].astype(str)
    return df

def load_dakshina_dataset(language_code='hi', base_dir='/kaggle/input/dak-data/dakshina_dataset_v1.0'):
    path = os.path.join(base_dir, language_code, 'lexicons')
    return (
        load_data(os.path.join(path, f'{language_code}.translit.sampled.train.tsv')),
        load_data(os.path.join(path, f'{language_code}.translit.sampled.dev.tsv')),
        load_data(os.path.join(path, f'{language_code}.translit.sampled.test.tsv')),
    )

train_data, val_data, test_data = load_dakshina_dataset()

In [9]:
print(train_data.shape)

(44202, 2)


In [25]:
wandb.init(
    project="DA_seq2seq_transliteration",
    name="vanilla_lstm_run_q1",
    config={
        "model_type": "vanilla",
        "cell_type": "LSTM",
        "embedding_dim": 64,
        "hidden_dim": 128,
        "dropout_rate": 0.2,
        "batch_size": 64,
        "epochs": 10,
        "input_vocab_size": len(input_tokenizer.word_index) + 1,
        "target_vocab_size": len(target_tokenizer.word_index) + 1,
        "max_input_length": max_in,
        "max_target_length": max_out,
        "optimizer": "adam",
        "loss": "sparse_categorical_crossentropy",
        "dataset": "dakshina_hi"
    }
)

class VanillaSeq2Seq:
    def __init__(self, input_vocab_size, target_vocab_size, embedding_dim, hidden_dim, cell_type='LSTM', dropout_rate=0.2):
        self.input_vocab_size = input_vocab_size
        self.target_vocab_size = target_vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.cell_type = cell_type
        self.dropout_rate = dropout_rate
        self.model = self._build_model()

    def _get_rnn_cell(self, return_sequences, return_state):
        if self.cell_type == 'LSTM':
            return LSTM(self.hidden_dim, return_sequences=return_sequences, return_state=return_state)
        elif self.cell_type == 'GRU':
            return GRU(self.hidden_dim, return_sequences=return_sequences, return_state=return_state)
        else:
            return RNN(self.hidden_dim, return_sequences=return_sequences, return_state=return_state)

    def _build_model(self):
        encoder_inputs = Input(shape=(None,), name='encoder_input')
        decoder_inputs = Input(shape=(None,), name='decoder_input')

        encoder_embed = Embedding(self.input_vocab_size, self.embedding_dim)(encoder_inputs)
        encoder_embed = Dropout(self.dropout_rate)(encoder_embed)

        if self.cell_type == 'LSTM':
            encoder_outputs, state_h, state_c = LSTM(self.hidden_dim, return_state=True)(encoder_embed)
            encoder_states = [state_h, state_c]
        else:
            encoder_outputs, state_h = self._get_rnn_cell(return_sequences=False, return_state=True)(encoder_embed)
            encoder_states = [state_h]

        decoder_embed = Embedding(self.target_vocab_size, self.hidden_dim)(decoder_inputs)
        decoder_embed = Dropout(self.dropout_rate)(decoder_embed)

        if self.cell_type == 'LSTM':
            decoder_outputs, _, _ = LSTM(self.hidden_dim, return_sequences=True, return_state=True)(decoder_embed, initial_state=encoder_states)
        else:
            decoder_outputs, _ = self._get_rnn_cell(return_sequences=True, return_state=True)(decoder_embed, initial_state=encoder_states)

        decoder_dense = Dense(self.target_vocab_size, activation='softmax')
        decoder_outputs = decoder_dense(decoder_outputs)

        return Model([encoder_inputs, decoder_inputs], decoder_outputs)

    def compile(self):
        self.model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

    def fit(self, train_data, val_data, batch_size=64, epochs=10, callbacks=None):
        return self.model.fit(
            [train_data['encoder_input'], train_data['decoder_input']],
            np.expand_dims(train_data['decoder_target'], -1),
            validation_data=(
                [val_data['encoder_input'], val_data['decoder_input']],
                np.expand_dims(val_data['decoder_target'], -1)
            ),
            batch_size=batch_size,
            epochs=epochs,
            callbacks=callbacks
        )

# ✅ 7. Initialize and train model
input_vocab_size = len(input_tokenizer.word_index) + 1
target_vocab_size = len(target_tokenizer.word_index) + 1

model = VanillaSeq2Seq(
    input_vocab_size=input_vocab_size,
    target_vocab_size=target_vocab_size,
    embedding_dim=64,
    hidden_dim=128,
    cell_type='LSTM',
    dropout_rate=0.2
)

model.compile()

wandb_callback = WandbCallback(
    log_model=False,           # no wandb artifact
    save_graph=False,          # don't try to render graph
    save_model=False           # ✅ disables all auto saving
)

model.fit(
    train_data={
        'encoder_input': encoder_input_train,
        'decoder_input': decoder_input_train,
        'decoder_target': decoder_target_train
    },
    val_data={
        'encoder_input': encoder_input_val,
        'decoder_input': decoder_input_val,
        'decoder_target': decoder_target_val
    },
    batch_size=64,
    epochs=10,
    callbacks=[wandb_callback]
)

0,1
accuracy,▁
epoch,▁
loss,▁
val_accuracy,▁
val_loss,▁

0,1
accuracy,0.81783
best_epoch,0.0
best_val_loss,0.3863
epoch,0.0
loss,0.44345
val_accuracy,0.82446
val_loss,0.3863


Epoch 1/10
[1m691/691[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 34ms/step - accuracy: 0.7929 - loss: 0.5777 - val_accuracy: 0.8266 - val_loss: 0.3852
Epoch 2/10
[1m691/691[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 33ms/step - accuracy: 0.8264 - loss: 0.3971 - val_accuracy: 0.8274 - val_loss: 0.3826
Epoch 3/10
[1m691/691[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 33ms/step - accuracy: 0.8291 - loss: 0.3908 - val_accuracy: 0.8340 - val_loss: 0.3777
Epoch 4/10
[1m691/691[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 33ms/step - accuracy: 0.8315 - loss: 0.3884 - val_accuracy: 0.8352 - val_loss: 0.3720
Epoch 5/10
[1m691/691[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m22s[0m 32ms/step - accuracy: 0.8348 - loss: 0.3868 - val_accuracy: 0.8373 - val_loss: 0.3726
Epoch 6/10
[1m691/691[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 33ms/step - accuracy: 0.8374 - loss: 0.3820 - val_accuracy: 0.8389 - val_loss: 0.3699
Epoch 7/10
[1m6

<keras.src.callbacks.history.History at 0x7d5b312bbdd0>