In [2]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Embedding, LSTM, GRU, RNN, Dense, Dropout

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

In [2]:
!pip install --upgrade wandb

Collecting wandb
  Downloading wandb-0.19.11-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Downloading wandb-0.19.11-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (21.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.4/21.4 MB[0m [31m63.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: wandb
  Attempting uninstall: wandb
    Found existing installation: wandb 0.19.9
    Uninstalling wandb-0.19.9:
      Successfully uninstalled wandb-0.19.9
Successfully installed wandb-0.19.11


In [3]:
import wandb
from wandb.integration.keras import WandbCallback

wandb.login(key='e030007b097df00d9a751748294abc8440f932b1')

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [4]:
def load_data(file_path):
    df = pd.read_csv(file_path, sep='\t', header=None, names=['latin', 'native'])
    df = df.dropna()
    df['latin'] = df['latin'].astype(str)
    df['native'] = df['native'].astype(str)
    return df

def load_dakshina_dataset(language_code='hi', base_dir='/kaggle/input/dak-data/dakshina_dataset_v1.0'):
    path = os.path.join(base_dir, language_code, 'lexicons')
    return (
        load_data(os.path.join(path, f'{language_code}.translit.sampled.train.tsv')),
        load_data(os.path.join(path, f'{language_code}.translit.sampled.dev.tsv')),
        load_data(os.path.join(path, f'{language_code}.translit.sampled.test.tsv')),
    )

train_data, val_data, test_data = load_dakshina_dataset()

# Question 1

In [5]:
# ─── 1) Extract raw texts ─────────────────────────────────────────
input_texts  = train_data['latin'].tolist()
target_texts = ['\t' + t + '\n' for t in train_data['native'].tolist()]

val_input_texts  = val_data['latin'].tolist()
val_target_texts = ['\t' + t + '\n' for t in val_data['native'].tolist()]

input_tokenizer = Tokenizer(char_level=True, oov_token=None)
input_tokenizer.fit_on_texts(input_texts + val_input_texts)

target_tokenizer = Tokenizer(char_level=True, oov_token=None)
target_tokenizer.fit_on_texts(target_texts + val_target_texts)

# ─── 3) Convert texts → integer sequences + pad to max lengths ───
# Compute max lengths
max_in  = max(len(txt) for txt in input_texts + val_input_texts)
max_out = max(len(txt) for txt in target_texts + val_target_texts)

# Integer‑encode + pad
encoder_input_train = pad_sequences(
    input_tokenizer.texts_to_sequences(input_texts),
    maxlen=max_in,
    padding='post'
)
decoder_input_train = pad_sequences(
    target_tokenizer.texts_to_sequences(target_texts),
    maxlen=max_out,
    padding='post'
)
# decoder targets are the decoder inputs shifted left by one
decoder_target_train = np.array(decoder_input_train)[:, 1:]
decoder_input_train   = np.array(decoder_input_train)[:, :-1]

# Do the same for validation set
encoder_input_val = pad_sequences(
    input_tokenizer.texts_to_sequences(val_input_texts),
    maxlen=max_in,
    padding='post'
)
decoder_input_val = pad_sequences(
    target_tokenizer.texts_to_sequences(val_target_texts),
    maxlen=max_out,
    padding='post'
)
decoder_target_val = np.array(decoder_input_val)[:, 1:]
decoder_input_val   = np.array(decoder_input_val)[:, :-1]

# Make sure any previous wandb runs are finished
try:
    wandb.finish()
except:
    pass

# Initialize wandb with proper error handling
try:
    wandb.init(
        project="DA_seq2seq_transliteration",
        name="vanilla_lstm_run_q1",
        # Removed reinit=True to prevent connection issues
        config={
            "model_type": "vanilla",
            "cell_type": "LSTM",
            "embedding_dim": 64,
            "hidden_dim": 128,
            "dropout_rate": 0.2,
            "batch_size": 64,
            "epochs": 10,
            "input_vocab_size": len(input_tokenizer.word_index) + 1,
            "target_vocab_size": len(target_tokenizer.word_index) + 1,
            "max_input_length": max_in,
            "max_target_length": max_out,
            "optimizer": "adam",
            "loss": "sparse_categorical_crossentropy",
            "dataset": "dakshina_hi"
        }
    )
except Exception as e:
    print(f"Failed to initialize wandb: {e}")
    # Create a dummy wandb to avoid errors in the code
    class DummyWandb:
        def log(self, *args, **kwargs):
            pass
        def config(self, *args, **kwargs):
            return type('obj', (object,), {
                'embedding_dim': 64,
                'hidden_dim': 128,
                'get': lambda s, k, d: d
            })
    wandb = DummyWandb()

class VanillaSeq2Seq:
    def __init__(self,
                 input_vocab_size,
                 target_vocab_size,
                 embedding_dim,
                 hidden_dim,
                 cell_type='LSTM',
                 dropout_rate=0.2,
                 num_encoder_layers=1,
                 num_decoder_layers=1):
        self.input_vocab_size  = input_vocab_size
        self.target_vocab_size = target_vocab_size
        self.embedding_dim     = embedding_dim
        self.hidden_dim        = hidden_dim
        self.cell_type         = cell_type
        self.dropout_rate      = dropout_rate
        self.num_encoder_layers = num_encoder_layers
        self.num_decoder_layers = num_decoder_layers
        self.model = self._build_model()

    def _rnn_layer(self, return_sequences, return_state):
        """Factory for one RNN/LSTM/GRU layer."""
        if self.cell_type == 'LSTM':
            return LSTM(self.hidden_dim,
                        return_sequences=return_sequences,
                        return_state=return_state)
        elif self.cell_type == 'GRU':
            return GRU(self.hidden_dim,
                       return_sequences=return_sequences,
                       return_state=return_state)
        else:
            return RNN(self.hidden_dim,
                       return_sequences=return_sequences,
                       return_state=return_state)

    def _build_model(self):
        encoder_inputs = Input(shape=(None,), name='encoder_input')
        x = Embedding(self.input_vocab_size, self.embedding_dim)(encoder_inputs)
        x = Dropout(self.dropout_rate)(x)

        # Stack encoder layers
        encoder_states = []
        for i in range(self.num_encoder_layers):
            # last encoder layer returns only state, earlier ones return sequences
            rs = (i < self.num_encoder_layers - 1)
            if self.cell_type == 'LSTM':
                x, state_h, state_c = LSTM(
                    self.hidden_dim,
                    return_sequences=rs,
                    return_state=True,
                    name=f'enc_lstm_{i}'
                )(x)
                encoder_states = [state_h, state_c]
            else:
                x, state_h = self._rnn_layer(
                    return_sequences=rs,
                    return_state=True
                )(x)
                encoder_states = [state_h]

        decoder_inputs = Input(shape=(None,), name='decoder_input')
        y = Embedding(self.target_vocab_size, self.embedding_dim)(decoder_inputs)
        y = Dropout(self.dropout_rate)(y)

        # Stack decoder layers
        for i in range(self.num_decoder_layers):
            rs = True  # decoder always returns sequences for all but we only care about final dense
            if self.cell_type == 'LSTM':
                # feed initial_state only to the first decoder layer
                init_st = encoder_states if i == 0 else None
                y, dh, dc = LSTM(
                    self.hidden_dim,
                    return_sequences=True,
                    return_state=True,
                    name=f'dec_lstm_{i}'
                )(y, initial_state=init_st) if init_st else LSTM(
                    self.hidden_dim,
                    return_sequences=True,
                    return_state=True,
                    name=f'dec_lstm_{i}'
                )(y)
            else:
                init_st = encoder_states if i == 0 else None
                y, dh = self._rnn_layer(
                    return_sequences=True,
                    return_state=True
                )(y, initial_state=init_st) if init_st else self._rnn_layer(
                    return_sequences=True,
                    return_state=True
                )(y)

        # Final projection
        outputs = Dense(self.target_vocab_size, activation='softmax')(y)
        return Model([encoder_inputs, decoder_inputs], outputs)

    def compile(self, optimizer='adam', loss='sparse_categorical_crossentropy'):
        self.model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])

    def fit(self, train_data, val_data, batch_size=64, epochs=10, callbacks=None):
        return self.model.fit(
            [train_data['encoder_input'], train_data['decoder_input']],
            np.expand_dims(train_data['decoder_target'], -1),
            validation_data=(
                [val_data['encoder_input'], val_data['decoder_input']],
                np.expand_dims(val_data['decoder_target'], -1)
            ),
            batch_size=batch_size,
            epochs=epochs,
            callbacks=callbacks
        )
        
input_vocab_size = len(input_tokenizer.word_index) + 1
target_vocab_size = len(target_tokenizer.word_index) + 1

model = VanillaSeq2Seq(
    input_vocab_size=input_vocab_size,
    target_vocab_size=target_vocab_size,
    embedding_dim=64,
    hidden_dim=128,
    cell_type='LSTM',
    dropout_rate=0.2
)

model.compile()

# Modified wandb callback with error handling
try:
    wandb_callback = WandbCallback(
        log_model=False,           # no wandb artifact
        save_graph=False,          # don't try to render graph
        save_model=False           # ✅ disables all auto saving
    )
    callbacks = [wandb_callback]
except Exception as e:
    print(f"Failed to initialize WandbCallback: {e}")
    callbacks = []

# Use try-except for wandb config access
try:
    D = wandb.config.embedding_dim
    H = wandb.config.hidden_dim
    L_e = wandb.config.get("num_encoder_layers", 1)
    L_d = wandb.config.get("num_decoder_layers", 1)
except Exception as e:
    print(f"Failed to access wandb config: {e}")
    D = 64  # Default values
    H = 128
    L_e = 1
    L_d = 1

T_enc = encoder_input_train.shape[1]
T_dec = decoder_input_train.shape[1]

flops_per_step = 4 * (H * D + H * H)

# 4) Total ops over all layers & timesteps
total_enc_flops = L_e * T_enc * flops_per_step
total_dec_flops = L_d * T_dec * flops_per_step
total_flops = total_enc_flops + total_dec_flops

print(f"Approximate total multiplications (encoder + decoder): {total_flops:,}")

total_params = model.model.count_params()
print(f"Total trainable parameters: {total_params:,}")

model.model.summary()

history = model.fit(
    train_data={
        'encoder_input': encoder_input_train,
        'decoder_input': decoder_input_train,
        'decoder_target': decoder_target_train
    },
    val_data={
        'encoder_input': encoder_input_val,
        'decoder_input': decoder_input_val,
        'decoder_target': decoder_target_val
    },
    batch_size=64,
    epochs=10,
    callbacks=callbacks
)

# Log metrics to wandb with error handling
try:
    wandb.log({
        "total_flops": total_flops,
        "total_trainable_params": total_params
    })
    # Properly close the wandb run
    wandb.finish()
except Exception as e:
    print(f"Failed to log to wandb: {e}")

2025-05-20 05:41:36.847827: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


Approximate total multiplications (encoder + decoder): 2,260,992
Total trainable parameters: 201,869


Epoch 1/10
[1m691/691[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 32ms/step - accuracy: 0.7650 - loss: 0.6321 - val_accuracy: 0.8245 - val_loss: 0.3865
Epoch 2/10
[1m691/691[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m22s[0m 32ms/step - accuracy: 0.8265 - loss: 0.3960 - val_accuracy: 0.8294 - val_loss: 0.3805
Epoch 3/10
[1m691/691[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m22s[0m 31ms/step - accuracy: 0.8296 - loss: 0.3907 - val_accuracy: 0.8330 - val_loss: 0.3783
Epoch 4/10
[1m691/691[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 30ms/step - accuracy: 0.8345 - loss: 0.3857 - val_accuracy: 0.8411 - val_loss: 0.3725
Epoch 5/10
[1m691/691[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 31ms/step - accuracy: 0.8358 - loss: 0.3850 - val_accuracy: 0.8402 - val_loss: 0.3715
Epoch 6/10
[1m691/691[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 31ms/step - accuracy: 0.8394 - loss: 0.3791 - val_accuracy: 0.8363 - val_loss: 0.3704
Epoch 7/10
[1m6

0,1
accuracy,▁▄▅▆▆▇▇▇██
epoch,▁▂▃▃▄▅▆▆▇█
loss,█▃▃▂▂▂▂▁▁▁
total_flops,▁
total_trainable_params,▁
val_accuracy,▁▃▄▇▆▅▆▇██
val_loss,█▆▆▄▄▄▂▂▁▁

0,1
accuracy,0.8455
best_epoch,8.0
best_val_loss,0.35896
epoch,9.0
loss,0.36751
total_flops,2260992.0
total_trainable_params,201869.0
val_accuracy,0.84381
val_loss,0.35908


# Question 2

In [11]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.layers import Input, Embedding, LSTM, GRU, RNN, Dense, Dropout
from tensorflow.keras.models import Model
import os
import time
import matplotlib.pyplot as plt

In [None]:
# Load the Dakshina dataset (Hindi as an example)
# You can change 'hi' to the language of your choice
def load_dakshina_data(lang='hi'):
    base_path = f'/kaggle/input/dak-data/dakshina_dataset_v1.0/{lang}/lexicons/'
    
    # Load train, dev, test sets
    train_data = pd.read_csv(f'{base_path}{lang}.translit.sampled.train.tsv', sep='\t', 
                             header=None, names=['latin', 'native', 'class'])
    val_data = pd.read_csv(f'{base_path}{lang}.translit.sampled.dev.tsv', sep='\t', 
                           header=None, names=['latin', 'native', 'class'])
    test_data = pd.read_csv(f'{base_path}{lang}.translit.sampled.test.tsv', sep='\t', 
                            header=None, names=['latin', 'native', 'class'])

    # Drop any rows with missing values
    train_data = train_data.dropna().astype(str)
    val_data   = val_data.dropna().astype(str)
    test_data  = test_data.dropna().astype(str)

    return train_data, val_data, test_data

# Process data and create sequences
def process_data(train_data, val_data):
    # Extract texts
    input_texts = train_data['latin'].tolist()
    target_texts = ['\t' + t + '\n' for t in train_data['native'].tolist()]
    
    val_input_texts = val_data['latin'].tolist()
    val_target_texts = ['\t' + t + '\n' for t in val_data['native'].tolist()]
    
    # Build character-level tokenizers
    input_tokenizer = Tokenizer(char_level=True, oov_token=None)
    input_tokenizer.fit_on_texts(input_texts + val_input_texts)
    
    target_tokenizer = Tokenizer(char_level=True, oov_token=None)
    target_tokenizer.fit_on_texts(target_texts + val_target_texts)
    
    # Find max lengths
    max_in = max(len(txt) for txt in input_texts + val_input_texts)
    max_out = max(len(txt) for txt in target_texts + val_target_texts)
    
    # Convert to sequences and pad
    encoder_input_train = pad_sequences(
        input_tokenizer.texts_to_sequences(input_texts),
        maxlen=max_in,
        padding='post'
    )
    decoder_input_train = pad_sequences(
        target_tokenizer.texts_to_sequences(target_texts),
        maxlen=max_out,
        padding='post'
    )
    decoder_target_train = np.array(decoder_input_train)[:, 1:]
    decoder_input_train = np.array(decoder_input_train)[:, :-1]
    
    # Same for validation set
    encoder_input_val = pad_sequences(
        input_tokenizer.texts_to_sequences(val_input_texts),
        maxlen=max_in,
        padding='post'
    )
    decoder_input_val = pad_sequences(
        target_tokenizer.texts_to_sequences(val_target_texts),
        maxlen=max_out,
        padding='post'
    )
    decoder_target_val = np.array(decoder_input_val)[:, 1:]
    decoder_input_val = np.array(decoder_input_val)[:, :-1]
    
    return {
        'input_tokenizer': input_tokenizer,
        'target_tokenizer': target_tokenizer,
        'max_in': max_in,
        'max_out': max_out,
        'encoder_input_train': encoder_input_train,
        'decoder_input_train': decoder_input_train,
        'decoder_target_train': decoder_target_train,
        'encoder_input_val': encoder_input_val,
        'decoder_input_val': decoder_input_val,
        'decoder_target_val': decoder_target_val,
        'input_texts': input_texts,
        'target_texts': target_texts,
        'val_input_texts': val_input_texts,
        'val_target_texts': val_target_texts
    }

# Seq2Seq model with configurable parameters
class VanillaSeq2Seq:
    def __init__(self,
                 input_vocab_size,
                 target_vocab_size,
                 embedding_dim,
                 hidden_dim,
                 cell_type='LSTM',
                 dropout_rate=0.2,
                 num_encoder_layers=1,
                 num_decoder_layers=1):
        self.input_vocab_size = input_vocab_size
        self.target_vocab_size = target_vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.cell_type = cell_type
        self.dropout_rate = dropout_rate
        self.num_encoder_layers = num_encoder_layers
        self.num_decoder_layers = num_decoder_layers
        self.model = self._build_model()
        
    def _rnn_layer(self, return_sequences, return_state):
        """Factory for one RNN/LSTM/GRU layer."""
        if self.cell_type == 'LSTM':
            return LSTM(self.hidden_dim,
                        return_sequences=return_sequences,
                        return_state=return_state)
        elif self.cell_type == 'GRU':
            return GRU(self.hidden_dim,
                       return_sequences=return_sequences,
                       return_state=return_state)
        else:
            return RNN(self.hidden_dim,
                       return_sequences=return_sequences,
                       return_state=return_state)
    
    def _build_model(self):
        encoder_inputs = Input(shape=(None,), name='encoder_input')
        x = Embedding(self.input_vocab_size, self.embedding_dim)(encoder_inputs)
        x = Dropout(self.dropout_rate)(x)
        
        # Stack encoder layers
        encoder_states = []
        for i in range(self.num_encoder_layers):
            # last encoder layer returns only state, earlier ones return sequences
            rs = (i < self.num_encoder_layers - 1)
            if self.cell_type == 'LSTM':
                x, state_h, state_c = LSTM(
                    self.hidden_dim,
                    return_sequences=rs,
                    return_state=True,
                    name=f'enc_lstm_{i}'
                )(x)
                encoder_states = [state_h, state_c]
            else:
                x, state_h = self._rnn_layer(
                    return_sequences=rs,
                    return_state=True
                )(x)
                encoder_states = [state_h]
        
        decoder_inputs = Input(shape=(None,), name='decoder_input')
        y = Embedding(self.target_vocab_size, self.embedding_dim)(decoder_inputs)
        y = Dropout(self.dropout_rate)(y)
        
        # Stack decoder layers
        for i in range(self.num_decoder_layers):
            rs = True  # decoder always returns sequences
            if self.cell_type == 'LSTM':
                # feed initial_state only to the first decoder layer
                init_st = encoder_states if i == 0 else None
                y, dh, dc = LSTM(
                    self.hidden_dim, 
                    return_sequences=True,
                    return_state=True,
                    name=f'dec_lstm_{i}'
                )(y, initial_state=init_st) if init_st else LSTM(
                    self.hidden_dim,
                    return_sequences=True,
                    return_state=True,
                    name=f'dec_lstm_{i}'
                )(y)
            else:
                init_st = encoder_states if i == 0 else None
                y, dh = self._rnn_layer(
                    return_sequences=True,
                    return_state=True
                )(y, initial_state=init_st) if init_st else self._rnn_layer(
                    return_sequences=True,
                    return_state=True
                )(y)
        
        # Final projection
        outputs = Dense(self.target_vocab_size, activation='softmax')(y)
        return Model([encoder_inputs, decoder_inputs], outputs)
    
    def compile(self, optimizer='adam', loss='sparse_categorical_crossentropy'):
        self.model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
        
    def fit(self, train_data, val_data, batch_size=64, epochs=10, callbacks=None):
        return self.model.fit(
            [train_data['encoder_input'], train_data['decoder_input']],
            np.expand_dims(train_data['decoder_target'], -1),
            validation_data=(
                [val_data['encoder_input'], val_data['decoder_input']],
                np.expand_dims(val_data['decoder_target'], -1)
            ),
            batch_size=batch_size,
            epochs=epochs,
            callbacks=callbacks
        )

def run_wandb_sweep(processed_data):
    # Define sweep configuration
    sweep_config = {
        'method': 'bayes',
        'metric': {
            'name': 'val_accuracy',
            'goal': 'maximize'
        },
        'parameters': {
            'embedding_dim': {
                'values': [16, 32, 64, 128]
            },
            'hidden_dim': {
                'values': [32, 64, 128, 256]
            },
            'cell_type': {
                'values': ['RNN', 'GRU', 'LSTM']
            },
            'dropout_rate': {
                'values': [0.1, 0.2, 0.3]
            },
            'num_encoder_layers': {
                'values': [1, 2]
            },
            'num_decoder_layers': {
                'values': [1, 2]
            }
        }
    }
    
    # Initialize sweep
    sweep_id = wandb.sweep(sweep_config, project="DA_seq2seq_transliteration")
    
    # Define the training function
    def train_model():
        # Make sure we're in a clean wandb state
        try:
            wandb.finish()
        except:
            pass
        
        # Start a new wandb run
        run = wandb.init()
        
        # Access hyperparameters from wandb
        config = wandb.config
        
        # Create model with hyperparameters from wandb
        input_vocab_size = len(processed_data['input_tokenizer'].word_index) + 1
        target_vocab_size = len(processed_data['target_tokenizer'].word_index) + 1
        
        model = VanillaSeq2Seq(
            input_vocab_size=input_vocab_size,
            target_vocab_size=target_vocab_size,
            embedding_dim=config.embedding_dim,
            hidden_dim=config.hidden_dim,
            cell_type=config.cell_type,
            dropout_rate=config.dropout_rate,
            num_encoder_layers=config.num_encoder_layers,
            num_decoder_layers=config.num_decoder_layers
        )
        
        model.compile()
        
        # Configure wandb callback
        wandb_callback = WandbCallback(
            log_model=False,
            save_graph=False,
            save_model=False
        )
        
        # Compute model complexity metrics
        D = config.embedding_dim
        H = config.hidden_dim
        L_e = config.num_encoder_layers
        L_d = config.num_decoder_layers
        T_enc = processed_data['encoder_input_train'].shape[1]
        T_dec = processed_data['decoder_input_train'].shape[1]
        
        flops_per_step = 4 * (H * D + H * H)
        total_enc_flops = L_e * T_enc * flops_per_step
        total_dec_flops = L_d * T_dec * flops_per_step
        total_flops = total_enc_flops + total_dec_flops
        
        total_params = model.model.count_params()
        
        print(f"Embedding dim: {D}, Hidden dim: {H}")
        print(f"Encoder layers: {L_e}, Decoder layers: {L_d}")
        print(f"Cell type: {config.cell_type}, Dropout: {config.dropout_rate}")
        print(f"Total parameters: {total_params:,}")
        print(f"Total FLOPs: {total_flops:,}")
        
        # Log model complexity metrics
        wandb.log({
            "total_flops": total_flops,
            "total_params": total_params
        })
        
        # Train the model
        start_time = time.time()
        
        history = model.fit(
            train_data={
                'encoder_input': processed_data['encoder_input_train'],
                'decoder_input': processed_data['decoder_input_train'],
                'decoder_target': processed_data['decoder_target_train']
            },
            val_data={
                'encoder_input': processed_data['encoder_input_val'],
                'decoder_input': processed_data['decoder_input_val'],
                'decoder_target': processed_data['decoder_target_val']
            },
            batch_size=64,
            epochs=10,
            callbacks=[wandb_callback]
        )
        
        training_time = time.time() - start_time
        
        # Log additional metrics
        wandb.log({
            "training_time": training_time,
            "final_train_accuracy": history.history['accuracy'][-1],
            "final_val_accuracy": history.history['val_accuracy'][-1]
        })
        
        # Clean up (important to avoid memory leaks)
        tf.keras.backend.clear_session()
        run.finish()
    
    # Run the sweep
    wandb.agent(sweep_id, train_model, count=3)  # Adjust count based on time constraints

# Main execution
if __name__ == "__main__":
    # Ensure TensorFlow doesn't reserve all GPU memory
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    
    # Load data
    try:
        train_data, val_data, test_data = load_dakshina_data(lang='hi')
        print(f"Data loaded successfully! Train size: {len(train_data)}")
        
        # Process data
        processed_data = process_data(train_data, val_data)
        print("Data processed successfully!")
        
        # Run sweep
        print("Starting hyperparameter sweep...")
        run_wandb_sweep(processed_data)
        
    except Exception as e:
        print(f"Error: {str(e)}")

Data loaded successfully! Train size: 44202
Data processed successfully!
Starting hyperparameter sweep...
Create sweep with ID: qoxu1z5h
Sweep URL: https://wandb.ai/mm21b044-indian-institute-of-technology-madras/DA_seq2seq_transliteration/sweeps/qoxu1z5h


[34m[1mwandb[0m: Agent Starting Run: tr1ifgs3 with config:
[34m[1mwandb[0m: 	cell_type: GRU
[34m[1mwandb[0m: 	dropout_rate: 0.3
[34m[1mwandb[0m: 	embedding_dim: 128
[34m[1mwandb[0m: 	hidden_dim: 32
[34m[1mwandb[0m: 	num_decoder_layers: 1
[34m[1mwandb[0m: 	num_encoder_layers: 1


Embedding dim: 128, Hidden dim: 32
Encoder layers: 1, Decoder layers: 1
Cell type: GRU, Dropout: 0.3
Total parameters: 43,965
Total FLOPs: 819,200
Epoch 1/10
[1m691/691[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 24ms/step - accuracy: 0.6977 - loss: 1.2141 - val_accuracy: 0.7457 - val_loss: 0.8501
Epoch 2/10
[1m691/691[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 23ms/step - accuracy: 0.7442 - loss: 0.8595 - val_accuracy: 0.7710 - val_loss: 0.7643
Epoch 3/10
[1m409/691[0m [32m━━━━━━━━━━━[0m[37m━━━━━━━━━[0m [1m6s[0m 22ms/step - accuracy: 0.7614 - loss: 0.7994