In [7]:
import torch
from transformers import BartTokenizer, BartForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput
import numpy as np
import logging
from typing import Tuple, Optional

# Set up logging for better visibility
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')


In [8]:
# --- Global Configuration ---
MODEL_NAME = "facebook/bart-base"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize Tokenizer and Model
try:
    logging.info(f"Loading BART model '{MODEL_NAME}' to device: {DEVICE}")
    tokenizer = BartTokenizer.from_pretrained(MODEL_NAME)
    model = BartForConditionalGeneration.from_pretrained(MODEL_NAME).to(DEVICE)
except Exception as e:
    logging.error(f"Error loading model: {e}")
    raise


INFO: Loading BART model 'facebook/bart-base' to device: cpu


In [9]:
# ---------------------------------------------------------------------------
# 1) TEXT → ENCODER CONTEXT
# ---------------------------------------------------------------------------
def encode_text_to_context(
    text_sentence: str,
    *,
    max_length: int = 1024,
    device: Optional[str] = None
) -> Tuple[BaseModelOutput, torch.Tensor]:
    """
    Encodes an input sentence into the BART encoder and returns the full encoder context.

    Args:
        text_sentence (str):
            Input English text that will be encoded by the BART encoder.

        max_length (int, optional):
            Maximum token length for the tokenizer. Defaults to 1024.

        device (str, optional):
            Device to run the model on ("cpu" or "cuda").
            If None, uses global DEVICE.

    Returns:
        Tuple[BaseModelOutput, torch.Tensor]:
            encoder_output:
                The full hidden-state sequence from the encoder.
            attention_mask:
                Attention mask corresponding to the input sequence.

    Notes:
        - If the user passes an empty string, a zero placeholder encoder
          output is returned to avoid runtime crashes.
        - This function only performs encoding; no pooling or decoding.
    """
    device = device or DEVICE

    if not text_sentence:
        logging.warning("Input sentence is empty. Returning zero tensors.")
        empty = BaseModelOutput(
            last_hidden_state=torch.zeros((1, 1, model.config.hidden_size), device=device)
        )
        return empty, torch.zeros((1, 1), device=device)

    try:
        inputs = tokenizer(
            text_sentence,
            return_tensors="pt",
            max_length=max_length,
            truncation=True
        ).to(device)

        with torch.no_grad():
            encoder_output = model.model.encoder(**inputs)

        logging.info(
            f"Encoded text into sequence context: {encoder_output.last_hidden_state.shape}"
        )

        return encoder_output, inputs["attention_mask"]

    except Exception as e:
        logging.error(f"Error during encoding: {e}")
        empty = BaseModelOutput(
            last_hidden_state=torch.zeros((1, 1, model.config.hidden_size), device=device)
        )
        return empty, torch.zeros((1, 1), device=device)
    

In [10]:
# ---------------------------------------------------------------------------
# 2) ENCODER CONTEXT → SONAR VECTOR
# ---------------------------------------------------------------------------
def calculate_sonar_vector(
    encoder_output: BaseModelOutput,
    *,
    pooling: str = "mean"
) -> torch.Tensor:
    """
    Converts the encoder's hidden-state sequence into a fixed-size vector.

    Args:
        encoder_output (BaseModelOutput):
            The full sequence output from the BART encoder.

        pooling (str, optional):
            How to compress the sequence into one vector:
                "mean" → mean pool across sequence length.
                "cls"  → use token at position 0.
            Default: "mean"

    Returns:
        torch.Tensor:
            A 768-dimensional latent vector representing the entire input sentence.
    """
    if pooling == "cls":
        return encoder_output.last_hidden_state[:, 0, :]

    return torch.mean(encoder_output.last_hidden_state, dim=1)

In [11]:
# ---------------------------------------------------------------------------
# 3) ENCODER CONTEXT → DECODED TEXT
# ---------------------------------------------------------------------------
def decode_context_to_text(
    encoder_output: BaseModelOutput,
    attention_mask: torch.Tensor,
    *,
    max_new_tokens: int = 40,
    num_beams: int = 4,
    device: Optional[str] = None
) -> str:
    """
    Decodes text directly from a BART encoder output using model.generate().

    Args:
        encoder_output (BaseModelOutput):
            Encoder hidden-state sequence to be decoded.

        attention_mask (torch.Tensor):
            Attention mask for the original input sequence.

        max_new_tokens (int, optional):
            Maximum generated tokens. Defaults to 40.

        num_beams (int, optional):
            Number of beams for beam-search decoding. Defaults to 4.

        device (str, optional):
            Device for inference. Default uses global DEVICE.

    Returns:
        str:
            The decoded text generated from the latent sequence context.

    Notes:
        - The caller *must* ensure the encoder_output and attention_mask come
          from the same input batch.
        - This function **expects valid encoder_output**, not latent vectors.
    """
    device = device or DEVICE

    try:
        generated_ids = model.generate(
            encoder_outputs=encoder_output,
            attention_mask=attention_mask,
            max_length=max_new_tokens,
            num_beams=num_beams,
            early_stopping=True,
        )

        decoded_text = tokenizer.decode(
            generated_ids.squeeze(),
            skip_special_tokens=True
        )

        logging.info(f"Decoded text: '{decoded_text}'")
        return decoded_text

    except Exception as e:
        logging.error(f"Error during decoding: {e}")
        return ""
    

In [12]:
# ---------------------------------------------------------------------------
# Example usage
# ---------------------------------------------------------------------------
if __name__ == '__main__':
    print(f"--- SONAR Latent Context Test ({DEVICE}) ---")

    input_sentence = "The team will evaluate the performance of models using the SuperGLU benchmark for NLP tasks."
    print(f"\n[INPUT TEXT]: \"{input_sentence}\"")

    print("\n--- Running Encoder (Text -> Full Latent Context) ---")
    encoder_output, attention_mask = encode_text_to_context(input_sentence)

    sonar_vector = calculate_sonar_vector(encoder_output)

    print(f"Latent Sequence Shape: {encoder_output.last_hidden_state.shape}")
    print(f"SONAR Vector Shape: {sonar_vector.shape}")
    print(f"SONAR Vector (first 5): {sonar_vector[0, :5].cpu().numpy()}")

    print("\n--- Running Decoder (Context -> Text) ---")
    decoded_output = decode_context_to_text(encoder_output, attention_mask)

    print(f"\n[DECODED TEXT]: \"{decoded_output}\"")
    print("\n--- Test Complete ---")



INFO: Encoded text into sequence context: torch.Size([1, 21, 768])


--- SONAR Latent Context Test (cpu) ---

[INPUT TEXT]: "The team will evaluate the performance of models using the SuperGLU benchmark for NLP tasks."

--- Running Encoder (Text -> Full Latent Context) ---
Latent Sequence Shape: torch.Size([1, 21, 768])
SONAR Vector Shape: torch.Size([1, 768])
SONAR Vector (first 5): [ 0.22488703 -0.03958526  0.01248108  0.06425868 -0.03248952]

--- Running Decoder (Context -> Text) ---


INFO: Decoded text: 'The team will evaluate the performance of models using the SuperGLU benchmark for NLP tasks.'



[DECODED TEXT]: "The team will evaluate the performance of models using the SuperGLU benchmark for NLP tasks."

--- Test Complete ---
