# LM VAE - VAE in text

Lets go through the evolution of Language Variational Autoencoders , from their foundational concepts to modern, scalable frameworks. We will explore:

- High-level background of VAE

- Traditional VAE for Text Generation: Understanding the initial approach and its inherent challenges.


- LM-VAE (Optimus): How VAEs were scaled up using large pre-trained language models.


- LangVAE and LangSpace: Modern, modular frameworks for building and probing LM-VAEs.

## Part 0 - high-level background of VAE


A Variational Autoencoder (VAE) is a type of generative neural network that learns to create new data similar to what it was trained on.

![](./VAE.png)

The input data, denoted as **x**, is first passed through a probabilistic encoder, which is a neural network that compresses the data into a lower-dimensional representation. Instead of outputting a single vector, the encoder produces two vectors: a mean vector **μ** and a standard deviation vector **σ**. These vectors parameterize a probability distribution in the latent space, from which a latent vector **z** is sampled using a random variable **ε** in a process called the reparameterization trick. 



This sampled latent vector **z** is then fed into a probabilistic decoder, another neural network, which aims to reconstruct the original input. The model's training process minimizes both the reconstruction error between the output **x'** and the original input **x**, and the difference between the learned latent distribution and a standard normal distribution. 


For generative inference after training, one can discard the encoder, sample a random vector **z** from the latent space, and use the decoder to synthesize entirely new data points.


## Part 1: Generating Sentences from a Continuous Space – The Genesis of Language VAEs

![Generating Sentences from a Continuous Space](./1stPaper.png)


- Problem: Standard recurrent neural network language models (RNNLMs) generate text word-by-word, lacking an explicit, global representation of the entire sentence. The model is like someone writing a story word by word, not knowing where it's going, similar to telling a joke without knowing the punchline.


- Solution: Introduce a Variational Autoencoder (VAE) for sentences to explicitly model holistic properties like style, topic, and high-level syntactic features.



- VAE Benefits: 

![Standard encoder-decoder VS VAE in text interpolation.](./StandardVSvae.png)


![](./disentanglement.png)

By imposing a prior distribution (typically a standard Gaussian) on the latent space, **the model learns a smooth and structured representation of sentences**. 

### The VAE Model Architecture

The model consists of two main components: an **Encoder** and a **Decoder**, both of which are Recurrent Neural Networks (RNNs), specifically LSTMs in this implementation.


![Figure 1](./paper1figure1.png)



#### Posterior collapse

The decoder is an autoregressive model that generates new word each step at a time when each new word is predicted based on all the words that came before that.

Sometimes the decoder becomes so effective at predecting the next word using only the previous words that it **learns to completely ignore the latent code z**.

Since the decoder ignores z, the **encoder has no incentive to encode useful information from the input sentence x**.

Therefore the model finds a trivial solution and it sets the KL divergence term to zero by making the posterior distribution q(z|x) identical to the prior p(z) to every input sentence x.

The result is a "collapsed" latent space. **The model fails to learn a meaningful, continuous representation of sentences**.


#### solutions proposed in the paper

![KL Cost Annealing](./KLannealing.png)

* ***KL Cost Annealing***: We start training with the KL divergence weight beta set to **0 and gradually increase it to 1**. So at first the encoder is forced to encode necessary information into z so the **decoder can reconstruct the input and rely on z for that**. Then the slowly increase of beta towards 1 introduces the regulariation pressure, **encouraging the model to organize the latent space to resemble the prior**.

* **Word Dropout**: In training we randomly replace the ground truth previous words in the decoder's input with an `<unk>` token. This weakens the decoder, **forcing it to rely more on the latent vector `z`** to make accurate predictions.

![](./elbo2.png)

### Model Pytorch Implementation

In [102]:
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
from TraditionalVAEcheckPoints.utils import to_var


class SentenceVAE(nn.Module):
    def __init__(self, vocab_size, embedding_size, rnn_type, hidden_size, word_dropout, embedding_dropout, latent_size,
                sos_idx, eos_idx, pad_idx, unk_idx, max_sequence_length, num_layers=1, bidirectional=False):

        super().__init__()
        self.tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
        
        # שמירת כל ההיפר-פרמטרים והאינדקסים החשובים
        self.max_sequence_length = max_sequence_length
        self.sos_idx = sos_idx
        self.eos_idx = eos_idx
        self.pad_idx = pad_idx
        self.unk_idx = unk_idx

        self.latent_size = latent_size
        self.rnn_type = rnn_type
        self.bidirectional = bidirectional
        self.num_layers = num_layers
        self.hidden_size = hidden_size

        # שכבת הembedding הפיכת מילים לווקטורים
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.word_dropout_rate = word_dropout
        self.embedding_dropout = nn.Dropout(p=embedding_dropout)

        # לב המודל: שכבות הRNN
        if rnn_type == 'rnn':
            rnn = nn.RNN
        elif rnn_type == 'gru':
            rnn = nn.GRU
        elif rnn_type == 'lstm':
            rnn = nn.LSTM
        else:
            raise ValueError()

        self.encoder_rnn = rnn(embedding_size, hidden_size, num_layers=num_layers, bidirectional=self.bidirectional,
                               batch_first=True)
        self.decoder_rnn = rnn(embedding_size, hidden_size, num_layers=num_layers, bidirectional=self.bidirectional,
                               batch_first=True)
        self.hidden_factor = (2 if bidirectional else 1) * num_layers


        # הקשר בין הencoder לdecoder בVAE
        self.hidden2mean = nn.Linear(hidden_size * self.hidden_factor, latent_size)
        self.hidden2logv = nn.Linear(hidden_size * self.hidden_factor, latent_size)
        self.latent2hidden = nn.Linear(latent_size, hidden_size * self.hidden_factor)

        #שכבת הפלט
        self.outputs2vocab = nn.Linear(hidden_size * (2 if bidirectional else 1), vocab_size)

    def forward(self, input_sequence, length):

        batch_size = input_sequence.size(0)
            
        input_embedding = self.embedding(input_sequence)

        # Pack the sequences so the RNN can ignore the padding tokens
        sorted_lengths, sorted_idx = torch.sort(length, descending=True)
        input_sequence = input_sequence[sorted_idx]
        packed_input = rnn_utils.pack_padded_sequence(input_embedding, sorted_lengths.data.tolist(), batch_first=True)

        _, hidden = self.encoder_rnn(packed_input)

        if self.bidirectional or self.num_layers > 1:
            # flatten hidden state
            hidden = hidden.view(batch_size, self.hidden_size*self.hidden_factor)
        else:
            hidden = hidden.squeeze()

        # REPARAMETERIZATION TRICK
        mean = self.hidden2mean(hidden)
        logv = self.hidden2logv(hidden)
        std = torch.exp(0.5 * logv)
        epsilon = to_var(torch.randn([batch_size, self.latent_size]))
        z = epsilon * std + mean

        
        hidden = self.latent2hidden(z)

        if self.bidirectional or self.num_layers > 1:
            # unflatten hidden state
            hidden = hidden.view(self.hidden_factor, batch_size, self.hidden_size)
        else:
            hidden = hidden.unsqueeze(0)

        # Word Dropout (for the posterior issue)
        if self.word_dropout_rate > 0:
            # randomly replace decoder input with <unk>
            prob = torch.rand(input_sequence.size())
            if torch.cuda.is_available():
                prob=prob.cuda()
            prob[(input_sequence.data - self.sos_idx) * (input_sequence.data - self.pad_idx) == 0] = 1
            decoder_input_sequence = input_sequence.clone()
            decoder_input_sequence[prob < self.word_dropout_rate] = self.unk_idx
            input_embedding = self.embedding(decoder_input_sequence)
            
        input_embedding = self.embedding_dropout(input_embedding)
        packed_input = rnn_utils.pack_padded_sequence(input_embedding, sorted_lengths.data.tolist(), batch_first=True)

        # decoder forward pass
        outputs, _ = self.decoder_rnn(packed_input, hidden)
        padded_outputs = rnn_utils.pad_packed_sequence(outputs, batch_first=True)[0]
        padded_outputs = padded_outputs.contiguous()
        _,reversed_idx = torch.sort(sorted_idx)
        padded_outputs = padded_outputs[reversed_idx]
        b,s,_ = padded_outputs.size()

        # project outputs to distribution over vocabulary
        logp = nn.functional.log_softmax(self.outputs2vocab(padded_outputs.view(-1, padded_outputs.size(2))), dim=-1)
        logp = logp.view(b, s, self.embedding.num_embeddings)

        return logp, mean, logv, z

    def inference(self, n=4, z=None):

        if z is None:
            batch_size = n
            z = to_var(torch.randn([batch_size, self.latent_size]))
        else:
            batch_size = z.size(0)

        # Create hidden state from latent variable z (for decoder)
        hidden = self.latent2hidden(z)
        if self.bidirectional or self.num_layers > 1:
            hidden = hidden.view(self.hidden_factor, batch_size, self.hidden_size)
        hidden = hidden.unsqueeze(0)

        # Below variable are used for efficient generation of sentences (for dynamic stopping)
        sequence_idx = torch.arange(0, batch_size, out=self.tensor()).long()
        sequence_running = torch.arange(0, batch_size, out=self.tensor()).long()
        sequence_mask = torch.ones(batch_size, out=self.tensor()).bool()
        running_seqs = torch.arange(0, batch_size, out=self.tensor()).long()

        generations = self.tensor(batch_size, self.max_sequence_length).fill_(self.pad_idx).long()

        t = 0
        while t < self.max_sequence_length and len(running_seqs) > 0:
            # First token is SOS
            if t == 0:
                input_sequence = to_var(torch.Tensor(batch_size).fill_(self.sos_idx).long())
            
            input_sequence = input_sequence.unsqueeze(1)
            input_embedding = self.embedding(input_sequence)
            output, hidden = self.decoder_rnn(input_embedding, hidden)
            logits = self.outputs2vocab(output)
            input_sequence = self._sample(logits)

            # save the word at position t
            generations = self._save_sample(generations, input_sequence, sequence_running, t)


            # update the dynamic stopping variables
            sequence_mask[sequence_running] = (input_sequence != self.eos_idx)
            sequence_running = sequence_idx.masked_select(sequence_mask)
            running_mask = (input_sequence != self.eos_idx).data
            running_seqs = running_seqs.masked_select(running_mask)
            if len(running_seqs) > 0:
                input_sequence = input_sequence[running_seqs]
                hidden = hidden[:, running_seqs]
                running_seqs = torch.arange(0, len(running_seqs), out=self.tensor()).long()

            t += 1

        return generations, z

    def _sample(self, dist, mode='greedy'):

        if mode == 'greedy':
            _, sample = torch.topk(dist, 1, dim=-1)
        sample = sample.reshape(-1)

        return sample

    def _save_sample(self, save_to, sample, running_seqs, t):
        # select only still running
        running_latest = save_to[running_seqs]
        # update token at position t
        running_latest[:,t] = sample.data
        # save back
        save_to[running_seqs] = running_latest

        return save_to


#### Reparameterization Trick
![](./Reparameterization_Trick.png)

## LM-VAE: Optimus


![OPTIMUS: Organizing Sentences via Pre-trained Modeling of a Latent Space](./paper2.png)

### Introduction

**What will happen if we scale up a VAE and use it as a new pre-trained language model (PLM)?**. 

![](./deep_latent.png)

Optimus is the first large-scale deep latent variable model for natural language 

### Architecture 

![Architecture](./optimus_architecture.png)

We initialize encoder with BERT and initialize decoder with GPT-2.


![BERT Architecture](./Bert_arch.png)

Reminder: $T_{i}$ is the embedding of the word $tok_{i}$ , and an embedding is a numerical representation (vector) of that word meaning in its specific context.  

The CLS token in bert final output is a rich, context aware summary of the entire input sentence x therfore it is used to obtain the latent variable z.


![Latent vector injection](./latent_integration.png)

To facilitate z in GPT-2 decoding without re-training the weights from scratch we can use 2 schemes:
1) Memory scheme
2) Embedding scheme


Memory scheme:

![Causal self-attention mechanism](./causal_self_attention.png)

![](./transformer_memory.jpeg)
![](./self_attention_matrix_calculation.png)

(a) GPT-2 "memory" is an inherent part of its causal self-attention mechanism. "Memory" is the context window of previous tokens that the model is allowed to "look back" for each new word in generates.

For every self attention layer we compute a new memory vector $h^{l}_{mem}$ by computing $zW_{M}$ when z is the latent vector and $W_{M}$ is the new learned matrix.

Then $h^{l}_{mem}$ is added as a new column to K and new row to V in every step of the generation.

![](./bert_embedding.png)

(b) First we transform the latent vector z by a new weight matrix $W_{D}$ the resulting vector is added with element wise addition to the standard transformer embedding (token/BPE+positional embedding). 


**Comparison:**

| Feature              | Memory Injection                                           | Embedding Injection                                |
|----------------------|------------------------------------------------------------|----------------------------------------------------|
| Concept              | $\mathbf{z}$ as memory vector for attention                | $\mathbf{z}$ added to token embeddings             |
| Integration Point    | All layers via attention                                   | Input embedding layer                              |
| Strength             | Persistent global conditioning                             | Early influence on generation                      |
| Efficacy             | More effective empirically                                 | Less effective                                     |



## LangVAE and LangSpace A Modern, Modular Framework

![](./langvae_paper.png)

*LangVAE* is a novel framework for building Variational Autoencoders (VAEs) on top of pre-trained LLMs. 

*LangVAE* offers a flexible and modular way to combine different encoder and decoder models.


These representations can then be analyzed and manipulated using its companion framework, *LangSpace*.

### LangVAE Novel Solution

While standard Optimus training is E2E training meaning that the

1) encoder projection layer
2) memory and embedding layers

are jointly trained with the base encoder and decoder to "weld" the new layers with the pre-trained model.

LangVAE design is different, it use the KV cache injection mechanism that is compatible with a wide range of models. 

![](./KVcache.png)

A key advantage of this approach that it dosen't require modifiying the decoder's architecture. This allows the weights of the PLM encoder and decoder to kept frozen during training. 

As a result training is much faster with less resources. They achieved parameter reduction of over 95%.

### Training with LangVAE

In [50]:
import os
import torch
from datasets import load_dataset
from pythae.models.vae import VAEConfig
from langvae import LangVAE
from langvae.encoders import SentenceEncoder
from langvae.decoders import SentenceDecoder
from langvae.data_conversion.tokenization import TokenizedDataSet
from langvae.pipelines import LanguageTrainingPipeline
from langvae.trainers import CyclicalScheduleKLThresholdTrainerConfig
from langvae.trainers.training_callbacks import TensorBoardCallback
import torch._dynamo

torch._dynamo.config.suppress_errors = True

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LATENT_SIZE = 128
MAX_SENT_LEN = 16

# Load pre-trained sentence encoder and decoder models.
decoder = SentenceDecoder("gpt2", LATENT_SIZE, MAX_SENT_LEN, device=DEVICE, device_map="auto")
encoder = SentenceEncoder("bert-base-cased", LATENT_SIZE, decoder.tokenizer, caching=True, device=DEVICE)


# Load the ag_news dataset
raw_datasets = load_dataset("ag_news", split='train').train_test_split(test_size=0.1)

# Extract the text sentences for training and evaluation
train_texts = sorted(raw_datasets['train']['text'], key=len, reverse=True)
eval_texts = sorted(raw_datasets['test']['text'], key=len, reverse=True)

# Set training and evaluation datasets with auto tokenization.
train_dataset = TokenizedDataSet(train_texts,
                                 decoder.tokenizer, decoder.max_len, caching=True,
                                 cache_persistence=f"ag_news_train_tok-gpt2_cache.jsonl")
eval_dataset = TokenizedDataSet(eval_texts,
                                decoder.tokenizer, decoder.max_len, caching=True,
                                cache_persistence=f"ag_news_eval_tok-gpt2_cache.jsonl")

# Define VAE model configuration
model_config = VAEConfig(latent_dim=LATENT_SIZE)

# Initialize LangVAE model
model = LangVAE(model_config, encoder, decoder)

exp_label = f"ag_news-langvae-bert-gpt2-{LATENT_SIZE}"

# --- MODIFICATION 2: Configure Checkpoint Saving ---
# The `output_dir` parameter tells the trainer where to save checkpoints.
# `steps_saving` defines how often (in training steps) to save a checkpoint.
training_config = CyclicalScheduleKLThresholdTrainerConfig(
    output_dir=exp_label,  # Checkpoints will be saved in this directory
    num_epochs=5,
    learning_rate=1e-3,
    per_device_train_batch_size=50,
    per_device_eval_batch_size=50,
    steps_saving=500,  # Save a checkpoint every 500 steps
    optimizer_cls="AdamW",
    scheduler_cls="ReduceLROnPlateau",
    scheduler_params={"patience": 5, "factor": 0.5},
    max_beta=1.0,
    n_cycles=16,
    target_kl=2.0,
    keep_best_on_train=True
)

pipeline = LanguageTrainingPipeline(
    training_config=training_config,
    model=model
)

# Monitor the training progress with `tensorboard --logdir=runs &`
tb_callback = TensorBoardCallback(exp_label)

# The pipeline will now automatically save checkpoints to the 'output_dir'
print("--- Starting Training ---")
pipeline(
    train_data=train_dataset,
    eval_data=eval_dataset,
    callbacks=[tb_callback]
)
print("--- Training Finished ---")

Loading dataset cache at ag_news_train_tok-gpt2_cache.jsonl: 66258it [00:00, 81418.15it/s]
Loading dataset cache at ag_news_eval_tok-gpt2_cache.jsonl: 10it [00:00, 9181.93it/s]
Checking train dataset...
Checking eval dataset...
Using Base Trainer



--- Starting Training ---


Model passed sanity check !
Ready for training.

Created ag_news-langvae-bert-gpt2-128/VAE_training_2025-06-20_03-24-10. 
Training config, checkpoints and final model will be saved here.

Training params:
 - max_epochs: 5
 - per_device_train_batch_size: 50
 - per_device_eval_batch_size: 50
 - checkpoint saving every: 500
Optimizer: AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0.01
)
Scheduler: <torch.optim.lr_scheduler.ReduceLROnPlateau object at 0x46672f410>

Successfully launched training !



Training of epoch 1/5:   0%|          | 0/2160 [00:00<?, ?batch/s]

Eval of epoch 1/5:   0%|          | 0/240 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 1718.6916
Eval loss: 1451.1828
--------------------------------------------------------------------------


Training of epoch 2/5:   0%|          | 0/2160 [00:00<?, ?batch/s]

Eval of epoch 2/5:   0%|          | 0/240 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 1439.8797
Eval loss: 1458.5375
--------------------------------------------------------------------------


Training of epoch 3/5:   0%|          | 0/2160 [00:00<?, ?batch/s]

Eval of epoch 3/5:   0%|          | 0/240 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 1311.1592
Eval loss: 1498.7332
--------------------------------------------------------------------------


Training of epoch 4/5:   0%|          | 0/2160 [00:00<?, ?batch/s]

Eval of epoch 4/5:   0%|          | 0/240 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 1202.0847
Eval loss: 1517.0797
--------------------------------------------------------------------------


Training of epoch 5/5:   0%|          | 0/2160 [00:00<?, ?batch/s]

Eval of epoch 5/5:   0%|          | 0/240 [00:00<?, ?batch/s]

--------------------------------------------------------------------------
Train loss: 1111.3661
Eval loss: 1559.536
--------------------------------------------------------------------------
Training ended!
Saved final model in ag_news-langvae-bert-gpt2-128/VAE_training_2025-06-20_03-24-10/final_model


--- Training Finished ---


## LangSpace

In [4]:
import torch
import nltk
from langvae import LangVAE
from saf_datasets import EntailmentBankDataSet
from langspace.probe import DisentanglementProbe
from langspace.metrics.disentanglement import DisentanglementMetric as Metric
from langspace.probe import InterpolationProbe
from langspace.metrics.interpolation import InterpolationMetric as InterpMetric
from saf.importers import ListImporter

# Load annotated data from saf_datasets.
dataset_example = EntailmentBankDataSet.from_resource("pos+lemma+ctag+dep+srl#expl_only-noreps")
annotations = {"srl_f": dataset_example.annotations["srl"]}

# The 'srl' annotation contains a list with the role of a single token in each phrase in the sentence.
# 'srl_f' will contain the first non-empty srl annotation for each token.
for sent in dataset_example:
    for token in sent.tokens:
        srl = token.annotations["srl"]
        token_annot = [lbl for lbl in srl if (lbl != "O")][0] if (len(set(srl)) > 1) else srl[0]
        token.annotations["srl_f"] = token_annot
        


### load pre trained LangVAE model

In [7]:
# Load explanation LM-VAE for generation.
model = LangVAE.load_from_hf_hub("neuro-symbolic-ai/eb-langcvae-bert-base-cased-gpt2-srl-l128") # Loads model from HuggingFace Hub.
model.eval()

Downloading LangVAE files for rebuilding...
Successfully downloaded LangVAE model!


LangVAE(
  (decoder): SentenceDecoder(
    (context_hidden): ModuleList(
      (0-11): 12 x LazyLinear(in_features=0, out_features=50688, bias=True)
    )
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): SentenceEncoder(
    (linear): LazyLinear(in_features=0, out_features=256, bias=False)
  )
)

#### interpolation

In [9]:
if (torch.cuda.is_available()):
  model.encoder.to("cuda")
  model.decoder.to("cuda")
  model.encoder.init_pretrained_model()
  model.decoder.init_pretrained_model()

# Probing latent interpolation
nltk.download('punkt_tab')

sentences = [
    ("humans require freshwater for survival", "B-ARG0 B-V B-ARG1 B-ARGM-PRP I-ARGM-PRP"),
    ("animals require food to survive", "B-ARG0 B-V B-ARG1 B-ARGM-PRP I-ARGM-PRP"),
    ("the sun is in the northern hemisphere", "B-ARG0 I-ARG0 B-V B-ARGM-LOC I-ARGM-LOC I-ARGM-LOC I-ARGM-LOC"),
    ("food is a source of energy for animals / plants", "B-ARG0 B-V B-ARG2 I-ARG2 I-ARG2 I-ARG2 B-ARGM-PRP I-ARGM-PRP")
]
sentences_ds = ListImporter(annotations=["srl_f"])([[(tok, lbl) for tok, lbl in zip(sent[0].split(), sent[1].split())] for sent in sentences]).sentences


interp_dataset = [(sentences_ds[0], sentences_ds[1]), (sentences_ds[2], sentences_ds[3])]

interp_report = InterpolationProbe(model, interp_dataset, eval=[InterpMetric.SMOOTHNESS], annotations=annotations).report()

[nltk_data] Downloading package punkt_tab to
[nltk_data]     /Users/topazfreizeit/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/topazfreizeit/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [10]:
for idx, row in interp_report.iterrows():
    print(f"########################")
    for col, val in row.items():
        if col == "generate":
            print(col)
            print(f"{val}")
            continue
        print(f"  {col}: {val}")
    print()  # Blank line between rows

print(f"########################")

########################
  source: humans require freshwater for survival
  target: animals require food to survive
  distance: 0.9457203847853638
generate
 humans require to water
 humans require food water
 humans require food water
 humans require food for survive
 humans require food to survive
 animals require food to survive
 animals require food to survive
 animals require food to survive
 animals require food to survive
 animals require food to survive
 animals require food to survive

########################
  source: the sun is in the northern hemisphere
  target: food is a source of energy for animals
  distance: 0.49064225131560396
generate
 the sun is located in the northern sun
 the sun is located in the northern hemisphere
 the sun is located in the northern hemisphere
 the sun is located in the solar system
 the sun is located in the solar system
 the sun is a source in the solar
 the sun is a source of energy
 the sun is a source of energy
 food is a source of energy 

#### Arithmetic operations

In [13]:
from langspace.probe import ArithmeticProbe
from langspace.probe.arithmetic import ArithmeticOps


op_dataset = [
    ("animals require food for survival", "animals require warmth for survival"),
    ("water vapor is invisible", "the water is warm")
]
op_dataset = [(dataset_example[i], dataset_example[i+1]) for i in range(0, 50, 2)]
arith_report = ArithmeticProbe(model, op_dataset, ops=list(ArithmeticOps), annotations=annotations).report()
print(arith_report)
arith_report.to_csv("arithm.csv")



                                               source  \
0                      Earth revolves around the sun.   
1   the earth revolving around the sun causes star...   
2   Its position appears to shift relative to the ...   
3   stars appear to move relative to the horizon d...   
4   the earth rotating on its axis causes stars to...   
..                                                ...   
70                          earth is a kind of planet   
71  a complete rotation of the earth on earth 's a...   
72                                 season of the year   
73                           Earth turns on its axis.   
74  summer is when a hemisphere is tilted towards ...   

                                               target   op  \
0                      leo is a kind of constellation  sum   
1                      a constellation contains stars  sum   
2                 earth is a kind of celestial object  sum   
3   a star is a kind of celestial object / celesti...  sum   
4   a

In [25]:
sample = arith_report.iloc[0]
for col, val in sample.items():
    print(f"{col}: {val}")
    print()

source: Earth revolves around the sun.

target: leo is a kind of constellation

op: sum

generate:  Mars is a kind of planet

