In [None]:
# @title === Imports & Configuration ===

import os
import time
import json
import psutil
import threading
import subprocess
import numpy as np
import pandas as pd
import tensorflow as tf
from keras import layers
from tensorflow import keras
from datetime import datetime
import sklearn.model_selection
import matplotlib.pyplot as plt
from collections import Counter
from keras.src import regularizers
from keras.src import constraints

# Using a random seed for reproducability
tf.random.set_seed(42)
np.random.seed(42)

# Mounting notebook to drive for persistant storage of results and models.
from google.colab import drive
drive.mount('/content/drive')
BASE_PATH = "/content/drive/MyDrive/MGU"
INPUT_DIR = "/content/drive/MyDrive/MGU/aol_processed/processed_files"

Mounted at /content/drive


In [None]:
!cat /proc/cpuinfo | grep 'model name' | head -1

model name	: Intel(R) Xeon(R) CPU @ 2.20GHz


In [None]:
# @title === MGU Model Implementation ===

def build_mgu_model(vocab_size, embedding_dim, mgu_units, max_encoder_length, max_decoder_length):
    print("\nBuilding Training Model...")

    # Input layer for encoder
    # Since the input has been tokenized to integers we set the datatype to int32
    encoder_inputs = keras.Input(shape=(max_encoder_length,), dtype='int32', name='encoder_inputs')

    # Embedding layer for encoder
    encoder_embedding_layer = layers.Embedding(vocab_size, embedding_dim, mask_zero=True, name='encoder_embedding')
    encoder_embeddings = encoder_embedding_layer(encoder_inputs)

    # MGU layer for encoder
    # In teacher forcing, the encoder returns its final state which the decoder starts from along with the target sequence.
    encoder_cell = MGUCell(mgu_units, dropout=0.2, recurrent_dropout=0.2)
    encoder_mgu_layer = layers.RNN(encoder_cell, return_state=True, name='encoder_mgu')
    _, encoder_state = encoder_mgu_layer(encoder_embeddings)

    # Input Layer for decoder
    decoder_inputs = keras.Input(shape=(max_decoder_length,), dtype='int32', name='decoder_inputs')

    # Embedding layer for decoder
    decoder_embedding_layer= layers.Embedding(vocab_size, embedding_dim, mask_zero=True, name='decoder_embedding')
    decoder_embeddings = decoder_embedding_layer(decoder_inputs)

    # MGU layer for decoder
    decoder_cell = MGUCell(mgu_units, dropout=0.2, recurrent_dropout=0.2)
    decoder_mgu_layer = layers.RNN(decoder_cell, return_sequences=True, return_state=True, name='decoder_mgu')
    decoder_outputs, _ = decoder_mgu_layer(decoder_embeddings, initial_state=encoder_state)

    # Dense layer for output of the decoder
    decoder_dense = layers.Dense(vocab_size, activation='softmax', name='decoder_dense')
    decoder_predictions = decoder_dense(decoder_outputs)

    # Takes encoder_inputs and decoder_inputs, and outputs decoder predictions
    training_model = keras.Model([encoder_inputs, decoder_inputs], decoder_predictions, name='seq2seq_training_mgu')
    training_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])


    print("\nBuilding Inference Models...")

    # Inference Encoder model
    # Reuses the encoder_inputs from the training model
    encoder_inf_model = keras.Model(encoder_inputs, encoder_state, name='encoder_inference')

    # Inference Decoder model
    # Input layer that handles 1 token at a time
    decoder_input_single = keras.Input(shape=(1,), dtype='int32', name='decoder_inf_input_single')

    # Single step Embedding layer that reuses the same embedding layer from the training model
    single_step_decoder_embeddings = decoder_embedding_layer(decoder_input_single)

    # Input layer for the decoder state
    decoder_state_input = keras.Input(shape=(mgu_units,), name='decoder_inf_state_input')

    # Inference decoder MGU layer, reusing the training decoder MGU
    decoder_inf_mgu_layer = layers.RNN(decoder_cell, return_state=True, name='decoder_inf_mgu')
    decoder_outputs_single, decoder_state_output = decoder_inf_mgu_layer(
        single_step_decoder_embeddings, initial_state=decoder_state_input
    )

    # Reusing the decoder_dense layer from the training model
    decoder_predictions_single = decoder_dense(decoder_outputs_single)

    # Takes current token + previous state, outputs prediction + new state
    decoder_inf_model = keras.Model(
        [decoder_input_single, decoder_state_input],
        [decoder_predictions_single, decoder_state_output],
        name='decoder_inference'
    )

    return training_model, encoder_inf_model, decoder_inf_model


class MGUCell(layers.Layer):

    # Minimal Gated Unit (MGU) cell implementation.
    # (Adapted from on tensorflow's GRU implementation)

    # Implements the equations:
    # f_t = σ(W_f[h_{t-1},x_t] + b_f)
    # h̃_t = tanh(W_h[f_t⊙h_{t-1},x_t] + b_h)
    # h_t = (1-f_t)⊙h_{t-1} + f_t⊙h̃_t

    def __init__(self,
                 units,
                 activation='tanh',
                 recurrent_activation='sigmoid',
                 use_bias=True,
                 kernel_initializer='glorot_uniform',
                 recurrent_initializer='orthogonal',
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 recurrent_regularizer=None,
                 bias_regularizer=None,
                 kernel_constraint=None,
                 recurrent_constraint=None,
                 bias_constraint=None,
                 dropout=0.0,
                 recurrent_dropout=0.0,
                 seed=42,
                 **kwargs):
        if units <= 0:
            raise ValueError(
                "Received an invalid value for argument `units`, "
                f"expected a positive integer, got {units}."
            )
        kwargs.pop("implementation", None)
        super().__init__(**kwargs)
        self.units = units
        self.activation = tf.keras.activations.get(activation)
        self.recurrent_activation = tf.keras.activations.get(recurrent_activation)
        self.use_bias = use_bias

        self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)
        self.recurrent_initializer = tf.keras.initializers.get(recurrent_initializer)
        self.bias_initializer = tf.keras.initializers.get(bias_initializer)

        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)

        self.kernel_constraint = constraints.get(kernel_constraint)
        self.recurrent_constraint = constraints.get(recurrent_constraint)
        self.bias_constraint = constraints.get(bias_constraint)

        self.dropout = min(1.0, max(0.0, dropout))
        self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout))
        self.seed = seed

        self.state_size = self.units
        self.output_size = self.units

    def build(self, input_shape):
        super().build(input_shape)
        input_dim = input_shape[-1]

        # For more computational efficieny, tensorflows' implementation combines the weight matricies,
        # performing one matrix multiplication for the input and one for the previous hidden state.
        # The forget gate [W_f] and candidate hidden state [W_h] weights correspond to:
        #   their half of the matrix connecting input to hidden state called "kernel",
        #   their half of matrix connecting previous hidden state to current hidden state: called "recurrent_kernel".

        # Input weights for the forget gate [W_f] and candidate hidden state [W_h]
        self.kernel = self.add_weight(
            shape=(input_dim, self.units * 2), # * 2 for [forget, candidate hidden]
            name='kernel',
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
        )

        # Hidden state weights for the forget gate [W_f] and candidate hidden state [W_h]
        self.recurrent_kernel = self.add_weight(
            shape=(self.units, self.units * 2), # * 2 for [forget, candidate hidden]
            name='recurrent_kernel',
            initializer=self.recurrent_initializer,
            regularizer=self.recurrent_regularizer,
            constraint=self.recurrent_constraint,
        )

        if self.use_bias:
            self.bias = self.add_weight(
                shape=(self.units * 2,), # [forget_bias, candidate_bias]
                name='bias',
                initializer=self.bias_initializer,
                regularizer=self.bias_regularizer,
                constraint=self.bias_constraint,
            )
        else:
            self.bias = None

        self.built = True

    def call(self, inputs, states, training=False):
          h_prev = states[0]  # Previous hidden state

          if training and 0.0 < self.dropout < 1.0:
              inputs = tf.nn.dropout(inputs, rate=self.dropout, seed=self.seed)

          if training and 0.0 < self.recurrent_dropout < 1.0:
              h_prev = tf.nn.dropout(h_prev, rate=self.recurrent_dropout, seed=self.seed)


          # Compute forget gate and candidate hidden state

          # Inputs multiplied by input weights
          # This is performing the multiplications:
          #    W_f * x_t and W_h * x_t from the equations
          # matrix_x = [x_f, x_h] = [forget_gate_input, candidate_input]
          matrix_x = tf.matmul(inputs, self.kernel)

          # Previous hidden state multipled by candidate hidden state weights
          # This is performing the multiplications:
          #   W_f * h_{t-1} and W_h * H_{t-1} from the equations
          # matrix_h = [h_f, h_h] = [forget_gate_hidden, candidate_hidden]
          matrix_h = tf.matmul(h_prev, self.recurrent_kernel)

          # bias = [b_f, b_h] applied before split
          if self.use_bias:
              matrix_x = tf.nn.bias_add(matrix_x, self.bias)


          # Split the combined computations into forget and candidate hidden parts
          # Split input projections: [x_f, x_h] = [forget_gate_input, candidate_input]
          x_f, x_h = tf.split(matrix_x, 2, axis=-1)

          # Split hidden state projections: [h_f, h_h] = [forget_gate_hidden, candidate_hidden]
          h_f, h_h = tf.split(matrix_h, 2, axis=-1)  # We only need h_f here

          # Compute forget gate: f_t = σ(W_f[h_{t-1},x_t] + b_f)
          # Which can be expanded for clarification to: f_t = σ(W_f·x_t + W_f·h_{t-1} + b_f)
          # (The biases were added before splitting, so we dont need to add them here)
          # recuccent_activation = σ
          f = self.recurrent_activation(x_f + h_f)

          # For the candidate hidden state state: h̃_t = tanh(W_h[f_t⊙h_{t-1},x_t] + b_h)
          # Compute element-wise multiplication of the forget gate and previous hidden state
          #    f_t ⊙ h_{t-1}
          f_h_prev = f * h_prev

          # Now multiply (f_t⊙h_{t-1}) with W_h
          #   We've already computed the x_t part (x_h)
          # Note: we reuse the second half of the recurrennt_kernel which holds
          #   the weight matrix between prev hidden state and candidate hidden state
          f_h_prev_h = tf.matmul(
              f_h_prev, self.recurrent_kernel[:, self.units:])

          # Compute candidate state:  h̃_t = tanh(W_h[f_t⊙h_{t-1},x_t] + b_h)
          h_tilde = self.activation(x_h + f_h_prev_h)

          # Compute new hidden state: h_t = (1-f_t)⊙h_{t-1} + f_t⊙h̃_t
          h = (1 - f) * h_prev + f * h_tilde

          # Returns the output and a list of states to be passed to the next timestep.
          return h, [h]

    def get_config(self):
            config = {
                "units": self.units,
                "activation": tf.keras.activations.serialize(self.activation),
                "recurrent_activation": tf.keras.activations.serialize(
                    self.recurrent_activation
                ),
                "use_bias": self.use_bias,
                "kernel_initializer": tf.keras.initializers.serialize(
                    self.kernel_initializer
                ),
                "recurrent_initializer": tf.keras.initializers.serialize(
                    self.recurrent_initializer
                ),
                "bias_initializer": tf.keras.initializers.serialize(self.bias_initializer),
                "dropout": self.dropout,
                "recurrent_dropout": self.recurrent_dropout
            }
            base_config = super().get_config()
            return {**base_config, **config}

@classmethod
def from_config(cls, config):
    return cls(**config)

# Register so the custom cell can properly be saved and loaded by keras
tf.keras.utils.get_custom_objects().update({'MGUCell': MGUCell})

In [None]:
# @title === Data Processing Functions ===

def create_vocabulary(input_dir, vocab_size=45000):
            word_counts = Counter()

            print("Building vocabulary using modified tokenization (split on '.' and whitespace)...")
            for filename in os.listdir(input_dir):
                if not filename.endswith('.txt'):
                    continue

                file_path = os.path.join(input_dir, filename)
                print(f"Counting tokens from: {filename}")

                for chunk in pd.read_csv(file_path, sep='\t', names=['userID', 'query'], chunksize=100000):
                    for query in chunk['query'].astype(str):
                        query_with_spaced_dots = query.replace('.', ' . ')
                        tokens = query_with_spaced_dots.split()
                        word_counts.update(tokens)

            total_tokens_counted = sum(word_counts.values())

            special_tokens = ['<PAD>', '<OOV>', '<START>', '<SEP>', '<END>']

            most_common_words = [
                word for word, _ in word_counts.most_common(vocab_size - len(special_tokens))
            ]

            vocabulary = special_tokens + most_common_words
            vocab_dict = {word: idx for idx, word in enumerate(vocabulary)}
            actual_vocab_size = len(vocab_dict)


            total_tokens_counted = sum(word_counts.values())
            covered_tokens_count = sum(word_counts[word] for word in most_common_words)
            coverage_percentage = (covered_tokens_count / total_tokens_counted) * 100 if total_tokens_counted > 0 else 0

            vocab_stats = {
                "Requested_Vocabulary_Size": vocab_size,
                "Actual_Vocabulary_Size": actual_vocab_size,
                "Total_Tokens_Found": total_tokens_counted,
                "Total_Unique_Tokens_Found": len(word_counts),
                "Coverage_Percentage_Of_Top_Tokens": round(coverage_percentage, 2),
                "Special_Tokens": special_tokens
            }
            print("Vocabulary Stats:")
            print(json.dumps(vocab_stats, indent=4))


            return vocab_dict, vocab_stats

def get_vocabulary(vocab_size):
    vocab_path = f"{BASE_PATH}/vocab_dict.json"
    vocab_stats_path = f"{BASE_PATH}/vocab_stats.json"

    # Load existing vocabulary
    if os.path.exists(vocab_path) and os.path.exists(vocab_stats_path):
        with open(vocab_path, 'r') as f:
            vocab_dict = json.load(f)
            if len(vocab_dict) == vocab_size:
              with open(vocab_stats_path, 'r') as f:
                vocab_stats = json.load(f)
              print("Loading existing vocabulary")
              return vocab_dict, vocab_stats

    # Or generate new one
    print("Creating new vocabulary")
    vocab_dict, vocab_stats = create_vocabulary(INPUT_DIR, vocab_size)
    with open(vocab_path, 'w') as f1, open(vocab_stats_path, 'w') as f2:
      json.dump(vocab_dict, f1)
      json.dump(vocab_stats, f2)

    print(f"Vocabulary size: {len(vocab_dict)}")
    return vocab_dict, vocab_stats


def tokenize_text(text, vocab_dict):
    oov = vocab_dict['<OOV>']
    text = str(text)

    text_with_spaced_dots = text.replace('.', ' . ')
    words = text_with_spaced_dots.split()

    token_ids = [vocab_dict.get(word, oov) for word in words]

    return token_ids


def prepare_training_data(input_dir, vocab_dict, context_length, test_split, max_encoder_length, max_decoder_length, batch_size):

    start_token = vocab_dict['<START>']
    pad_token = vocab_dict['<PAD>']
    sep_token = vocab_dict['<SEP>']
    end_token = vocab_dict['<END>']

    print(f"Preparing training data...")

    encoder_inputs = []
    decoder_inputs = []
    decoder_targets = []

    for filename in os.listdir(input_dir):
        if not filename.endswith('.txt'):
            continue

        file_path = os.path.join(input_dir, filename)
        print(f"Processing {filename}...")

        for chunk in pd.read_csv(file_path, sep='\t', names=['userID', 'query'], chunksize=50000):
            chunk['query'] = chunk['query'].astype(str)
            user_queries = chunk.groupby('userID')['query'].apply(list).reset_index()

            for _, user in user_queries.iterrows():
                queries = user['query']

                if len(queries) < context_length + 1:
                    continue

                for i in range(len(queries) - context_length):
                    # Get context and target
                    context_queries = queries[i:i+context_length]
                    target_query = queries[i+context_length]

                    context_tokens_list = []
                    for query in context_queries:
                        query_tokens = tokenize_text(query, vocab_dict)
                        if context_tokens_list:
                            context_tokens_list.append(sep_token)
                        context_tokens_list.extend(query_tokens)

                    target_tokens = tokenize_text(target_query, vocab_dict)
                    target_tokens = target_tokens + [end_token]

                    if not target_tokens:
                        continue

                    decoder_input = [start_token] + target_tokens[:-1]

                    decoder_target = target_tokens

                    if context_tokens_list and decoder_input and decoder_target:
                        encoder_inputs.append(context_tokens_list)
                        decoder_inputs.append(decoder_input)
                        decoder_targets.append(decoder_target)


    print(f"\nFinished processing data. Found {len(encoder_inputs)} valid sequences.")
    if not encoder_inputs:
        raise ValueError("No valid sequences found after processing data. Check input data and tokenization.")

    print(f"Truncating sequences to max_encoder_length={max_encoder_length}, max_decoder_length={max_decoder_length}")
    encoder_inputs = [seq[:max_encoder_length] for seq in encoder_inputs]
    decoder_inputs = [seq[:max_decoder_length] for seq in decoder_inputs]
    decoder_targets = [seq[:max_decoder_length] for seq in decoder_targets]


    print("Padding sequences...")
    encoder_inputs_padded = keras.preprocessing.sequence.pad_sequences(
          encoder_inputs, maxlen=max_encoder_length, padding='post', value=pad_token)
    decoder_inputs_padded = keras.preprocessing.sequence.pad_sequences(
          decoder_inputs, maxlen=max_decoder_length, padding='post', value=pad_token)
    decoder_targets_padded = keras.preprocessing.sequence.pad_sequences(
          decoder_targets, maxlen=max_decoder_length, padding='post', value=pad_token)

    print("Splitting data into train/validation sets...")
    (encoder_train, encoder_val,
     decoder_input_train, decoder_input_val,
     decoder_target_train, decoder_target_val) = sklearn.model_selection.train_test_split(
        encoder_inputs_padded, decoder_inputs_padded, decoder_targets_padded,
        test_size=test_split, random_state=42
    )
    print("Data splitting complete.")

    print("Creating tf.data Datasets...")
    train_dataset = tf.data.Dataset.from_tensor_slices(
        ((encoder_train, decoder_input_train), decoder_target_train)
    ).batch(batch_size).prefetch(tf.data.AUTOTUNE)

    val_dataset = tf.data.Dataset.from_tensor_slices(
        ((encoder_val, decoder_input_val), decoder_target_val)
    ).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    print("Dataset creation complete.")

    return train_dataset, val_dataset

def load_or_create_datasets(input_dir, vocab_dict, context_length, test_split, max_encoder_length, max_decoder_length, dataset_dir, batch_size, force_regenerate=False,):
    train_save_path = os.path.join(dataset_dir, f"training_set_CL{context_length}")
    val_save_path = os.path.join(dataset_dir, f"validation_set_CL{context_length}")

    if (os.path.exists(train_save_path) and
        os.path.exists(val_save_path) and
        not force_regenerate):

        print(f"Loading datasets from {dataset_dir}...")
        start_load = time.time()
        try:
            train_dataset = tf.data.Dataset.load(train_save_path)
            val_dataset = tf.data.Dataset.load(val_save_path)
            print(f"Datasets loaded successfully in {time.time() - start_load:.2f} seconds.")
            return train_dataset, val_dataset
        except Exception as e:
            print(f"Error loading datasets from {dataset_dir}: {e}")
            print("Will attempt to regenerate the datasets.")

    print("Generating new datasets...")
    start_generate = time.time()

    train_dataset, val_dataset = prepare_training_data(
        input_dir=input_dir,
        vocab_dict=vocab_dict,
        context_length=context_length,
        test_split=test_split,
        max_encoder_length=max_encoder_length,
        max_decoder_length=max_decoder_length,
        batch_size=batch_size
    )

    print(f"Dataset generation finished in {time.time() - start_generate:.2f} seconds.")

    os.makedirs(dataset_dir, exist_ok=True)

    print(f"Saving datasets to {dataset_dir}...")
    start_save = time.time()
    try:
        train_dataset.save(train_save_path)
        val_dataset.save(val_save_path)
        print(f"Datasets saved successfully in {time.time() - start_save:.2f} seconds.")
    except Exception as e:
        print(f"Error saving datasets to {dataset_dir}: {e}")

    return train_dataset, val_dataset

In [None]:
# @title === Prediction & Evaluation Functions ===

def calculate_mrr(predictions, actual_tokens):
    for rank, prediction in enumerate(predictions, 1):
        if prediction == actual_tokens:
            return 1.0 / rank

    return 0.0

def reconstruct_text(token_ids, index_to_word):
    words = [index_to_word.get(token_id, '<OOV>') for token_id in token_ids]

    if not words:
        return ""

    reconstructed = ""
    for i, word in enumerate(words):
        # Add space before the current word unless it's the first word,
        # or the previous word was '.', or the current word is '.'
        if i > 0 and words[i-1] != '.' and word != '.':
            reconstructed += " "
        reconstructed += word
    return reconstructed



def get_top_k_predictions(inf_encoder_model, inf_decoder_model, input_seq, vocab_dict, top_k=5, max_length=20):
    # Generates k predictions by:
    #   Predicting the k most likely first tokens after <START>.
    #   Running greedy decoding for each of those k starting tokens.

    start_token = vocab_dict['<START>']
    end_token = vocab_dict['<END>']
    pad_token = vocab_dict['<PAD>']
    predictions = []

    # Prepare initial input (<START> token) and state for the decoder
    initial_decoder_input = tf.constant([[start_token]], dtype=tf.int32)
    initial_decoder_state = inf_encoder_model.predict(input_seq, verbose=0)

    # Predict the first step
    first_step_predictions, first_step_state = inf_decoder_model.predict(
        [initial_decoder_input, initial_decoder_state],
        verbose=0
    )

    # Get the top-k first tokens
    first_token_probabilities = first_step_predictions[0]
    top_k_values, top_k_indices = tf.math.top_k(first_token_probabilities, k=top_k)
    top_k_first_token_ids = top_k_indices.numpy()

    # Generate a full query prediction for each top-k starting token
    for first_token_id in top_k_first_token_ids:
        generated_token_ids = []

        if first_token_id == end_token or first_token_id == pad_token:
            predictions.append("")
            continue

        generated_token_ids.append(first_token_id)
        # Prepare initial input and state following the first token
        current_input = tf.constant([[first_token_id]], dtype=tf.int32)
        current_states_copy = tf.identity(first_step_state)

        # Recurrently generate and pass along token predictions and cell states
        # Until either max_length is reached or a <END> or <PAD> token is predicted
        for _ in range(max_length - 1):
            output_tokens, new_decoder_state = inf_decoder_model.predict(
                [current_input, current_states_copy],
                verbose=0
            )

            next_token = np.argmax(output_tokens[0])

            if next_token == end_token or next_token == pad_token:
                break

            generated_token_ids.append(next_token)

            current_input = tf.constant([[next_token]], dtype=tf.int32)
            current_states_copy = new_decoder_state

            if len(generated_token_ids) >= max_length:
                break

        predictions.append(generated_token_ids)

    return predictions


def batch_get_predictions(inf_encoder_model, inf_decoder_model, val_dataset, vocab_dict, top_k=5, max_examples=1000, max_length=20):
    results = []
    count = 0
    pad_token_id = vocab_dict['<PAD>']
    start_token_id = vocab_dict['<START>']
    end_token_id = vocab_dict['<END>']

    generation_start_time = time.time()

    for inputs_batch, targets_batch in val_dataset:
        if max_examples is not None and count >= max_examples:
            break

        encoder_inputs, _ = inputs_batch
        batch_size = tf.shape(encoder_inputs)[0].numpy()

        for i in range(batch_size):
            if max_examples is not None and count >= max_examples:
                break

            encoder_input = encoder_inputs[i:i+1]  # Keep batch dimension

            input_ids = encoder_input[0].numpy()
            input_ids = [iid for iid in input_ids if iid != pad_token_id]

            target_ids = targets_batch[i].numpy()
            target_ids = [tid for tid in target_ids if tid != pad_token_id and tid != end_token_id]

            predictions = get_top_k_predictions(
                inf_encoder_model,
                inf_decoder_model,
                encoder_input,
                vocab_dict,
                top_k=top_k,
                max_length=max_length
            )

            if (count + 1) % 100 == 0:
                 current_time = time.time()
                 print(f"... Generated predictions for {count + 1} examples (Total time: {current_time - generation_start_time:.2f}s) ...")

            result = {
                "input": input_ids,
                "actual": target_ids,
                "predictions": predictions
            }

            results.append(result)
            count += 1

    total_gen_time = time.time() - generation_start_time
    print(f"\nFinished generating predictions for {count} examples in {total_gen_time:.2f} seconds.")

    return results


def evaluate_model(inf_encoder_model, inf_decoder_model, val_dataset, index_to_word, vocab_dict, context_length, top_k=5, max_examples=1000, max_length=30):

    print(f"\nEvaluating model using inference models for context length {context_length}...")
    overall_start_time = time.time()

    print("\n Getting predictions in batches...")
    results = batch_get_predictions(
        inf_encoder_model,
        inf_decoder_model,
        val_dataset,
        vocab_dict,
        top_k=top_k,
        max_examples=max_examples,
        max_length=max_length
    )

    print("\nCalculating MRR metrics...")
    metrics_start_time = time.time()

    mrr_scores = []
    correct_predictions = 0
    total_examples_processed = len(results)

    for i, result in enumerate(results):
        input_ids = result["input"]
        actual_ids = result["actual"]
        predictions = result["predictions"]

        mrr = calculate_mrr(predictions, actual_ids)
        mrr_scores.append(mrr)

        # Check top-1 accuracy
        if mrr == 1.0:
             correct_predictions += 1

        if i < 10:
            input_text = reconstruct_text(input_ids, index_to_word)
            actual_text = reconstruct_text(actual_ids, index_to_word)
            predictions_text = [reconstruct_text(pred, index_to_word) for pred in predictions]

            print(f"\nExample {i+1}:")
            print(f"Context Length : {context_length}")
            print(f"Input : {input_text}")
            print(f"Actual: {actual_text}")
            print(f"Preds :")
            for rank, pred in enumerate(predictions_text, 1):
              print(f"\t{rank}. {pred} {'<<< MATCH!' if pred == actual_text else ''}")
            print(f"MRR   : {mrr:.4f}")
        elif (i + 1) % 100 == 0:
             current_time = time.time()
             print(f"... Calculated MRR for {i + 1}/{total_examples_processed} examples (Time: {current_time - metrics_start_time:.2f}s) ...")

    # Calculate overall / global MRR for the whole model
    mean_mrr = np.mean(mrr_scores)
    accuracy = correct_predictions / total_examples_processed
    eval_time = time.time() - overall_start_time

    print("\n" + "="*50)
    print("Evaluation Results:")
    print(f"Evaluated {total_examples_processed} examples in {eval_time:.2f} seconds")
    print(f"Mean Reciprocal Rank (MRR): {mean_mrr:.4f}")
    print(f"Top-1 Accuracy: {accuracy:.4f} ({correct_predictions}/{total_examples_processed})")
    print("="*50)

    return {
        "Mean_MRR": float(mean_mrr),
        "Accuracy": float(accuracy), # Top-1 accuracy
        "Total_Examples": total_examples_processed,
        "Correct_Predictions": correct_predictions
    }

In [None]:
# @title === Plotting ===

def plot_history(history, context_length):
    plt.figure(figsize=(10, 5))

    # Accuracy subplot
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title('Model Accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'])

    # Loss subplot
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Model Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'])

    plt.tight_layout()
    plt.savefig(f"{BASE_PATH}/results/training_history_plots_CL{context_length}.pdf")
    plt.close()

In [None]:
# @title === System Resource Logging ===

class ResourceLogger:
    def __init__(self, output_path, interval=300):
        self.output_path = os.path.abspath(output_path)
        self.output_dir = os.path.dirname(self.output_path)
        os.makedirs(self.output_dir, exist_ok=True)

        self.interval = interval
        self.stop_event = threading.Event()
        self.logs = []
        self.thread = None
        self.start_time = None

        with open(self.output_path, 'w') as f:
            f.write('[]')

    def _collect_resources(self):
        elapsed_seconds = time.time() - self.start_time

        cpu_percent = psutil.cpu_percent(interval=1)
        memory = psutil.virtual_memory()
        disk = psutil.disk_usage('/')

        gpu_resources = []
        command = [
            "nvidia-smi",
            "--query-gpu=index,name,utilization.gpu,memory.total,memory.used,memory.free",
            "--format=csv,noheader,nounits"
        ]

        result = subprocess.run(
            command,
            capture_output=True,
            text=True,
            check=True,
            encoding='utf-8'
        )

        gpu_output = result.stdout.strip()

        if gpu_output:
            lines = gpu_output.split('\n')
            for line in lines:
                if not line: continue


                idx_str, name, util_str, mem_total_str, mem_used_str, mem_free_str = line.split(',')

                gpu_resources.append({
                    'gpu_id': int(idx_str.strip()),
                    'gpu_name': name.strip(),
                    'gpu_load': float(util_str.strip()), # Utilization %
                    'gpu_memory_total': int(mem_total_str.strip()),
                    'gpu_memory_used': int(mem_used_str.strip()),
                    'gpu_memory_free': int(mem_free_str.strip())
                })

        return {
            'elapsed_seconds': elapsed_seconds,
            'cpu_percent': cpu_percent,
            'memory': {
                'total': memory.total,
                'available': memory.available,
                'used': memory.used,
                'percent': memory.percent
            },
            'disk': {
                'total': disk.total,
                'used': disk.used,
                'free': disk.free,
                'percent': disk.percent
            },
            'gpus': gpu_resources
        }

    def _logging_thread(self):
        while not self.stop_event.is_set():
            next_log_time = time.time() + self.interval
            try:
                resource_entry = self._collect_resources()
                self.logs.append(resource_entry)
                print(f"Logged resource entry at elapsed time: {resource_entry['elapsed_seconds']:.2f}s")
            except Exception as e:
                print(f"Error during resource collection in logging thread: {e}")

            wait_time = max(0, next_log_time - time.time())
            self.stop_event.wait(timeout=wait_time)

    def start(self):
        if self.thread is not None and self.thread.is_alive():
            print("Logger thread already running.")
            return

        print(f"Starting resource logger. Interval: {self.interval}s. Output: {self.output_path}")
        self.stop_event.clear()
        self.logs = []
        self.start_time = time.time()

        self.thread = threading.Thread(target=self._logging_thread, daemon=True)
        self.thread.start()

    def stop(self):
        if self.thread is None or not self.thread.is_alive():
            print("Logger thread not running or already stopped.")
            return

        print("Stopping resource logger...")
        self.stop_event.set()

        self.thread.join(timeout=5.0)

        if self.thread.is_alive():
            print("Warning: Logger thread did not stop gracefully.")

        self._save_logs()
        self.thread = None

    def _save_logs(self):
        print(f"Saving {len(self.logs)} resource log entries...")

        with open(self.output_path, 'w') as f:
            json.dump(self.logs, f, indent=2)
        print(f"Resource logs saved successfully to {self.output_path}")
        print(f"Number of log entries: {len(self.logs)}")

In [None]:
# @title === Training And Evaluation of Models ===

# Experiment settings
CONTEXT_LENGTHS = [1, 2, 3, 4, 5]
MAX_EXAMPLES = 1000
TOP_K = 5

# Model hyperparameters
EMBEDDING_DIM = 128
HIDDEN_DIM = 256
BATCH_SIZE = 512
EPOCHS = 10
MAX_DECODER_LENGTH = 30
VOCABULARY_SIZE = 45000
TEST_SPLIT = 0.2

def train_models():
    os.makedirs(f"{BASE_PATH}/models", exist_ok=True)
    os.makedirs(f"{BASE_PATH}/results", exist_ok=True)
    os.makedirs(f"{BASE_PATH}/data_cache", exist_ok=True)

    print(f"Training with: Batch={BATCH_SIZE}, Embedding={EMBEDDING_DIM}, Hidden={HIDDEN_DIM}, Epochs={EPOCHS}")
    print(f"Context lengths to evaluate: {CONTEXT_LENGTHS}")

    vocab_dict, vocab_stats = get_vocabulary(VOCABULARY_SIZE)

    print("Vocabulary Statistics:")
    print(json.dumps(vocab_stats, indent=4))

    for context_length in CONTEXT_LENGTHS:
        print(f"\n{'='*40}")
        print(f"Current context length: {context_length}")

        MAX_ENCODER_LENGTH = MAX_DECODER_LENGTH * context_length

        resource_log_path = f"{BASE_PATH}/results/system_resources_training_CL{context_length}.json"
        resource_logger = ResourceLogger(resource_log_path, interval=150)  # Log every 2.5 minutes

        training_model_path = f"{BASE_PATH}/models/MGU_Training_CL{context_length}.keras"
        inf_encoder_model_path = f"{BASE_PATH}/models/MGU_Inference_Encoder_CL{context_length}.keras"
        inf_decoder_model_path = f"{BASE_PATH}/models/MGU_Inference_Decoder_CL{context_length}.keras"

        try:
            resource_logger.start()
            start_time = time.time()

            train_dataset, val_dataset = load_or_create_datasets(
                input_dir=INPUT_DIR,
                vocab_dict=vocab_dict,
                context_length=context_length,
                test_split=TEST_SPLIT,
                max_encoder_length=MAX_ENCODER_LENGTH,
                max_decoder_length=MAX_DECODER_LENGTH,
                dataset_dir=f"{BASE_PATH}/data_cache",
                batch_size=BATCH_SIZE,
                force_regenerate=True
            )

            data_prep_time = time.time() - start_time
            print(f"Preparation of dataset completed in {data_prep_time:.2f} seconds")

            if os.path.exists(training_model_path) and os.path.exists(inf_encoder_model_path) and os.path.exists(inf_decoder_model_path):
                print(f"Loading existing models:")
                print(f"  - Training model: {training_model_path}")
                print(f"  - Inference encoder: {inf_encoder_model_path}")
                print(f"  - Inference decoder: {inf_decoder_model_path}")

                training_model = tf.keras.models.load_model(
                    training_model_path,
                    custom_objects={'MGUCell': MGUCell}
                )
                inf_encoder_model = tf.keras.models.load_model(
                    inf_encoder_model_path,
                    custom_objects={'MGUCell': MGUCell}
                )
                inf_decoder_model = tf.keras.models.load_model(
                    inf_decoder_model_path,
                    custom_objects={'MGUCell': MGUCell}
                )

            else:
                print(f"No existing models found. Building and training new models...")

                training_model, inf_encoder_model, inf_decoder_model = build_mgu_model(
                    VOCABULARY_SIZE, EMBEDDING_DIM, HIDDEN_DIM, MAX_ENCODER_LENGTH, MAX_DECODER_LENGTH)

                training_model.summary()
                inf_encoder_model.summary()
                inf_decoder_model.summary()

                keras.utils.plot_model(training_model,
                      show_shapes=True,
                      show_layer_names=True,
                      show_layer_activations=True,
                      to_file=f'{BASE_PATH}/results/training_model_plot_CL{context_length}.pdf'
                )
                keras.utils.plot_model(inf_encoder_model,
                      show_shapes=True,
                      show_layer_names=True,
                      show_layer_activations=True,
                      to_file=f'{BASE_PATH}/results/inf_encoder_model_plot_CL{context_length}.pdf'
                )
                keras.utils.plot_model(inf_decoder_model,
                      show_shapes=True,
                      show_layer_names=True,
                      show_layer_activations=True,
                      to_file=f'{BASE_PATH}/results/inf_decoder_model_plot_CL{context_length}.pdf'
                )

                with open(f'{BASE_PATH}/results/training_model_summary_CL{context_length}.txt', 'w') as f:
                  training_model.summary(print_fn=lambda x: f.write(x + '\n'))
                with open(f'{BASE_PATH}/results/inf_encoder_model_summary_CL{context_length}.txt', 'w') as f:
                  inf_encoder_model.summary(print_fn=lambda x: f.write(x + '\n'))
                with open(f'{BASE_PATH}/results/inf_decoder_model_summary_CL{context_length}.txt', 'w') as f:
                  inf_decoder_model.summary(print_fn=lambda x: f.write(x + '\n'))

                callbacks = [
                    tf.keras.callbacks.ModelCheckpoint(
                        training_model_path,
                        save_best_only=True,
                        monitor='val_loss',
                        mode='min'
                    ),
                    tf.keras.callbacks.EarlyStopping(
                        monitor='val_loss',
                        mode='min',
                        patience=3,
                        verbose=1,
                        restore_best_weights=True
                    )
                ]

                print(f"Training model...")
                start_time = time.time()
                history = training_model.fit(
                    train_dataset,
                    validation_data=val_dataset,
                    epochs=EPOCHS,
                    batch_size=BATCH_SIZE,
                    callbacks=callbacks
                )
                training_time = time.time() - start_time
                print(f"Model training finished in {training_time:.2f} seconds")

                with open(f'{BASE_PATH}/results/training_history_CL{context_length}.json', 'w') as f:
                  json.dump(history.history, f, indent=4)

                training_model.save(training_model_path)
                inf_encoder_model.save(inf_encoder_model_path)
                inf_decoder_model.save(inf_decoder_model_path)

                plot_history(history, context_length)

        finally:
            resource_logger.stop()
            print("End of train models block reached")



def evaluate_models():
    os.makedirs(f"{BASE_PATH}/results", exist_ok=True)

    print(f"Context lengths to evaluate: {CONTEXT_LENGTHS}")

    vocab_dict, vocab_stats = get_vocabulary(VOCABULARY_SIZE)
    index_to_word = {idx: word for word, idx in vocab_dict.items()}

    print("Vocabulary Statistics:")
    print(json.dumps(vocab_stats, indent=4))

    for context_length in CONTEXT_LENGTHS:
        print(f"\n{'='*40}")
        print(f"Evaluating models with context length: {context_length}")

        MAX_ENCODER_LENGTH = MAX_DECODER_LENGTH * context_length

        resource_log_path = f"{BASE_PATH}/results/system_resources_inference_CL{context_length}.json"
        resource_logger = ResourceLogger(resource_log_path, interval=150)  # Log every 2.5 minutes

        inf_encoder_model_path = f"{BASE_PATH}/models/MGU_Inference_Encoder_CL{context_length}.keras"
        inf_decoder_model_path = f"{BASE_PATH}/models/MGU_Inference_Decoder_CL{context_length}.keras"

        print(f"Loading models:")
        print(f"  - Inference encoder: {inf_encoder_model_path}")
        print(f"  - Inference decoder: {inf_decoder_model_path}")

        inf_encoder_model = tf.keras.models.load_model(
            inf_encoder_model_path,
            custom_objects={'MGUCell': MGUCell}
        )
        inf_decoder_model = tf.keras.models.load_model(
            inf_decoder_model_path,
            custom_objects={'MGUCell': MGUCell}
        )

        print(f"Loading validation dataset...")
        start_time = time.time()
        _, val_dataset = load_or_create_datasets(
            input_dir=INPUT_DIR,
            vocab_dict=vocab_dict,
            context_length=context_length,
            test_split=TEST_SPLIT,
            max_encoder_length=MAX_ENCODER_LENGTH,
            max_decoder_length=MAX_DECODER_LENGTH,
            dataset_dir=f"{BASE_PATH}/data_cache",
            batch_size=BATCH_SIZE,
            force_regenerate=False
        )
        data_prep_time = time.time() - start_time
        print(f"Preparation of validation dataset completed in {data_prep_time:.2f} seconds")

        try:
          print(f"Evaluating model...")
          resource_logger.start()
          start_time = time.time()
          results = evaluate_model(
              inf_encoder_model,
              inf_decoder_model,
              val_dataset,
              index_to_word,
              vocab_dict,
              context_length,
              top_k=TOP_K,
              max_examples=MAX_EXAMPLES,
              max_length=MAX_DECODER_LENGTH
          )
          eval_time = time.time() - start_time
          print(f"Model evaluation finished in {eval_time:.2f} seconds")

          with open(f'{BASE_PATH}/results/evaluation_results_CL{context_length}.json', 'w') as f:
              json.dump(results, f, indent=4)

        finally:
            resource_logger.stop()
            print("End of evaluate models block reached")

if __name__ == "__main__":
    start_time = time.time()
    train_models()
    evaluate_models()
    total_time = time.time() - start_time

    print(f"\nExperiment completed in {total_time:.2f} seconds")
    print(f"Results saved to {BASE_PATH}/results/")

Creating new vocabulary
Building vocabulary using modified tokenization (split on '.' and whitespace)...
Counting tokens from: user-ct-test-collection-01.txt
Vocabulary Stats:
{
    "Requested_Vocabulary_Size": 45000,
    "Actual_Vocabulary_Size": 45000,
    "Total_Tokens_Found": 6464796,
    "Total_Unique_Tokens_Found": 467801,
    "Coverage_Percentage_Of_Top_Tokens": 90.04,
    "Special_Tokens": [
        "<PAD>",
        "<OOV>",
        "<START>",
        "<SEP>",
        "<END>"
    ]
}
Vocabulary size: 45000

Experiment completed in 5.26 seconds
Results saved to /content/drive/MyDrive/MGU/results/
