In [None]:
import numpy
import os
from tensorflow import keras
import numpy as np
import math
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import (
    Input, Embedding, Dense, Concatenate, Layer
)
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.sequence import pad_sequences

In [None]:
def extract_data(data_location = '/kaggle/input/as3-dataset/lexicons'):
    def load_tsv(file_path):
        data = []
        with open(file_path, 'r', encoding='utf-8') as file:
            for line in file:
                data.append(line.strip().split('\t'))
        return numpy.array(data, dtype=object)

    data = {}
    for split in ['train', 'dev', 'test']:
        file_name = f"gu.translit.sampled.{split}.tsv"
        file_path = os.path.join(data_location, file_name)
        data[split] = load_tsv(file_path)
    
    return data

def tokanize_texts(texts, char_level=True, start_end_tokens=False):
    tokenizer = keras.preprocessing.text.Tokenizer(char_level=char_level, filters='', lower=False)

    if start_end_tokens:
        start_token = '<start>'
        end_token = '<end>'
        texts = [start_token + text + end_token for text in texts]
        
    tokenizer.fit_on_texts(texts)

    return tokenizer

In [None]:
class seq2seq:
    def __init__(
        self,
        input_vocab_size,
        output_vocab_size,
        embedding_dim,
        hidden_units,
        encoder_layers,
        decoder_layers,
        dropout_rate,
        recurrent_dropout_rate,
        encoder_type,
        decoder_type,
        beam_width
    ):
        self.input_vocab_size = input_vocab_size
        self.output_vocab_size = output_vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_units = hidden_units
        self.encoder_layers = encoder_layers
        self.decoder_layers = decoder_layers
        self.dropout_rate = dropout_rate
        self.recurrent_dropout_rate = recurrent_dropout_rate
        self.beam_width = beam_width
        self.encoder_type = encoder_type
        self.decoder_type = decoder_type
    
    def build_training_model(self):
        encoder_inputs = keras.layers.Input(shape=(None,), name='encoder_inputs')
        encoder_embedding = keras.layers.Embedding(
            input_dim=self.input_vocab_size,
            output_dim=self.embedding_dim,
            mask_zero=True,
            name='encoder_embedding'
        )(encoder_inputs)

        encoder_states = []
        encoder_outputs = encoder_embedding

        for i in range(self.encoder_layers):
            return_sequences = (i < self.encoder_layers - 1)
            return_state = True

            if self.encoder_type == 'LSTM':
                rnn_layer = keras.layers.LSTM(
                    units=self.hidden_units,
                    return_sequences=return_sequences,
                    return_state=return_state,
                    dropout=self.dropout_rate,
                    recurrent_dropout=self.recurrent_dropout_rate,
                    name=f'encoder_{i}'
                )

                encoder_outputs, state_h, state_c = rnn_layer(encoder_outputs)
                encoder_states.extend([state_h, state_c])
            
            elif self.encoder_type == 'GRU':
                rnn_layer = keras.layers.GRU(
                    units=self.hidden_units,
                    return_sequences=return_sequences,
                    return_state=return_state,
                    dropout=self.dropout_rate,
                    recurrent_dropout=self.recurrent_dropout_rate,
                    name=f'encoder_{i}'
                )

                encoder_outputs, state_h = rnn_layer(encoder_outputs)
                encoder_states.append(state_h)

            elif self.encoder_type == 'RNN':
                rnn_layer = keras.layers.SimpleRNN(
                    units=self.hidden_units,
                    return_sequences=return_sequences,
                    return_state=return_state,
                    dropout=self.dropout_rate,
                    recurrent_dropout=self.recurrent_dropout_rate,
                    name=f'encoder_{i}'
                )

                encoder_outputs, state_h = rnn_layer(encoder_outputs)
                encoder_states.append(state_h)
        
        decoder_inputs = keras.layers.Input(shape=(None,), name='decoder_inputs')
        decoder_embedding = keras.layers.Embedding(
            input_dim=self.output_vocab_size,
            output_dim=self.embedding_dim,
            mask_zero=True,
            name='decoder_embedding'
        )(decoder_inputs)

        decoder_outputs = decoder_embedding
        decoder_init_states = []

        idx = 0
        for i in range(self.decoder_layers):
            if i<self.encoder_layers:
                if self.decoder_type == 'LSTM':
                    h = encoder_states[idx]
                    c = encoder_states[idx + 1]
                    decoder_init_states.append([h,c])
                    idx += 2
                else:
                    h = encoder_states[idx]
                    decoder_init_states.append([h])
                    idx += 1
            else:
                if self.decoder_type == 'LSTM':
                    h = encoder_states[-2]
                    c = encoder_states[-1]
                    decoder_init_states.append([h,c])
                else:
                    h = encoder_states[-1]
                    decoder_init_states.append([h])

        for i in range(self.decoder_layers):
            return_sequences = True
            return_state = True

            if self.decoder_type == 'LSTM':
                rnn_layer = keras.layers.LSTM(
                    units=self.hidden_units,
                    return_sequences=return_sequences,
                    return_state=return_state,
                    dropout=self.dropout_rate,
                    recurrent_dropout=self.recurrent_dropout_rate,
                    name=f'decoder_{i}'
                )
                decoder_outputs, _, _ = rnn_layer(decoder_outputs, initial_state=decoder_init_states[i])
                
            elif self.decoder_type == 'GRU':
                rnn_layer = keras.layers.GRU(
                    units=self.hidden_units,
                    return_sequences=return_sequences,
                    return_state=return_state,
                    dropout=self.dropout_rate,
                    recurrent_dropout=self.recurrent_dropout_rate,
                    name=f'decoder_{i}'
                )
                decoder_outputs, _ = rnn_layer(decoder_outputs, initial_state=decoder_init_states[i])

            elif self.decoder_type == 'RNN':
                rnn_layer = keras.layers.SimpleRNN(
                    units=self.hidden_units,
                    return_sequences=return_sequences,
                    return_state=return_state,
                    dropout=self.dropout_rate,
                    recurrent_dropout=self.recurrent_dropout_rate,
                    name=f'decoder_{i}'
                )
                decoder_outputs, _ = rnn_layer(decoder_outputs, initial_state=decoder_init_states[i])

        decoder_dense = keras.layers.Dense(
            units=self.output_vocab_size,
            activation='softmax',
            name='decoder_dense'
        )
        decoder_outputs = decoder_dense(decoder_outputs)
        self.training_model = keras.models.Model(
            inputs=[encoder_inputs, decoder_inputs],
            outputs=decoder_outputs
        )
    
    def build_inference_model(self):
        # Encoder
        encoder_inputs = keras.layers.Input(shape=(None,), name='encoder_inputs')
        encoder_embedding_layer = self.training_model.get_layer('encoder_embedding')
        encoder_embedding = encoder_embedding_layer(encoder_inputs)

        encoder_outputs = encoder_embedding
        encoder_states = []
        for i in range(self.encoder_layers):
            encoder_rnn_layer = self.training_model.get_layer(f'encoder_{i}')
            encoder_outputs, *state = encoder_rnn_layer(encoder_outputs)
            encoder_states.extend(state)

        self.encoder_model = keras.models.Model(
            inputs=encoder_inputs,
            outputs=encoder_states
        )

        # Decoder
        decoder_inputs = keras.layers.Input(shape=(None,), name='decoder_inputs')
        decoder_embedding_layer = self.training_model.get_layer('decoder_embedding')
        decoder_embedding = decoder_embedding_layer(decoder_inputs)

        decoder_states_inputs = []
        for idx, state in enumerate(encoder_states):
            decoder_states_inputs.append(
                keras.layers.Input(shape=(self.hidden_units,), name=f'decoder_state_input_{idx}')
            )

        decoder_outputs = decoder_embedding
        decoder_states = []

        state_idx = 0
        for i in range(self.decoder_layers):
            decoder_rnn_layer = self.training_model.get_layer(f'decoder_{i}')
            if self.decoder_type == 'LSTM':
                init_h = decoder_states_inputs[state_idx]
                init_c = decoder_states_inputs[state_idx + 1]
                decoder_outputs, state_h, state_c = decoder_rnn_layer(
                    decoder_outputs, initial_state=[init_h, init_c]
                )
                decoder_states.extend([state_h, state_c])
                state_idx += 2
            else:
                init_h = decoder_states_inputs[state_idx]
                decoder_outputs, state_h = decoder_rnn_layer(
                    decoder_outputs, initial_state=[init_h]
                )
                decoder_states.append(state_h)
                state_idx += 1

        decoder_dense_layer = self.training_model.get_layer('decoder_dense')
        decoder_outputs = decoder_dense_layer(decoder_outputs)

        self.decoder_model = keras.models.Model(
            inputs=[decoder_inputs] + decoder_states_inputs,
            outputs=[decoder_outputs] + decoder_states
        )
    
    def compile(self, optimizer='adam', loss = 'categorical_crossentropy', metrics=['accuracy']):
        self.training_model.compile(
            optimizer=optimizer,
            loss=loss,
            metrics=metrics
        )
    
    def fit(self, x, y, batch_size=64, epochs=10, validation_split=0):
        self.training_model.fit(
            x=x,
            y=y,
            batch_size=batch_size,
            epochs=epochs,
            validation_split=validation_split,
        )

    def evaluate(
        self,
        input_seqs,
        target_seqs,
        start_token,
        end_token,
        max_dec_len,
        batch_size=64):
        """
        Batched beam search decoding + exact‐match accuracy.
        Uses one big GPU call per time‐step over all (batch×beam) hypotheses.
        """
        N = input_seqs.shape[0]
        n_batches = math.ceil(N / batch_size)
        total_correct = 0

        for bi in range(n_batches):
            
            batch_inputs = input_seqs[bi*batch_size : (bi+1)*batch_size]
            bsz = batch_inputs.shape[0]

            enc_states = self.encoder_model.predict(batch_inputs, verbose=0)

            B = self.beam_width
            flat_states = []
            for state in enc_states:
                tiled = np.repeat(state[:, None, :], B, axis=1)
                flat_states.append(tiled.reshape(bsz*B, -1))

            flat_dec_input = np.full((bsz*B, 1), start_token, dtype='int32')

            seqs   = [[[start_token]] * B for _ in range(bsz)]
            scores = np.zeros((bsz, B), dtype=np.float32)

            for t in range(max_dec_len):
                inputs = [flat_dec_input] + flat_states
                outs   = self.decoder_model.predict(inputs, verbose=0)
                logits = outs[0]                
                next_lp = np.log(logits[:,0,:] + 1e-9) 

                next_lp = next_lp.reshape(bsz, B, -1)

                new_seqs  = []
                new_scores = []
                new_states = [np.zeros_like(s) for s in flat_states]

                for i in range(bsz):
                    total_lp = scores[i][:, None] + next_lp[i]   
                    flat_indices = total_lp.reshape(-1)      

                    topk_idx = np.argpartition(-flat_indices, B-1)[:B]
                    topk_scores = flat_indices[topk_idx]

                    prev_beam = topk_idx // next_lp.shape[2]   
                    token_id  = topk_idx %  next_lp.shape[2]      

                    bs_seqs = []
                    for j, (bprev, tok) in enumerate(zip(prev_beam, token_id)):
                        seq = seqs[i][bprev] + [int(tok)]
                        bs_seqs.append(seq)

                        src_idx = i*B + bprev
                        dst_idx = i*B + j
                        for k, st in enumerate(outs[1:]):
                            new_states[k][dst_idx] = st[src_idx]

                    new_seqs.append(bs_seqs)
                    new_scores.append(topk_scores)

                seqs   = new_seqs
                scores = np.stack(new_scores, axis=0)
                flat_states = [ns.reshape(bsz*B, -1) for ns in new_states]
                last_tokens = [ [s[-1] for s in bs] for bs in seqs ]
                flat_dec_input = np.array(last_tokens).reshape(-1,1)

                if all(s[-1] == end_token for bs in seqs for s in bs):
                    break

            batch_preds = []
            for bs in seqs:
                best_idx = int(np.argmax([scores[i,j] for j in range(B)]))
                seq = bs[best_idx]
                seq = [tok for tok in seq if tok not in (start_token,)]
                if end_token in seq:
                    seq = seq[:seq.index(end_token)]
                seq += [0] * (max_dec_len - len(seq))
                batch_preds.append(seq)

            tgt_slice = target_seqs[bi*batch_size : bi*batch_size+bsz]
            for p, t in zip(batch_preds, tgt_slice):
                if np.array_equal(p, t):
                    total_correct += 1

        return total_correct / N

In [None]:
class BahdanauAttention(Layer):
    def __init__(self, units, **kwargs):
        super().__init__(**kwargs)
        self.W1 = Dense(units)
        self.W2 = Dense(units)
        self.V  = Dense(1)

    def call(self, query, values, mask=None):
        q_expanded = tf.expand_dims(query, 2)
        v_expanded = tf.expand_dims(values, 1)

        score = self.V(tf.nn.tanh(self.W1(q_expanded) + self.W2(v_expanded)))

        if mask is not None and mask[1] is not None:
            enc_mask = tf.expand_dims(mask[1], 1)
            score -= (1.0 - tf.cast(enc_mask, score.dtype)) * 1e9

        attn_weights = tf.nn.softmax(score, axis=2)       
        attn_weights = tf.squeeze(attn_weights, -1)       
        context = tf.matmul(attn_weights, values)     

        return context, attn_weights

class Seq2SeqAttention:
    def __init__(
        self,
        input_vocab_size,
        output_vocab_size,
        embedding_dim,
        hidden_units,
        dropout_rate=0.0,
        recurrent_dropout_rate=0.0,
        encoder_type='LSTM',
        decoder_type='LSTM',
        beam_width = 1
    ):
        self.input_vocab_size = input_vocab_size
        self.output_vocab_size = output_vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_units = hidden_units
        self.dropout_rate = dropout_rate
        self.recurrent_dropout_rate = recurrent_dropout_rate
        self.encoder_type = encoder_type
        self.decoder_type = decoder_type
        self.beam_width = beam_width

    def build_training_model(self):
        enc_inputs = Input(shape=(None,), name='encoder_inputs')
        enc_emb = Embedding(
            self.input_vocab_size,
            self.embedding_dim,
            mask_zero=True,
            name='encoder_embedding'
        )(enc_inputs)

        EncoderCell = getattr(tf.keras.layers, self.encoder_type)
        self.encoder_rnn = EncoderCell(
            self.hidden_units,
            return_sequences=True,
            return_state=True,
            dropout=self.dropout_rate,
            recurrent_dropout=self.recurrent_dropout_rate,
            name='encoder_'+self.encoder_type.lower()
        )
        enc_outputs_and_states = self.encoder_rnn(enc_emb)
        enc_outputs, *enc_states = enc_outputs_and_states

        dec_inputs = Input(shape=(None,), name='decoder_inputs')
        dec_emb = Embedding(
            self.output_vocab_size,
            self.embedding_dim,
            mask_zero=True,
            name='decoder_embedding'
        )(dec_inputs)

        DecoderCell = getattr(tf.keras.layers, self.decoder_type)
        self.decoder_rnn = DecoderCell(
            self.hidden_units,
            return_sequences=True,
            return_state=True,
            dropout=self.dropout_rate,
            recurrent_dropout=self.recurrent_dropout_rate,
            name='decoder_'+self.decoder_type.lower()
        )
        dec_outputs_and_states = self.decoder_rnn(
            dec_emb, initial_state=enc_states
        )
        dec_outputs, *dec_states = dec_outputs_and_states

        self.attention_layer = BahdanauAttention(
            self.hidden_units, name='bahdanau_attn'
        )
        context, _ = self.attention_layer(
            dec_outputs, enc_outputs
        )

        concat = Concatenate(axis=-1, name='concat_layer')([dec_outputs, context])
        dec_logits = Dense(
            self.output_vocab_size,
            activation='softmax',
            name='output_dense'
        )(concat)

        self.training_model = Model(
            inputs=[enc_inputs, dec_inputs],
            outputs=dec_logits,
            name='seq2seq_training'
        )
        
    def build_inference_model(self):
        enc_inputs_inf = Input(
            shape=(None,), name='encoder_inputs_inf'
        )
        enc_emb_inf = self._training_model.get_layer('encoder_embedding')(
            enc_inputs_inf
        )
        enc_rnn = self._training_model.get_layer(
            'encoder_'+self.encoder_type.lower()
        )
        enc_outputs_and_states = enc_rnn(enc_emb_inf)
        enc_outputs_inf, *enc_states_inf = enc_outputs_and_states

        self.encoder_model = Model(
            inputs=enc_inputs_inf,
            outputs=[enc_outputs_inf] + enc_states_inf,
            name='encoder_inference'
        )

        dec_token_inf   = Input(shape=(1,), name='decoder_token_inf')
        enc_outputs_inp = Input(
            shape=(None, self.hidden_units),
            name='encoder_outputs_inf'
        )

        dec_state_inputs = [
            Input(shape=(self.hidden_units,), name=f'decoder_state_inf_{i}')
            for i in range(len(enc_states_inf))
        ]

        dec_emb_inf = self._training_model.get_layer('decoder_embedding')(
            dec_token_inf
        )
        dec_rnn = self._training_model.get_layer(
            'decoder_'+self.decoder_type.lower()
        )
        dec_outputs_and_states_inf = dec_rnn(
            dec_emb_inf, initial_state=dec_state_inputs
        )
        dec_out_step, *dec_states_out = dec_outputs_and_states_inf

        context_inf, _ = self.attention_layer(
            dec_out_step, enc_outputs_inp
        )
        concat_inf = Concatenate(axis=-1)([dec_out_step, context_inf])
        dec_logits_inf = self._training_model.get_layer('output_dense')(
            concat_inf
        )

        self.decoder_model = Model(
            inputs=[dec_token_inf, enc_outputs_inp] + dec_state_inputs,
            outputs=[dec_logits_inf] + dec_states_out,
            name='decoder_inference'
        )

    def compile(self, optimizer='adam', loss = 'categorical_crossentropy', metrics=['accuracy']):
        self.training_model.compile(
            optimizer=optimizer,
            loss=loss,
            metrics=metrics
        )
    
    def fit(self, x, y, batch_size=64, epochs=10, validation_split=0):
        self.training_model.fit(
            x=x,
            y=y,
            batch_size=batch_size,
            epochs=epochs,
            validation_split=validation_split,
        )
    
    def evaluate(
        self,
        input_seqs,
        target_seqs,
        start_token,
        end_token,
        max_dec_len,
        batch_size=64):
        """
        Batched beam search decoding + exact‐match accuracy.
        Uses one big GPU call per time‐step over all (batch×beam) hypotheses.
        """
        N = input_seqs.shape[0]
        n_batches = math.ceil(N / batch_size)
        total_correct = 0
        predictions = []

        for bi in range(n_batches):
            
            batch_inputs = input_seqs[bi*batch_size : (bi+1)*batch_size]
            bsz = batch_inputs.shape[0]

            enc_states = self.encoder_model.predict(batch_inputs, verbose=0)

            B = self.beam_width
            flat_states = []
            for state in enc_states:
                tiled = np.repeat(state[:, None, :], B, axis=1)
                flat_states.append(tiled.reshape(bsz*B, -1))

            flat_dec_input = np.full((bsz*B, 1), start_token, dtype='int32')

            seqs   = [[[start_token]] * B for _ in range(bsz)]
            scores = np.zeros((bsz, B), dtype=np.float32)

            for t in range(max_dec_len):
                inputs = [flat_dec_input] + flat_states
                outs   = self.decoder_model.predict(inputs, verbose=0)
                logits = outs[0]                
                next_lp = np.log(logits[:,0,:] + 1e-9) 

                next_lp = next_lp.reshape(bsz, B, -1)

                new_seqs  = []
                new_scores = []
                new_states = [np.zeros_like(s) for s in flat_states]

                for i in range(bsz):
                    total_lp = scores[i][:, None] + next_lp[i]   
                    flat_indices = total_lp.reshape(-1)      

                    topk_idx = np.argpartition(-flat_indices, B-1)[:B]
                    topk_scores = flat_indices[topk_idx]

                    prev_beam = topk_idx // next_lp.shape[2]   
                    token_id  = topk_idx %  next_lp.shape[2]      

                    bs_seqs = []
                    for j, (bprev, tok) in enumerate(zip(prev_beam, token_id)):
                        seq = seqs[i][bprev] + [int(tok)]
                        bs_seqs.append(seq)

                        src_idx = i*B + bprev
                        dst_idx = i*B + j
                        for k, st in enumerate(outs[1:]):
                            new_states[k][dst_idx] = st[src_idx]

                    new_seqs.append(bs_seqs)
                    new_scores.append(topk_scores)

                seqs   = new_seqs
                scores = np.stack(new_scores, axis=0)
                flat_states = [ns.reshape(bsz*B, -1) for ns in new_states]
                last_tokens = [ [s[-1] for s in bs] for bs in seqs ]
                flat_dec_input = np.array(last_tokens).reshape(-1,1)

                if all(s[-1] == end_token for bs in seqs for s in bs):
                    break

            batch_preds = []
            for bs in seqs:
                best_idx = int(np.argmax([scores[i,j] for j in range(B)]))
                seq = bs[best_idx]
                seq = [tok for tok in seq if tok not in (start_token,)]
                if end_token in seq:
                    seq = seq[:seq.index(end_token)]
                seq += [0] * (max_dec_len - len(seq))
                batch_preds.append(seq)

            tgt_slice = target_seqs[bi*batch_size : bi*batch_size+bsz]
            for p, t in zip(batch_preds, tgt_slice):
                if np.array_equal(p, t):
                    total_correct += 1
            
            predictions.extend(batch_preds)

        return total_correct / N, predictions



In [None]:
import os
import gc

os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Use GPU 0; change if needed

import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"GPUs found: {[gpu.name for gpu in gpus]}")
    except RuntimeError as e:
        print(e)
else:
    print("No GPU found, running on CPU.")

In [None]:
import wandb
!wandb login 

In [None]:
# this function is used to get the config value from wandb or command line arguments
def get_config_value(config, args, key, default=None):
    return getattr(config, key, getattr(args, key, default))

def train_model(config=None):
    # set default hyperparameters
    defaults = {
        'embedding_dim': 256,
        'hidden_units': 512,
        'dropout_rate': 0.2,
        'recurrent_dropout_rate': 0.2,
        'encoder_layers': 1,
        'decoder_layers': 1,
        'cell_type': 'LSTM',
        'beam_width': 1,
        'dataset': '/kaggle/input/as3-dataset/lexicons',
        'attention': False,
        'do_val':True,
        'do_test': False,
    }

    # Initialize wandb with the provided entity and project
    with wandb.init(entity='me21b138-indian-institute-of-technology-madras', project='AS3', config=config):
        config = wandb.config

        # Create a class to mimic argparse for the helper functions
        class Args:
            def __init__(self, **kwargs):
                for key, value in kwargs.items():
                    setattr(self, key, value)

        # Set up args with defaults
        args = Args(**defaults)

        data = extract_data()
        train_data = data['train']
        dev_data = data['dev']
        test_data = data['test']

        encoder_tokenizer = tokanize_texts(np.concatenate((dev_data[:,1], train_data[:,1]), axis=0))
        decoder_tokenizer = tokanize_texts(np.concatenate((dev_data[:,0], train_data[:,0]), axis=0), start_end_tokens=True)

        train_x = encoder_tokenizer.texts_to_sequences(train_data[:,1])
        train_y = decoder_tokenizer.texts_to_sequences(train_data[:,0])
        dev_x = encoder_tokenizer.texts_to_sequences(dev_data[:,1])
        dev_y = decoder_tokenizer.texts_to_sequences(dev_data[:,0])
        test_x = encoder_tokenizer.texts_to_sequences(test_data[:,1])
        test_y = decoder_tokenizer.texts_to_sequences(test_data[:,0])

        max_encoder_seq_length = max([len(seq) for seq in train_x + dev_x + test_x])
        max_decoder_seq_length = max([len(seq) for seq in train_y + dev_y + test_y])

        train_x = pad_sequences(train_x, maxlen=max_encoder_seq_length, padding='post')
        dev_x = pad_sequences(dev_x, maxlen=max_encoder_seq_length, padding='post')
        test_x = pad_sequences(test_x, maxlen=max_encoder_seq_length, padding='post')

        train_y = pad_sequences(train_y, maxlen=max_decoder_seq_length, padding='post')
        dev_y = pad_sequences(dev_y, maxlen=max_decoder_seq_length, padding='post')
        test_y = pad_sequences(test_y, maxlen=max_decoder_seq_length, padding='post')
        train_x = train_x
        train_y = train_y
        dev_x = dev_x
        dev_y = dev_y

        input_vocab_size = len(encoder_tokenizer.word_index) + 1
        output_vocab_size = len(decoder_tokenizer.word_index) + 1
        train_y_cat = np.eye(output_vocab_size)[train_y]

        model = None
        
        if get_config_value(config, args, 'attention'):
            model = Seq2SeqAttention(
                input_vocab_size=input_vocab_size,
                output_vocab_size=output_vocab_size,
                embedding_dim=get_config_value(config, args, 'embedding_dim'),
                hidden_units=get_config_value(config, args, 'hidden_units'),
                dropout_rate=get_config_value(config, args, 'dropout_rate'),
                recurrent_dropout_rate=get_config_value(config, args, 'recurrent_dropout_rate'),
                encoder_type=get_config_value(config, args, 'cell_type'),
                decoder_type=get_config_value(config, args, 'cell_type'),
                beam_width=get_config_value(config, args, 'beam_width')
            )
        
        else:
            model = seq2seq(
                input_vocab_size=input_vocab_size,
                output_vocab_size=output_vocab_size,
                embedding_dim=get_config_value(config, args, 'embedding_dim'),
                hidden_units=get_config_value(config, args, 'hidden_units'),
                encoder_layers=get_config_value(config, args, 'encoder_layers'),
                decoder_layers=get_config_value(config, args, 'decoder_layers'),
                dropout_rate=get_config_value(config, args, 'dropout_rate'),
                recurrent_dropout_rate=get_config_value(config, args, 'recurrent_dropout_rate'),
                encoder_type=get_config_value(config, args, 'cell_type'),
                decoder_type=get_config_value(config, args, 'cell_type'),
                beam_width=get_config_value(config, args, 'beam_width')
            )
        
        model.build_training_model()
        model.compile(
            optimizer='adam',
            loss='categorical_crossentropy',
            metrics=['accuracy']
        )
        model.fit(
            x=[train_x, train_y],
            y=train_y_cat,
            batch_size=400,
            epochs=10,
            validation_split=0.2
        )
        model.build_inference_model()


        dev_y_eval = dev_y[:, 1:]
        test_y_eval = test_y[:, 1:]

        if get_config_value(config, args, 'do_val'):  
            dev_acc, dev_prediction = model.evaluate(
                input_seqs=dev_x,
                target_seqs=dev_y_eval,
                start_token=decoder_tokenizer.word_index['<start>'],
                end_token=decoder_tokenizer.word_index['<end>'],
                max_dec_len=dev_y_eval.shape[1],
                batch_size=1000
            )
            wandb.log({"validation_accuracy": dev_acc})

        if get_config_value(config, args, 'do_test'):
            test_acc, test_prediction = model.evaluate(
                input_seqs=test_x,
                target_seqs=test_y_eval,
                start_token=decoder_tokenizer.word_index['<start>'],
                end_token=decoder_tokenizer.word_index['<end>'],
                max_dec_len=dev_y_eval.shape[1],
                batch_size=1000
            )
            wandb.log({"test_accuracy": test_acc})

        gc.collect()

In [None]:
train_model()

In [None]:
sweep_config = {
    'method': 'bayes',
    'metric': {
        'name': 'validation_accuracy',
        'goal': 'maximize'
    },
    'parameters': {
        'embedding_dim': {
            'values': [16, 32, 64, 128, 256, 512]
        },
        'hidden_units': {
            'values': [16, 32, 64, 128, 256, 512]
        },
        'dropout_rate': {
            'distribution': 'uniform',
            'min': 0.0,
            'max': 0.5
        },
        'recurrent_dropout_rate': {
            'distribution': 'uniform',
            'min': 0.0,
            'max': 0.5
        },
        'encoder_layers': {
            'values': [1, 2, 3]
        },
        'decoder_layers': {
            'values': [1, 2, 3]
        },
        'cell_type': {
            'values': ['RNN', 'LSTM', 'GRU']
        },
        'beam_width': {
            'values': [1, 2, 3]
        }
    }
}

In [None]:
# Configuration for the sweep
entity = 'me21b138-indian-institute-of-technology-madras'  # Your wandb entity
project = 'AS3'  # Your wandb project
count = 100  # Number of runs to execute

# Initialize the sweep
wandb.require("core")
sweep_id = wandb.sweep(sweep_config, entity=entity, project=project)

# Start the sweep agent
wandb.agent(sweep_id, function=train_model, count=count)