In [17]:
import tensorflow as tf
import os
# from tensorflow.python.keras.layers import Layer # No longer needed, can import from tf.keras
from tensorflow.python.keras import backend as K
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
# from tensorflow.python import keras # No longer needed, can use tf.keras directly

# Import layers from the standard tf.keras.layers
from tensorflow.keras.layers import (
    Embedding, Input, Dense, LSTM, GRU, RNN, SimpleRNN, Softmax,
    Dropout, Concatenate, TimeDistributed, Layer # Import Layer here
)
from tensorflow.python.keras.callbacks import Callback
from tensorflow.python.keras import Model

from math import log
import math


class AttentionLayer(Layer): # Inherit from the correctly imported Layer
    """
    Implements Bahdanau Attention Mechanism (Additive Attention).
    Reference: https://arxiv.org/pdf/1409.0473.pdf

    The layer takes as input a sequence of encoder outputs and decoder outputs,
    and computes attention weights and context vectors for each decoder timestep.

    Trainable weights:
        - W_a: applied to encoder outputs
        - U_a: applied to decoder state
        - V_a: used to produce attention energies
    """

    def __init__(self, **kwargs):
        super(AttentionLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        """
        Initialize trainable weights based on input shapes.
        input_shape: List of shapes [encoder_output_shape, decoder_output_shape]
        """
        assert isinstance(input_shape, list)

        self.W_a = self.add_weight(
            name='W_a',
            shape=(input_shape[0][2], input_shape[0][2]),
            initializer='uniform',
            trainable=True
        )
        self.U_a = self.add_weight(
            name='U_a',
            shape=(input_shape[1][2], input_shape[0][2]),
            initializer='uniform',
            trainable=True
        )
        self.V_a = self.add_weight(
            name='V_a',
            shape=(input_shape[0][2], 1),
            initializer='uniform',
            trainable=True
        )

        super(AttentionLayer, self).build(input_shape)

    def call(self, inputs, verbose=False):
        """
        Compute context vectors and attention weights.
        Args:
            inputs: List of two tensors [encoder_output_seq, decoder_output_seq]
            verbose: Flag to print intermediate shapes (for debugging)
        Returns:
            context_vectors: Weighted sum of encoder outputs per decoder timestep
            attention_weights: Attention weights for each encoder timestep
        """
        assert isinstance(inputs, list)
        encoder_output_seq, decoder_output_seq = inputs

        if verbose:
            print('Encoder Output Shape:', encoder_output_seq.shape)
            print('Decoder Output Shape:', decoder_output_seq.shape)

        def compute_energy(decoder_hidden_state, _):
            """
            Compute attention energy scores for a single decoder timestep.
            """
            # Project encoder outputs using W_a
            encoder_projection = K.dot(encoder_output_seq, self.W_a)

            # Project current decoder state using U_a and expand dims
            decoder_projection = K.expand_dims(K.dot(decoder_hidden_state, self.U_a), 1)

            # Compute energy scores with tanh non-linearity and project using V_a
            combined_projection = K.tanh(encoder_projection + decoder_projection)
            energy_scores = K.squeeze(K.dot(combined_projection, self.V_a), axis=-1)

            # Apply softmax to obtain attention weights
            attention_weights = K.softmax(energy_scores)

            return attention_weights, [attention_weights]

        def compute_context(attention_weights, _):
            """
            Compute context vector as the weighted sum of encoder outputs.
            """
            context_vector = K.sum(encoder_output_seq * K.expand_dims(attention_weights, -1), axis=1)
            return context_vector, [context_vector]

        # Initialize dummy states for RNN functions
        # Use tf.zeros_like to create dummy states with correct shapes and dtype
        dummy_state_energy = tf.zeros_like(encoder_output_seq[:, 0, :]) # Shape (batch_size, encoder_dim) for energy computation
        dummy_state_context = tf.zeros_like(encoder_output_seq[:, 0, :]) # Shape (batch_size, encoder_dim) for context computation


        # Run RNN to compute attention weights for all decoder steps
        # K.rnn expects initial states with shape (num_states,) + state_shape
        # attention_weights_seq is the sequence of outputs
        _, attention_weights_seq, _ = K.rnn(
            compute_energy, decoder_output_seq, [dummy_state_energy] # Pass dummy state as a list
        )

        # Run RNN again to compute context vectors using the attention weights
        # context_vector_seq is the sequence of outputs
        _, context_vector_seq, _ = K.rnn(
            compute_context, attention_weights_seq, [dummy_state_context] # Pass dummy state as a list
        )

        return context_vector_seq, attention_weights_seq

    def compute_output_shape(self, input_shape):
        """
        Specify output shapes of the layer.
        Returns:
            - context_vectors: (batch_size, decoder_timesteps, encoder_dim) - based on the K.sum in compute_context
            - attention_weights: (batch_size, decoder_timesteps, encoder_timesteps)
        """
        return [
            tf.TensorShape((input_shape[1][0], input_shape[1][1], input_shape[0][2])), # Context vector shape should match encoder_dim, not decoder_dim
            tf.TensorShape((input_shape[1][0], input_shape[1][1], input_shape[0][1]))
        ]

For data

In [18]:
#mount drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [19]:
train = '/content/drive/MyDrive/dakshina_dataset_v1.0/dakshina_dataset_v1.0/hi/lexicons/hi.translit.sampled.train.tsv'
val = '/content/drive/MyDrive/dakshina_dataset_v1.0/dakshina_dataset_v1.0/hi/lexicons/hi.translit.sampled.dev.tsv'
test = '/content/drive/MyDrive/dakshina_dataset_v1.0/dakshina_dataset_v1.0/hi/lexicons/hi.translit.sampled.test.tsv'

Preprocessing

In [20]:
def training_extract(path):
    """
    Extracts vocabulary and sequence length metadata from a training dataset.

    This function reads a tab-separated text file and computes:
    - Character-level vocabularies for both input and target sequences.
    - Mapping dictionaries between characters and token indices.
    - Maximum sequence lengths for encoder and decoder inputs.

    Args:
        path (str): Path to the training data file.

    Returns:
        tuple: A tuple containing:
            - num_encoder_tokens (int): Number of unique input characters.
            - num_decoder_tokens (int): Number of unique target characters.
            - input_token_index (dict): Mapping from input characters to indices.
            - target_token_index (dict): Mapping from target characters to indices.
            - reverse_input_token_index (dict): Mapping from indices to input characters.
            - reverse_target_token_index (dict): Mapping from indices to target characters.
            - max_encoder_seq_length (int): Maximum length among all input sequences.
            - max_decoder_seq_length (int): Maximum length among all target sequences.
    """
    input_texts = []
    target_texts = []
    input_characters = set()
    target_characters = set()

    with open(path, "r", encoding="utf-8") as file:
        lines = file.read().split("\n")

    for line in lines[:-1]:  # Skip last empty line
        target_text, input_text, _ = line.split("\t")
        input_texts.append(input_text)
        target_texts.append("\t" + target_text + "\n")

        input_characters.update(input_text)
        target_characters.update(target_text)

    target_characters.update(["\t", "\n"])

    input_characters = sorted(list(input_characters))
    target_characters = sorted(list(target_characters))

    input_token_index = {char: i for i, char in enumerate(input_characters)}
    target_token_index = {char: i for i, char in enumerate(target_characters)}
    reverse_input_token_index = {i: char for char, i in input_token_index.items()}
    reverse_target_token_index = {i: char for char, i in target_token_index.items()}

    num_encoder_tokens = len(input_characters)
    num_decoder_tokens = len(target_characters)
    max_encoder_seq_length = max(len(txt) for txt in input_texts)
    max_decoder_seq_length = max(len(txt) for txt in target_texts)

    return (
        num_encoder_tokens,
        num_decoder_tokens,
        input_token_index,
        target_token_index,
        reverse_input_token_index,
        reverse_target_token_index,
        max_encoder_seq_length,
        max_decoder_seq_length
    )


def extract_data(path, max_encoder_seq_length, max_decoder_seq_length, num_decoder_tokens):
    """
    Converts raw text data into padded and tokenized input/output arrays for model training.

    Each input sequence is converted into a sequence of character indices.
    Decoder input and target sequences are prepared in parallel, where
    the target is a one-hot encoded representation shifted by one timestep.

    Args:
        path (str): Path to the input data file.
        max_encoder_seq_length (int): Maximum encoder sequence length for padding.
        max_decoder_seq_length (int): Maximum decoder sequence length for padding.
        num_decoder_tokens (int): Size of the decoder vocabulary.

    Returns:
        tuple: A tuple containing:
            - input_texts (list): Raw input sequences.
            - target_texts (list): Raw target sequences (with start and end tokens).
            - encoder_input_data (np.ndarray): Padded token indices for encoder input.
            - decoder_input_data (np.ndarray): Padded token indices for decoder input.
            - decoder_target_data (np.ndarray): One-hot encoded decoder targets.
    """
    input_texts = []
    target_texts = []

    with open(path, "r", encoding="utf-8") as file:
        lines = file.read().split("\n")

    for line in lines[:-1]:  # Skip last empty line
        target_text, input_text, _ = line.split("\t")
        input_texts.append(input_text)
        target_texts.append("\t" + target_text + "\n")

    encoder_input_data = np.zeros((len(input_texts), max_encoder_seq_length), dtype="float32")
    decoder_input_data = np.zeros((len(input_texts), max_decoder_seq_length), dtype="float32")
    decoder_target_data = np.zeros((len(input_texts), max_decoder_seq_length, num_decoder_tokens), dtype="float32")

    for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
        for t, char in enumerate(input_text):
            encoder_input_data[i, t] = input_token[char]
        for t, char in enumerate(target_text):
            decoder_input_data[i, t] = target_token[char]
            if t > 0:
                decoder_target_data[i, t - 1, target_token[char]] = 1.0

    return input_texts, target_texts, encoder_input_data, decoder_input_data, decoder_target_data


In [21]:
d=training_extract(train)
print(d)
num_encoder_tokens =d[0]
num_decoder_tokens=d[1]
input_token=d[2]
target_token=d[3]
reverse_input_token=d[4]
reverse_target_token=d[5]
max_encoder_seq_length=d[6]
max_decoder_seq_length=d[7]

(26, 65, {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7, 'i': 8, 'j': 9, 'k': 10, 'l': 11, 'm': 12, 'n': 13, 'o': 14, 'p': 15, 'q': 16, 'r': 17, 's': 18, 't': 19, 'u': 20, 'v': 21, 'w': 22, 'x': 23, 'y': 24, 'z': 25}, {'\t': 0, '\n': 1, 'ँ': 2, 'ं': 3, 'ः': 4, 'अ': 5, 'आ': 6, 'इ': 7, 'ई': 8, 'उ': 9, 'ऊ': 10, 'ऋ': 11, 'ए': 12, 'ऐ': 13, 'ऑ': 14, 'ओ': 15, 'औ': 16, 'क': 17, 'ख': 18, 'ग': 19, 'घ': 20, 'ङ': 21, 'च': 22, 'छ': 23, 'ज': 24, 'झ': 25, 'ञ': 26, 'ट': 27, 'ठ': 28, 'ड': 29, 'ढ': 30, 'ण': 31, 'त': 32, 'थ': 33, 'द': 34, 'ध': 35, 'न': 36, 'प': 37, 'फ': 38, 'ब': 39, 'भ': 40, 'म': 41, 'य': 42, 'र': 43, 'ल': 44, 'व': 45, 'श': 46, 'ष': 47, 'स': 48, 'ह': 49, '़': 50, 'ा': 51, 'ि': 52, 'ी': 53, 'ु': 54, 'ू': 55, 'ृ': 56, 'ॅ': 57, 'े': 58, 'ै': 59, 'ॉ': 60, 'ो': 61, 'ौ': 62, '्': 63, 'ॐ': 64}, {0: 'a', 1: 'b', 2: 'c', 3: 'd', 4: 'e', 5: 'f', 6: 'g', 7: 'h', 8: 'i', 9: 'j', 10: 'k', 11: 'l', 12: 'm', 13: 'n', 14: 'o', 15: 'p', 16: 'q', 17: 'r', 18: 's', 19: 't', 20: 'u', 21: 'v

In [22]:
# Extract and preprocess training, validation, and test datasets
# Each call returns:
# - input_texts: raw encoder-side text sequences
# - target_texts: raw decoder-side text sequences (with start '\t' and end '\n' tokens)
# - encoder_input_data: padded and indexed input sequences
# - decoder_input_data: padded and indexed decoder input sequences
# - decoder_target_data: one-hot encoded shifted decoder output

train_input_texts, train_target_texts, encoder_input_train, decoder_input_train, decoder_target_train = extract_data(
    train, max_encoder_seq_length, max_decoder_seq_length, num_decoder_tokens)

val_input_texts, val_target_texts, encoder_input_val, decoder_input_val, decoder_target_val = extract_data(
    val, max_encoder_seq_length, max_decoder_seq_length, num_decoder_tokens)

test_input_texts, test_target_texts, encoder_input_test, decoder_input_test, decoder_target_test = extract_data(
    test, max_encoder_seq_length, max_decoder_seq_length, num_decoder_tokens)

# Print dataset shapes to verify dimensions
print("Train data shape      :", encoder_input_train.shape)
print("Validation data shape :", encoder_input_val.shape)
print("Test data shape       :", encoder_input_test.shape)


Train data shape      : (44204, 20)
Validation data shape : (4358, 20)
Test data shape       : (4502, 20)


In [23]:
def beam_search_decoder(prob_distributions, beam_width):
    """
    Performs beam search decoding on a sequence of probability distributions.

    Args:
        prob_distributions (list of list of float):
            A list where each element is a list of probabilities over the vocabulary at that time step.
        beam_width (int):
            The number of top sequences to keep at each step (beam width).

    Returns:
        list of list:
            Top `beam_width` sequences along with their cumulative log-probability scores.
            Each element is a list: [sequence (list of token indices), score (float)]
    """
    # Initialize with an empty sequence and zero score
    decoded_sequences = [[[], 0.0]]

    # Iterate over each time step's probability distribution
    for timestep_probs in prob_distributions:
        all_candidates = []

        # Expand each sequence in the current beam
        for seq, score in decoded_sequences:
            for token_idx, token_prob in enumerate(timestep_probs):
                new_seq = seq + [token_idx]
                new_score = score - log(token_prob)  # Negative log-likelihood
                all_candidates.append([new_seq, new_score])

        # Keep only the top 'beam_width' sequences
        ordered = sorted(all_candidates, key=lambda tup: tup[1])
        decoded_sequences = ordered[:beam_width]

    return decoded_sequences


def translate(token_sequence):
    """
    Converts a sequence of token indices into a string using the reverse target vocabulary.

    Args:
        token_sequence (list of int):
            A list of token indices corresponding to characters.

    Returns:
        str:
            The decoded sentence (string of characters).
    """
    decoded_sentence = [reverse_target_token[idx] for idx in token_sequence]
    return "".join(decoded_sentence)



In [24]:
class WordAccuracyCallback(keras.callbacks.Callback):
    """
    Custom Keras Callback to compute word-level accuracy on the validation set
    at the end of each epoch using beam search decoding.

    Args:
        beam_size (int): Number of candidate sequences to consider during beam search.
    """

    def __init__(self, beam_size):
        super(WordAccuracyCallback, self).__init__()
        self.beam_size = beam_size

    def on_epoch_end(self, epoch, logs=None):
        """
        Called at the end of each epoch. Computes and logs word-level accuracy.
        """
        if logs is None:
            logs = {}

        # Predict decoder outputs from the model
        predictions = self.model.predict([encoder_input_val, decoder_input_val])
        correct_predictions = 0

        # Iterate through each predicted sample
        for i in range(predictions.shape[0]):
            # Apply beam search decoding to get top candidate sequences
            candidate_sequences = beam_search_decoder(predictions[i], self.beam_size)

            # Compare each candidate against the actual target text
            for j in range(self.beam_size):
                decoded_sequence = translate(candidate_sequences[j][0][:len(val_target_texts[i]) - 1])
                if "\t" + decoded_sequence == val_target_texts[i]:
                    correct_predictions += 1
                    break  # Found a correct candidate, move to next sample

        # Calculate and store accuracy (truncated to 4 decimal places)
        accuracy = correct_predictions / predictions.shape[0]
        logs["WordAccuracy"] = math.trunc(accuracy * 10000) / 10000

        print("- wordAccuracy:", logs["WordAccuracy"])


In [25]:
from wandb.integration.keras import WandbMetricsLogger

class RNN_Model:
    """
    A customizable RNN-based sequence-to-sequence model with support for
    LSTM, GRU, and SimpleRNN cells, attention mechanism, and WordAccuracy tracking.

    Parameters:
    - embed_size (int): Dimension of the embedding layer.
    - no_of_encoder_layers (int): Number of RNN layers in the encoder.
    - no_of_decoder_layers (int): Number of RNN layers in the decoder.
    - latent_dimension (int): Number of units in each RNN layer.
    - dropout (float): Dropout rate for RNN layers.
    - recurrent_dropout (float): Recurrent dropout rate.
    - cell_type (str): Type of RNN cell to use ('LSTM', 'GRU', or 'RNN').
    - beam_size (int): Beam width for beam search during validation.
    """

    def __init__(self, embed_size, no_of_encoder_layers, no_of_decoder_layers,
                 latent_dimension, dropout, recurrent_dropout, cell_type, beam_size):
        self.embed_size = embed_size
        self.no_of_encoder_layers = no_of_encoder_layers
        self.no_of_decoder_layers = no_of_decoder_layers
        self.latent_dimension = latent_dimension
        self.dropout = dropout
        self.recurrent_dropout = recurrent_dropout
        self.cell_type = cell_type
        self.beam_size = beam_size

        self.model = None
        self.input_layers = []    # Encoder layers
        self.output_layers = []   # Decoder layers
        self.encoder_model = None
        self.decoder_model = None

    def BUILD_FIT_MODEL(self, en_ip_tr_data, de_ip_tr_data, de_op_tr_data,
                        epochs, batch_size, max_encoder_seq_length, num_encoder_tokens,
                        max_decoder_seq_length, num_decoder_tokens):
        """
        Builds and trains the encoder-decoder model with attention.

        Args:
            en_ip_tr_data (ndarray): Encoder input training data.
            de_ip_tr_data (ndarray): Decoder input training data.
            de_op_tr_data (ndarray): Decoder output training data (one-hot).
            epochs (int): Number of training epochs.
            batch_size (int): Batch size for training.
            max_encoder_seq_length (int): Max length of encoder input sequences.
            num_encoder_tokens (int): Size of encoder vocabulary.
            max_decoder_seq_length (int): Max length of decoder input sequences.
            num_decoder_tokens (int): Size of decoder vocabulary.
        """

        # Encoder
        encoder_inputs = Input(shape=(max_encoder_seq_length,))
        x = Embedding(input_dim=num_encoder_tokens,
                      output_dim=self.embed_size,
                      input_length=max_encoder_seq_length,
                      name='enc_embd_layer')(encoder_inputs)

        encoder_states = []

        for _ in range(self.no_of_encoder_layers):
            if self.cell_type == 'LSTM':
                rnn = LSTM(self.latent_dimension, return_sequences=True,
                           return_state=True, dropout=self.dropout,
                           recurrent_dropout=self.recurrent_dropout)
            elif self.cell_type == 'GRU':
                rnn = GRU(self.latent_dimension, return_sequences=True,
                          return_state=True, dropout=self.dropout,
                          recurrent_dropout=self.recurrent_dropout)
            elif self.cell_type == 'RNN':
                rnn = SimpleRNN(self.latent_dimension, return_sequences=True,
                                return_state=True, dropout=self.dropout,
                                recurrent_dropout=self.recurrent_dropout)
            self.input_layers.append(rnn)
            outputs = rnn(x)
            x, states = outputs[0], outputs[1:]
            encoder_states.append(states)

        encoder_outputs = x

        # Decoder
        decoder_inputs = Input(shape=(max_decoder_seq_length,))
        y = Embedding(input_dim=num_decoder_tokens,
                      output_dim=self.embed_size,
                      input_length=max_decoder_seq_length,
                      name='dec_embd_layer')(decoder_inputs)

        for i in range(self.no_of_decoder_layers):
            initial_state = encoder_states[i]
            rnn = None
            if self.cell_type == 'LSTM':
                rnn = LSTM(self.latent_dimension, return_sequences=True,
                           return_state=True, dropout=self.dropout,
                           recurrent_dropout=self.recurrent_dropout)
            elif self.cell_type == 'GRU':
                rnn = GRU(self.latent_dimension, return_sequences=True,
                          return_state=True, dropout=self.dropout,
                          recurrent_dropout=self.recurrent_dropout)
            elif self.cell_type == 'RNN':
                rnn = SimpleRNN(self.latent_dimension, return_sequences=True,
                                return_state=True, dropout=self.dropout,
                                recurrent_dropout=self.recurrent_dropout)
            self.output_layers.append(rnn)
            y, _ = rnn(y, initial_state=initial_state)

        decoder_outputs = y

        # Attention Layer
        attn_layer = AttentionLayer(name='attention_layer')
        attn_out, attn_states = attn_layer([encoder_outputs, decoder_outputs])

        # Concatenate attention output with decoder RNN output
        decoder_concat_input = Concatenate(axis=-1, name='concat_layer')([decoder_outputs, attn_out])

        # Output Dense Layer
        dense = Dense(num_decoder_tokens, activation='softmax', name='dense_layer')
        decoder_pred = TimeDistributed(dense, name='time_distributed_layer')(decoder_concat_input)

        # Final model
        self.model = keras.Model([encoder_inputs, decoder_inputs], decoder_pred)
        self.model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

        # Train the model
        self.model.fit(
            [en_ip_tr_data, de_ip_tr_data],
            de_op_tr_data,
            batch_size=batch_size,
            epochs=epochs,
            shuffle=True,
            callbacks=[
                WordAccuracyCallback(self.beam_size),
                WandbMetricsLogger(log_freq="epoch")
            ],
            verbose=1
        )


In [26]:
!pip install wandb
import wandb
wandb.login()



True

Sweep Configuration

In [27]:
sweep_config = {
  'name': 'RNN',
  'method': 'bayes',
  'metric': {
      'name': 'accuracy',
      'goal': 'maximize'
    },

  'early_terminate': {
        'type': 'hyperband',
        'min_iter': 3,
        'max_iter': 20,
        's': 2
    },

  'parameters': {
        'epochs':{
            'values':[10, 20, 30]
        },
        'batch_size':{
            'values':[32, 64, 128]
        },
        'encoder_layers':{
            'values':[1, 2, 3]
        },
        'decoder_layers':{
            'values':[1, 2, 3]
        },
        'hidden_layer_size':{
            'values':[16, 32, 64, 256]
        },
        'cell_type':{
            'values':['GRU', 'LSTM','RNN']
        },
        'dropout':{
            'values':[0, 0.2, 0.3]
        },
        'recurrent_dropout':{
            'values':[0, 0.2, 0.3]
        },
        'beam_size':{
            'values':[1, 3, 5]
        }

    }

}



In [28]:
sweep_id = wandb.sweep(sweep_config, project = 'DL_Assignment3 with attention', entity = 'me21b118-iit-madras' )

Create sweep with ID: 9ybcvf5q
Sweep URL: https://wandb.ai/me21b118-iit-madras/DL_Assignment3%20with%20attention/sweeps/9ybcvf5q


In [29]:
def train():
    """
    Trains the RNN model with attention using parameters from a WandB run.

    This function is designed to be the target function for a WandB sweep.
    It initializes a new WandB run, configures the model based on the run's
    hyperparameters, builds and trains the model using the preprocessed data,
    and logs metrics to WandB via the WandbMetricsLogger callback.
    """
    # Start a new WandB run
    run = wandb.init()

    # Get hyperparameters from the WandB run configuration
    config = run.config

    # Initialize the RNN_Model with hyperparameters from the run config
    rnn_model_with_attention = RNN_Model(
        embed_size=config.hidden_layer_size,  # Using hidden_layer_size for embedding size
        no_of_encoder_layers=config.encoder_layers,
        no_of_decoder_layers=config.decoder_layers,
        latent_dimension=config.hidden_layer_size,
        dropout=config.dropout,
        recurrent_dropout=config.recurrent_dropout,
        cell_type=config.cell_type,
        beam_size=config.beam_size
    )

    # Build and train the model
    rnn_model_with_attention.BUILD_FIT_MODEL(
        encoder_input_train,
        decoder_input_train,
        decoder_target_train,
        epochs=config.epochs,
        batch_size=config.batch_size,
        max_encoder_seq_length=max_encoder_seq_length,
        num_encoder_tokens=num_encoder_tokens,
        max_decoder_seq_length=max_decoder_seq_length,
        num_decoder_tokens=num_decoder_tokens
    )

    # Finish the WandB run
    run.finish()


In [None]:
# Kick off the sweep agent
wandb.agent(sweep_id=sweep_id, function=train, count=7)